Skip to main content

dyn_encoding/codec/
json.rs

1//! `application/json` codec backed by `serde_json`.
2
3use std::collections::HashMap;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8use crate::error::CodecError;
9use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
10
11type EncodeFn =
12    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
13type DecodeFn =
14    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
15
16/// Codec that serialises [`WireValue`] types as JSON.
17///
18/// Types must be registered through [`Self::register`] before they
19/// can be encoded or decoded. Registration installs the
20/// `serde_json::to_vec` / `serde_json::from_slice` paths under the
21/// type's [`WireTypeId`].
22#[derive(Default)]
23pub struct JsonCodec {
24    encoders: HashMap<WireTypeId, EncodeFn>,
25    decoders: HashMap<WireTypeId, DecodeFn>,
26}
27
28impl JsonCodec {
29    /// Construct an empty JSON codec with no registered types.
30    #[must_use]
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Register a [`WireValue`] type with the codec. Once registered,
36    /// values of `T` can be encoded through this codec and bytes can
37    /// be decoded back into `T` via [`WireCodec::decode`].
38    pub fn register<T>(&mut self) -> &mut Self
39    where
40        T: WireValue + Serialize + DeserializeOwned,
41    {
42        let id = T::wire_type_id();
43        self.encoders.insert(
44            id,
45            Box::new(
46                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
47                    let concrete = v
48                        .as_any()
49                        .downcast_ref::<T>()
50                        .ok_or(CodecError::TypeMismatch { expected: id })?;
51                    serde_json::to_vec(concrete).map_err(CodecError::encode_failure)
52                },
53            ),
54        );
55        self.decoders.insert(
56            id,
57            Box::new(
58                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
59                    let value: T =
60                        serde_json::from_slice(bytes).map_err(CodecError::decode_failure)?;
61                    Ok(Box::new(value))
62                },
63            ),
64        );
65        self
66    }
67
68    /// Number of message types registered with this codec.
69    #[must_use]
70    pub fn registered_type_count(&self) -> usize {
71        self.encoders.len()
72    }
73}
74
75impl WireCodec for JsonCodec {
76    fn content_type(&self) -> &'static str {
77        "application/json"
78    }
79
80    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
81        let id = value.type_id();
82        let encoder = self
83            .encoders
84            .get(&id)
85            .ok_or(CodecError::UnknownTypeId(id))?;
86        encoder(value)
87    }
88
89    fn decode(
90        &self,
91        type_id: WireTypeId,
92        bytes: &[u8],
93    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
94        let decoder = self
95            .decoders
96            .get(&type_id)
97            .ok_or(CodecError::UnknownTypeId(type_id))?;
98        decoder(bytes)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use serde::{Deserialize, Serialize};
106
107    #[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
108    struct Sample {
109        name: String,
110        seq: u32,
111        payload: Vec<u8>,
112    }
113
114    impl WireValue for Sample {
115        fn wire_type_id() -> WireTypeId {
116            WireTypeId::new("test.json.Sample")
117        }
118    }
119
120    #[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
121    struct Other {
122        x: i32,
123    }
124
125    impl WireValue for Other {
126        fn wire_type_id() -> WireTypeId {
127            WireTypeId::new("test.json.Other")
128        }
129    }
130
131    fn fixture() -> Sample {
132        Sample {
133            name: "alpha".into(),
134            seq: 42,
135            payload: vec![0, 1, 2, 0xfe, 0xff],
136        }
137    }
138
139    #[test]
140    fn round_trip_recovers_value() {
141        let mut codec = JsonCodec::new();
142        codec.register::<Sample>();
143        let v = fixture();
144        let bytes = codec.encode(&v).expect("encode");
145        let back = codec
146            .decode(Sample::wire_type_id(), &bytes)
147            .expect("decode");
148        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
149        assert_eq!(back, &v);
150    }
151
152    #[test]
153    fn idempotent_encode_is_byte_equal() {
154        let mut codec = JsonCodec::new();
155        codec.register::<Sample>();
156        let v = fixture();
157        let a = codec.encode(&v).expect("encode 1");
158        let b = codec.encode(&v).expect("encode 2");
159        assert_eq!(a, b);
160        // Round-trip then re-encode is also byte-equal.
161        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
162        let c = codec.encode(back.as_ref()).expect("encode 3");
163        assert_eq!(a, c);
164    }
165
166    #[test]
167    fn unregistered_type_returns_unknown_type_id_on_encode() {
168        let codec = JsonCodec::new();
169        let v = fixture();
170        let err = codec.encode(&v).expect_err("expected unknown type");
171        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
172    }
173
174    #[test]
175    fn unregistered_type_returns_unknown_type_id_on_decode() {
176        let codec = JsonCodec::new();
177        let err = codec
178            .decode(Sample::wire_type_id(), b"{}")
179            .expect_err("expected unknown type");
180        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
181    }
182
183    #[test]
184    fn malformed_bytes_yield_decode_failure() {
185        let mut codec = JsonCodec::new();
186        codec.register::<Sample>();
187        let err = codec
188            .decode(Sample::wire_type_id(), b"{not valid json")
189            .expect_err("expected decode failure");
190        assert!(matches!(err, CodecError::Decode(_)));
191    }
192
193    #[test]
194    fn registered_type_count_tracks_registrations() {
195        let mut codec = JsonCodec::new();
196        assert_eq!(codec.registered_type_count(), 0);
197        codec.register::<Sample>();
198        assert_eq!(codec.registered_type_count(), 1);
199        codec.register::<Other>();
200        assert_eq!(codec.registered_type_count(), 2);
201    }
202}