Skip to main content

murk_obs/
flatbuf.rs

1//! Binary serialization for observation specs.
2//!
3//! Provides round-trip serialization of [`ObsSpec`] using a compact
4//! binary format. The format uses a "MOBS" file identifier and
5//! version 1.
6//!
7//! Wire format:
8//! ```text
9//! [4 bytes] magic "MOBS"
10//! [2 bytes] version (little-endian u16)
11//! [2 bytes] n_entries (little-endian u16)
12//! [n_entries × entry]
13//! ```
14//!
15//! Each entry:
16//! ```text
17//! [4 bytes] field_id (LE u32)
18//! [1 byte]  region_type
19//! [2 bytes] n_region_params (LE u16)
20//! [n × 4 bytes] region_params (LE i32 each)
21//! [1 byte]  transform_type
22//! [8 bytes] normalize_min (LE f64, if Normalize)
23//! [8 bytes] normalize_max (LE f64, if Normalize)
24//! [1 byte]  dtype
25//! [1 byte]  pool_kernel (0=None)
26//! [4 bytes] pool_kernel_size (LE u32, if pool_kernel != 0)
27//! [4 bytes] pool_stride (LE u32, if pool_kernel != 0)
28//! ```
29
30use crate::spec::{ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel};
31use murk_core::error::ObsError;
32use murk_core::FieldId;
33use murk_space::RegionSpec;
34use smallvec::SmallVec;
35
36const MAGIC: &[u8; 4] = b"MOBS";
37const VERSION: u16 = 1;
38
39// Region type tags
40const REGION_ALL: u8 = 0;
41const REGION_DISK: u8 = 1;
42const REGION_RECT: u8 = 2;
43const REGION_NEIGHBOURS: u8 = 3;
44const REGION_COORDS: u8 = 4;
45const REGION_AGENT_DISK: u8 = 5;
46const REGION_AGENT_RECT: u8 = 6;
47
48// Transform type tags
49const TRANSFORM_IDENTITY: u8 = 0;
50const TRANSFORM_NORMALIZE: u8 = 1;
51
52// Pool kernel tags
53const POOL_NONE: u8 = 0;
54const POOL_MEAN: u8 = 1;
55const POOL_MAX: u8 = 2;
56const POOL_MIN: u8 = 3;
57const POOL_SUM: u8 = 4;
58
59/// Serialize an [`ObsSpec`] to binary bytes.
60pub fn serialize(spec: &ObsSpec) -> Vec<u8> {
61    let mut buf = Vec::with_capacity(128);
62
63    // Header
64    buf.extend_from_slice(MAGIC);
65    buf.extend_from_slice(&VERSION.to_le_bytes());
66    buf.extend_from_slice(&(spec.entries.len() as u16).to_le_bytes());
67
68    for entry in &spec.entries {
69        write_entry(&mut buf, entry);
70    }
71
72    buf
73}
74
75fn write_entry(buf: &mut Vec<u8>, entry: &ObsEntry) {
76    buf.extend_from_slice(&entry.field_id.0.to_le_bytes());
77
78    // Region
79    let (region_tag, region_params) = encode_region(&entry.region);
80    buf.push(region_tag);
81    buf.extend_from_slice(&(region_params.len() as u16).to_le_bytes());
82    for &p in &region_params {
83        buf.extend_from_slice(&p.to_le_bytes());
84    }
85
86    // Transform
87    match &entry.transform {
88        ObsTransform::Identity => {
89            buf.push(TRANSFORM_IDENTITY);
90        }
91        ObsTransform::Normalize { min, max } => {
92            buf.push(TRANSFORM_NORMALIZE);
93            buf.extend_from_slice(&min.to_le_bytes());
94            buf.extend_from_slice(&max.to_le_bytes());
95        }
96    }
97
98    // Dtype
99    match entry.dtype {
100        ObsDtype::F32 => buf.push(0),
101    }
102
103    // Pool
104    match &entry.pool {
105        None => buf.push(POOL_NONE),
106        Some(cfg) => {
107            let tag = match cfg.kernel {
108                PoolKernel::Mean => POOL_MEAN,
109                PoolKernel::Max => POOL_MAX,
110                PoolKernel::Min => POOL_MIN,
111                PoolKernel::Sum => POOL_SUM,
112            };
113            buf.push(tag);
114            buf.extend_from_slice(&(cfg.kernel_size as u32).to_le_bytes());
115            buf.extend_from_slice(&(cfg.stride as u32).to_le_bytes());
116        }
117    }
118}
119
120fn encode_region(region: &ObsRegion) -> (u8, Vec<i32>) {
121    match region {
122        ObsRegion::Fixed(RegionSpec::All) => (REGION_ALL, vec![]),
123        ObsRegion::Fixed(RegionSpec::Disk { center, radius }) => {
124            let mut params: Vec<i32> = center.iter().copied().collect();
125            params.push(*radius as i32);
126            (REGION_DISK, params)
127        }
128        ObsRegion::Fixed(RegionSpec::Rect { min, max }) => {
129            let mut params: Vec<i32> = min.iter().copied().collect();
130            params.extend(max.iter().copied());
131            (REGION_RECT, params)
132        }
133        ObsRegion::Fixed(RegionSpec::Neighbours { center, depth }) => {
134            let mut params: Vec<i32> = center.iter().copied().collect();
135            params.push(*depth as i32);
136            (REGION_NEIGHBOURS, params)
137        }
138        ObsRegion::Fixed(RegionSpec::Coords(coords)) => {
139            let ndim = coords.first().map(|c| c.len()).unwrap_or(0) as i32;
140            let n_coords = coords.len() as i32;
141            let mut params = vec![ndim, n_coords];
142            for c in coords {
143                params.extend(c.iter().copied());
144            }
145            (REGION_COORDS, params)
146        }
147        ObsRegion::AgentDisk { radius } => (REGION_AGENT_DISK, vec![*radius as i32]),
148        ObsRegion::AgentRect { half_extent } => {
149            let params: Vec<i32> = half_extent.iter().map(|&h| h as i32).collect();
150            (REGION_AGENT_RECT, params)
151        }
152    }
153}
154
155/// Deserialize an [`ObsSpec`] from binary bytes.
156pub fn deserialize(bytes: &[u8]) -> Result<ObsSpec, ObsError> {
157    let mut r = Reader::new(bytes);
158
159    // Magic
160    let magic = r.read_bytes(4)?;
161    if magic != MAGIC {
162        return Err(ObsError::InvalidObsSpec {
163            reason: format!(
164                "invalid magic: expected 'MOBS', got '{}'",
165                String::from_utf8_lossy(magic)
166            ),
167        });
168    }
169
170    // Version
171    let version = r.read_u16()?;
172    if version > VERSION {
173        return Err(ObsError::InvalidObsSpec {
174            reason: format!("unsupported version {version}, max supported is {VERSION}"),
175        });
176    }
177
178    // Entries
179    let n_entries = r.read_u16()? as usize;
180    let mut entries = Vec::with_capacity(n_entries);
181    for i in 0..n_entries {
182        entries.push(read_entry(&mut r, i)?);
183    }
184
185    Ok(ObsSpec { entries })
186}
187
188fn read_entry(r: &mut Reader<'_>, idx: usize) -> Result<ObsEntry, ObsError> {
189    let field_id = FieldId(r.read_u32().map_err(|e| truncated(idx, &e))?);
190
191    // Region
192    let region_tag = r.read_u8().map_err(|e| truncated(idx, &e))?;
193    let n_params = r.read_u16().map_err(|e| truncated(idx, &e))? as usize;
194    let mut region_params = Vec::with_capacity(n_params);
195    for _ in 0..n_params {
196        region_params.push(r.read_i32().map_err(|e| truncated(idx, &e))?);
197    }
198    let region = decode_region(region_tag, &region_params, idx)?;
199
200    // Transform
201    let transform_tag = r.read_u8().map_err(|e| truncated(idx, &e))?;
202    let transform = match transform_tag {
203        TRANSFORM_IDENTITY => ObsTransform::Identity,
204        TRANSFORM_NORMALIZE => {
205            let min = r.read_f64().map_err(|e| truncated(idx, &e))?;
206            let max = r.read_f64().map_err(|e| truncated(idx, &e))?;
207            ObsTransform::Normalize { min, max }
208        }
209        other => {
210            return Err(ObsError::InvalidObsSpec {
211                reason: format!("entry {idx}: unknown transform type {other}"),
212            });
213        }
214    };
215
216    // Dtype
217    let dtype_val = r.read_u8().map_err(|e| truncated(idx, &e))?;
218    let dtype = match dtype_val {
219        0 => ObsDtype::F32,
220        other => {
221            return Err(ObsError::InvalidObsSpec {
222                reason: format!("entry {idx}: unknown dtype {other}"),
223            });
224        }
225    };
226
227    // Pool
228    let pool_tag = r.read_u8().map_err(|e| truncated(idx, &e))?;
229    let pool = if pool_tag == POOL_NONE {
230        None
231    } else {
232        let kernel = match pool_tag {
233            POOL_MEAN => PoolKernel::Mean,
234            POOL_MAX => PoolKernel::Max,
235            POOL_MIN => PoolKernel::Min,
236            POOL_SUM => PoolKernel::Sum,
237            other => {
238                return Err(ObsError::InvalidObsSpec {
239                    reason: format!("entry {idx}: unknown pool kernel {other}"),
240                });
241            }
242        };
243        let kernel_size = r.read_u32().map_err(|e| truncated(idx, &e))? as usize;
244        let stride = r.read_u32().map_err(|e| truncated(idx, &e))? as usize;
245        if kernel_size == 0 || stride == 0 {
246            return Err(ObsError::InvalidObsSpec {
247                reason: format!("entry {idx}: pool kernel_size and stride must be > 0"),
248            });
249        }
250        Some(PoolConfig {
251            kernel,
252            kernel_size,
253            stride,
254        })
255    };
256
257    Ok(ObsEntry {
258        field_id,
259        region,
260        pool,
261        transform,
262        dtype,
263    })
264}
265
266fn decode_region(tag: u8, params: &[i32], idx: usize) -> Result<ObsRegion, ObsError> {
267    match tag {
268        REGION_ALL => Ok(ObsRegion::Fixed(RegionSpec::All)),
269        REGION_DISK => {
270            if params.len() < 2 {
271                return Err(ObsError::InvalidObsSpec {
272                    reason: format!("entry {idx}: Disk region needs at least 2 params"),
273                });
274            }
275            let ndim = params.len() - 1;
276            let center: SmallVec<[i32; 4]> = params[..ndim].iter().copied().collect();
277            let radius = params[ndim] as u32;
278            Ok(ObsRegion::Fixed(RegionSpec::Disk { center, radius }))
279        }
280        REGION_RECT => {
281            if params.is_empty() || !params.len().is_multiple_of(2) {
282                return Err(ObsError::InvalidObsSpec {
283                    reason: format!("entry {idx}: Rect region needs even number of params"),
284                });
285            }
286            let ndim = params.len() / 2;
287            let min: SmallVec<[i32; 4]> = params[..ndim].iter().copied().collect();
288            let max: SmallVec<[i32; 4]> = params[ndim..].iter().copied().collect();
289            Ok(ObsRegion::Fixed(RegionSpec::Rect { min, max }))
290        }
291        REGION_NEIGHBOURS => {
292            if params.len() < 2 {
293                return Err(ObsError::InvalidObsSpec {
294                    reason: format!("entry {idx}: Neighbours region needs at least 2 params"),
295                });
296            }
297            let ndim = params.len() - 1;
298            let center: SmallVec<[i32; 4]> = params[..ndim].iter().copied().collect();
299            let depth = params[ndim] as u32;
300            Ok(ObsRegion::Fixed(RegionSpec::Neighbours { center, depth }))
301        }
302        REGION_COORDS => {
303            if params.len() < 2 {
304                return Err(ObsError::InvalidObsSpec {
305                    reason: format!("entry {idx}: Coords region needs ndim + n_coords header"),
306                });
307            }
308            let ndim = params[0] as usize;
309            let n_coords = params[1] as usize;
310            let data = &params[2..];
311            if ndim == 0 || data.len() != ndim * n_coords {
312                return Err(ObsError::InvalidObsSpec {
313                    reason: format!(
314                        "entry {idx}: Coords expected {} values, got {}",
315                        ndim * n_coords,
316                        data.len()
317                    ),
318                });
319            }
320            let coords: Vec<SmallVec<[i32; 4]>> = data
321                .chunks(ndim)
322                .map(|chunk| chunk.iter().copied().collect())
323                .collect();
324            Ok(ObsRegion::Fixed(RegionSpec::Coords(coords)))
325        }
326        REGION_AGENT_DISK => {
327            if params.len() != 1 {
328                return Err(ObsError::InvalidObsSpec {
329                    reason: format!("entry {idx}: AgentDisk needs exactly 1 param (radius)"),
330                });
331            }
332            Ok(ObsRegion::AgentDisk {
333                radius: params[0] as u32,
334            })
335        }
336        REGION_AGENT_RECT => {
337            if params.is_empty() {
338                return Err(ObsError::InvalidObsSpec {
339                    reason: format!("entry {idx}: AgentRect needs at least 1 param"),
340                });
341            }
342            let half_extent: SmallVec<[u32; 4]> = params.iter().map(|&p| p as u32).collect();
343            Ok(ObsRegion::AgentRect { half_extent })
344        }
345        other => Err(ObsError::InvalidObsSpec {
346            reason: format!("entry {idx}: unknown region type {other}"),
347        }),
348    }
349}
350
351fn truncated(idx: usize, _inner: &ObsError) -> ObsError {
352    ObsError::InvalidObsSpec {
353        reason: format!("entry {idx}: truncated data"),
354    }
355}
356
357/// Simple cursor reader for safe byte parsing.
358struct Reader<'a> {
359    data: &'a [u8],
360    pos: usize,
361}
362
363impl<'a> Reader<'a> {
364    fn new(data: &'a [u8]) -> Self {
365        Self { data, pos: 0 }
366    }
367
368    fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ObsError> {
369        if self.pos + n > self.data.len() {
370            return Err(ObsError::InvalidObsSpec {
371                reason: "unexpected end of data".into(),
372            });
373        }
374        let slice = &self.data[self.pos..self.pos + n];
375        self.pos += n;
376        Ok(slice)
377    }
378
379    fn read_u8(&mut self) -> Result<u8, ObsError> {
380        Ok(self.read_bytes(1)?[0])
381    }
382
383    fn read_u16(&mut self) -> Result<u16, ObsError> {
384        let b = self.read_bytes(2)?;
385        Ok(u16::from_le_bytes([b[0], b[1]]))
386    }
387
388    fn read_u32(&mut self) -> Result<u32, ObsError> {
389        let b = self.read_bytes(4)?;
390        Ok(u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
391    }
392
393    fn read_i32(&mut self) -> Result<i32, ObsError> {
394        let b = self.read_bytes(4)?;
395        Ok(i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
396    }
397
398    fn read_f64(&mut self) -> Result<f64, ObsError> {
399        let b = self.read_bytes(8)?;
400        Ok(f64::from_le_bytes([
401            b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
402        ]))
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use smallvec::smallvec;
410
411    fn round_trip(spec: &ObsSpec) -> ObsSpec {
412        let bytes = serialize(spec);
413        deserialize(&bytes).unwrap()
414    }
415
416    #[test]
417    fn round_trip_all_identity() {
418        let spec = ObsSpec {
419            entries: vec![ObsEntry {
420                field_id: FieldId(0),
421                region: ObsRegion::Fixed(RegionSpec::All),
422                pool: None,
423                transform: ObsTransform::Identity,
424                dtype: ObsDtype::F32,
425            }],
426        };
427        assert_eq!(round_trip(&spec), spec);
428    }
429
430    #[test]
431    fn round_trip_normalize() {
432        let spec = ObsSpec {
433            entries: vec![ObsEntry {
434                field_id: FieldId(7),
435                region: ObsRegion::Fixed(RegionSpec::All),
436                pool: None,
437                transform: ObsTransform::Normalize {
438                    min: -1.0,
439                    max: 5.0,
440                },
441                dtype: ObsDtype::F32,
442            }],
443        };
444        assert_eq!(round_trip(&spec), spec);
445    }
446
447    #[test]
448    fn round_trip_disk_region() {
449        let spec = ObsSpec {
450            entries: vec![ObsEntry {
451                field_id: FieldId(0),
452                region: ObsRegion::Fixed(RegionSpec::Disk {
453                    center: smallvec![3, 4],
454                    radius: 5,
455                }),
456                pool: None,
457                transform: ObsTransform::Identity,
458                dtype: ObsDtype::F32,
459            }],
460        };
461        assert_eq!(round_trip(&spec), spec);
462    }
463
464    #[test]
465    fn round_trip_rect_region() {
466        let spec = ObsSpec {
467            entries: vec![ObsEntry {
468                field_id: FieldId(0),
469                region: ObsRegion::Fixed(RegionSpec::Rect {
470                    min: smallvec![1, 2],
471                    max: smallvec![5, 8],
472                }),
473                pool: None,
474                transform: ObsTransform::Identity,
475                dtype: ObsDtype::F32,
476            }],
477        };
478        assert_eq!(round_trip(&spec), spec);
479    }
480
481    #[test]
482    fn round_trip_neighbours_region() {
483        let spec = ObsSpec {
484            entries: vec![ObsEntry {
485                field_id: FieldId(0),
486                region: ObsRegion::Fixed(RegionSpec::Neighbours {
487                    center: smallvec![3, 4],
488                    depth: 2,
489                }),
490                pool: None,
491                transform: ObsTransform::Identity,
492                dtype: ObsDtype::F32,
493            }],
494        };
495        assert_eq!(round_trip(&spec), spec);
496    }
497
498    #[test]
499    fn round_trip_coords_region() {
500        let spec = ObsSpec {
501            entries: vec![ObsEntry {
502                field_id: FieldId(0),
503                region: ObsRegion::Fixed(RegionSpec::Coords(vec![
504                    smallvec![0, 0],
505                    smallvec![1, 2],
506                    smallvec![3, 4],
507                ])),
508                pool: None,
509                transform: ObsTransform::Identity,
510                dtype: ObsDtype::F32,
511            }],
512        };
513        assert_eq!(round_trip(&spec), spec);
514    }
515
516    #[test]
517    fn round_trip_agent_disk() {
518        let spec = ObsSpec {
519            entries: vec![ObsEntry {
520                field_id: FieldId(0),
521                region: ObsRegion::AgentDisk { radius: 3 },
522                pool: None,
523                transform: ObsTransform::Identity,
524                dtype: ObsDtype::F32,
525            }],
526        };
527        assert_eq!(round_trip(&spec), spec);
528    }
529
530    #[test]
531    fn round_trip_agent_rect() {
532        let spec = ObsSpec {
533            entries: vec![ObsEntry {
534                field_id: FieldId(0),
535                region: ObsRegion::AgentRect {
536                    half_extent: smallvec![3, 4],
537                },
538                pool: None,
539                transform: ObsTransform::Identity,
540                dtype: ObsDtype::F32,
541            }],
542        };
543        assert_eq!(round_trip(&spec), spec);
544    }
545
546    #[test]
547    fn round_trip_with_pool() {
548        let spec = ObsSpec {
549            entries: vec![ObsEntry {
550                field_id: FieldId(2),
551                region: ObsRegion::AgentDisk { radius: 5 },
552                pool: Some(PoolConfig {
553                    kernel: PoolKernel::Mean,
554                    kernel_size: 2,
555                    stride: 2,
556                }),
557                transform: ObsTransform::Normalize {
558                    min: 0.0,
559                    max: 100.0,
560                },
561                dtype: ObsDtype::F32,
562            }],
563        };
564        assert_eq!(round_trip(&spec), spec);
565    }
566
567    #[test]
568    fn round_trip_multiple_entries() {
569        let spec = ObsSpec {
570            entries: vec![
571                ObsEntry {
572                    field_id: FieldId(0),
573                    region: ObsRegion::Fixed(RegionSpec::All),
574                    pool: None,
575                    transform: ObsTransform::Identity,
576                    dtype: ObsDtype::F32,
577                },
578                ObsEntry {
579                    field_id: FieldId(1),
580                    region: ObsRegion::AgentDisk { radius: 3 },
581                    pool: Some(PoolConfig {
582                        kernel: PoolKernel::Max,
583                        kernel_size: 3,
584                        stride: 1,
585                    }),
586                    transform: ObsTransform::Normalize {
587                        min: -5.0,
588                        max: 5.0,
589                    },
590                    dtype: ObsDtype::F32,
591                },
592            ],
593        };
594        assert_eq!(round_trip(&spec), spec);
595    }
596
597    #[test]
598    fn round_trip_all_pool_kernels() {
599        for kernel in [
600            PoolKernel::Mean,
601            PoolKernel::Max,
602            PoolKernel::Min,
603            PoolKernel::Sum,
604        ] {
605            let spec = ObsSpec {
606                entries: vec![ObsEntry {
607                    field_id: FieldId(0),
608                    region: ObsRegion::Fixed(RegionSpec::All),
609                    pool: Some(PoolConfig {
610                        kernel,
611                        kernel_size: 2,
612                        stride: 1,
613                    }),
614                    transform: ObsTransform::Identity,
615                    dtype: ObsDtype::F32,
616                }],
617            };
618            assert_eq!(round_trip(&spec), spec, "failed for kernel {kernel:?}");
619        }
620    }
621
622    #[test]
623    fn version_rejection() {
624        let spec = ObsSpec {
625            entries: vec![ObsEntry {
626                field_id: FieldId(0),
627                region: ObsRegion::Fixed(RegionSpec::All),
628                pool: None,
629                transform: ObsTransform::Identity,
630                dtype: ObsDtype::F32,
631            }],
632        };
633        let mut bytes = serialize(&spec);
634        // Set version to 99
635        bytes[4] = 99;
636        bytes[5] = 0;
637        let err = deserialize(&bytes).unwrap_err();
638        assert!(format!("{err:?}").contains("unsupported version"));
639    }
640
641    #[test]
642    fn truncated_bytes_error() {
643        assert!(deserialize(&[]).is_err());
644        assert!(deserialize(&[0, 0]).is_err());
645        assert!(deserialize(b"MOBS").is_err());
646    }
647
648    #[test]
649    fn invalid_magic_error() {
650        let bytes = b"BAD!\x01\x00\x00\x00";
651        let err = deserialize(bytes).unwrap_err();
652        assert!(format!("{err:?}").contains("invalid magic"));
653    }
654}