dyn_encoding/codec/
bebop.rs1use std::collections::HashMap;
13
14use crate::error::CodecError;
15use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
16
17type EncodeFn =
18 Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
19type DecodeFn =
20 Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
21
22pub trait BebopWire: WireValue + Sized {
29 fn bebop_encode(&self) -> Result<Vec<u8>, CodecError>;
31
32 fn bebop_decode(bytes: &[u8]) -> Result<Self, CodecError>;
34}
35
36#[derive(Default)]
39pub struct BebopCodec {
40 encoders: HashMap<WireTypeId, EncodeFn>,
41 decoders: HashMap<WireTypeId, DecodeFn>,
42}
43
44impl BebopCodec {
45 #[must_use]
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn register<T>(&mut self) -> &mut Self
53 where
54 T: BebopWire + 'static,
55 {
56 let id = T::wire_type_id();
57 self.encoders.insert(
58 id,
59 Box::new(
60 move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
61 let concrete = v
62 .as_any()
63 .downcast_ref::<T>()
64 .ok_or(CodecError::TypeMismatch { expected: id })?;
65 concrete.bebop_encode()
66 },
67 ),
68 );
69 self.decoders.insert(
70 id,
71 Box::new(
72 move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
73 let value = T::bebop_decode(bytes)?;
74 Ok(Box::new(value))
75 },
76 ),
77 );
78 self
79 }
80
81 #[must_use]
83 pub fn registered_type_count(&self) -> usize {
84 self.encoders.len()
85 }
86}
87
88impl WireCodec for BebopCodec {
89 fn content_type(&self) -> &'static str {
90 "application/x-bebop"
91 }
92
93 fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
94 let id = value.type_id();
95 let encoder = self
96 .encoders
97 .get(&id)
98 .ok_or(CodecError::UnknownTypeId(id))?;
99 encoder(value)
100 }
101
102 fn decode(
103 &self,
104 type_id: WireTypeId,
105 bytes: &[u8],
106 ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
107 let decoder = self
108 .decoders
109 .get(&type_id)
110 .ok_or(CodecError::UnknownTypeId(type_id))?;
111 decoder(bytes)
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use ::bebop::SubRecord;
119
120 #[derive(Debug, Eq, PartialEq)]
137 struct Sample {
138 name: String,
139 seq: u32,
140 payload: Vec<u8>,
141 }
142
143 impl WireValue for Sample {
144 fn wire_type_id() -> WireTypeId {
145 WireTypeId::new("test.bebop.Sample")
146 }
147 }
148
149 impl BebopWire for Sample {
150 fn bebop_encode(&self) -> Result<Vec<u8>, CodecError> {
151 let mut buf = Vec::with_capacity(
152 self.name.serialized_size()
153 + self.seq.serialized_size()
154 + self.payload.serialized_size(),
155 );
156 self.name
157 ._serialize_chained(&mut buf)
158 .map_err(CodecError::encode_failure)?;
159 self.seq
160 ._serialize_chained(&mut buf)
161 .map_err(CodecError::encode_failure)?;
162 self.payload
163 ._serialize_chained(&mut buf)
164 .map_err(CodecError::encode_failure)?;
165 Ok(buf)
166 }
167
168 fn bebop_decode(bytes: &[u8]) -> Result<Self, CodecError> {
169 let mut off = 0usize;
170 let (n, name) = <String as SubRecord<'_>>::_deserialize_chained(&bytes[off..])
171 .map_err(CodecError::decode_failure)?;
172 off = off
173 .checked_add(n)
174 .ok_or_else(|| CodecError::decode_failure("bebop: cursor overflow"))?;
175 let (n, seq) = <u32 as SubRecord<'_>>::_deserialize_chained(&bytes[off..])
176 .map_err(CodecError::decode_failure)?;
177 off = off
178 .checked_add(n)
179 .ok_or_else(|| CodecError::decode_failure("bebop: cursor overflow"))?;
180 let (n, payload) = <Vec<u8> as SubRecord<'_>>::_deserialize_chained(&bytes[off..])
181 .map_err(CodecError::decode_failure)?;
182 let _ = off
183 .checked_add(n)
184 .ok_or_else(|| CodecError::decode_failure("bebop: cursor overflow"))?;
185 Ok(Sample { name, seq, payload })
186 }
187 }
188
189 fn fixture() -> Sample {
190 Sample {
191 name: "zeta".into(),
192 seq: 0xdead_beef,
193 payload: vec![0x00, 0x11, 0x22, 0x33],
194 }
195 }
196
197 #[test]
198 fn round_trip_recovers_value() {
199 let mut codec = BebopCodec::new();
200 codec.register::<Sample>();
201 let v = fixture();
202 let bytes = codec.encode(&v).expect("encode");
203 let back = codec
204 .decode(Sample::wire_type_id(), &bytes)
205 .expect("decode");
206 let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
207 assert_eq!(back, &v);
208 }
209
210 #[test]
211 fn idempotent_encode_is_byte_equal() {
212 let mut codec = BebopCodec::new();
213 codec.register::<Sample>();
214 let v = fixture();
215 let a = codec.encode(&v).expect("encode 1");
216 let b = codec.encode(&v).expect("encode 2");
217 assert_eq!(a, b);
218 let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
219 let c = codec.encode(back.as_ref()).expect("encode 3");
220 assert_eq!(a, c);
221 }
222
223 #[test]
224 fn unregistered_type_returns_unknown_type_id_on_encode() {
225 let codec = BebopCodec::new();
226 let v = fixture();
227 let err = codec.encode(&v).expect_err("expected unknown type");
228 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
229 }
230
231 #[test]
232 fn unregistered_type_returns_unknown_type_id_on_decode() {
233 let codec = BebopCodec::new();
234 let err = codec
235 .decode(Sample::wire_type_id(), b"")
236 .expect_err("expected unknown type");
237 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
238 }
239
240 #[test]
241 fn malformed_bytes_yield_decode_failure() {
242 let mut codec = BebopCodec::new();
243 codec.register::<Sample>();
244 let err = codec
247 .decode(Sample::wire_type_id(), &[0x01, 0x02, 0x03])
248 .expect_err("expected decode failure");
249 assert!(matches!(err, CodecError::Decode(_)));
250 }
251}