dyn_encoding/codec/
protobuf.rs1use std::collections::HashMap;
11
12use prost::Message;
13
14use crate::error::CodecError;
15use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
16
17type EncodeFn =
18 Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
19type DecodeFn =
20 Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
21
22#[derive(Default)]
25pub struct ProtobufCodec {
26 encoders: HashMap<WireTypeId, EncodeFn>,
27 decoders: HashMap<WireTypeId, DecodeFn>,
28}
29
30impl ProtobufCodec {
31 #[must_use]
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn register<T>(&mut self) -> &mut Self
41 where
42 T: WireValue + Message + Default,
43 {
44 let id = T::wire_type_id();
45 self.encoders.insert(
46 id,
47 Box::new(
48 move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
49 let concrete = v
50 .as_any()
51 .downcast_ref::<T>()
52 .ok_or(CodecError::TypeMismatch { expected: id })?;
53 Ok(concrete.encode_to_vec())
54 },
55 ),
56 );
57 self.decoders.insert(
58 id,
59 Box::new(
60 move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
61 let value = T::decode(bytes).map_err(CodecError::decode_failure)?;
62 Ok(Box::new(value))
63 },
64 ),
65 );
66 self
67 }
68
69 #[must_use]
71 pub fn registered_type_count(&self) -> usize {
72 self.encoders.len()
73 }
74}
75
76impl WireCodec for ProtobufCodec {
77 fn content_type(&self) -> &'static str {
78 "application/x-protobuf"
79 }
80
81 fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
82 let id = value.type_id();
83 let encoder = self
84 .encoders
85 .get(&id)
86 .ok_or(CodecError::UnknownTypeId(id))?;
87 encoder(value)
88 }
89
90 fn decode(
91 &self,
92 type_id: WireTypeId,
93 bytes: &[u8],
94 ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
95 let decoder = self
96 .decoders
97 .get(&type_id)
98 .ok_or(CodecError::UnknownTypeId(type_id))?;
99 decoder(bytes)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use prost::Message;
107
108 #[derive(Clone, Eq, PartialEq, Message)]
117 struct Sample {
118 #[prost(string, tag = "1")]
119 name: String,
120 #[prost(uint32, tag = "2")]
121 seq: u32,
122 #[prost(bytes = "vec", tag = "3")]
123 payload: Vec<u8>,
124 }
125
126 impl WireValue for Sample {
127 fn wire_type_id() -> WireTypeId {
128 WireTypeId::new("test.protobuf.Sample")
129 }
130 }
131
132 fn fixture() -> Sample {
133 Sample {
134 name: "gamma".into(),
135 seq: 99,
136 payload: vec![0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef],
137 }
138 }
139
140 #[test]
141 fn round_trip_recovers_value() {
142 let mut codec = ProtobufCodec::new();
143 codec.register::<Sample>();
144 let v = fixture();
145 let bytes = codec.encode(&v).expect("encode");
146 let back = codec
147 .decode(Sample::wire_type_id(), &bytes)
148 .expect("decode");
149 let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
150 assert_eq!(back, &v);
151 }
152
153 #[test]
154 fn idempotent_encode_is_byte_equal() {
155 let mut codec = ProtobufCodec::new();
156 codec.register::<Sample>();
157 let v = fixture();
158 let a = codec.encode(&v).expect("encode 1");
159 let b = codec.encode(&v).expect("encode 2");
160 assert_eq!(a, b);
161 let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
162 let c = codec.encode(back.as_ref()).expect("encode 3");
163 assert_eq!(a, c);
164 }
165
166 #[test]
167 fn produced_bytes_match_prost_native_encoding() {
168 let mut codec = ProtobufCodec::new();
169 codec.register::<Sample>();
170 let v = fixture();
171 let through_codec = codec.encode(&v).expect("encode");
172 let direct = v.encode_to_vec();
173 assert_eq!(through_codec, direct);
174 }
175
176 #[test]
177 fn unregistered_type_returns_unknown_type_id() {
178 let codec = ProtobufCodec::new();
179 let v = fixture();
180 let err = codec.encode(&v).expect_err("expected unknown type");
181 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
182 }
183
184 #[test]
185 fn malformed_bytes_yield_decode_failure() {
186 let mut codec = ProtobufCodec::new();
187 codec.register::<Sample>();
188 let err = codec
191 .decode(Sample::wire_type_id(), &[0x0a, 0xff])
192 .expect_err("expected decode failure");
193 assert!(matches!(err, CodecError::Decode(_)));
194 }
195}