Skip to main content

dyn_encoding/codec/
bson.rs

1//! `application/bson` codec backed by the `bson` crate.
2//!
3//! BSON is the wire format MongoDB uses; the upstream Rust crate
4//! exposes a serde-compatible API. The codec is shaped exactly like
5//! [`crate::JsonCodec`] / [`crate::CborCodec`]: registered types are
6//! bounded on `Serialize + DeserializeOwned`, and dispatch happens by
7//! [`WireTypeId`].
8//!
9//! BSON's wire model is document-oriented: the top-level value must
10//! be a BSON document (a struct or map in serde terms). Registering
11//! a type whose serde representation is not a struct or map will
12//! still compile, but encoding such a value will surface a
13//! [`CodecError::Encode`] from the underlying serializer at runtime.
14
15use std::collections::HashMap;
16
17use serde::de::DeserializeOwned;
18use serde::Serialize;
19
20use crate::error::CodecError;
21use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
22
23type EncodeFn =
24    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
25type DecodeFn =
26    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
27
28/// Codec that serialises [`WireValue`] types as BSON via the
29/// upstream `bson` crate's serde integration.
30#[derive(Default)]
31pub struct BsonCodec {
32    encoders: HashMap<WireTypeId, EncodeFn>,
33    decoders: HashMap<WireTypeId, DecodeFn>,
34}
35
36impl BsonCodec {
37    /// Construct an empty BSON codec with no registered types.
38    #[must_use]
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Register a [`WireValue`] type with the codec.
44    pub fn register<T>(&mut self) -> &mut Self
45    where
46        T: WireValue + Serialize + DeserializeOwned,
47    {
48        let id = T::wire_type_id();
49        self.encoders.insert(
50            id,
51            Box::new(
52                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
53                    let concrete = v
54                        .as_any()
55                        .downcast_ref::<T>()
56                        .ok_or(CodecError::TypeMismatch { expected: id })?;
57                    bson::serialize_to_vec(concrete).map_err(CodecError::encode_failure)
58                },
59            ),
60        );
61        self.decoders.insert(
62            id,
63            Box::new(
64                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
65                    let value: T =
66                        bson::deserialize_from_slice(bytes).map_err(CodecError::decode_failure)?;
67                    Ok(Box::new(value))
68                },
69            ),
70        );
71        self
72    }
73
74    /// Number of message types registered with this codec.
75    #[must_use]
76    pub fn registered_type_count(&self) -> usize {
77        self.encoders.len()
78    }
79}
80
81impl WireCodec for BsonCodec {
82    fn content_type(&self) -> &'static str {
83        "application/bson"
84    }
85
86    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
87        let id = value.type_id();
88        let encoder = self
89            .encoders
90            .get(&id)
91            .ok_or(CodecError::UnknownTypeId(id))?;
92        encoder(value)
93    }
94
95    fn decode(
96        &self,
97        type_id: WireTypeId,
98        bytes: &[u8],
99    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
100        let decoder = self
101            .decoders
102            .get(&type_id)
103            .ok_or(CodecError::UnknownTypeId(type_id))?;
104        decoder(bytes)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use serde::{Deserialize, Serialize};
112
113    #[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
114    struct Sample {
115        name: String,
116        seq: u32,
117        payload: Vec<u8>,
118    }
119
120    impl WireValue for Sample {
121        fn wire_type_id() -> WireTypeId {
122            WireTypeId::new("test.bson.Sample")
123        }
124    }
125
126    fn fixture() -> Sample {
127        Sample {
128            name: "eta".into(),
129            seq: 8_192,
130            payload: vec![0xa5, 0x5a, 0xa5, 0x5a],
131        }
132    }
133
134    #[test]
135    fn round_trip_recovers_value() {
136        let mut codec = BsonCodec::new();
137        codec.register::<Sample>();
138        let v = fixture();
139        let bytes = codec.encode(&v).expect("encode");
140        let back = codec
141            .decode(Sample::wire_type_id(), &bytes)
142            .expect("decode");
143        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
144        assert_eq!(back, &v);
145    }
146
147    #[test]
148    fn idempotent_encode_is_byte_equal() {
149        let mut codec = BsonCodec::new();
150        codec.register::<Sample>();
151        let v = fixture();
152        let a = codec.encode(&v).expect("encode 1");
153        let b = codec.encode(&v).expect("encode 2");
154        assert_eq!(a, b);
155        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
156        let c = codec.encode(back.as_ref()).expect("encode 3");
157        assert_eq!(a, c);
158    }
159
160    #[test]
161    fn unregistered_type_returns_unknown_type_id_on_encode() {
162        let codec = BsonCodec::new();
163        let v = fixture();
164        let err = codec.encode(&v).expect_err("expected unknown type");
165        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
166    }
167
168    #[test]
169    fn unregistered_type_returns_unknown_type_id_on_decode() {
170        let codec = BsonCodec::new();
171        let err = codec
172            .decode(Sample::wire_type_id(), b"")
173            .expect_err("expected unknown type");
174        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
175    }
176
177    #[test]
178    fn malformed_bytes_yield_decode_failure() {
179        let mut codec = BsonCodec::new();
180        codec.register::<Sample>();
181        // Three bytes is below the BSON document header (u32 length).
182        let err = codec
183            .decode(Sample::wire_type_id(), &[0x01, 0x02, 0x03])
184            .expect_err("expected decode failure");
185        assert!(matches!(err, CodecError::Decode(_)));
186    }
187}