Skip to main content

dyn_encoding/codec/
flatbuffers.rs

1//! `application/octet-stream;schema=flatbuffers` codec backed by the
2//! `flatbuffers` runtime.
3//!
4//! FlatBuffers is a schema-first format: the bytes on the wire are a
5//! laid-out table whose vtable encodes every field's offset. There is
6//! no general-purpose `Serialize` shape that the runtime can drive,
7//! so each registered type carries its own pair of conversion
8//! functions, exposed through the [`FlatbuffersWire`] trait. The
9//! codec is otherwise a sibling of [`crate::JsonCodec`] /
10//! [`crate::CborCodec`] / [`crate::ProtobufCodec`]: per-type
11//! registration, dispatch by [`WireTypeId`], and a single
12//! [`crate::CodecError`] for both encode and decode failures.
13//!
14//! Avoiding `flatc` keeps the build hermetic. The downside is that
15//! callers must hand-roll their `flatbuffers_encode` /
16//! `flatbuffers_decode` bodies (or generate them out-of-band and
17//! plug them into a [`FlatbuffersWire`] impl). The
18//! `dyniak` crate is expected to host the schema set and the
19//! generated code; this codec simply dispatches.
20
21use std::collections::HashMap;
22
23use crate::error::CodecError;
24use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
25
26type EncodeFn =
27    Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
28type DecodeFn =
29    Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
30
31/// Per-message-type FlatBuffers encode/decode contract.
32///
33/// The trait deliberately does not bound `T` on any
34/// `flatbuffers::Follow` / `flatbuffers::Push` shape: those traits
35/// model the on-wire view of a flatbuffer table, not the owned Rust
36/// representation that travels through the codec API. Implementors
37/// build a [`flatbuffers::FlatBufferBuilder`] inside
38/// `flatbuffers_encode` and finish it; `flatbuffers_decode` reads the
39/// fields back into an owned `Self` (the codec returns
40/// `Box<dyn ErasedWireValue>`, which is necessarily owned).
41pub trait FlatbuffersWire: WireValue + Sized {
42    /// Serialise `self` into a finished, root-prefixed flatbuffer.
43    fn flatbuffers_encode(&self) -> Result<Vec<u8>, CodecError>;
44
45    /// Parse a finished, root-prefixed flatbuffer back into `Self`.
46    fn flatbuffers_decode(bytes: &[u8]) -> Result<Self, CodecError>;
47}
48
49/// Codec that serialises [`WireValue`] types as FlatBuffers via
50/// per-type [`FlatbuffersWire`] implementations.
51#[derive(Default)]
52pub struct FlatbuffersCodec {
53    encoders: HashMap<WireTypeId, EncodeFn>,
54    decoders: HashMap<WireTypeId, DecodeFn>,
55}
56
57impl FlatbuffersCodec {
58    /// Construct an empty FlatBuffers codec with no registered types.
59    #[must_use]
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    /// Register a [`WireValue`] type with the codec.
65    pub fn register<T>(&mut self) -> &mut Self
66    where
67        T: FlatbuffersWire + 'static,
68    {
69        let id = T::wire_type_id();
70        self.encoders.insert(
71            id,
72            Box::new(
73                move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
74                    let concrete = v
75                        .as_any()
76                        .downcast_ref::<T>()
77                        .ok_or(CodecError::TypeMismatch { expected: id })?;
78                    concrete.flatbuffers_encode()
79                },
80            ),
81        );
82        self.decoders.insert(
83            id,
84            Box::new(
85                move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
86                    let value = T::flatbuffers_decode(bytes)?;
87                    Ok(Box::new(value))
88                },
89            ),
90        );
91        self
92    }
93
94    /// Number of message types registered with this codec.
95    #[must_use]
96    pub fn registered_type_count(&self) -> usize {
97        self.encoders.len()
98    }
99}
100
101impl WireCodec for FlatbuffersCodec {
102    fn content_type(&self) -> &'static str {
103        "application/octet-stream;schema=flatbuffers"
104    }
105
106    fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
107        let id = value.type_id();
108        let encoder = self
109            .encoders
110            .get(&id)
111            .ok_or(CodecError::UnknownTypeId(id))?;
112        encoder(value)
113    }
114
115    fn decode(
116        &self,
117        type_id: WireTypeId,
118        bytes: &[u8],
119    ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
120        let decoder = self
121            .decoders
122            .get(&type_id)
123            .ok_or(CodecError::UnknownTypeId(type_id))?;
124        decoder(bytes)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use ::flatbuffers::{FlatBufferBuilder, Vector, WIPOffset};
132
133    /// Hand-rolled flatbuffer fixture. The schema, written out, is:
134    ///
135    /// ```text
136    /// table Sample {
137    ///   name:    string;   // field 0, vtable slot 4
138    ///   seq:     uint32;   // field 1, vtable slot 6
139    ///   payload: [ubyte];  // field 2, vtable slot 8
140    /// }
141    /// root_type Sample;
142    /// ```
143    ///
144    /// Encoder uses [`FlatBufferBuilder`] from the runtime. Decoder
145    /// is a hand-written safe parser over the well-defined
146    /// FlatBuffers wire format -- the runtime's typed reader path
147    /// requires `unsafe { Table::new }` because FlatBuffers does not
148    /// validate untrusted bytes by default; we sidestep that here so
149    /// the codec module stays under the crate-wide
150    /// `forbid(unsafe_code)`.
151    #[derive(Debug, Eq, PartialEq)]
152    struct Sample {
153        name: String,
154        seq: u32,
155        payload: Vec<u8>,
156    }
157
158    impl WireValue for Sample {
159        fn wire_type_id() -> WireTypeId {
160            WireTypeId::new("test.flatbuffers.Sample")
161        }
162    }
163
164    const VT_NAME: u16 = 4;
165    const VT_SEQ: u16 = 6;
166    const VT_PAYLOAD: u16 = 8;
167
168    fn read_u16(buf: &[u8], at: usize) -> Result<u16, CodecError> {
169        buf.get(at..at + 2)
170            .ok_or_else(|| CodecError::decode_failure("flatbuffers: u16 read out of bounds"))
171            .map(|s| u16::from_le_bytes([s[0], s[1]]))
172    }
173
174    fn read_u32(buf: &[u8], at: usize) -> Result<u32, CodecError> {
175        buf.get(at..at + 4)
176            .ok_or_else(|| CodecError::decode_failure("flatbuffers: u32 read out of bounds"))
177            .map(|s| u32::from_le_bytes([s[0], s[1], s[2], s[3]]))
178    }
179
180    impl FlatbuffersWire for Sample {
181        fn flatbuffers_encode(&self) -> Result<Vec<u8>, CodecError> {
182            let mut b = FlatBufferBuilder::new();
183            let name_off = b.create_string(&self.name);
184            let payload_off = b.create_vector(&self.payload);
185            let table = b.start_table();
186            b.push_slot::<WIPOffset<&str>>(VT_NAME, name_off, WIPOffset::new(0));
187            b.push_slot::<u32>(VT_SEQ, self.seq, 0);
188            b.push_slot::<WIPOffset<Vector<'_, u8>>>(VT_PAYLOAD, payload_off, WIPOffset::new(0));
189            let root_off = b.end_table(table);
190            b.finish_minimal(root_off);
191            Ok(b.finished_data().to_vec())
192        }
193
194        fn flatbuffers_decode(bytes: &[u8]) -> Result<Self, CodecError> {
195            let root_off = read_u32(bytes, 0)? as usize;
196            let vt_off_signed = i32::from_le_bytes([
197                *bytes.get(root_off).ok_or_else(|| {
198                    CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
199                })?,
200                *bytes.get(root_off + 1).ok_or_else(|| {
201                    CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
202                })?,
203                *bytes.get(root_off + 2).ok_or_else(|| {
204                    CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
205                })?,
206                *bytes.get(root_off + 3).ok_or_else(|| {
207                    CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
208                })?,
209            ]);
210            // The vtable lives at `root_off - vt_off_signed`. Promote
211            // both operands to i64 to dodge the sign-vs-width
212            // pitfalls of casting `usize` to `isize`.
213            let root_i64 = i64::try_from(root_off).map_err(|_| {
214                CodecError::decode_failure("flatbuffers: root offset overflows i64")
215            })?;
216            let vtable_pos_i64 = root_i64 - i64::from(vt_off_signed);
217            let vtable_pos = usize::try_from(vtable_pos_i64).map_err(|_| {
218                CodecError::decode_failure("flatbuffers: vtable position underflow")
219            })?;
220            let vt_size = read_u16(bytes, vtable_pos)? as usize;
221
222            let read_slot = |slot: u16| -> Result<Option<usize>, CodecError> {
223                let slot = slot as usize;
224                if slot + 2 > vt_size {
225                    return Ok(None);
226                }
227                let raw = read_u16(bytes, vtable_pos + slot)?;
228                if raw == 0 {
229                    Ok(None)
230                } else {
231                    Ok(Some(root_off + raw as usize))
232                }
233            };
234
235            let name = match read_slot(VT_NAME)? {
236                Some(field_pos) => {
237                    let str_pos = field_pos + read_u32(bytes, field_pos)? as usize;
238                    let len = read_u32(bytes, str_pos)? as usize;
239                    let body = bytes.get(str_pos + 4..str_pos + 4 + len).ok_or_else(|| {
240                        CodecError::decode_failure("flatbuffers: string body out of bounds")
241                    })?;
242                    std::str::from_utf8(body)
243                        .map_err(CodecError::decode_failure)?
244                        .to_owned()
245                }
246                None => String::new(),
247            };
248
249            let seq = match read_slot(VT_SEQ)? {
250                Some(field_pos) => read_u32(bytes, field_pos)?,
251                None => 0,
252            };
253
254            let payload = match read_slot(VT_PAYLOAD)? {
255                Some(field_pos) => {
256                    let vec_pos = field_pos + read_u32(bytes, field_pos)? as usize;
257                    let len = read_u32(bytes, vec_pos)? as usize;
258                    bytes
259                        .get(vec_pos + 4..vec_pos + 4 + len)
260                        .ok_or_else(|| {
261                            CodecError::decode_failure("flatbuffers: vector body out of bounds")
262                        })?
263                        .to_vec()
264                }
265                None => Vec::new(),
266            };
267
268            Ok(Sample { name, seq, payload })
269        }
270    }
271
272    fn fixture() -> Sample {
273        Sample {
274            name: "delta".into(),
275            seq: 1024,
276            payload: vec![0x10, 0x20, 0x30, 0x40, 0x50],
277        }
278    }
279
280    #[test]
281    fn round_trip_recovers_value() {
282        let mut codec = FlatbuffersCodec::new();
283        codec.register::<Sample>();
284        let v = fixture();
285        let bytes = codec.encode(&v).expect("encode");
286        let back = codec
287            .decode(Sample::wire_type_id(), &bytes)
288            .expect("decode");
289        let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
290        assert_eq!(back, &v);
291    }
292
293    #[test]
294    fn idempotent_encode_is_byte_equal() {
295        let mut codec = FlatbuffersCodec::new();
296        codec.register::<Sample>();
297        let v = fixture();
298        let a = codec.encode(&v).expect("encode 1");
299        let b = codec.encode(&v).expect("encode 2");
300        assert_eq!(a, b);
301        let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
302        let c = codec.encode(back.as_ref()).expect("encode 3");
303        assert_eq!(a, c);
304    }
305
306    #[test]
307    fn unregistered_type_returns_unknown_type_id_on_encode() {
308        let codec = FlatbuffersCodec::new();
309        let v = fixture();
310        let err = codec.encode(&v).expect_err("expected unknown type");
311        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
312    }
313
314    #[test]
315    fn unregistered_type_returns_unknown_type_id_on_decode() {
316        let codec = FlatbuffersCodec::new();
317        let err = codec
318            .decode(Sample::wire_type_id(), b"")
319            .expect_err("expected unknown type");
320        assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
321    }
322
323    #[test]
324    fn malformed_bytes_yield_decode_failure() {
325        let mut codec = FlatbuffersCodec::new();
326        codec.register::<Sample>();
327        // Two bytes is not enough to even hold the root uoffset.
328        let err = codec
329            .decode(Sample::wire_type_id(), &[0x01, 0x02])
330            .expect_err("expected decode failure");
331        assert!(matches!(err, CodecError::Decode(_)));
332    }
333}