1use std::collections::HashMap;
22
23use crate::error::CodecError;
24use crate::value::{ErasedWireValue, WireCodec, WireTypeId, WireValue};
25
26type EncodeFn =
27 Box<dyn Fn(&dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> + Send + Sync + 'static>;
28type DecodeFn =
29 Box<dyn Fn(&[u8]) -> Result<Box<dyn ErasedWireValue>, CodecError> + Send + Sync + 'static>;
30
31pub trait FlatbuffersWire: WireValue + Sized {
42 fn flatbuffers_encode(&self) -> Result<Vec<u8>, CodecError>;
44
45 fn flatbuffers_decode(bytes: &[u8]) -> Result<Self, CodecError>;
47}
48
49#[derive(Default)]
52pub struct FlatbuffersCodec {
53 encoders: HashMap<WireTypeId, EncodeFn>,
54 decoders: HashMap<WireTypeId, DecodeFn>,
55}
56
57impl FlatbuffersCodec {
58 #[must_use]
60 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn register<T>(&mut self) -> &mut Self
66 where
67 T: FlatbuffersWire + 'static,
68 {
69 let id = T::wire_type_id();
70 self.encoders.insert(
71 id,
72 Box::new(
73 move |v: &dyn ErasedWireValue| -> Result<Vec<u8>, CodecError> {
74 let concrete = v
75 .as_any()
76 .downcast_ref::<T>()
77 .ok_or(CodecError::TypeMismatch { expected: id })?;
78 concrete.flatbuffers_encode()
79 },
80 ),
81 );
82 self.decoders.insert(
83 id,
84 Box::new(
85 move |bytes: &[u8]| -> Result<Box<dyn ErasedWireValue>, CodecError> {
86 let value = T::flatbuffers_decode(bytes)?;
87 Ok(Box::new(value))
88 },
89 ),
90 );
91 self
92 }
93
94 #[must_use]
96 pub fn registered_type_count(&self) -> usize {
97 self.encoders.len()
98 }
99}
100
101impl WireCodec for FlatbuffersCodec {
102 fn content_type(&self) -> &'static str {
103 "application/octet-stream;schema=flatbuffers"
104 }
105
106 fn encode(&self, value: &dyn ErasedWireValue) -> Result<Vec<u8>, CodecError> {
107 let id = value.type_id();
108 let encoder = self
109 .encoders
110 .get(&id)
111 .ok_or(CodecError::UnknownTypeId(id))?;
112 encoder(value)
113 }
114
115 fn decode(
116 &self,
117 type_id: WireTypeId,
118 bytes: &[u8],
119 ) -> Result<Box<dyn ErasedWireValue>, CodecError> {
120 let decoder = self
121 .decoders
122 .get(&type_id)
123 .ok_or(CodecError::UnknownTypeId(type_id))?;
124 decoder(bytes)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use ::flatbuffers::{FlatBufferBuilder, Vector, WIPOffset};
132
133 #[derive(Debug, Eq, PartialEq)]
152 struct Sample {
153 name: String,
154 seq: u32,
155 payload: Vec<u8>,
156 }
157
158 impl WireValue for Sample {
159 fn wire_type_id() -> WireTypeId {
160 WireTypeId::new("test.flatbuffers.Sample")
161 }
162 }
163
164 const VT_NAME: u16 = 4;
165 const VT_SEQ: u16 = 6;
166 const VT_PAYLOAD: u16 = 8;
167
168 fn read_u16(buf: &[u8], at: usize) -> Result<u16, CodecError> {
169 buf.get(at..at + 2)
170 .ok_or_else(|| CodecError::decode_failure("flatbuffers: u16 read out of bounds"))
171 .map(|s| u16::from_le_bytes([s[0], s[1]]))
172 }
173
174 fn read_u32(buf: &[u8], at: usize) -> Result<u32, CodecError> {
175 buf.get(at..at + 4)
176 .ok_or_else(|| CodecError::decode_failure("flatbuffers: u32 read out of bounds"))
177 .map(|s| u32::from_le_bytes([s[0], s[1], s[2], s[3]]))
178 }
179
180 impl FlatbuffersWire for Sample {
181 fn flatbuffers_encode(&self) -> Result<Vec<u8>, CodecError> {
182 let mut b = FlatBufferBuilder::new();
183 let name_off = b.create_string(&self.name);
184 let payload_off = b.create_vector(&self.payload);
185 let table = b.start_table();
186 b.push_slot::<WIPOffset<&str>>(VT_NAME, name_off, WIPOffset::new(0));
187 b.push_slot::<u32>(VT_SEQ, self.seq, 0);
188 b.push_slot::<WIPOffset<Vector<'_, u8>>>(VT_PAYLOAD, payload_off, WIPOffset::new(0));
189 let root_off = b.end_table(table);
190 b.finish_minimal(root_off);
191 Ok(b.finished_data().to_vec())
192 }
193
194 fn flatbuffers_decode(bytes: &[u8]) -> Result<Self, CodecError> {
195 let root_off = read_u32(bytes, 0)? as usize;
196 let vt_off_signed = i32::from_le_bytes([
197 *bytes.get(root_off).ok_or_else(|| {
198 CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
199 })?,
200 *bytes.get(root_off + 1).ok_or_else(|| {
201 CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
202 })?,
203 *bytes.get(root_off + 2).ok_or_else(|| {
204 CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
205 })?,
206 *bytes.get(root_off + 3).ok_or_else(|| {
207 CodecError::decode_failure("flatbuffers: vtable offset out of bounds")
208 })?,
209 ]);
210 let root_i64 = i64::try_from(root_off).map_err(|_| {
214 CodecError::decode_failure("flatbuffers: root offset overflows i64")
215 })?;
216 let vtable_pos_i64 = root_i64 - i64::from(vt_off_signed);
217 let vtable_pos = usize::try_from(vtable_pos_i64).map_err(|_| {
218 CodecError::decode_failure("flatbuffers: vtable position underflow")
219 })?;
220 let vt_size = read_u16(bytes, vtable_pos)? as usize;
221
222 let read_slot = |slot: u16| -> Result<Option<usize>, CodecError> {
223 let slot = slot as usize;
224 if slot + 2 > vt_size {
225 return Ok(None);
226 }
227 let raw = read_u16(bytes, vtable_pos + slot)?;
228 if raw == 0 {
229 Ok(None)
230 } else {
231 Ok(Some(root_off + raw as usize))
232 }
233 };
234
235 let name = match read_slot(VT_NAME)? {
236 Some(field_pos) => {
237 let str_pos = field_pos + read_u32(bytes, field_pos)? as usize;
238 let len = read_u32(bytes, str_pos)? as usize;
239 let body = bytes.get(str_pos + 4..str_pos + 4 + len).ok_or_else(|| {
240 CodecError::decode_failure("flatbuffers: string body out of bounds")
241 })?;
242 std::str::from_utf8(body)
243 .map_err(CodecError::decode_failure)?
244 .to_owned()
245 }
246 None => String::new(),
247 };
248
249 let seq = match read_slot(VT_SEQ)? {
250 Some(field_pos) => read_u32(bytes, field_pos)?,
251 None => 0,
252 };
253
254 let payload = match read_slot(VT_PAYLOAD)? {
255 Some(field_pos) => {
256 let vec_pos = field_pos + read_u32(bytes, field_pos)? as usize;
257 let len = read_u32(bytes, vec_pos)? as usize;
258 bytes
259 .get(vec_pos + 4..vec_pos + 4 + len)
260 .ok_or_else(|| {
261 CodecError::decode_failure("flatbuffers: vector body out of bounds")
262 })?
263 .to_vec()
264 }
265 None => Vec::new(),
266 };
267
268 Ok(Sample { name, seq, payload })
269 }
270 }
271
272 fn fixture() -> Sample {
273 Sample {
274 name: "delta".into(),
275 seq: 1024,
276 payload: vec![0x10, 0x20, 0x30, 0x40, 0x50],
277 }
278 }
279
280 #[test]
281 fn round_trip_recovers_value() {
282 let mut codec = FlatbuffersCodec::new();
283 codec.register::<Sample>();
284 let v = fixture();
285 let bytes = codec.encode(&v).expect("encode");
286 let back = codec
287 .decode(Sample::wire_type_id(), &bytes)
288 .expect("decode");
289 let back = back.as_any().downcast_ref::<Sample>().expect("downcast");
290 assert_eq!(back, &v);
291 }
292
293 #[test]
294 fn idempotent_encode_is_byte_equal() {
295 let mut codec = FlatbuffersCodec::new();
296 codec.register::<Sample>();
297 let v = fixture();
298 let a = codec.encode(&v).expect("encode 1");
299 let b = codec.encode(&v).expect("encode 2");
300 assert_eq!(a, b);
301 let back = codec.decode(Sample::wire_type_id(), &a).expect("decode");
302 let c = codec.encode(back.as_ref()).expect("encode 3");
303 assert_eq!(a, c);
304 }
305
306 #[test]
307 fn unregistered_type_returns_unknown_type_id_on_encode() {
308 let codec = FlatbuffersCodec::new();
309 let v = fixture();
310 let err = codec.encode(&v).expect_err("expected unknown type");
311 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
312 }
313
314 #[test]
315 fn unregistered_type_returns_unknown_type_id_on_decode() {
316 let codec = FlatbuffersCodec::new();
317 let err = codec
318 .decode(Sample::wire_type_id(), b"")
319 .expect_err("expected unknown type");
320 assert!(matches!(err, CodecError::UnknownTypeId(id) if id == Sample::wire_type_id()));
321 }
322
323 #[test]
324 fn malformed_bytes_yield_decode_failure() {
325 let mut codec = FlatbuffersCodec::new();
326 codec.register::<Sample>();
327 let err = codec
329 .decode(Sample::wire_type_id(), &[0x01, 0x02])
330 .expect_err("expected decode failure");
331 assert!(matches!(err, CodecError::Decode(_)));
332 }
333}