dyn_encoding/codec/
cbor.rs1use std::collections::HashMap;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8use crate::error::CodecError;
9use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
10
11type EncodeFn =
12 Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
13type DecodeFn =
14 Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
15
16#[derive(Default)]
22pub struct CborCodec {
23 encoders: HashMap<WireTypeId, EncodeFn>,
24 decoders: HashMap<WireTypeId, DecodeFn>,
25}
26
27impl CborCodec {
28 #[must_use]
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn register<T>(&mut self) -> &mut Self
36 where
37 T: WireValue + Serialize + DeserializeOwned,
38 {
39 let id = T::wire_type_id();
40 self.encoders.insert(
41 id,
42 Box::new(
43 move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
44 let concrete = v
45 .as_any()
46 .downcast_ref::<T>()
47 .ok_or(CodecError::TypeMismatch { expected: id })?;
48 let mut buf = Vec::new();
49 ciborium::into_writer(concrete, &mut buf)
50 .map_err(CodecError::encode_failure)?;
51 Ok(buf)
52 },
53 ),
54 );
55 self.decoders.insert(
56 id,
57 Box::new(
58 move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
59 let value: T =
60 ciborium::from_reader(bytes).map_err(CodecError::decode_failure)?;
61 Ok(Box::new(value))
62 },
63 ),
64 );
65 self
66 }
67
68 #[must_use]
70 pub fn registered_type_count(&self) -> usize {
71 self.encoders.len()
72 }
73}
74
75impl WireCodec for CborCodec {
76 fn content_type(&self) -> &'static str {
77 "application/cbor"
78 }
79
80 fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
81 let id = value.type_id();
82 let encoder = self
83 .encoders
84 .get(&id)
85 .ok_or(CodecError::UnknownTypeId(id))?;
86 encoder(value)
87 }
88
89 fn decode(
90 &self,
91 type_id: WireTypeId,
92 bytes: &[u8],
93 ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
94 let decoder = self
95 .decoders
96 .get(&type_id)
97 .ok_or(CodecError::UnknownTypeId(type_id))?;
98 decoder(bytes)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use serde::{Deserialize, Serialize};
106
107 #[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
108 struct Sample {
109 name: String,
110 seq: u32,
111 payload: Vec<u8>,
112 }
113
114 impl WireValue for Sample {
115 fn wire_type_id() -> WireTypeId {
116 WireTypeId::new("test.cbor.Sample")
117 }
118 }
119
120 fn fixture() -> Sample {
121 Sample {
122 name: "beta".into(),
123 seq: 7,
124 payload: vec![0xde, 0xad, 0xbe, 0xef],
125 }
126 }
127
128 #[test]
129 fn round_trip_recovers_value() {
130 let mut codec = CborCodec::new();
131 codec.register::<Sample>();
132 let v = fixture();
133 let bytes = codec.encode(&v).expect("encode");
134 let back = codec
135 .decode(Sample::wire_type_id(), &bytes)
136 .expect("decode");
137 let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
138 assert_eq!(back, &v);
139 }
140
141 #[test]
142 fn idempotent_encode_is_byte_equal() {
143 let mut codec = CborCodec::new();
144 codec.register::<Sample>();
145 let v = fixture();
146 let a = codec.encode(&v).expect("encode 1");
147 let b = codec.encode(&v).expect("encode 2");
148 assert_eq!(a, b);
149 let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
150 let c = codec.encode(back.as_ref()).expect("encode 3");
151 assert_eq!(a, c);
152 }
153
154 #[test]
155 fn unregistered_type_returns_unknown_type_id_on_encode() {
156 let codec = CborCodec::new();
157 let v = fixture();
158 let err = codec.encode(&v).expect_err("expected unknown type");
159 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
160 }
161
162 #[test]
163 fn unregistered_type_returns_unknown_type_id_on_decode() {
164 let codec = CborCodec::new();
165 let err = codec
166 .decode(Sample::wire_type_id(), b"")
167 .expect_err("expected unknown type");
168 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
169 }
170
171 #[test]
172 fn malformed_bytes_yield_decode_failure() {
173 let mut codec = CborCodec::new();
174 codec.register::<Sample>();
175 let err = codec
178 .decode(Sample::wire_type_id(), &[0xff])
179 .expect_err("expected decode failure");
180 assert!(matches!(err, CodecError::Decode(_)));
181 }
182}