Skip to main content

dyn_encoding/codec/
protobuf.rs

1//! `application/x-protobuf` codec backed by `prost`.
2//!
3//! Unlike the JSON and CBOR codecs, the protobuf codec does not lean
4//! on `serde`. Each registered message type provides its own
5//! `prost::Message` impl (typically generated by `prost-build` from a
6//! `.proto` file, or hand-derived for small fixtures). The encoder
7//! calls `Message::encode_to_vec`; the decoder calls
8//! `Message::decode`.
9
10use std::collections::HashMap;
11
12use prost::Message;
13
14use crate::error::CodecError;
15use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
16
17type EncodeFn =
18    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
19type DecodeFn =
20    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
21
22/// Codec that serialises [`WireValue`] types using protobuf's
23/// length-undelimited wire format via `prost`.
24#[derive(Default)]
25pub struct ProtobufCodec {
26    encoders: HashMap<WireTypeId, EncodeFn>,
27    decoders: HashMap<WireTypeId, DecodeFn>,
28}
29
30impl ProtobufCodec {
31    /// Construct an empty protobuf codec with no registered types.
32    #[must_use]
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Register a [`WireValue`] type with the codec. The type must
38    /// implement `prost::Message + Default` (the standard bound for
39    /// `prost`-generated message types).
40    pub fn register<T>(&mut self) -> &mut Self
41    where
42        T: WireValue + Message + Default,
43    {
44        let id = T::wire_type_id();
45        self.encoders.insert(
46            id,
47            Box::new(
48                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
49                    let concrete = v
50                        .as_any()
51                        .downcast_ref::<T>()
52                        .ok_or(CodecError::TypeMismatch { expected: id })?;
53                    Ok(concrete.encode_to_vec())
54                },
55            ),
56        );
57        self.decoders.insert(
58            id,
59            Box::new(
60                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
61                    let value = T::decode(bytes).map_err(CodecError::decode_failure)?;
62                    Ok(Box::new(value))
63                },
64            ),
65        );
66        self
67    }
68
69    /// Number of message types registered with this codec.
70    #[must_use]
71    pub fn registered_type_count(&self) -> usize {
72        self.encoders.len()
73    }
74}
75
76impl WireCodec for ProtobufCodec {
77    fn content_type(&self) -> &'static str {
78        "application/x-protobuf"
79    }
80
81    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
82        let id = value.type_id();
83        let encoder = self
84            .encoders
85            .get(&id)
86            .ok_or(CodecError::UnknownTypeId(id))?;
87        encoder(value)
88    }
89
90    fn decode(
91        &self,
92        type_id: WireTypeId,
93        bytes: &[u8],
94    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
95        let decoder = self
96            .decoders
97            .get(&type_id)
98            .ok_or(CodecError::UnknownTypeId(type_id))?;
99        decoder(bytes)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use prost::Message;
107
108    /// Hand-derived protobuf message used purely as a test fixture.
109    /// Field tags are stable so the binary layout is deterministic
110    /// across test runs. The shape is deliberately tiny: a string
111    /// (tag 1), an integer (tag 2), and a Vec<u8> (tag 3) -- the
112    /// three field flavours the brief calls for.
113    ///
114    /// `Default` is generated by `prost::Message`; deriving it again
115    /// would conflict.
116    #[derive(Clone, Eq, PartialEq, Message)]
117    struct Sample {
118        #[prost(string, tag = "1")]
119        name: String,
120        #[prost(uint32, tag = "2")]
121        seq: u32,
122        #[prost(bytes = "vec", tag = "3")]
123        payload: Vec<u8>,
124    }
125
126    impl WireValue for Sample {
127        fn wire_type_id() -> WireTypeId {
128            WireTypeId::new("test.protobuf.Sample")
129        }
130    }
131
132    fn fixture() -> Sample {
133        Sample {
134            name: "gamma".into(),
135            seq: 99,
136            payload: vec![0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef],
137        }
138    }
139
140    #[test]
141    fn round_trip_recovers_value() {
142        let mut codec = ProtobufCodec::new();
143        codec.register::<Sample>();
144        let v = fixture();
145        let bytes = codec.encode(&v).expect("encode");
146        let back = codec
147            .decode(Sample::wire_type_id(), &bytes)
148            .expect("decode");
149        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
150        assert_eq!(back, &v);
151    }
152
153    #[test]
154    fn idempotent_encode_is_byte_equal() {
155        let mut codec = ProtobufCodec::new();
156        codec.register::<Sample>();
157        let v = fixture();
158        let a = codec.encode(&v).expect("encode 1");
159        let b = codec.encode(&v).expect("encode 2");
160        assert_eq!(a, b);
161        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
162        let c = codec.encode(back.as_ref()).expect("encode 3");
163        assert_eq!(a, c);
164    }
165
166    #[test]
167    fn produced_bytes_match_prost_native_encoding() {
168        let mut codec = ProtobufCodec::new();
169        codec.register::<Sample>();
170        let v = fixture();
171        let through_codec = codec.encode(&v).expect("encode");
172        let direct = v.encode_to_vec();
173        assert_eq!(through_codec, direct);
174    }
175
176    #[test]
177    fn unregistered_type_returns_unknown_type_id() {
178        let codec = ProtobufCodec::new();
179        let v = fixture();
180        let err = codec.encode(&v).expect_err("expected unknown type");
181        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
182    }
183
184    #[test]
185    fn malformed_bytes_yield_decode_failure() {
186        let mut codec = ProtobufCodec::new();
187        codec.register::<Sample>();
188        // Field tag 1 (wire type 2 = length-delimited) followed by a
189        // length that runs past the end of the buffer.
190        let err = codec
191            .decode(Sample::wire_type_id(), &[0x0a, 0xff])
192            .expect_err("expected decode failure");
193        assert!(matches!(err, CodecError::Decode(_)));
194    }
195}