1use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d_into;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27#[derive(Debug)]
29pub struct ObsPlanResult {
30 pub plan: ObsPlan,
32 pub output_len: usize,
34 pub entry_shapes: Vec<Vec<usize>>,
36 pub mask_len: usize,
38}
39
40#[derive(Debug)]
48pub struct ObsPlan {
49 strategy: PlanStrategy,
50 output_len: usize,
52 mask_len: usize,
54 compiled_generation: Option<WorldGenerationId>,
56}
57
58#[derive(Debug, Clone)]
63struct GatherOp {
64 field_data_idx: usize,
66 tensor_idx: usize,
68}
69
70#[derive(Debug)]
72struct CompiledEntry {
73 field_id: FieldId,
74 transform: ObsTransform,
75 #[allow(dead_code)]
76 dtype: ObsDtype,
77 output_offset: usize,
79 mask_offset: usize,
81 element_count: usize,
83 gather_ops: Vec<GatherOp>,
85 valid_mask: Vec<u8>,
87 valid_count: usize,
89 #[allow(dead_code)]
91 valid_ratio: f64,
92}
93
94#[derive(Debug, Clone)]
100struct TemplateOp {
101 relative: Coord,
103 tensor_idx: usize,
105 stride_offset: isize,
108 in_disk: bool,
111}
112
113#[derive(Debug)]
119struct AgentCompiledEntry {
120 field_id: FieldId,
121 pool: Option<PoolConfig>,
122 transform: ObsTransform,
123 #[allow(dead_code)]
124 dtype: ObsDtype,
125 output_offset: usize,
127 mask_offset: usize,
129 element_count: usize,
131 pre_pool_element_count: usize,
133 pre_pool_shape: Vec<usize>,
135 active_ops: Vec<TemplateOp>,
137 radius: u32,
139}
140
141#[derive(Debug)]
143struct StandardPlanData {
144 fixed_entries: Vec<CompiledEntry>,
146 agent_entries: Vec<AgentCompiledEntry>,
148 geometry: Option<GridGeometry>,
150}
151
152#[derive(Debug)]
154struct SimplePlanData {
155 entries: Vec<CompiledEntry>,
156 total_valid: usize,
157 total_elements: usize,
158}
159
160#[derive(Debug)]
162enum PlanStrategy {
163 Simple(SimplePlanData),
165 Standard(StandardPlanData),
167}
168
169impl ObsPlan {
170 pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
177 if spec.entries.is_empty() {
178 return Err(ObsError::InvalidObsSpec {
179 reason: "ObsSpec has no entries".into(),
180 });
181 }
182
183 for (i, entry) in spec.entries.iter().enumerate() {
185 if let ObsTransform::Normalize { min, max } = &entry.transform {
186 if !min.is_finite() || !max.is_finite() {
187 return Err(ObsError::InvalidObsSpec {
188 reason: format!(
189 "entry {i}: Normalize min/max must be finite, got min={min}, max={max}"
190 ),
191 });
192 }
193 if min > max {
194 return Err(ObsError::InvalidObsSpec {
195 reason: format!("entry {i}: Normalize min ({min}) must be <= max ({max})"),
196 });
197 }
198 }
199 }
200
201 let has_agent = spec.entries.iter().any(|e| {
202 matches!(
203 e.region,
204 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
205 )
206 });
207
208 if has_agent {
209 Self::compile_standard(spec, space)
210 } else {
211 Self::compile_simple(spec, space)
212 }
213 }
214
215 fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
217 let canonical = space.canonical_ordering();
218 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
219 .into_iter()
220 .enumerate()
221 .map(|(idx, coord)| (coord, idx))
222 .collect();
223
224 let mut entries = Vec::with_capacity(spec.entries.len());
225 let mut total_valid = 0usize;
226 let mut total_elements = 0usize;
227 let mut output_offset = 0usize;
228 let mut mask_offset = 0usize;
229 let mut entry_shapes = Vec::with_capacity(spec.entries.len());
230
231 for (i, entry) in spec.entries.iter().enumerate() {
232 let fixed_region = match &entry.region {
233 ObsRegion::Fixed(spec) => spec,
234 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
235 return Err(ObsError::InvalidObsSpec {
236 reason: format!("entry {i}: agent-relative region in Simple plan"),
237 });
238 }
239 };
240 if entry.pool.is_some() {
241 return Err(ObsError::InvalidObsSpec {
242 reason: format!(
243 "entry {i}: pooling requires a Standard plan (use agent-relative region)"
244 ),
245 });
246 }
247
248 let mut region_plan =
249 space
250 .compile_region(fixed_region)
251 .map_err(|e| ObsError::InvalidObsSpec {
252 reason: format!("entry {i}: region compile failed: {e}"),
253 })?;
254
255 let ratio = region_plan.valid_ratio();
256 if ratio < COVERAGE_ERROR_THRESHOLD {
257 return Err(ObsError::InvalidComposition {
258 reason: format!(
259 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
260 ),
261 });
262 }
263 if ratio < COVERAGE_WARN_THRESHOLD {
264 eprintln!(
265 "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
266 );
267 }
268
269 let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
270 for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
271 let field_data_idx =
272 *coord_to_field_idx
273 .get(coord)
274 .ok_or_else(|| ObsError::InvalidObsSpec {
275 reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
276 })?;
277 let tensor_idx = region_plan.tensor_indices()[coord_idx];
278 gather_ops.push(GatherOp {
279 field_data_idx,
280 tensor_idx,
281 });
282 }
283
284 let element_count = region_plan.bounding_shape().total_elements();
285 let shape = match region_plan.bounding_shape() {
286 murk_space::BoundingShape::Rect(dims) => dims.clone(),
287 };
288 entry_shapes.push(shape);
289
290 let valid_mask = region_plan.take_valid_mask();
291 let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
292 entries.push(CompiledEntry {
293 field_id: entry.field_id,
294 transform: entry.transform.clone(),
295 dtype: entry.dtype,
296 output_offset,
297 mask_offset,
298 element_count,
299 gather_ops,
300 valid_mask,
301 valid_count,
302 valid_ratio: ratio,
303 });
304 total_valid += valid_count;
305 total_elements += element_count;
306
307 output_offset += element_count;
308 mask_offset += element_count;
309 }
310
311 let plan = ObsPlan {
312 strategy: PlanStrategy::Simple(SimplePlanData {
313 entries,
314 total_valid,
315 total_elements,
316 }),
317 output_len: output_offset,
318 mask_len: mask_offset,
319 compiled_generation: None,
320 };
321
322 Ok(ObsPlanResult {
323 output_len: plan.output_len,
324 mask_len: plan.mask_len,
325 entry_shapes,
326 plan,
327 })
328 }
329
330 fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
335 let canonical = space.canonical_ordering();
336 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
337 .into_iter()
338 .enumerate()
339 .map(|(idx, coord)| (coord, idx))
340 .collect();
341
342 let geometry = GridGeometry::from_space(space);
343 let ndim = space.ndim();
344
345 let mut fixed_entries = Vec::new();
346 let mut agent_entries = Vec::new();
347 let mut output_offset = 0usize;
348 let mut mask_offset = 0usize;
349 let mut entry_shapes = Vec::new();
350
351 for (i, entry) in spec.entries.iter().enumerate() {
352 match &entry.region {
353 ObsRegion::Fixed(region_spec) => {
354 if entry.pool.is_some() {
355 return Err(ObsError::InvalidObsSpec {
356 reason: format!("entry {i}: pooling on Fixed regions not supported"),
357 });
358 }
359
360 let mut region_plan = space.compile_region(region_spec).map_err(|e| {
361 ObsError::InvalidObsSpec {
362 reason: format!("entry {i}: region compile failed: {e}"),
363 }
364 })?;
365
366 let ratio = region_plan.valid_ratio();
367 if ratio < COVERAGE_ERROR_THRESHOLD {
368 return Err(ObsError::InvalidComposition {
369 reason: format!(
370 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
371 ),
372 });
373 }
374
375 let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
376 for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
377 let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
378 ObsError::InvalidObsSpec {
379 reason: format!(
380 "entry {i}: coord {coord:?} not in canonical ordering"
381 ),
382 }
383 })?;
384 let tensor_idx = region_plan.tensor_indices()[coord_idx];
385 gather_ops.push(GatherOp {
386 field_data_idx,
387 tensor_idx,
388 });
389 }
390
391 let element_count = region_plan.bounding_shape().total_elements();
392 let shape = match region_plan.bounding_shape() {
393 murk_space::BoundingShape::Rect(dims) => dims.clone(),
394 };
395 entry_shapes.push(shape);
396
397 let valid_mask = region_plan.take_valid_mask();
398 let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
399 fixed_entries.push(CompiledEntry {
400 field_id: entry.field_id,
401 transform: entry.transform.clone(),
402 dtype: entry.dtype,
403 output_offset,
404 mask_offset,
405 element_count,
406 gather_ops,
407 valid_mask,
408 valid_count,
409 valid_ratio: ratio,
410 });
411
412 output_offset += element_count;
413 mask_offset += element_count;
414 }
415
416 ObsRegion::AgentDisk { radius } => {
417 let half_ext: smallvec::SmallVec<[u32; 4]> =
418 (0..ndim).map(|_| *radius).collect();
419 let (ae, shape) = Self::compile_agent_entry(
420 i,
421 entry,
422 &half_ext,
423 *radius,
424 &geometry,
425 Some(*radius),
426 output_offset,
427 mask_offset,
428 )?;
429 entry_shapes.push(shape);
430 output_offset += ae.element_count;
431 mask_offset += ae.element_count;
432 agent_entries.push(ae);
433 }
434
435 ObsRegion::AgentRect { half_extent } => {
436 if half_extent.len() != ndim {
437 return Err(ObsError::InvalidObsSpec {
438 reason: format!(
439 "entry {i}: AgentRect half_extent has {} dims, but space requires {ndim}",
440 half_extent.len()
441 ),
442 });
443 }
444 let radius = *half_extent.iter().max().unwrap_or(&0);
445 let (ae, shape) = Self::compile_agent_entry(
446 i,
447 entry,
448 half_extent,
449 radius,
450 &geometry,
451 None,
452 output_offset,
453 mask_offset,
454 )?;
455 entry_shapes.push(shape);
456 output_offset += ae.element_count;
457 mask_offset += ae.element_count;
458 agent_entries.push(ae);
459 }
460 }
461 }
462
463 let plan = ObsPlan {
464 strategy: PlanStrategy::Standard(StandardPlanData {
465 fixed_entries,
466 agent_entries,
467 geometry,
468 }),
469 output_len: output_offset,
470 mask_len: mask_offset,
471 compiled_generation: None,
472 };
473
474 Ok(ObsPlanResult {
475 output_len: plan.output_len,
476 mask_len: plan.mask_len,
477 entry_shapes,
478 plan,
479 })
480 }
481
482 #[allow(clippy::too_many_arguments)]
487 fn compile_agent_entry(
488 entry_idx: usize,
489 entry: &crate::spec::ObsEntry,
490 half_extent: &[u32],
491 radius: u32,
492 geometry: &Option<GridGeometry>,
493 disk_radius: Option<u32>,
494 output_offset: usize,
495 mask_offset: usize,
496 ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
497 let pre_pool_shape: Vec<usize> =
498 half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
499 let pre_pool_element_count: usize = pre_pool_shape.iter().product();
500
501 let template_ops = generate_template_ops(half_extent, geometry, disk_radius)?;
502 let active_ops = template_ops
503 .iter()
504 .filter(|op| op.in_disk)
505 .cloned()
506 .collect();
507
508 let (element_count, output_shape) = if let Some(pool) = &entry.pool {
509 if pre_pool_shape.len() != 2 {
510 return Err(ObsError::InvalidObsSpec {
511 reason: format!(
512 "entry {entry_idx}: pooling requires 2D region, got {}D",
513 pre_pool_shape.len()
514 ),
515 });
516 }
517 let h = pre_pool_shape[0];
518 let w = pre_pool_shape[1];
519 let ks = pool.kernel_size;
520 let stride = pool.stride;
521 if ks == 0 || stride == 0 {
522 return Err(ObsError::InvalidObsSpec {
523 reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
524 });
525 }
526 let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
527 let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
528 if out_h == 0 || out_w == 0 {
529 return Err(ObsError::InvalidObsSpec {
530 reason: format!(
531 "entry {entry_idx}: pool produces empty output \
532 (region [{h},{w}], kernel_size {ks}, stride {stride})"
533 ),
534 });
535 }
536 (out_h * out_w, vec![out_h, out_w])
537 } else {
538 (pre_pool_element_count, pre_pool_shape.clone())
539 };
540
541 Ok((
542 AgentCompiledEntry {
543 field_id: entry.field_id,
544 pool: entry.pool.clone(),
545 transform: entry.transform.clone(),
546 dtype: entry.dtype,
547 output_offset,
548 mask_offset,
549 element_count,
550 pre_pool_element_count,
551 pre_pool_shape,
552 active_ops,
553 radius,
554 },
555 output_shape,
556 ))
557 }
558
559 pub fn compile_bound(
564 spec: &ObsSpec,
565 space: &dyn Space,
566 generation: WorldGenerationId,
567 ) -> Result<ObsPlanResult, ObsError> {
568 let mut result = Self::compile(spec, space)?;
569 result.plan.compiled_generation = Some(generation);
570 Ok(result)
571 }
572
573 pub fn output_len(&self) -> usize {
575 self.output_len
576 }
577
578 pub fn mask_len(&self) -> usize {
580 self.mask_len
581 }
582
583 pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
585 self.compiled_generation
586 }
587
588 pub fn execute(
607 &self,
608 snapshot: &dyn SnapshotAccess,
609 engine_tick: Option<TickId>,
610 output: &mut [f32],
611 mask: &mut [u8],
612 ) -> Result<ObsMetadata, ObsError> {
613 let simple = match &self.strategy {
614 PlanStrategy::Simple(data) => data,
615 PlanStrategy::Standard(_) => {
616 return Err(ObsError::ExecutionFailed {
617 reason: "Standard plan requires execute_agents(), not execute()".into(),
618 });
619 }
620 };
621
622 if output.len() < self.output_len {
623 return Err(ObsError::ExecutionFailed {
624 reason: format!(
625 "output buffer too small: {} < {}",
626 output.len(),
627 self.output_len
628 ),
629 });
630 }
631 if mask.len() < self.mask_len {
632 return Err(ObsError::ExecutionFailed {
633 reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
634 });
635 }
636
637 if let Some(compiled_gen) = self.compiled_generation {
639 let snapshot_gen = snapshot.world_generation_id();
640 if compiled_gen != snapshot_gen {
641 return Err(ObsError::PlanInvalidated {
642 reason: format!(
643 "plan compiled for generation {}, snapshot is generation {}",
644 compiled_gen.0, snapshot_gen.0
645 ),
646 });
647 }
648 }
649
650 Self::execute_simple_entries(&simple.entries, snapshot, output, mask)?;
651
652 let coverage = if simple.total_elements == 0 {
653 0.0
654 } else {
655 simple.total_valid as f64 / simple.total_elements as f64
656 };
657
658 let age_ticks = match engine_tick {
659 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
660 None => 0,
661 };
662
663 Ok(ObsMetadata {
664 tick_id: snapshot.tick_id(),
665 age_ticks,
666 coverage,
667 world_generation_id: snapshot.world_generation_id(),
668 parameter_version: snapshot.parameter_version(),
669 })
670 }
671
672 pub fn execute_batch(
684 &self,
685 snapshots: &[&dyn SnapshotAccess],
686 engine_tick: Option<TickId>,
687 output: &mut [f32],
688 mask: &mut [u8],
689 ) -> Result<Vec<ObsMetadata>, ObsError> {
690 let simple = match &self.strategy {
691 PlanStrategy::Simple(data) => data,
692 PlanStrategy::Standard(_) => {
693 return Err(ObsError::ExecutionFailed {
694 reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
695 });
696 }
697 };
698
699 let batch_size = snapshots.len();
700 let expected_out = batch_size * self.output_len;
701 let expected_mask = batch_size * self.mask_len;
702
703 if output.len() < expected_out {
704 return Err(ObsError::ExecutionFailed {
705 reason: format!(
706 "batch output buffer too small: {} < {}",
707 output.len(),
708 expected_out
709 ),
710 });
711 }
712 if mask.len() < expected_mask {
713 return Err(ObsError::ExecutionFailed {
714 reason: format!(
715 "batch mask buffer too small: {} < {}",
716 mask.len(),
717 expected_mask
718 ),
719 });
720 }
721
722 let coverage = if simple.total_elements == 0 {
723 0.0
724 } else {
725 simple.total_valid as f64 / simple.total_elements as f64
726 };
727
728 let mut metadata = Vec::with_capacity(batch_size);
729 for (i, snap) in snapshots.iter().enumerate() {
730 if let Some(compiled_gen) = self.compiled_generation {
731 let snapshot_gen = snap.world_generation_id();
732 if compiled_gen != snapshot_gen {
733 return Err(ObsError::PlanInvalidated {
734 reason: format!(
735 "plan compiled for generation {}, snapshot is generation {}",
736 compiled_gen.0, snapshot_gen.0
737 ),
738 });
739 }
740 }
741
742 let out_start = i * self.output_len;
743 let mask_start = i * self.mask_len;
744 let out_slice = &mut output[out_start..out_start + self.output_len];
745 let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
746 Self::execute_simple_entries(&simple.entries, *snap, out_slice, mask_slice)?;
747
748 let age_ticks = match engine_tick {
749 Some(tick) => tick.0.saturating_sub(snap.tick_id().0),
750 None => 0,
751 };
752 metadata.push(ObsMetadata {
753 tick_id: snap.tick_id(),
754 age_ticks,
755 coverage,
756 world_generation_id: snap.world_generation_id(),
757 parameter_version: snap.parameter_version(),
758 });
759 }
760 Ok(metadata)
761 }
762
763 pub fn execute_agents(
774 &self,
775 snapshot: &dyn SnapshotAccess,
776 space: &dyn Space,
777 agent_centers: &[Coord],
778 engine_tick: Option<TickId>,
779 output: &mut [f32],
780 mask: &mut [u8],
781 ) -> Result<Vec<ObsMetadata>, ObsError> {
782 let standard = match &self.strategy {
783 PlanStrategy::Standard(data) => data,
784 PlanStrategy::Simple(_) => {
785 return Err(ObsError::ExecutionFailed {
786 reason: "execute_agents requires a Standard plan \
787 (spec must contain agent-relative entries)"
788 .into(),
789 });
790 }
791 };
792
793 let n_agents = agent_centers.len();
794 let expected_out = n_agents * self.output_len;
795 let expected_mask = n_agents * self.mask_len;
796
797 if output.len() < expected_out {
798 return Err(ObsError::ExecutionFailed {
799 reason: format!(
800 "output buffer too small: {} < {}",
801 output.len(),
802 expected_out
803 ),
804 });
805 }
806 if mask.len() < expected_mask {
807 return Err(ObsError::ExecutionFailed {
808 reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
809 });
810 }
811
812 let expected_ndim = space.ndim();
814 for (i, center) in agent_centers.iter().enumerate() {
815 if center.len() != expected_ndim {
816 return Err(ObsError::ExecutionFailed {
817 reason: format!(
818 "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
819 center.len()
820 ),
821 });
822 }
823 }
824
825 if let Some(compiled_gen) = self.compiled_generation {
827 let snapshot_gen = snapshot.world_generation_id();
828 if compiled_gen != snapshot_gen {
829 return Err(ObsError::PlanInvalidated {
830 reason: format!(
831 "plan compiled for generation {}, snapshot is generation {}",
832 compiled_gen.0, snapshot_gen.0
833 ),
834 });
835 }
836 }
837
838 let mut fixed_field_data = Vec::with_capacity(standard.fixed_entries.len());
842 for entry in &standard.fixed_entries {
843 let data =
844 snapshot
845 .read_field(entry.field_id)
846 .ok_or_else(|| ObsError::ExecutionFailed {
847 reason: format!("field {:?} not in snapshot", entry.field_id),
848 })?;
849 fixed_field_data.push(data);
850 }
851 let mut agent_field_data = Vec::with_capacity(standard.agent_entries.len());
852 for entry in &standard.agent_entries {
853 let data =
854 snapshot
855 .read_field(entry.field_id)
856 .ok_or_else(|| ObsError::ExecutionFailed {
857 reason: format!("field {:?} not in snapshot", entry.field_id),
858 })?;
859 agent_field_data.push(data);
860 }
861
862 let has_fixed = !standard.fixed_entries.is_empty();
866 let mut fixed_out_scratch = if has_fixed {
867 vec![0.0f32; self.output_len]
868 } else {
869 Vec::new()
870 };
871 let mut fixed_mask_scratch = if has_fixed {
872 vec![0u8; self.mask_len]
873 } else {
874 Vec::new()
875 };
876 let mut fixed_valid = 0usize;
877 let mut fixed_elements = 0usize;
878
879 for (entry, field_data) in standard
880 .fixed_entries
881 .iter()
882 .zip(fixed_field_data.iter().copied())
883 {
884 let out_slice = &mut fixed_out_scratch
885 [entry.output_offset..entry.output_offset + entry.element_count];
886 let mask_slice =
887 &mut fixed_mask_scratch[entry.mask_offset..entry.mask_offset + entry.element_count];
888
889 mask_slice.copy_from_slice(&entry.valid_mask);
890 for op in &entry.gather_ops {
891 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
892 ObsError::ExecutionFailed {
893 reason: format!(
894 "field {:?} has {} elements but gather requires index {}",
895 entry.field_id,
896 field_data.len(),
897 op.field_data_idx,
898 ),
899 }
900 })?;
901 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
902 }
903
904 fixed_valid += entry.valid_count;
905 fixed_elements += entry.element_count;
906 }
907
908 let max_pool_scratch = standard
912 .agent_entries
913 .iter()
914 .filter(|e| e.pool.is_some())
915 .map(|e| e.pre_pool_element_count)
916 .max()
917 .unwrap_or(0);
918 let max_pool_output = standard
919 .agent_entries
920 .iter()
921 .filter(|e| e.pool.is_some())
922 .map(|e| e.element_count)
923 .max()
924 .unwrap_or(0);
925 let mut pool_scratch = vec![0.0f32; max_pool_scratch];
926 let mut pool_scratch_mask = vec![0u8; max_pool_scratch];
927 let mut pooled_scratch = vec![0.0f32; max_pool_output];
928 let mut pooled_scratch_mask = vec![0u8; max_pool_output];
929
930 let mut metadata = Vec::with_capacity(n_agents);
931
932 for (agent_i, center) in agent_centers.iter().enumerate() {
933 let out_start = agent_i * self.output_len;
934 let mask_start = agent_i * self.mask_len;
935 let agent_output = &mut output[out_start..out_start + self.output_len];
936 let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
937
938 agent_output.fill(0.0);
941 agent_mask.fill(0);
942 if has_fixed {
943 for entry in &standard.fixed_entries {
944 let out_range = entry.output_offset..entry.output_offset + entry.element_count;
945 let mask_range = entry.mask_offset..entry.mask_offset + entry.element_count;
946 agent_output[out_range.clone()].copy_from_slice(&fixed_out_scratch[out_range]);
947 agent_mask[mask_range.clone()].copy_from_slice(&fixed_mask_scratch[mask_range]);
948 }
949 }
950
951 let mut total_valid = fixed_valid;
952 let mut total_elements = fixed_elements;
953
954 for (entry, field_data) in standard
956 .agent_entries
957 .iter()
958 .zip(agent_field_data.iter().copied())
959 {
960 let use_fast_path = standard
964 .geometry
965 .as_ref()
966 .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
967 .unwrap_or(false);
968
969 if entry.pool.is_some() {
971 pool_scratch[..entry.pre_pool_element_count].fill(0.0);
972 pool_scratch_mask[..entry.pre_pool_element_count].fill(0);
973 }
974
975 let valid = execute_agent_entry(
976 entry,
977 center,
978 field_data,
979 &standard.geometry,
980 space,
981 use_fast_path,
982 agent_output,
983 agent_mask,
984 &mut pool_scratch,
985 &mut pool_scratch_mask,
986 &mut pooled_scratch,
987 &mut pooled_scratch_mask,
988 )?;
989
990 total_valid += valid;
991 total_elements += entry.element_count;
992 }
993
994 let coverage = if total_elements == 0 {
995 0.0
996 } else {
997 total_valid as f64 / total_elements as f64
998 };
999
1000 let age_ticks = match engine_tick {
1001 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
1002 None => 0,
1003 };
1004
1005 metadata.push(ObsMetadata {
1006 tick_id: snapshot.tick_id(),
1007 age_ticks,
1008 coverage,
1009 world_generation_id: snapshot.world_generation_id(),
1010 parameter_version: snapshot.parameter_version(),
1011 });
1012 }
1013
1014 Ok(metadata)
1015 }
1016
1017 pub fn is_standard(&self) -> bool {
1019 matches!(self.strategy, PlanStrategy::Standard(_))
1020 }
1021
1022 fn execute_simple_entries(
1024 entries: &[CompiledEntry],
1025 snapshot: &dyn SnapshotAccess,
1026 output: &mut [f32],
1027 mask: &mut [u8],
1028 ) -> Result<(), ObsError> {
1029 for entry in entries {
1030 let field_data =
1031 snapshot
1032 .read_field(entry.field_id)
1033 .ok_or_else(|| ObsError::ExecutionFailed {
1034 reason: format!("field {:?} not in snapshot", entry.field_id),
1035 })?;
1036
1037 let out_slice =
1038 &mut output[entry.output_offset..entry.output_offset + entry.element_count];
1039 let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1040
1041 out_slice.fill(0.0);
1043 mask_slice.copy_from_slice(&entry.valid_mask);
1044
1045 for op in &entry.gather_ops {
1047 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
1048 ObsError::ExecutionFailed {
1049 reason: format!(
1050 "field {:?} has {} elements but gather requires index {}",
1051 entry.field_id,
1052 field_data.len(),
1053 op.field_data_idx,
1054 ),
1055 }
1056 })?;
1057 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
1058 }
1059 }
1060 Ok(())
1061 }
1062}
1063
1064#[allow(clippy::too_many_arguments)]
1072fn execute_agent_entry(
1073 entry: &AgentCompiledEntry,
1074 center: &Coord,
1075 field_data: &[f32],
1076 geometry: &Option<GridGeometry>,
1077 space: &dyn Space,
1078 use_fast_path: bool,
1079 agent_output: &mut [f32],
1080 agent_mask: &mut [u8],
1081 pool_scratch: &mut [f32],
1082 pool_scratch_mask: &mut [u8],
1083 pooled_scratch: &mut [f32],
1084 pooled_scratch_mask: &mut [u8],
1085) -> Result<usize, ObsError> {
1086 if entry.pool.is_some() {
1087 execute_agent_entry_pooled(
1088 entry,
1089 center,
1090 field_data,
1091 geometry,
1092 space,
1093 use_fast_path,
1094 agent_output,
1095 agent_mask,
1096 &mut pool_scratch[..entry.pre_pool_element_count],
1097 &mut pool_scratch_mask[..entry.pre_pool_element_count],
1098 &mut pooled_scratch[..entry.element_count],
1099 &mut pooled_scratch_mask[..entry.element_count],
1100 )
1101 } else {
1102 Ok(execute_agent_entry_direct(
1103 entry,
1104 center,
1105 field_data,
1106 geometry,
1107 space,
1108 use_fast_path,
1109 agent_output,
1110 agent_mask,
1111 ))
1112 }
1113}
1114
1115#[allow(clippy::too_many_arguments)]
1117fn execute_agent_entry_direct(
1118 entry: &AgentCompiledEntry,
1119 center: &Coord,
1120 field_data: &[f32],
1121 geometry: &Option<GridGeometry>,
1122 space: &dyn Space,
1123 use_fast_path: bool,
1124 agent_output: &mut [f32],
1125 agent_mask: &mut [u8],
1126) -> usize {
1127 let out_slice =
1128 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1129 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1130
1131 if use_fast_path {
1132 let geo = geometry.as_ref().unwrap();
1134 let base_rank = geo.canonical_rank(center) as isize;
1135 let mut valid = 0;
1136 for op in &entry.active_ops {
1137 let field_idx = (base_rank + op.stride_offset) as usize;
1138 if let Some(&val) = field_data.get(field_idx) {
1139 out_slice[op.tensor_idx] = apply_transform(val, &entry.transform);
1140 mask_slice[op.tensor_idx] = 1;
1141 valid += 1;
1142 }
1143 }
1144 valid
1145 } else {
1146 let mut valid = 0;
1148 for op in &entry.active_ops {
1149 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1150 if let Some(idx) = field_idx {
1151 if idx < field_data.len() {
1152 out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
1153 mask_slice[op.tensor_idx] = 1;
1154 valid += 1;
1155 }
1156 }
1157 }
1158 valid
1159 }
1160}
1161
1162#[allow(clippy::too_many_arguments)]
1169fn execute_agent_entry_pooled(
1170 entry: &AgentCompiledEntry,
1171 center: &Coord,
1172 field_data: &[f32],
1173 geometry: &Option<GridGeometry>,
1174 space: &dyn Space,
1175 use_fast_path: bool,
1176 agent_output: &mut [f32],
1177 agent_mask: &mut [u8],
1178 scratch: &mut [f32],
1179 scratch_mask: &mut [u8],
1180 pooled: &mut [f32],
1181 pooled_mask: &mut [u8],
1182) -> Result<usize, ObsError> {
1183 if use_fast_path {
1184 let geo = geometry.as_ref().unwrap();
1185 let base_rank = geo.canonical_rank(center) as isize;
1186 for op in &entry.active_ops {
1187 let field_idx = (base_rank + op.stride_offset) as usize;
1188 if let Some(&val) = field_data.get(field_idx) {
1189 scratch[op.tensor_idx] = val;
1190 scratch_mask[op.tensor_idx] = 1;
1191 }
1192 }
1193 } else {
1194 for op in &entry.active_ops {
1195 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1196 if let Some(idx) = field_idx {
1197 if idx < field_data.len() {
1198 scratch[op.tensor_idx] = field_data[idx];
1199 scratch_mask[op.tensor_idx] = 1;
1200 }
1201 }
1202 }
1203 }
1204
1205 let pool_config = entry.pool.as_ref().unwrap();
1206 let (out_h, out_w) = pool_2d_into(
1207 scratch,
1208 scratch_mask,
1209 &entry.pre_pool_shape,
1210 pool_config,
1211 pooled,
1212 pooled_mask,
1213 )?;
1214
1215 let out_slice =
1216 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1217 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1218
1219 let n = (out_h * out_w).min(entry.element_count);
1220 for i in 0..n {
1221 out_slice[i] = apply_transform(pooled[i], &entry.transform);
1222 }
1223 mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1224
1225 Ok(pooled_mask[..n].iter().filter(|&&v| v == 1).count())
1226}
1227
1228fn generate_template_ops(
1240 half_extent: &[u32],
1241 geometry: &Option<GridGeometry>,
1242 disk_radius: Option<u32>,
1243) -> Result<Vec<TemplateOp>, ObsError> {
1244 let ndim = half_extent.len();
1245 let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1246 let total: usize = shape.iter().product();
1247
1248 let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1249
1250 let mut ops = Vec::with_capacity(total);
1251
1252 for tensor_idx in 0..total {
1253 let mut relative = Coord::new();
1255 let mut remaining = tensor_idx;
1256 for d in (0..ndim).rev() {
1258 let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1259 relative.push(coord);
1260 remaining /= shape[d];
1261 }
1262 relative.reverse();
1263
1264 let stride_offset = strides
1265 .map(|s| {
1266 relative
1267 .iter()
1268 .zip(s.iter())
1269 .map(|(&r, &s)| r as isize * s as isize)
1270 .sum::<isize>()
1271 })
1272 .unwrap_or(0);
1273
1274 let in_disk = match disk_radius {
1275 Some(r) => match geometry {
1276 Some(geo) => geo.graph_distance(&relative)? <= r,
1277 None => true, },
1279 None => true, };
1281
1282 ops.push(TemplateOp {
1283 relative,
1284 tensor_idx,
1285 stride_offset,
1286 in_disk,
1287 });
1288 }
1289
1290 Ok(ops)
1291}
1292
1293fn resolve_field_index(
1300 center: &Coord,
1301 relative: &Coord,
1302 geometry: &Option<GridGeometry>,
1303 space: &dyn Space,
1304) -> Option<usize> {
1305 if let Some(geo) = geometry {
1306 if geo.all_wrap {
1307 let wrapped: Coord = center
1309 .iter()
1310 .zip(relative.iter())
1311 .zip(geo.coord_dims.iter())
1312 .map(|((&c, &r), &d)| {
1313 let d = d as i32;
1314 ((c + r) % d + d) % d
1315 })
1316 .collect();
1317 Some(geo.canonical_rank(&wrapped))
1318 } else {
1319 let abs_coord: Coord = center
1320 .iter()
1321 .zip(relative.iter())
1322 .map(|(&c, &r)| c + r)
1323 .collect();
1324 let abs_slice: &[i32] = &abs_coord;
1325 if geo.in_bounds(abs_slice) {
1326 Some(geo.canonical_rank(abs_slice))
1327 } else {
1328 None
1329 }
1330 }
1331 } else {
1332 let abs_coord: Coord = center
1333 .iter()
1334 .zip(relative.iter())
1335 .map(|(&c, &r)| c + r)
1336 .collect();
1337 space.canonical_rank(&abs_coord)
1338 }
1339}
1340
1341fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1343 match transform {
1344 ObsTransform::Identity => raw,
1345 ObsTransform::Normalize { min, max } => {
1346 let range = max - min;
1347 if range == 0.0 {
1348 0.0
1349 } else {
1350 let normalized = (raw as f64 - min) / range;
1351 normalized.clamp(0.0, 1.0) as f32
1352 }
1353 }
1354 }
1355}
1356
1357#[cfg(test)]
1358mod tests {
1359 use super::*;
1360 use crate::spec::{
1361 ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1362 };
1363 use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1364 use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1365 use murk_test_utils::MockSnapshot;
1366
1367 fn square4_space() -> Square4 {
1368 Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1369 }
1370
1371 fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1372 let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1373 snap.set_field(field, data);
1374 snap
1375 }
1376
1377 #[test]
1380 fn compile_empty_spec_errors() {
1381 let space = square4_space();
1382 let spec = ObsSpec { entries: vec![] };
1383 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1384 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1385 }
1386
1387 #[test]
1388 fn compile_all_region_square4() {
1389 let space = square4_space();
1390 let spec = ObsSpec {
1391 entries: vec![ObsEntry {
1392 field_id: FieldId(0),
1393 region: ObsRegion::Fixed(RegionSpec::All),
1394 pool: None,
1395 transform: ObsTransform::Identity,
1396 dtype: ObsDtype::F32,
1397 }],
1398 };
1399 let result = ObsPlan::compile(&spec, &space).unwrap();
1400 assert_eq!(result.output_len, 9); assert_eq!(result.mask_len, 9);
1402 assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1403 }
1404
1405 #[test]
1406 fn compile_rect_region() {
1407 let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1408 let spec = ObsSpec {
1409 entries: vec![ObsEntry {
1410 field_id: FieldId(0),
1411 region: ObsRegion::Fixed(RegionSpec::Rect {
1412 min: smallvec::smallvec![1, 1],
1413 max: smallvec::smallvec![2, 3],
1414 }),
1415 pool: None,
1416 transform: ObsTransform::Identity,
1417 dtype: ObsDtype::F32,
1418 }],
1419 };
1420 let result = ObsPlan::compile(&spec, &space).unwrap();
1421 assert_eq!(result.output_len, 6);
1423 assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1424 }
1425
1426 #[test]
1427 fn compile_two_entries_offsets() {
1428 let space = square4_space();
1429 let spec = ObsSpec {
1430 entries: vec![
1431 ObsEntry {
1432 field_id: FieldId(0),
1433 region: ObsRegion::Fixed(RegionSpec::All),
1434 pool: None,
1435 transform: ObsTransform::Identity,
1436 dtype: ObsDtype::F32,
1437 },
1438 ObsEntry {
1439 field_id: FieldId(1),
1440 region: ObsRegion::Fixed(RegionSpec::All),
1441 pool: None,
1442 transform: ObsTransform::Identity,
1443 dtype: ObsDtype::F32,
1444 },
1445 ],
1446 };
1447 let result = ObsPlan::compile(&spec, &space).unwrap();
1448 assert_eq!(result.output_len, 18); assert_eq!(result.mask_len, 18);
1450 }
1451
1452 #[test]
1453 fn compile_invalid_region_errors() {
1454 let space = square4_space();
1455 let spec = ObsSpec {
1456 entries: vec![ObsEntry {
1457 field_id: FieldId(0),
1458 region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1459 pool: None,
1460 transform: ObsTransform::Identity,
1461 dtype: ObsDtype::F32,
1462 }],
1463 };
1464 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1465 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1466 }
1467
1468 #[test]
1469 fn compile_agent_rect_wrong_ndim_errors() {
1470 let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap(); let spec = ObsSpec {
1472 entries: vec![ObsEntry {
1473 field_id: FieldId(0),
1474 region: ObsRegion::AgentRect {
1475 half_extent: smallvec::smallvec![1], },
1477 pool: None,
1478 transform: ObsTransform::Identity,
1479 dtype: ObsDtype::F32,
1480 }],
1481 };
1482 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1483 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1484 }
1485
1486 #[test]
1487 fn compile_agent_rect_correct_ndim_ok() {
1488 let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap(); let spec = ObsSpec {
1490 entries: vec![ObsEntry {
1491 field_id: FieldId(0),
1492 region: ObsRegion::AgentRect {
1493 half_extent: smallvec::smallvec![1, 2], },
1495 pool: None,
1496 transform: ObsTransform::Identity,
1497 dtype: ObsDtype::F32,
1498 }],
1499 };
1500 assert!(ObsPlan::compile(&spec, &space).is_ok());
1501 }
1502
1503 #[test]
1506 fn execute_identity_all_region() {
1507 let space = square4_space();
1508 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1511 let snap = snapshot_with_field(FieldId(0), data);
1512
1513 let spec = ObsSpec {
1514 entries: vec![ObsEntry {
1515 field_id: FieldId(0),
1516 region: ObsRegion::Fixed(RegionSpec::All),
1517 pool: None,
1518 transform: ObsTransform::Identity,
1519 dtype: ObsDtype::F32,
1520 }],
1521 };
1522 let result = ObsPlan::compile(&spec, &space).unwrap();
1523
1524 let mut output = vec![0.0f32; result.output_len];
1525 let mut mask = vec![0u8; result.mask_len];
1526 let meta = result
1527 .plan
1528 .execute(&snap, None, &mut output, &mut mask)
1529 .unwrap();
1530
1531 let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1533 assert_eq!(output, expected);
1534 assert_eq!(mask, vec![1u8; 9]);
1535 assert_eq!(meta.tick_id, TickId(5));
1536 assert_eq!(meta.coverage, 1.0);
1537 assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1538 assert_eq!(meta.parameter_version, ParameterVersion(0));
1539 assert_eq!(meta.age_ticks, 0);
1540 }
1541
1542 #[test]
1543 fn execute_normalize_transform() {
1544 let space = square4_space();
1545 let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1547 let snap = snapshot_with_field(FieldId(0), data);
1548
1549 let spec = ObsSpec {
1550 entries: vec![ObsEntry {
1551 field_id: FieldId(0),
1552 region: ObsRegion::Fixed(RegionSpec::All),
1553 pool: None,
1554 transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1555 dtype: ObsDtype::F32,
1556 }],
1557 };
1558 let result = ObsPlan::compile(&spec, &space).unwrap();
1559
1560 let mut output = vec![0.0f32; result.output_len];
1561 let mut mask = vec![0u8; result.mask_len];
1562 result
1563 .plan
1564 .execute(&snap, None, &mut output, &mut mask)
1565 .unwrap();
1566
1567 for (i, &v) in output.iter().enumerate() {
1569 let expected = i as f32 / 8.0;
1570 assert!(
1571 (v - expected).abs() < 1e-6,
1572 "output[{i}] = {v}, expected {expected}"
1573 );
1574 }
1575 }
1576
1577 #[test]
1578 fn execute_normalize_clamps_out_of_range() {
1579 let space = square4_space();
1580 let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1582 let snap = snapshot_with_field(FieldId(0), data);
1583
1584 let spec = ObsSpec {
1585 entries: vec![ObsEntry {
1586 field_id: FieldId(0),
1587 region: ObsRegion::Fixed(RegionSpec::All),
1588 pool: None,
1589 transform: ObsTransform::Normalize {
1590 min: 0.0,
1591 max: 10.0,
1592 },
1593 dtype: ObsDtype::F32,
1594 }],
1595 };
1596 let result = ObsPlan::compile(&spec, &space).unwrap();
1597
1598 let mut output = vec![0.0f32; result.output_len];
1599 let mut mask = vec![0u8; result.mask_len];
1600 result
1601 .plan
1602 .execute(&snap, None, &mut output, &mut mask)
1603 .unwrap();
1604
1605 for &v in &output {
1606 assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1607 }
1608 }
1609
1610 #[test]
1611 fn execute_normalize_zero_range() {
1612 let space = square4_space();
1613 let data = vec![5.0f32; 9];
1614 let snap = snapshot_with_field(FieldId(0), data);
1615
1616 let spec = ObsSpec {
1617 entries: vec![ObsEntry {
1618 field_id: FieldId(0),
1619 region: ObsRegion::Fixed(RegionSpec::All),
1620 pool: None,
1621 transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1622 dtype: ObsDtype::F32,
1623 }],
1624 };
1625 let result = ObsPlan::compile(&spec, &space).unwrap();
1626
1627 let mut output = vec![-1.0f32; result.output_len];
1628 let mut mask = vec![0u8; result.mask_len];
1629 result
1630 .plan
1631 .execute(&snap, None, &mut output, &mut mask)
1632 .unwrap();
1633
1634 assert!(output.iter().all(|&v| v == 0.0));
1636 }
1637
1638 #[test]
1639 fn execute_rect_subregion_correct_values() {
1640 let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1641 let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1643 let snap = snapshot_with_field(FieldId(0), data);
1644
1645 let spec = ObsSpec {
1646 entries: vec![ObsEntry {
1647 field_id: FieldId(0),
1648 region: ObsRegion::Fixed(RegionSpec::Rect {
1649 min: smallvec::smallvec![1, 1],
1650 max: smallvec::smallvec![2, 2],
1651 }),
1652 pool: None,
1653 transform: ObsTransform::Identity,
1654 dtype: ObsDtype::F32,
1655 }],
1656 };
1657 let result = ObsPlan::compile(&spec, &space).unwrap();
1658 assert_eq!(result.output_len, 4); let mut output = vec![0.0f32; result.output_len];
1661 let mut mask = vec![0u8; result.mask_len];
1662 result
1663 .plan
1664 .execute(&snap, None, &mut output, &mut mask)
1665 .unwrap();
1666
1667 assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1669 assert_eq!(mask, vec![1, 1, 1, 1]);
1670 }
1671
1672 #[test]
1673 fn execute_two_fields() {
1674 let space = square4_space();
1675 let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1676 let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1677 let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1678 snap.set_field(FieldId(0), data_a);
1679 snap.set_field(FieldId(1), data_b);
1680
1681 let spec = ObsSpec {
1682 entries: vec![
1683 ObsEntry {
1684 field_id: FieldId(0),
1685 region: ObsRegion::Fixed(RegionSpec::All),
1686 pool: None,
1687 transform: ObsTransform::Identity,
1688 dtype: ObsDtype::F32,
1689 },
1690 ObsEntry {
1691 field_id: FieldId(1),
1692 region: ObsRegion::Fixed(RegionSpec::All),
1693 pool: None,
1694 transform: ObsTransform::Identity,
1695 dtype: ObsDtype::F32,
1696 },
1697 ],
1698 };
1699 let result = ObsPlan::compile(&spec, &space).unwrap();
1700 assert_eq!(result.output_len, 18);
1701
1702 let mut output = vec![0.0f32; result.output_len];
1703 let mut mask = vec![0u8; result.mask_len];
1704 result
1705 .plan
1706 .execute(&snap, None, &mut output, &mut mask)
1707 .unwrap();
1708
1709 let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1711 let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1712 assert_eq!(&output[..9], &expected_a);
1713 assert_eq!(&output[9..], &expected_b);
1714 }
1715
1716 #[test]
1717 fn execute_missing_field_errors() {
1718 let space = square4_space();
1719 let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1720
1721 let spec = ObsSpec {
1722 entries: vec![ObsEntry {
1723 field_id: FieldId(0),
1724 region: ObsRegion::Fixed(RegionSpec::All),
1725 pool: None,
1726 transform: ObsTransform::Identity,
1727 dtype: ObsDtype::F32,
1728 }],
1729 };
1730 let result = ObsPlan::compile(&spec, &space).unwrap();
1731
1732 let mut output = vec![0.0f32; result.output_len];
1733 let mut mask = vec![0u8; result.mask_len];
1734 let err = result
1735 .plan
1736 .execute(&snap, None, &mut output, &mut mask)
1737 .unwrap_err();
1738 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1739 }
1740
1741 #[test]
1742 fn execute_buffer_too_small_errors() {
1743 let space = square4_space();
1744 let data: Vec<f32> = vec![0.0; 9];
1745 let snap = snapshot_with_field(FieldId(0), data);
1746
1747 let spec = ObsSpec {
1748 entries: vec![ObsEntry {
1749 field_id: FieldId(0),
1750 region: ObsRegion::Fixed(RegionSpec::All),
1751 pool: None,
1752 transform: ObsTransform::Identity,
1753 dtype: ObsDtype::F32,
1754 }],
1755 };
1756 let result = ObsPlan::compile(&spec, &space).unwrap();
1757
1758 let mut output = vec![0.0f32; 4]; let mut mask = vec![0u8; result.mask_len];
1760 let err = result
1761 .plan
1762 .execute(&snap, None, &mut output, &mut mask)
1763 .unwrap_err();
1764 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1765 }
1766
1767 #[test]
1770 fn valid_ratio_one_for_square_all() {
1771 let space = square4_space();
1772 let data: Vec<f32> = vec![1.0; 9];
1773 let snap = snapshot_with_field(FieldId(0), data);
1774
1775 let spec = ObsSpec {
1776 entries: vec![ObsEntry {
1777 field_id: FieldId(0),
1778 region: ObsRegion::Fixed(RegionSpec::All),
1779 pool: None,
1780 transform: ObsTransform::Identity,
1781 dtype: ObsDtype::F32,
1782 }],
1783 };
1784 let result = ObsPlan::compile(&spec, &space).unwrap();
1785
1786 let mut output = vec![0.0f32; result.output_len];
1787 let mut mask = vec![0u8; result.mask_len];
1788 let meta = result
1789 .plan
1790 .execute(&snap, None, &mut output, &mut mask)
1791 .unwrap();
1792
1793 assert_eq!(meta.coverage, 1.0);
1794 }
1795
1796 #[test]
1799 fn plan_invalidated_on_generation_mismatch() {
1800 let space = square4_space();
1801 let data: Vec<f32> = vec![1.0; 9];
1802 let snap = snapshot_with_field(FieldId(0), data);
1803
1804 let spec = ObsSpec {
1805 entries: vec![ObsEntry {
1806 field_id: FieldId(0),
1807 region: ObsRegion::Fixed(RegionSpec::All),
1808 pool: None,
1809 transform: ObsTransform::Identity,
1810 dtype: ObsDtype::F32,
1811 }],
1812 };
1813 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1815
1816 let mut output = vec![0.0f32; result.output_len];
1817 let mut mask = vec![0u8; result.mask_len];
1818 let err = result
1819 .plan
1820 .execute(&snap, None, &mut output, &mut mask)
1821 .unwrap_err();
1822 assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1823 }
1824
1825 #[test]
1826 fn generation_match_succeeds() {
1827 let space = square4_space();
1828 let data: Vec<f32> = vec![1.0; 9];
1829 let snap = snapshot_with_field(FieldId(0), data);
1830
1831 let spec = ObsSpec {
1832 entries: vec![ObsEntry {
1833 field_id: FieldId(0),
1834 region: ObsRegion::Fixed(RegionSpec::All),
1835 pool: None,
1836 transform: ObsTransform::Identity,
1837 dtype: ObsDtype::F32,
1838 }],
1839 };
1840 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1841
1842 let mut output = vec![0.0f32; result.output_len];
1843 let mut mask = vec![0u8; result.mask_len];
1844 result
1845 .plan
1846 .execute(&snap, None, &mut output, &mut mask)
1847 .unwrap();
1848 }
1849
1850 #[test]
1851 fn unbound_plan_ignores_generation() {
1852 let space = square4_space();
1853 let data: Vec<f32> = vec![1.0; 9];
1854 let snap = snapshot_with_field(FieldId(0), data);
1855
1856 let spec = ObsSpec {
1857 entries: vec![ObsEntry {
1858 field_id: FieldId(0),
1859 region: ObsRegion::Fixed(RegionSpec::All),
1860 pool: None,
1861 transform: ObsTransform::Identity,
1862 dtype: ObsDtype::F32,
1863 }],
1864 };
1865 let result = ObsPlan::compile(&spec, &space).unwrap();
1867
1868 let mut output = vec![0.0f32; result.output_len];
1869 let mut mask = vec![0u8; result.mask_len];
1870 result
1871 .plan
1872 .execute(&snap, None, &mut output, &mut mask)
1873 .unwrap();
1874 }
1875
1876 #[test]
1879 fn metadata_fields_populated() {
1880 let space = square4_space();
1881 let data: Vec<f32> = vec![1.0; 9];
1882 let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1883 snap.set_field(FieldId(0), data);
1884
1885 let spec = ObsSpec {
1886 entries: vec![ObsEntry {
1887 field_id: FieldId(0),
1888 region: ObsRegion::Fixed(RegionSpec::All),
1889 pool: None,
1890 transform: ObsTransform::Identity,
1891 dtype: ObsDtype::F32,
1892 }],
1893 };
1894 let result = ObsPlan::compile(&spec, &space).unwrap();
1895
1896 let mut output = vec![0.0f32; result.output_len];
1897 let mut mask = vec![0u8; result.mask_len];
1898 let meta = result
1899 .plan
1900 .execute(&snap, None, &mut output, &mut mask)
1901 .unwrap();
1902
1903 assert_eq!(meta.tick_id, TickId(42));
1904 assert_eq!(meta.age_ticks, 0);
1905 assert_eq!(meta.coverage, 1.0);
1906 assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1907 assert_eq!(meta.parameter_version, ParameterVersion(3));
1908 }
1909
1910 #[test]
1913 fn execute_batch_n1_matches_execute() {
1914 let space = square4_space();
1915 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1916 let snap = snapshot_with_field(FieldId(0), data.clone());
1917
1918 let spec = ObsSpec {
1919 entries: vec![ObsEntry {
1920 field_id: FieldId(0),
1921 region: ObsRegion::Fixed(RegionSpec::All),
1922 pool: None,
1923 transform: ObsTransform::Identity,
1924 dtype: ObsDtype::F32,
1925 }],
1926 };
1927 let result = ObsPlan::compile(&spec, &space).unwrap();
1928
1929 let mut out_single = vec![0.0f32; result.output_len];
1931 let mut mask_single = vec![0u8; result.mask_len];
1932 let meta_single = result
1933 .plan
1934 .execute(&snap, None, &mut out_single, &mut mask_single)
1935 .unwrap();
1936
1937 let mut out_batch = vec![0.0f32; result.output_len];
1939 let mut mask_batch = vec![0u8; result.mask_len];
1940 let snap_ref: &dyn SnapshotAccess = &snap;
1941 let meta_batch = result
1942 .plan
1943 .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1944 .unwrap();
1945
1946 assert_eq!(out_single, out_batch);
1947 assert_eq!(mask_single, mask_batch);
1948 assert_eq!(meta_single, meta_batch[0]);
1949 }
1950
1951 #[test]
1952 fn execute_batch_multiple_snapshots() {
1953 let space = square4_space();
1954 let spec = ObsSpec {
1955 entries: vec![ObsEntry {
1956 field_id: FieldId(0),
1957 region: ObsRegion::Fixed(RegionSpec::All),
1958 pool: None,
1959 transform: ObsTransform::Identity,
1960 dtype: ObsDtype::F32,
1961 }],
1962 };
1963 let result = ObsPlan::compile(&spec, &space).unwrap();
1964
1965 let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1966 let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1967
1968 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1969 let mut output = vec![0.0f32; result.output_len * 2];
1970 let mut mask = vec![0u8; result.mask_len * 2];
1971 let metas = result
1972 .plan
1973 .execute_batch(&snaps, None, &mut output, &mut mask)
1974 .unwrap();
1975
1976 assert_eq!(metas.len(), 2);
1977 assert!(output[..9].iter().all(|&v| v == 1.0));
1978 assert!(output[9..].iter().all(|&v| v == 2.0));
1979 }
1980
1981 #[test]
1982 fn execute_batch_buffer_too_small() {
1983 let space = square4_space();
1984 let spec = ObsSpec {
1985 entries: vec![ObsEntry {
1986 field_id: FieldId(0),
1987 region: ObsRegion::Fixed(RegionSpec::All),
1988 pool: None,
1989 transform: ObsTransform::Identity,
1990 dtype: ObsDtype::F32,
1991 }],
1992 };
1993 let result = ObsPlan::compile(&spec, &space).unwrap();
1994
1995 let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1996 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1997 let mut output = vec![0.0f32; 9]; let mut mask = vec![0u8; 18];
1999 let err = result
2000 .plan
2001 .execute_batch(&snaps, None, &mut output, &mut mask)
2002 .unwrap_err();
2003 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2004 }
2005
2006 #[test]
2007 fn execute_batch_rejects_mismatched_generation() {
2008 let space = square4_space();
2009 let spec = ObsSpec {
2010 entries: vec![ObsEntry {
2011 field_id: FieldId(0),
2012 region: ObsRegion::Fixed(RegionSpec::All),
2013 pool: None,
2014 transform: ObsTransform::Identity,
2015 dtype: ObsDtype::F32,
2016 }],
2017 };
2018 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(5)).unwrap();
2020
2021 let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
2023 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap];
2024 let mut output = vec![0.0f32; result.output_len];
2025 let mut mask = vec![0u8; result.mask_len];
2026 let err = result
2027 .plan
2028 .execute_batch(&snaps, None, &mut output, &mut mask)
2029 .unwrap_err();
2030 assert!(matches!(err, ObsError::PlanInvalidated { .. }));
2031 }
2032
2033 #[test]
2036 fn short_field_buffer_returns_error_not_panic() {
2037 let space = square4_space(); let spec = ObsSpec {
2039 entries: vec![ObsEntry {
2040 field_id: FieldId(0),
2041 region: ObsRegion::Fixed(RegionSpec::All),
2042 pool: None,
2043 transform: ObsTransform::Identity,
2044 dtype: ObsDtype::F32,
2045 }],
2046 };
2047 let result = ObsPlan::compile(&spec, &space).unwrap();
2048
2049 let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
2051 let mut output = vec![0.0f32; result.output_len];
2052 let mut mask = vec![0u8; result.mask_len];
2053 let err = result
2054 .plan
2055 .execute(&snap, None, &mut output, &mut mask)
2056 .unwrap_err();
2057 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2058 }
2059
2060 #[test]
2063 fn standard_plan_detected_from_agent_region() {
2064 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2065 let spec = ObsSpec {
2066 entries: vec![ObsEntry {
2067 field_id: FieldId(0),
2068 region: ObsRegion::AgentRect {
2069 half_extent: smallvec::smallvec![2, 2],
2070 },
2071 pool: None,
2072 transform: ObsTransform::Identity,
2073 dtype: ObsDtype::F32,
2074 }],
2075 };
2076 let result = ObsPlan::compile(&spec, &space).unwrap();
2077 assert!(result.plan.is_standard());
2078 assert_eq!(result.output_len, 25);
2080 assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
2081 }
2082
2083 #[test]
2084 fn execute_on_standard_plan_errors() {
2085 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2086 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2087 let snap = snapshot_with_field(FieldId(0), data);
2088
2089 let spec = ObsSpec {
2090 entries: vec![ObsEntry {
2091 field_id: FieldId(0),
2092 region: ObsRegion::AgentDisk { radius: 2 },
2093 pool: None,
2094 transform: ObsTransform::Identity,
2095 dtype: ObsDtype::F32,
2096 }],
2097 };
2098 let result = ObsPlan::compile(&spec, &space).unwrap();
2099
2100 let mut output = vec![0.0f32; result.output_len];
2101 let mut mask = vec![0u8; result.mask_len];
2102 let err = result
2103 .plan
2104 .execute(&snap, None, &mut output, &mut mask)
2105 .unwrap_err();
2106 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2107 }
2108
2109 #[test]
2110 fn interior_boundary_equivalence() {
2111 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2114 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2115 let snap = snapshot_with_field(FieldId(0), data);
2116
2117 let radius = 3u32;
2118 let center: Coord = smallvec::smallvec![10, 10]; let standard_spec = ObsSpec {
2122 entries: vec![ObsEntry {
2123 field_id: FieldId(0),
2124 region: ObsRegion::AgentRect {
2125 half_extent: smallvec::smallvec![radius, radius],
2126 },
2127 pool: None,
2128 transform: ObsTransform::Identity,
2129 dtype: ObsDtype::F32,
2130 }],
2131 };
2132 let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
2133 let mut std_output = vec![0.0f32; std_result.output_len];
2134 let mut std_mask = vec![0u8; std_result.mask_len];
2135 std_result
2136 .plan
2137 .execute_agents(
2138 &snap,
2139 &space,
2140 std::slice::from_ref(¢er),
2141 None,
2142 &mut std_output,
2143 &mut std_mask,
2144 )
2145 .unwrap();
2146
2147 let r = radius as i32;
2149 let simple_spec = ObsSpec {
2150 entries: vec![ObsEntry {
2151 field_id: FieldId(0),
2152 region: ObsRegion::Fixed(RegionSpec::Rect {
2153 min: smallvec::smallvec![10 - r, 10 - r],
2154 max: smallvec::smallvec![10 + r, 10 + r],
2155 }),
2156 pool: None,
2157 transform: ObsTransform::Identity,
2158 dtype: ObsDtype::F32,
2159 }],
2160 };
2161 let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
2162 let mut simple_output = vec![0.0f32; simple_result.output_len];
2163 let mut simple_mask = vec![0u8; simple_result.mask_len];
2164 simple_result
2165 .plan
2166 .execute(&snap, None, &mut simple_output, &mut simple_mask)
2167 .unwrap();
2168
2169 assert_eq!(std_result.output_len, simple_result.output_len);
2171 assert_eq!(std_output, simple_output);
2172 assert_eq!(std_mask, simple_mask);
2173 }
2174
2175 #[test]
2176 fn boundary_agent_gets_padding() {
2177 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2179 let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
2180 let snap = snapshot_with_field(FieldId(0), data);
2181
2182 let spec = ObsSpec {
2183 entries: vec![ObsEntry {
2184 field_id: FieldId(0),
2185 region: ObsRegion::AgentRect {
2186 half_extent: smallvec::smallvec![2, 2],
2187 },
2188 pool: None,
2189 transform: ObsTransform::Identity,
2190 dtype: ObsDtype::F32,
2191 }],
2192 };
2193 let result = ObsPlan::compile(&spec, &space).unwrap();
2194 let center: Coord = smallvec::smallvec![0, 0];
2195 let mut output = vec![0.0f32; result.output_len];
2196 let mut mask = vec![0u8; result.mask_len];
2197 let metas = result
2198 .plan
2199 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2200 .unwrap();
2201
2202 let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
2205 assert_eq!(valid_count, 9);
2206
2207 assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
2209
2210 assert_eq!(mask[0], 0); assert_eq!(output[0], 0.0);
2214
2215 assert_eq!(mask[12], 1);
2218 assert_eq!(output[12], 1.0);
2219 }
2220
2221 #[test]
2222 fn hex_foveation_interior() {
2223 let space = Hex2D::new(20, 20).unwrap(); let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2226 let snap = snapshot_with_field(FieldId(0), data);
2227
2228 let spec = ObsSpec {
2229 entries: vec![ObsEntry {
2230 field_id: FieldId(0),
2231 region: ObsRegion::AgentDisk { radius: 2 },
2232 pool: None,
2233 transform: ObsTransform::Identity,
2234 dtype: ObsDtype::F32,
2235 }],
2236 };
2237 let result = ObsPlan::compile(&spec, &space).unwrap();
2238 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2242 let mut output = vec![0.0f32; result.output_len];
2243 let mut mask = vec![0u8; result.mask_len];
2244 result
2245 .plan
2246 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2247 .unwrap();
2248
2249 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2253 assert_eq!(valid_count, 19);
2254
2255 for &idx in &[0, 1, 5, 19, 23, 24] {
2263 assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2264 assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2265 }
2266
2267 assert_eq!(output[12], 210.0);
2270
2271 assert_eq!(output[17], 211.0);
2275 }
2276
2277 #[test]
2278 fn wrap_space_all_interior() {
2279 let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2281 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2282 let snap = snapshot_with_field(FieldId(0), data);
2283
2284 let spec = ObsSpec {
2285 entries: vec![ObsEntry {
2286 field_id: FieldId(0),
2287 region: ObsRegion::AgentRect {
2288 half_extent: smallvec::smallvec![2, 2],
2289 },
2290 pool: None,
2291 transform: ObsTransform::Identity,
2292 dtype: ObsDtype::F32,
2293 }],
2294 };
2295 let result = ObsPlan::compile(&spec, &space).unwrap();
2296
2297 let center: Coord = smallvec::smallvec![0, 0];
2299 let mut output = vec![0.0f32; result.output_len];
2300 let mut mask = vec![0u8; result.mask_len];
2301 result
2302 .plan
2303 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2304 .unwrap();
2305
2306 assert!(mask.iter().all(|&v| v == 1));
2308 assert_eq!(output[12], 0.0); }
2310
2311 #[test]
2312 fn execute_agents_multiple() {
2313 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2314 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2315 let snap = snapshot_with_field(FieldId(0), data);
2316
2317 let spec = ObsSpec {
2318 entries: vec![ObsEntry {
2319 field_id: FieldId(0),
2320 region: ObsRegion::AgentRect {
2321 half_extent: smallvec::smallvec![1, 1],
2322 },
2323 pool: None,
2324 transform: ObsTransform::Identity,
2325 dtype: ObsDtype::F32,
2326 }],
2327 };
2328 let result = ObsPlan::compile(&spec, &space).unwrap();
2329 assert_eq!(result.output_len, 9); let centers = vec![
2333 smallvec::smallvec![5, 5], smallvec::smallvec![0, 5], ];
2336 let n = centers.len();
2337 let mut output = vec![0.0f32; result.output_len * n];
2338 let mut mask = vec![0u8; result.mask_len * n];
2339 let metas = result
2340 .plan
2341 .execute_agents(&snap, &space, ¢ers, None, &mut output, &mut mask)
2342 .unwrap();
2343
2344 assert_eq!(metas.len(), 2);
2345
2346 assert!(mask[..9].iter().all(|&v| v == 1));
2348 assert_eq!(output[4], 55.0); let agent1_mask = &mask[9..18];
2352 let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2353 assert_eq!(valid_count, 6); }
2355
2356 #[test]
2357 fn execute_agents_with_normalize() {
2358 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2359 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2360 let snap = snapshot_with_field(FieldId(0), data);
2361
2362 let spec = ObsSpec {
2363 entries: vec![ObsEntry {
2364 field_id: FieldId(0),
2365 region: ObsRegion::AgentRect {
2366 half_extent: smallvec::smallvec![1, 1],
2367 },
2368 pool: None,
2369 transform: ObsTransform::Normalize {
2370 min: 0.0,
2371 max: 99.0,
2372 },
2373 dtype: ObsDtype::F32,
2374 }],
2375 };
2376 let result = ObsPlan::compile(&spec, &space).unwrap();
2377
2378 let center: Coord = smallvec::smallvec![5, 5];
2379 let mut output = vec![0.0f32; result.output_len];
2380 let mut mask = vec![0u8; result.mask_len];
2381 result
2382 .plan
2383 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2384 .unwrap();
2385
2386 let expected = 55.0 / 99.0;
2388 assert!((output[4] - expected as f32).abs() < 1e-5);
2389 }
2390
2391 #[test]
2392 fn execute_agents_with_pooling() {
2393 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2394 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2395 let snap = snapshot_with_field(FieldId(0), data);
2396
2397 let spec = ObsSpec {
2400 entries: vec![ObsEntry {
2401 field_id: FieldId(0),
2402 region: ObsRegion::AgentRect {
2403 half_extent: smallvec::smallvec![3, 3],
2404 },
2405 pool: Some(PoolConfig {
2406 kernel: PoolKernel::Mean,
2407 kernel_size: 2,
2408 stride: 2,
2409 }),
2410 transform: ObsTransform::Identity,
2411 dtype: ObsDtype::F32,
2412 }],
2413 };
2414 let result = ObsPlan::compile(&spec, &space).unwrap();
2415 assert_eq!(result.output_len, 9); assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2417
2418 let center: Coord = smallvec::smallvec![10, 10];
2420 let mut output = vec![0.0f32; result.output_len];
2421 let mut mask = vec![0u8; result.mask_len];
2422 result
2423 .plan
2424 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2425 .unwrap();
2426
2427 assert!(mask.iter().all(|&v| v == 1));
2429
2430 assert!((output[0] - 157.5).abs() < 1e-4);
2435 }
2436
2437 #[test]
2438 fn mixed_fixed_and_agent_entries() {
2439 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2440 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2441 let snap = snapshot_with_field(FieldId(0), data);
2442
2443 let spec = ObsSpec {
2444 entries: vec![
2445 ObsEntry {
2447 field_id: FieldId(0),
2448 region: ObsRegion::Fixed(RegionSpec::All),
2449 pool: None,
2450 transform: ObsTransform::Identity,
2451 dtype: ObsDtype::F32,
2452 },
2453 ObsEntry {
2455 field_id: FieldId(0),
2456 region: ObsRegion::AgentRect {
2457 half_extent: smallvec::smallvec![1, 1],
2458 },
2459 pool: None,
2460 transform: ObsTransform::Identity,
2461 dtype: ObsDtype::F32,
2462 },
2463 ],
2464 };
2465 let result = ObsPlan::compile(&spec, &space).unwrap();
2466 assert!(result.plan.is_standard());
2467 assert_eq!(result.output_len, 109); let center: Coord = smallvec::smallvec![5, 5];
2470 let mut output = vec![0.0f32; result.output_len];
2471 let mut mask = vec![0u8; result.mask_len];
2472 result
2473 .plan
2474 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2475 .unwrap();
2476
2477 let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2479 assert_eq!(&output[..100], &expected[..]);
2480 assert!(mask[..100].iter().all(|&v| v == 1));
2481
2482 assert_eq!(output[100 + 4], 55.0);
2485 }
2486
2487 #[test]
2488 fn wrong_dimensionality_returns_error() {
2489 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2491 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2492 let snap = snapshot_with_field(FieldId(0), data);
2493
2494 let spec = ObsSpec {
2495 entries: vec![ObsEntry {
2496 field_id: FieldId(0),
2497 region: ObsRegion::AgentDisk { radius: 1 },
2498 pool: None,
2499 transform: ObsTransform::Identity,
2500 dtype: ObsDtype::F32,
2501 }],
2502 };
2503 let result = ObsPlan::compile(&spec, &space).unwrap();
2504
2505 let bad_center: Coord = smallvec::smallvec![5]; let mut output = vec![0.0f32; result.output_len];
2507 let mut mask = vec![0u8; result.mask_len];
2508 let err =
2509 result
2510 .plan
2511 .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2512 assert!(err.is_err());
2513 let msg = format!("{}", err.unwrap_err());
2514 assert!(
2515 msg.contains("dimensions"),
2516 "error should mention dimensions: {msg}"
2517 );
2518 }
2519
2520 #[test]
2521 fn agent_disk_square4_filters_corners() {
2522 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2525 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2526 let snap = snapshot_with_field(FieldId(0), data);
2527
2528 let spec = ObsSpec {
2529 entries: vec![ObsEntry {
2530 field_id: FieldId(0),
2531 region: ObsRegion::AgentDisk { radius: 2 },
2532 pool: None,
2533 transform: ObsTransform::Identity,
2534 dtype: ObsDtype::F32,
2535 }],
2536 };
2537 let result = ObsPlan::compile(&spec, &space).unwrap();
2538 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2542 let mut output = vec![0.0f32; 25];
2543 let mut mask = vec![0u8; 25];
2544 result
2545 .plan
2546 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2547 .unwrap();
2548
2549 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2557 assert_eq!(
2558 valid_count, 13,
2559 "Manhattan disk radius=2 should have 13 cells"
2560 );
2561
2562 for &idx in &[0, 4, 20, 24] {
2568 assert_eq!(
2569 mask[idx], 0,
2570 "corner tensor_idx {idx} should be outside disk"
2571 );
2572 }
2573
2574 assert_eq!(output[12], 210.0);
2576 assert_eq!(mask[12], 1);
2577 }
2578
2579 #[test]
2580 fn agent_rect_no_disk_filtering() {
2581 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2583 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2584 let snap = snapshot_with_field(FieldId(0), data);
2585
2586 let spec = ObsSpec {
2587 entries: vec![ObsEntry {
2588 field_id: FieldId(0),
2589 region: ObsRegion::AgentRect {
2590 half_extent: smallvec::smallvec![2, 2],
2591 },
2592 pool: None,
2593 transform: ObsTransform::Identity,
2594 dtype: ObsDtype::F32,
2595 }],
2596 };
2597 let result = ObsPlan::compile(&spec, &space).unwrap();
2598
2599 let center: Coord = smallvec::smallvec![10, 10];
2600 let mut output = vec![0.0f32; 25];
2601 let mut mask = vec![0u8; 25];
2602 result
2603 .plan
2604 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2605 .unwrap();
2606
2607 assert!(mask.iter().all(|&v| v == 1));
2609 }
2610
2611 #[test]
2612 fn agent_disk_square8_chebyshev() {
2613 let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2616 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2617 let snap = snapshot_with_field(FieldId(0), data);
2618
2619 let spec = ObsSpec {
2620 entries: vec![ObsEntry {
2621 field_id: FieldId(0),
2622 region: ObsRegion::AgentDisk { radius: 1 },
2623 pool: None,
2624 transform: ObsTransform::Identity,
2625 dtype: ObsDtype::F32,
2626 }],
2627 };
2628 let result = ObsPlan::compile(&spec, &space).unwrap();
2629 assert_eq!(result.output_len, 9);
2630
2631 let center: Coord = smallvec::smallvec![5, 5];
2632 let mut output = vec![0.0f32; 9];
2633 let mut mask = vec![0u8; 9];
2634 result
2635 .plan
2636 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2637 .unwrap();
2638
2639 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2641 assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2642 }
2643
2644 #[test]
2645 fn compile_rejects_inverted_normalize_range() {
2646 let space = square4_space();
2647 let spec = ObsSpec {
2648 entries: vec![ObsEntry {
2649 field_id: FieldId(0),
2650 region: ObsRegion::Fixed(RegionSpec::All),
2651 pool: None,
2652 transform: ObsTransform::Normalize {
2653 min: 10.0,
2654 max: 5.0,
2655 },
2656 dtype: ObsDtype::F32,
2657 }],
2658 };
2659 let err = ObsPlan::compile(&spec, &space).unwrap_err();
2660 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
2661 }
2662
2663 #[test]
2664 fn compile_rejects_nan_normalize() {
2665 let space = square4_space();
2666 let spec = ObsSpec {
2667 entries: vec![ObsEntry {
2668 field_id: FieldId(0),
2669 region: ObsRegion::Fixed(RegionSpec::All),
2670 pool: None,
2671 transform: ObsTransform::Normalize {
2672 min: f64::NAN,
2673 max: 1.0,
2674 },
2675 dtype: ObsDtype::F32,
2676 }],
2677 };
2678 assert!(ObsPlan::compile(&spec, &space).is_err());
2679 }
2680
2681 #[test]
2682 fn execute_batch_partial_write_on_generation_mismatch() {
2683 let space = square4_space();
2684 let spec = ObsSpec {
2685 entries: vec![ObsEntry {
2686 field_id: FieldId(0),
2687 region: ObsRegion::Fixed(RegionSpec::All),
2688 pool: None,
2689 transform: ObsTransform::Identity,
2690 dtype: ObsDtype::F32,
2691 }],
2692 };
2693 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
2695
2696 let snap_ok = snapshot_with_field(FieldId(0), vec![7.0; 9]);
2698 let mut snap_bad = MockSnapshot::new(TickId(5), WorldGenerationId(99), ParameterVersion(0));
2700 snap_bad.set_field(FieldId(0), vec![0.0; 9]);
2701
2702 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_ok, &snap_bad];
2703 let mut output = vec![0.0f32; result.output_len * 2];
2704 let mut mask = vec![0u8; result.mask_len * 2];
2705
2706 let err = result
2708 .plan
2709 .execute_batch(&snaps, None, &mut output, &mut mask)
2710 .unwrap_err();
2711 assert!(matches!(err, ObsError::PlanInvalidated { .. }));
2712
2713 assert!(
2715 output[..9].iter().all(|&v| v == 7.0),
2716 "first snapshot should be written despite batch error"
2717 );
2718 }
2719}