use base64::Engine;
use buffa::encoding::{Tag, WireType, encode_varint};
use bytes::{Buf, BufMut, Bytes};
use crate::error::{ConnectError, ErrorDetail};
pub(crate) fn encode(err: &ConnectError) -> Bytes {
let mut buf = Vec::new();
Tag::new(1, WireType::Varint).encode(&mut buf);
encode_varint(err.code.grpc_code() as u64, &mut buf);
if let Some(ref message) = err.message {
write_bytes_field(&mut buf, 2, message.as_bytes());
}
for detail in &err.details {
let any_bytes = encode_any(&detail.type_url, &detail.value);
write_bytes_field(&mut buf, 3, &any_bytes);
}
Bytes::from(buf)
}
pub(crate) fn decode_details(data: &[u8]) -> Vec<ErrorDetail> {
let mut details = Vec::new();
let mut buf = data;
while buf.has_remaining() {
let Ok(tag) = Tag::decode(&mut buf) else {
break;
};
match tag.wire_type() {
WireType::Varint => {
if buffa::encoding::decode_varint(&mut buf).is_err() {
break;
}
}
WireType::LengthDelimited => {
let Ok(len) = buffa::encoding::decode_varint(&mut buf) else {
break;
};
let len = len as usize;
if buf.remaining() < len {
break;
}
let field_data = &buf.chunk()[..len];
if tag.field_number() == 3
&& let Some(detail) = decode_any(field_data)
{
details.push(detail);
}
buf.advance(len);
}
WireType::Fixed64 => {
if buf.remaining() < 8 {
break;
}
buf.advance(8);
}
WireType::Fixed32 => {
if buf.remaining() < 4 {
break;
}
buf.advance(4);
}
_ => break,
}
}
details
}
fn encode_any(type_url: &str, value: &Option<String>) -> Vec<u8> {
let mut buf = Vec::new();
write_bytes_field(&mut buf, 1, type_url.as_bytes());
if let Some(value_str) = value
&& let Ok(value_bytes) = base64::engine::general_purpose::STANDARD_NO_PAD
.decode(value_str)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(value_str))
{
write_bytes_field(&mut buf, 2, &value_bytes);
}
buf
}
fn decode_any(data: &[u8]) -> Option<ErrorDetail> {
let mut type_url = None;
let mut value = None;
let mut buf = data;
while buf.has_remaining() {
let tag = Tag::decode(&mut buf).ok()?;
match tag.wire_type() {
WireType::LengthDelimited => {
let len = buffa::encoding::decode_varint(&mut buf).ok()? as usize;
if buf.remaining() < len {
break;
}
let field_data = &buf.chunk()[..len];
match tag.field_number() {
1 => type_url = Some(std::str::from_utf8(field_data).ok()?.to_owned()),
2 => value = Some(field_data.to_vec()),
_ => {}
}
buf.advance(len);
}
WireType::Varint => {
buffa::encoding::decode_varint(&mut buf).ok()?;
}
_ => break,
}
}
Some(ErrorDetail {
type_url: type_url?,
value: Some(
base64::engine::general_purpose::STANDARD_NO_PAD.encode(value.unwrap_or_default()),
),
debug: None,
})
}
fn write_bytes_field(buf: &mut Vec<u8>, field_number: u32, data: &[u8]) {
Tag::new(field_number, WireType::LengthDelimited).encode(buf);
encode_varint(data.len() as u64, buf);
buf.put_slice(data);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ErrorCode;
#[test]
fn test_encode_decode_roundtrip() {
let err = ConnectError::new(ErrorCode::Internal, "test error");
let encoded = encode(&err);
let details = decode_details(&encoded);
assert!(details.is_empty());
}
#[test]
fn test_encode_decode_with_details() {
use base64::Engine;
let detail = ErrorDetail {
type_url: "type.googleapis.com/test.Detail".to_string(),
value: Some(base64::engine::general_purpose::STANDARD_NO_PAD.encode(b"\x01\x02\x03")),
debug: None,
};
let err = ConnectError::new(ErrorCode::NotFound, "not found").with_detail(detail);
let encoded = encode(&err);
let details = decode_details(&encoded);
assert_eq!(details.len(), 1);
assert_eq!(details[0].type_url, "type.googleapis.com/test.Detail");
let value_bytes = base64::engine::general_purpose::STANDARD_NO_PAD
.decode(details[0].value.as_ref().unwrap())
.unwrap();
assert_eq!(value_bytes, b"\x01\x02\x03");
}
#[test]
fn test_decode_empty() {
assert!(decode_details(&[]).is_empty());
}
#[test]
fn test_decode_skips_non_details_fields() {
let buf = vec![
0x08, 13, 0x12, 3, b'e', b'r', b'r', ];
assert!(decode_details(&buf).is_empty());
}
#[test]
fn test_encode_includes_code_and_message() {
let err = ConnectError::new(ErrorCode::Unavailable, "overloaded");
let encoded = encode(&err);
assert!(encoded.len() > 2);
assert_eq!(encoded[0], 0x08); assert_eq!(encoded[1], 14); }
#[test]
fn test_decode_truncated() {
assert!(decode_details(&[0x1A]).is_empty());
assert!(decode_details(&[0x1A, 0x80]).is_empty());
}
}