use prost::Message;
use crate::codec::Codec;
use crate::error::KnafehError;
#[derive(Debug, Clone, Default)]
pub struct ProtobufCodec;
impl ProtobufCodec {
pub fn new() -> Self {
Self
}
}
impl Codec for ProtobufCodec {
fn encode(&self, value: &[u8]) -> Result<Vec<u8>, KnafehError> {
let payload = Payload {
data: value.to_vec(),
};
Ok(payload.encode_to_vec())
}
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, KnafehError> {
let payload = Payload::decode(data)
.map_err(|e| KnafehError::Codec(format!("protobuf decode error: {e}")))?;
if payload.data.is_empty() && !data.is_empty() {
return Err(KnafehError::Codec(
"protobuf decode error: mismatched or unknown payload envelope".to_string(),
));
}
Ok(payload.data)
}
fn content_type(&self) -> &str {
"application/grpc+proto"
}
fn name(&self) -> &str {
"protobuf"
}
}
#[derive(Debug, Clone, Default)]
pub struct RawProtobufCodec;
impl RawProtobufCodec {
pub fn new() -> Self {
Self
}
}
impl Codec for RawProtobufCodec {
fn encode(&self, value: &[u8]) -> Result<Vec<u8>, KnafehError> {
Ok(value.to_vec())
}
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, KnafehError> {
Ok(data.to_vec())
}
fn content_type(&self) -> &str {
"application/grpc+proto"
}
fn name(&self) -> &str {
"raw-protobuf"
}
}
#[derive(Clone, PartialEq, Message)]
struct Payload {
#[prost(bytes = "vec", tag = "1")]
data: Vec<u8>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protobuf_codec_roundtrip() {
let codec = ProtobufCodec::new();
let input = b"hello protobuf";
let encoded = codec.encode(input).unwrap();
assert_ne!(encoded, input);
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(decoded, input);
}
#[test]
fn test_protobuf_codec_empty() {
let codec = ProtobufCodec::new();
let encoded = codec.encode(b"").unwrap();
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(decoded, b"");
}
#[test]
fn test_protobuf_codec_large() {
let codec = ProtobufCodec::new();
let input = vec![0xAB; 100_000];
let encoded = codec.encode(&input).unwrap();
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(decoded, input);
}
#[test]
fn test_protobuf_codec_invalid_decode() {
let codec = ProtobufCodec::new();
let result = codec.decode(&[0x80, 0x80]);
assert!(result.is_err());
}
#[test]
fn test_protobuf_codec_mismatched_message() {
let codec = ProtobufCodec::new();
let result = codec.decode(&[0x12, 0x03, 0x66, 0x6f, 0x6f]);
assert!(result.is_err());
}
#[test]
fn test_raw_protobuf_codec_passthrough() {
let codec = RawProtobufCodec::new();
let input = b"raw bytes";
let encoded = codec.encode(input).unwrap();
assert_eq!(encoded, input);
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(decoded, input);
}
}