Skip to main content

dyn_encoding/codec/
capnp.rs

1//! `application/capnproto` codec backed by the `capnp` runtime.
2//!
3//! Cap'n Proto is, like FlatBuffers, schema-first: messages are
4//! framed as a sequence of segments and the typed accessors are
5//! produced by `capnpc` from a `.capnp` schema. The codec deliberately
6//! does not depend on `capnpc`; instead, each registered type
7//! provides its own conversion through the [`CapnpWire`] trait. The
8//! `dyniak` crate is expected to host the schema set and the
9//! generated readers/builders; this codec dispatches by
10//! [`WireTypeId`] and routes through the trait.
11
12use std::collections::HashMap;
13
14use crate::error::CodecError;
15use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
16
17type EncodeFn =
18    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
19type DecodeFn =
20    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
21
22/// Per-message-type Cap'n Proto encode/decode contract.
23///
24/// Implementors typically build a `capnp::message::Builder` inside
25/// `capnp_encode`, populate it through the type's generated builder
26/// API, and write it out via `capnp::serialize::write_message`.
27/// `capnp_decode` parses the inverse path through
28/// `capnp::serialize::read_message`.
29pub trait CapnpWire: WireValue + Sized {
30    /// Serialise `self` into a fully-framed Cap'n Proto message.
31    fn capnp_encode(&self) -> Result<Vec<u8>, CodecError>;
32
33    /// Parse a fully-framed Cap'n Proto message back into `Self`.
34    fn capnp_decode(bytes: &[u8]) -> Result<Self, CodecError>;
35}
36
37/// Codec that serialises [`WireValue`] types as Cap'n Proto messages
38/// via per-type [`CapnpWire`] implementations.
39#[derive(Default)]
40pub struct CapnpCodec {
41    encoders: HashMap<WireTypeId, EncodeFn>,
42    decoders: HashMap<WireTypeId, DecodeFn>,
43}
44
45impl CapnpCodec {
46    /// Construct an empty Cap'n Proto codec with no registered types.
47    #[must_use]
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Register a [`WireValue`] type with the codec.
53    pub fn register<T>(&mut self) -> &mut Self
54    where
55        T: CapnpWire + 'static,
56    {
57        let id = T::wire_type_id();
58        self.encoders.insert(
59            id,
60            Box::new(
61                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
62                    let concrete = v
63                        .as_any()
64                        .downcast_ref::<T>()
65                        .ok_or(CodecError::TypeMismatch { expected: id })?;
66                    concrete.capnp_encode()
67                },
68            ),
69        );
70        self.decoders.insert(
71            id,
72            Box::new(
73                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
74                    let value = T::capnp_decode(bytes)?;
75                    Ok(Box::new(value))
76                },
77            ),
78        );
79        self
80    }
81
82    /// Number of message types registered with this codec.
83    #[must_use]
84    pub fn registered_type_count(&self) -> usize {
85        self.encoders.len()
86    }
87}
88
89impl WireCodec for CapnpCodec {
90    fn content_type(&self) -> &'static str {
91        "application/capnproto"
92    }
93
94    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
95        let id = value.type_id();
96        let encoder = self
97            .encoders
98            .get(&id)
99            .ok_or(CodecError::UnknownTypeId(id))?;
100        encoder(value)
101    }
102
103    fn decode(
104        &self,
105        type_id: WireTypeId,
106        bytes: &[u8],
107    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
108        let decoder = self
109            .decoders
110            .get(&type_id)
111            .ok_or(CodecError::UnknownTypeId(type_id))?;
112        decoder(bytes)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use ::capnp::message::{Builder, HeapAllocator, ReaderOptions};
120    use ::capnp::serialize;
121
122    /// Hand-rolled Cap'n Proto fixture. The schema, written out, is:
123    ///
124    /// ```text
125    /// # The fixture lives as a heterogeneous list of three Data
126    /// # entries. We pack the typed scalar `seq` as a 4-byte
127    /// # little-endian Data blob so the entire payload travels
128    /// # through capnp's `data_list` accessor without any generated
129    /// # `capnpc` code.
130    /// ```
131    ///
132    /// This deliberately uses only the runtime-side API
133    /// (`Builder::initn_root`, `data_list::Builder/Reader`,
134    /// `serialize::{read_message, write_message}`) and never touches
135    /// `capnpc`. A real consumer would replace this with a typed
136    /// `data.capnp` schema; the codec dispatch remains unchanged.
137    #[derive(Debug, Eq, PartialEq)]
138    struct Sample {
139        name: String,
140        seq: u32,
141        payload: Vec<u8>,
142    }
143
144    impl WireValue for Sample {
145        fn wire_type_id() -> WireTypeId {
146            WireTypeId::new("test.capnp.Sample")
147        }
148    }
149
150    impl CapnpWire for Sample {
151        fn capnp_encode(&self) -> Result<Vec<u8>, CodecError> {
152            let mut msg = Builder::new(HeapAllocator::new());
153            {
154                let mut root: ::capnp::data_list::Builder<'_> = msg.initn_root(3);
155                root.set(0, self.name.as_bytes());
156                let seq_le = self.seq.to_le_bytes();
157                root.set(1, &seq_le);
158                root.set(2, &self.payload);
159            }
160            let mut out = Vec::new();
161            serialize::write_message(&mut out, &msg).map_err(CodecError::encode_failure)?;
162            Ok(out)
163        }
164
165        fn capnp_decode(bytes: &[u8]) -> Result<Self, CodecError> {
166            let reader = serialize::read_message(bytes, ReaderOptions::new())
167                .map_err(CodecError::decode_failure)?;
168            let root: ::capnp::data_list::Reader<'_> =
169                reader.get_root().map_err(CodecError::decode_failure)?;
170            if root.len() != 3 {
171                return Err(CodecError::decode_failure(format!(
172                    "capnp: expected 3-entry data_list, got {}",
173                    root.len()
174                )));
175            }
176            let name_bytes = root.get(0).map_err(CodecError::decode_failure)?;
177            let name = std::str::from_utf8(name_bytes)
178                .map_err(CodecError::decode_failure)?
179                .to_owned();
180            let seq_bytes = root.get(1).map_err(CodecError::decode_failure)?;
181            if seq_bytes.len() != 4 {
182                return Err(CodecError::decode_failure(format!(
183                    "capnp: expected 4-byte seq blob, got {}",
184                    seq_bytes.len()
185                )));
186            }
187            let seq = u32::from_le_bytes([seq_bytes[0], seq_bytes[1], seq_bytes[2], seq_bytes[3]]);
188            let payload = root.get(2).map_err(CodecError::decode_failure)?.to_vec();
189            Ok(Sample { name, seq, payload })
190        }
191    }
192
193    fn fixture() -> Sample {
194        Sample {
195            name: "epsilon".into(),
196            seq: 65_537,
197            payload: vec![0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
198        }
199    }
200
201    #[test]
202    fn round_trip_recovers_value() {
203        let mut codec = CapnpCodec::new();
204        codec.register::<Sample>();
205        let v = fixture();
206        let bytes = codec.encode(&v).expect("encode");
207        let back = codec
208            .decode(Sample::wire_type_id(), &bytes)
209            .expect("decode");
210        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
211        assert_eq!(back, &v);
212    }
213
214    #[test]
215    fn idempotent_encode_is_byte_equal() {
216        let mut codec = CapnpCodec::new();
217        codec.register::<Sample>();
218        let v = fixture();
219        let a = codec.encode(&v).expect("encode 1");
220        let b = codec.encode(&v).expect("encode 2");
221        assert_eq!(a, b);
222        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
223        let c = codec.encode(back.as_ref()).expect("encode 3");
224        assert_eq!(a, c);
225    }
226
227    #[test]
228    fn unregistered_type_returns_unknown_type_id_on_encode() {
229        let codec = CapnpCodec::new();
230        let v = fixture();
231        let err = codec.encode(&v).expect_err("expected unknown type");
232        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
233    }
234
235    #[test]
236    fn unregistered_type_returns_unknown_type_id_on_decode() {
237        let codec = CapnpCodec::new();
238        let err = codec
239            .decode(Sample::wire_type_id(), b"")
240            .expect_err("expected unknown type");
241        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
242    }
243
244    #[test]
245    fn malformed_bytes_yield_decode_failure() {
246        let mut codec = CapnpCodec::new();
247        codec.register::<Sample>();
248        // Two bytes is far below the minimum capnp segment-table
249        // header, so `read_message` will reject it.
250        let err = codec
251            .decode(Sample::wire_type_id(), &[0xff, 0xff])
252            .expect_err("expected decode failure");
253        assert!(matches!(err, CodecError::Decode(_)));
254    }
255}