Skip to main content

dyn_encoding/codec/
cbor.rs

1//! `application/cbor` codec backed by `ciborium`.
2
3use std::collections::HashMap;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8use crate::error::CodecError;
9use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
10
11type EncodeFn =
12    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
13type DecodeFn =
14    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
15
16/// Codec that serialises [`WireValue`] types as CBOR via `ciborium`.
17///
18/// Mirrors the registration pattern of [`crate::JsonCodec`]; types
19/// must be registered through [`Self::register`] before they can be
20/// encoded or decoded.
21#[derive(Default)]
22pub struct CborCodec {
23    encoders: HashMap<WireTypeId, EncodeFn>,
24    decoders: HashMap<WireTypeId, DecodeFn>,
25}
26
27impl CborCodec {
28    /// Construct an empty CBOR codec with no registered types.
29    #[must_use]
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Register a [`WireValue`] type with the codec.
35    pub fn register<T>(&mut self) -> &mut Self
36    where
37        T: WireValue + Serialize + DeserializeOwned,
38    {
39        let id = T::wire_type_id();
40        self.encoders.insert(
41            id,
42            Box::new(
43                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
44                    let concrete = v
45                        .as_any()
46                        .downcast_ref::<T>()
47                        .ok_or(CodecError::TypeMismatch { expected: id })?;
48                    let mut buf = Vec::new();
49                    ciborium::into_writer(concrete, &mut buf)
50                        .map_err(CodecError::encode_failure)?;
51                    Ok(buf)
52                },
53            ),
54        );
55        self.decoders.insert(
56            id,
57            Box::new(
58                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
59                    let value: T =
60                        ciborium::from_reader(bytes).map_err(CodecError::decode_failure)?;
61                    Ok(Box::new(value))
62                },
63            ),
64        );
65        self
66    }
67
68    /// Number of message types registered with this codec.
69    #[must_use]
70    pub fn registered_type_count(&self) -> usize {
71        self.encoders.len()
72    }
73}
74
75impl WireCodec for CborCodec {
76    fn content_type(&self) -> &'static str {
77        "application/cbor"
78    }
79
80    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
81        let id = value.type_id();
82        let encoder = self
83            .encoders
84            .get(&id)
85            .ok_or(CodecError::UnknownTypeId(id))?;
86        encoder(value)
87    }
88
89    fn decode(
90        &self,
91        type_id: WireTypeId,
92        bytes: &[u8],
93    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
94        let decoder = self
95            .decoders
96            .get(&type_id)
97            .ok_or(CodecError::UnknownTypeId(type_id))?;
98        decoder(bytes)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use serde::{Deserialize, Serialize};
106
107    #[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
108    struct Sample {
109        name: String,
110        seq: u32,
111        payload: Vec<u8>,
112    }
113
114    impl WireValue for Sample {
115        fn wire_type_id() -> WireTypeId {
116            WireTypeId::new("test.cbor.Sample")
117        }
118    }
119
120    fn fixture() -> Sample {
121        Sample {
122            name: "beta".into(),
123            seq: 7,
124            payload: vec![0xde, 0xad, 0xbe, 0xef],
125        }
126    }
127
128    #[test]
129    fn round_trip_recovers_value() {
130        let mut codec = CborCodec::new();
131        codec.register::<Sample>();
132        let v = fixture();
133        let bytes = codec.encode(&v).expect("encode");
134        let back = codec
135            .decode(Sample::wire_type_id(), &bytes)
136            .expect("decode");
137        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
138        assert_eq!(back, &v);
139    }
140
141    #[test]
142    fn idempotent_encode_is_byte_equal() {
143        let mut codec = CborCodec::new();
144        codec.register::<Sample>();
145        let v = fixture();
146        let a = codec.encode(&v).expect("encode 1");
147        let b = codec.encode(&v).expect("encode 2");
148        assert_eq!(a, b);
149        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
150        let c = codec.encode(back.as_ref()).expect("encode 3");
151        assert_eq!(a, c);
152    }
153
154    #[test]
155    fn unregistered_type_returns_unknown_type_id_on_encode() {
156        let codec = CborCodec::new();
157        let v = fixture();
158        let err = codec.encode(&v).expect_err("expected unknown type");
159        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
160    }
161
162    #[test]
163    fn unregistered_type_returns_unknown_type_id_on_decode() {
164        let codec = CborCodec::new();
165        let err = codec
166            .decode(Sample::wire_type_id(), b"")
167            .expect_err("expected unknown type");
168        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
169    }
170
171    #[test]
172    fn malformed_bytes_yield_decode_failure() {
173        let mut codec = CborCodec::new();
174        codec.register::<Sample>();
175        // 0xff is an invalid CBOR initial byte for a top-level
176        // structured value.
177        let err = codec
178            .decode(Sample::wire_type_id(), &[0xff])
179            .expect_err("expected decode failure");
180        assert!(matches!(err, CodecError::Decode(_)));
181    }
182}