1use crate::spec::{ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel};
31use murk_core::error::ObsError;
32use murk_core::FieldId;
33use murk_space::RegionSpec;
34use smallvec::SmallVec;
35
36const MAGIC: &[u8; 4] = b"MOBS";
37const VERSION: u16 = 1;
38
39const REGION_ALL: u8 = 0;
41const REGION_DISK: u8 = 1;
42const REGION_RECT: u8 = 2;
43const REGION_NEIGHBOURS: u8 = 3;
44const REGION_COORDS: u8 = 4;
45const REGION_AGENT_DISK: u8 = 5;
46const REGION_AGENT_RECT: u8 = 6;
47
48const TRANSFORM_IDENTITY: u8 = 0;
50const TRANSFORM_NORMALIZE: u8 = 1;
51
52const POOL_NONE: u8 = 0;
54const POOL_MEAN: u8 = 1;
55const POOL_MAX: u8 = 2;
56const POOL_MIN: u8 = 3;
57const POOL_SUM: u8 = 4;
58
59pub fn serialize(spec: &ObsSpec) -> Vec<u8> {
61 let mut buf = Vec::with_capacity(128);
62
63 buf.extend_from_slice(MAGIC);
65 buf.extend_from_slice(&VERSION.to_le_bytes());
66 buf.extend_from_slice(&(spec.entries.len() as u16).to_le_bytes());
67
68 for entry in &spec.entries {
69 write_entry(&mut buf, entry);
70 }
71
72 buf
73}
74
75fn write_entry(buf: &mut Vec<u8>, entry: &ObsEntry) {
76 buf.extend_from_slice(&entry.field_id.0.to_le_bytes());
77
78 let (region_tag, region_params) = encode_region(&entry.region);
80 buf.push(region_tag);
81 buf.extend_from_slice(&(region_params.len() as u16).to_le_bytes());
82 for &p in ®ion_params {
83 buf.extend_from_slice(&p.to_le_bytes());
84 }
85
86 match &entry.transform {
88 ObsTransform::Identity => {
89 buf.push(TRANSFORM_IDENTITY);
90 }
91 ObsTransform::Normalize { min, max } => {
92 buf.push(TRANSFORM_NORMALIZE);
93 buf.extend_from_slice(&min.to_le_bytes());
94 buf.extend_from_slice(&max.to_le_bytes());
95 }
96 }
97
98 match entry.dtype {
100 ObsDtype::F32 => buf.push(0),
101 }
102
103 match &entry.pool {
105 None => buf.push(POOL_NONE),
106 Some(cfg) => {
107 let tag = match cfg.kernel {
108 PoolKernel::Mean => POOL_MEAN,
109 PoolKernel::Max => POOL_MAX,
110 PoolKernel::Min => POOL_MIN,
111 PoolKernel::Sum => POOL_SUM,
112 };
113 buf.push(tag);
114 buf.extend_from_slice(&(cfg.kernel_size as u32).to_le_bytes());
115 buf.extend_from_slice(&(cfg.stride as u32).to_le_bytes());
116 }
117 }
118}
119
120fn encode_region(region: &ObsRegion) -> (u8, Vec<i32>) {
121 match region {
122 ObsRegion::Fixed(RegionSpec::All) => (REGION_ALL, vec![]),
123 ObsRegion::Fixed(RegionSpec::Disk { center, radius }) => {
124 let mut params: Vec<i32> = center.iter().copied().collect();
125 params.push(*radius as i32);
126 (REGION_DISK, params)
127 }
128 ObsRegion::Fixed(RegionSpec::Rect { min, max }) => {
129 let mut params: Vec<i32> = min.iter().copied().collect();
130 params.extend(max.iter().copied());
131 (REGION_RECT, params)
132 }
133 ObsRegion::Fixed(RegionSpec::Neighbours { center, depth }) => {
134 let mut params: Vec<i32> = center.iter().copied().collect();
135 params.push(*depth as i32);
136 (REGION_NEIGHBOURS, params)
137 }
138 ObsRegion::Fixed(RegionSpec::Coords(coords)) => {
139 let ndim = coords.first().map(|c| c.len()).unwrap_or(0) as i32;
140 let n_coords = coords.len() as i32;
141 let mut params = vec![ndim, n_coords];
142 for c in coords {
143 params.extend(c.iter().copied());
144 }
145 (REGION_COORDS, params)
146 }
147 ObsRegion::AgentDisk { radius } => (REGION_AGENT_DISK, vec![*radius as i32]),
148 ObsRegion::AgentRect { half_extent } => {
149 let params: Vec<i32> = half_extent.iter().map(|&h| h as i32).collect();
150 (REGION_AGENT_RECT, params)
151 }
152 }
153}
154
155pub fn deserialize(bytes: &[u8]) -> Result<ObsSpec, ObsError> {
157 let mut r = Reader::new(bytes);
158
159 let magic = r.read_bytes(4)?;
161 if magic != MAGIC {
162 return Err(ObsError::InvalidObsSpec {
163 reason: format!(
164 "invalid magic: expected 'MOBS', got '{}'",
165 String::from_utf8_lossy(magic)
166 ),
167 });
168 }
169
170 let version = r.read_u16()?;
172 if version > VERSION {
173 return Err(ObsError::InvalidObsSpec {
174 reason: format!("unsupported version {version}, max supported is {VERSION}"),
175 });
176 }
177
178 let n_entries = r.read_u16()? as usize;
180 let mut entries = Vec::with_capacity(n_entries);
181 for i in 0..n_entries {
182 entries.push(read_entry(&mut r, i)?);
183 }
184
185 Ok(ObsSpec { entries })
186}
187
188fn read_entry(r: &mut Reader<'_>, idx: usize) -> Result<ObsEntry, ObsError> {
189 let field_id = FieldId(r.read_u32().map_err(|e| truncated(idx, &e))?);
190
191 let region_tag = r.read_u8().map_err(|e| truncated(idx, &e))?;
193 let n_params = r.read_u16().map_err(|e| truncated(idx, &e))? as usize;
194 let mut region_params = Vec::with_capacity(n_params);
195 for _ in 0..n_params {
196 region_params.push(r.read_i32().map_err(|e| truncated(idx, &e))?);
197 }
198 let region = decode_region(region_tag, ®ion_params, idx)?;
199
200 let transform_tag = r.read_u8().map_err(|e| truncated(idx, &e))?;
202 let transform = match transform_tag {
203 TRANSFORM_IDENTITY => ObsTransform::Identity,
204 TRANSFORM_NORMALIZE => {
205 let min = r.read_f64().map_err(|e| truncated(idx, &e))?;
206 let max = r.read_f64().map_err(|e| truncated(idx, &e))?;
207 ObsTransform::Normalize { min, max }
208 }
209 other => {
210 return Err(ObsError::InvalidObsSpec {
211 reason: format!("entry {idx}: unknown transform type {other}"),
212 });
213 }
214 };
215
216 let dtype_val = r.read_u8().map_err(|e| truncated(idx, &e))?;
218 let dtype = match dtype_val {
219 0 => ObsDtype::F32,
220 other => {
221 return Err(ObsError::InvalidObsSpec {
222 reason: format!("entry {idx}: unknown dtype {other}"),
223 });
224 }
225 };
226
227 let pool_tag = r.read_u8().map_err(|e| truncated(idx, &e))?;
229 let pool = if pool_tag == POOL_NONE {
230 None
231 } else {
232 let kernel = match pool_tag {
233 POOL_MEAN => PoolKernel::Mean,
234 POOL_MAX => PoolKernel::Max,
235 POOL_MIN => PoolKernel::Min,
236 POOL_SUM => PoolKernel::Sum,
237 other => {
238 return Err(ObsError::InvalidObsSpec {
239 reason: format!("entry {idx}: unknown pool kernel {other}"),
240 });
241 }
242 };
243 let kernel_size = r.read_u32().map_err(|e| truncated(idx, &e))? as usize;
244 let stride = r.read_u32().map_err(|e| truncated(idx, &e))? as usize;
245 if kernel_size == 0 || stride == 0 {
246 return Err(ObsError::InvalidObsSpec {
247 reason: format!("entry {idx}: pool kernel_size and stride must be > 0"),
248 });
249 }
250 Some(PoolConfig {
251 kernel,
252 kernel_size,
253 stride,
254 })
255 };
256
257 Ok(ObsEntry {
258 field_id,
259 region,
260 pool,
261 transform,
262 dtype,
263 })
264}
265
266fn decode_region(tag: u8, params: &[i32], idx: usize) -> Result<ObsRegion, ObsError> {
267 match tag {
268 REGION_ALL => Ok(ObsRegion::Fixed(RegionSpec::All)),
269 REGION_DISK => {
270 if params.len() < 2 {
271 return Err(ObsError::InvalidObsSpec {
272 reason: format!("entry {idx}: Disk region needs at least 2 params"),
273 });
274 }
275 let ndim = params.len() - 1;
276 let center: SmallVec<[i32; 4]> = params[..ndim].iter().copied().collect();
277 let radius = params[ndim] as u32;
278 Ok(ObsRegion::Fixed(RegionSpec::Disk { center, radius }))
279 }
280 REGION_RECT => {
281 if params.is_empty() || !params.len().is_multiple_of(2) {
282 return Err(ObsError::InvalidObsSpec {
283 reason: format!("entry {idx}: Rect region needs even number of params"),
284 });
285 }
286 let ndim = params.len() / 2;
287 let min: SmallVec<[i32; 4]> = params[..ndim].iter().copied().collect();
288 let max: SmallVec<[i32; 4]> = params[ndim..].iter().copied().collect();
289 Ok(ObsRegion::Fixed(RegionSpec::Rect { min, max }))
290 }
291 REGION_NEIGHBOURS => {
292 if params.len() < 2 {
293 return Err(ObsError::InvalidObsSpec {
294 reason: format!("entry {idx}: Neighbours region needs at least 2 params"),
295 });
296 }
297 let ndim = params.len() - 1;
298 let center: SmallVec<[i32; 4]> = params[..ndim].iter().copied().collect();
299 let depth = params[ndim] as u32;
300 Ok(ObsRegion::Fixed(RegionSpec::Neighbours { center, depth }))
301 }
302 REGION_COORDS => {
303 if params.len() < 2 {
304 return Err(ObsError::InvalidObsSpec {
305 reason: format!("entry {idx}: Coords region needs ndim + n_coords header"),
306 });
307 }
308 let ndim = params[0] as usize;
309 let n_coords = params[1] as usize;
310 let data = ¶ms[2..];
311 if ndim == 0 || data.len() != ndim * n_coords {
312 return Err(ObsError::InvalidObsSpec {
313 reason: format!(
314 "entry {idx}: Coords expected {} values, got {}",
315 ndim * n_coords,
316 data.len()
317 ),
318 });
319 }
320 let coords: Vec<SmallVec<[i32; 4]>> = data
321 .chunks(ndim)
322 .map(|chunk| chunk.iter().copied().collect())
323 .collect();
324 Ok(ObsRegion::Fixed(RegionSpec::Coords(coords)))
325 }
326 REGION_AGENT_DISK => {
327 if params.len() != 1 {
328 return Err(ObsError::InvalidObsSpec {
329 reason: format!("entry {idx}: AgentDisk needs exactly 1 param (radius)"),
330 });
331 }
332 Ok(ObsRegion::AgentDisk {
333 radius: params[0] as u32,
334 })
335 }
336 REGION_AGENT_RECT => {
337 if params.is_empty() {
338 return Err(ObsError::InvalidObsSpec {
339 reason: format!("entry {idx}: AgentRect needs at least 1 param"),
340 });
341 }
342 let half_extent: SmallVec<[u32; 4]> = params.iter().map(|&p| p as u32).collect();
343 Ok(ObsRegion::AgentRect { half_extent })
344 }
345 other => Err(ObsError::InvalidObsSpec {
346 reason: format!("entry {idx}: unknown region type {other}"),
347 }),
348 }
349}
350
351fn truncated(idx: usize, _inner: &ObsError) -> ObsError {
352 ObsError::InvalidObsSpec {
353 reason: format!("entry {idx}: truncated data"),
354 }
355}
356
357struct Reader<'a> {
359 data: &'a [u8],
360 pos: usize,
361}
362
363impl<'a> Reader<'a> {
364 fn new(data: &'a [u8]) -> Self {
365 Self { data, pos: 0 }
366 }
367
368 fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ObsError> {
369 if self.pos + n > self.data.len() {
370 return Err(ObsError::InvalidObsSpec {
371 reason: "unexpected end of data".into(),
372 });
373 }
374 let slice = &self.data[self.pos..self.pos + n];
375 self.pos += n;
376 Ok(slice)
377 }
378
379 fn read_u8(&mut self) -> Result<u8, ObsError> {
380 Ok(self.read_bytes(1)?[0])
381 }
382
383 fn read_u16(&mut self) -> Result<u16, ObsError> {
384 let b = self.read_bytes(2)?;
385 Ok(u16::from_le_bytes([b[0], b[1]]))
386 }
387
388 fn read_u32(&mut self) -> Result<u32, ObsError> {
389 let b = self.read_bytes(4)?;
390 Ok(u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
391 }
392
393 fn read_i32(&mut self) -> Result<i32, ObsError> {
394 let b = self.read_bytes(4)?;
395 Ok(i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
396 }
397
398 fn read_f64(&mut self) -> Result<f64, ObsError> {
399 let b = self.read_bytes(8)?;
400 Ok(f64::from_le_bytes([
401 b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
402 ]))
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use smallvec::smallvec;
410
411 fn round_trip(spec: &ObsSpec) -> ObsSpec {
412 let bytes = serialize(spec);
413 deserialize(&bytes).unwrap()
414 }
415
416 #[test]
417 fn round_trip_all_identity() {
418 let spec = ObsSpec {
419 entries: vec![ObsEntry {
420 field_id: FieldId(0),
421 region: ObsRegion::Fixed(RegionSpec::All),
422 pool: None,
423 transform: ObsTransform::Identity,
424 dtype: ObsDtype::F32,
425 }],
426 };
427 assert_eq!(round_trip(&spec), spec);
428 }
429
430 #[test]
431 fn round_trip_normalize() {
432 let spec = ObsSpec {
433 entries: vec![ObsEntry {
434 field_id: FieldId(7),
435 region: ObsRegion::Fixed(RegionSpec::All),
436 pool: None,
437 transform: ObsTransform::Normalize {
438 min: -1.0,
439 max: 5.0,
440 },
441 dtype: ObsDtype::F32,
442 }],
443 };
444 assert_eq!(round_trip(&spec), spec);
445 }
446
447 #[test]
448 fn round_trip_disk_region() {
449 let spec = ObsSpec {
450 entries: vec![ObsEntry {
451 field_id: FieldId(0),
452 region: ObsRegion::Fixed(RegionSpec::Disk {
453 center: smallvec![3, 4],
454 radius: 5,
455 }),
456 pool: None,
457 transform: ObsTransform::Identity,
458 dtype: ObsDtype::F32,
459 }],
460 };
461 assert_eq!(round_trip(&spec), spec);
462 }
463
464 #[test]
465 fn round_trip_rect_region() {
466 let spec = ObsSpec {
467 entries: vec![ObsEntry {
468 field_id: FieldId(0),
469 region: ObsRegion::Fixed(RegionSpec::Rect {
470 min: smallvec![1, 2],
471 max: smallvec![5, 8],
472 }),
473 pool: None,
474 transform: ObsTransform::Identity,
475 dtype: ObsDtype::F32,
476 }],
477 };
478 assert_eq!(round_trip(&spec), spec);
479 }
480
481 #[test]
482 fn round_trip_neighbours_region() {
483 let spec = ObsSpec {
484 entries: vec![ObsEntry {
485 field_id: FieldId(0),
486 region: ObsRegion::Fixed(RegionSpec::Neighbours {
487 center: smallvec![3, 4],
488 depth: 2,
489 }),
490 pool: None,
491 transform: ObsTransform::Identity,
492 dtype: ObsDtype::F32,
493 }],
494 };
495 assert_eq!(round_trip(&spec), spec);
496 }
497
498 #[test]
499 fn round_trip_coords_region() {
500 let spec = ObsSpec {
501 entries: vec![ObsEntry {
502 field_id: FieldId(0),
503 region: ObsRegion::Fixed(RegionSpec::Coords(vec![
504 smallvec![0, 0],
505 smallvec![1, 2],
506 smallvec![3, 4],
507 ])),
508 pool: None,
509 transform: ObsTransform::Identity,
510 dtype: ObsDtype::F32,
511 }],
512 };
513 assert_eq!(round_trip(&spec), spec);
514 }
515
516 #[test]
517 fn round_trip_agent_disk() {
518 let spec = ObsSpec {
519 entries: vec![ObsEntry {
520 field_id: FieldId(0),
521 region: ObsRegion::AgentDisk { radius: 3 },
522 pool: None,
523 transform: ObsTransform::Identity,
524 dtype: ObsDtype::F32,
525 }],
526 };
527 assert_eq!(round_trip(&spec), spec);
528 }
529
530 #[test]
531 fn round_trip_agent_rect() {
532 let spec = ObsSpec {
533 entries: vec![ObsEntry {
534 field_id: FieldId(0),
535 region: ObsRegion::AgentRect {
536 half_extent: smallvec![3, 4],
537 },
538 pool: None,
539 transform: ObsTransform::Identity,
540 dtype: ObsDtype::F32,
541 }],
542 };
543 assert_eq!(round_trip(&spec), spec);
544 }
545
546 #[test]
547 fn round_trip_with_pool() {
548 let spec = ObsSpec {
549 entries: vec![ObsEntry {
550 field_id: FieldId(2),
551 region: ObsRegion::AgentDisk { radius: 5 },
552 pool: Some(PoolConfig {
553 kernel: PoolKernel::Mean,
554 kernel_size: 2,
555 stride: 2,
556 }),
557 transform: ObsTransform::Normalize {
558 min: 0.0,
559 max: 100.0,
560 },
561 dtype: ObsDtype::F32,
562 }],
563 };
564 assert_eq!(round_trip(&spec), spec);
565 }
566
567 #[test]
568 fn round_trip_multiple_entries() {
569 let spec = ObsSpec {
570 entries: vec![
571 ObsEntry {
572 field_id: FieldId(0),
573 region: ObsRegion::Fixed(RegionSpec::All),
574 pool: None,
575 transform: ObsTransform::Identity,
576 dtype: ObsDtype::F32,
577 },
578 ObsEntry {
579 field_id: FieldId(1),
580 region: ObsRegion::AgentDisk { radius: 3 },
581 pool: Some(PoolConfig {
582 kernel: PoolKernel::Max,
583 kernel_size: 3,
584 stride: 1,
585 }),
586 transform: ObsTransform::Normalize {
587 min: -5.0,
588 max: 5.0,
589 },
590 dtype: ObsDtype::F32,
591 },
592 ],
593 };
594 assert_eq!(round_trip(&spec), spec);
595 }
596
597 #[test]
598 fn round_trip_all_pool_kernels() {
599 for kernel in [
600 PoolKernel::Mean,
601 PoolKernel::Max,
602 PoolKernel::Min,
603 PoolKernel::Sum,
604 ] {
605 let spec = ObsSpec {
606 entries: vec![ObsEntry {
607 field_id: FieldId(0),
608 region: ObsRegion::Fixed(RegionSpec::All),
609 pool: Some(PoolConfig {
610 kernel,
611 kernel_size: 2,
612 stride: 1,
613 }),
614 transform: ObsTransform::Identity,
615 dtype: ObsDtype::F32,
616 }],
617 };
618 assert_eq!(round_trip(&spec), spec, "failed for kernel {kernel:?}");
619 }
620 }
621
622 #[test]
623 fn version_rejection() {
624 let spec = ObsSpec {
625 entries: vec![ObsEntry {
626 field_id: FieldId(0),
627 region: ObsRegion::Fixed(RegionSpec::All),
628 pool: None,
629 transform: ObsTransform::Identity,
630 dtype: ObsDtype::F32,
631 }],
632 };
633 let mut bytes = serialize(&spec);
634 bytes[4] = 99;
636 bytes[5] = 0;
637 let err = deserialize(&bytes).unwrap_err();
638 assert!(format!("{err:?}").contains("unsupported version"));
639 }
640
641 #[test]
642 fn truncated_bytes_error() {
643 assert!(deserialize(&[]).is_err());
644 assert!(deserialize(&[0, 0]).is_err());
645 assert!(deserialize(b"MOBS").is_err());
646 }
647
648 #[test]
649 fn invalid_magic_error() {
650 let bytes = b"BAD!\x01\x00\x00\x00";
651 let err = deserialize(bytes).unwrap_err();
652 assert!(format!("{err:?}").contains("invalid magic"));
653 }
654}