dyn_encoding/codec/
bson.rs1use std::collections::HashMap;
16
17use serde::de::DeserializeOwned;
18use serde::Serialize;
19
20use crate::error::CodecError;
21use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
22
23type EncodeFn =
24 Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
25type DecodeFn =
26 Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
27
28#[derive(Default)]
31pub struct BsonCodec {
32 encoders: HashMap<WireTypeId, EncodeFn>,
33 decoders: HashMap<WireTypeId, DecodeFn>,
34}
35
36impl BsonCodec {
37 #[must_use]
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn register<T>(&mut self) -> &mut Self
45 where
46 T: WireValue + Serialize + DeserializeOwned,
47 {
48 let id = T::wire_type_id();
49 self.encoders.insert(
50 id,
51 Box::new(
52 move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
53 let concrete = v
54 .as_any()
55 .downcast_ref::<T>()
56 .ok_or(CodecError::TypeMismatch { expected: id })?;
57 bson::serialize_to_vec(concrete).map_err(CodecError::encode_failure)
58 },
59 ),
60 );
61 self.decoders.insert(
62 id,
63 Box::new(
64 move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
65 let value: T =
66 bson::deserialize_from_slice(bytes).map_err(CodecError::decode_failure)?;
67 Ok(Box::new(value))
68 },
69 ),
70 );
71 self
72 }
73
74 #[must_use]
76 pub fn registered_type_count(&self) -> usize {
77 self.encoders.len()
78 }
79}
80
81impl WireCodec for BsonCodec {
82 fn content_type(&self) -> &'static str {
83 "application/bson"
84 }
85
86 fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
87 let id = value.type_id();
88 let encoder = self
89 .encoders
90 .get(&id)
91 .ok_or(CodecError::UnknownTypeId(id))?;
92 encoder(value)
93 }
94
95 fn decode(
96 &self,
97 type_id: WireTypeId,
98 bytes: &[u8],
99 ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
100 let decoder = self
101 .decoders
102 .get(&type_id)
103 .ok_or(CodecError::UnknownTypeId(type_id))?;
104 decoder(bytes)
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use serde::{Deserialize, Serialize};
112
113 #[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
114 struct Sample {
115 name: String,
116 seq: u32,
117 payload: Vec<u8>,
118 }
119
120 impl WireValue for Sample {
121 fn wire_type_id() -> WireTypeId {
122 WireTypeId::new("test.bson.Sample")
123 }
124 }
125
126 fn fixture() -> Sample {
127 Sample {
128 name: "eta".into(),
129 seq: 8_192,
130 payload: vec![0xa5, 0x5a, 0xa5, 0x5a],
131 }
132 }
133
134 #[test]
135 fn round_trip_recovers_value() {
136 let mut codec = BsonCodec::new();
137 codec.register::<Sample>();
138 let v = fixture();
139 let bytes = codec.encode(&v).expect("encode");
140 let back = codec
141 .decode(Sample::wire_type_id(), &bytes)
142 .expect("decode");
143 let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
144 assert_eq!(back, &v);
145 }
146
147 #[test]
148 fn idempotent_encode_is_byte_equal() {
149 let mut codec = BsonCodec::new();
150 codec.register::<Sample>();
151 let v = fixture();
152 let a = codec.encode(&v).expect("encode 1");
153 let b = codec.encode(&v).expect("encode 2");
154 assert_eq!(a, b);
155 let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
156 let c = codec.encode(back.as_ref()).expect("encode 3");
157 assert_eq!(a, c);
158 }
159
160 #[test]
161 fn unregistered_type_returns_unknown_type_id_on_encode() {
162 let codec = BsonCodec::new();
163 let v = fixture();
164 let err = codec.encode(&v).expect_err("expected unknown type");
165 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
166 }
167
168 #[test]
169 fn unregistered_type_returns_unknown_type_id_on_decode() {
170 let codec = BsonCodec::new();
171 let err = codec
172 .decode(Sample::wire_type_id(), b"")
173 .expect_err("expected unknown type");
174 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
175 }
176
177 #[test]
178 fn malformed_bytes_yield_decode_failure() {
179 let mut codec = BsonCodec::new();
180 codec.register::<Sample>();
181 let err = codec
183 .decode(Sample::wire_type_id(), &[0x01, 0x02, 0x03])
184 .expect_err("expected decode failure");
185 assert!(matches!(err, CodecError::Decode(_)));
186 }
187}