1use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27#[derive(Debug)]
29pub struct ObsPlanResult {
30 pub plan: ObsPlan,
32 pub output_len: usize,
34 pub entry_shapes: Vec<Vec<usize>>,
36 pub mask_len: usize,
38}
39
40#[derive(Debug)]
48pub struct ObsPlan {
49 strategy: PlanStrategy,
50 output_len: usize,
52 mask_len: usize,
54 compiled_generation: Option<WorldGenerationId>,
56}
57
58#[derive(Debug, Clone)]
63struct GatherOp {
64 field_data_idx: usize,
66 tensor_idx: usize,
68}
69
70#[derive(Debug)]
72struct CompiledEntry {
73 field_id: FieldId,
74 transform: ObsTransform,
75 #[allow(dead_code)]
76 dtype: ObsDtype,
77 output_offset: usize,
79 mask_offset: usize,
81 element_count: usize,
83 gather_ops: Vec<GatherOp>,
85 valid_mask: Vec<u8>,
87 #[allow(dead_code)]
89 valid_ratio: f64,
90}
91
92#[derive(Debug, Clone)]
98struct TemplateOp {
99 relative: Coord,
101 tensor_idx: usize,
103 stride_offset: isize,
106 in_disk: bool,
109}
110
111#[derive(Debug)]
117struct AgentCompiledEntry {
118 field_id: FieldId,
119 pool: Option<PoolConfig>,
120 transform: ObsTransform,
121 #[allow(dead_code)]
122 dtype: ObsDtype,
123 output_offset: usize,
125 mask_offset: usize,
127 element_count: usize,
129 pre_pool_element_count: usize,
131 pre_pool_shape: Vec<usize>,
133 template_ops: Vec<TemplateOp>,
135 radius: u32,
137}
138
139#[derive(Debug)]
141struct StandardPlanData {
142 fixed_entries: Vec<CompiledEntry>,
144 agent_entries: Vec<AgentCompiledEntry>,
146 geometry: Option<GridGeometry>,
148}
149
150#[derive(Debug)]
152enum PlanStrategy {
153 Simple(Vec<CompiledEntry>),
155 Standard(StandardPlanData),
157}
158
159impl ObsPlan {
160 pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
167 if spec.entries.is_empty() {
168 return Err(ObsError::InvalidObsSpec {
169 reason: "ObsSpec has no entries".into(),
170 });
171 }
172
173 let has_agent = spec.entries.iter().any(|e| {
174 matches!(
175 e.region,
176 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
177 )
178 });
179
180 if has_agent {
181 Self::compile_standard(spec, space)
182 } else {
183 Self::compile_simple(spec, space)
184 }
185 }
186
187 fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
189 let canonical = space.canonical_ordering();
190 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
191 .into_iter()
192 .enumerate()
193 .map(|(idx, coord)| (coord, idx))
194 .collect();
195
196 let mut entries = Vec::with_capacity(spec.entries.len());
197 let mut output_offset = 0usize;
198 let mut mask_offset = 0usize;
199 let mut entry_shapes = Vec::with_capacity(spec.entries.len());
200
201 for (i, entry) in spec.entries.iter().enumerate() {
202 let fixed_region = match &entry.region {
203 ObsRegion::Fixed(spec) => spec,
204 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
205 return Err(ObsError::InvalidObsSpec {
206 reason: format!("entry {i}: agent-relative region in Simple plan"),
207 });
208 }
209 };
210 if entry.pool.is_some() {
211 return Err(ObsError::InvalidObsSpec {
212 reason: format!(
213 "entry {i}: pooling requires a Standard plan (use agent-relative region)"
214 ),
215 });
216 }
217
218 let region_plan =
219 space
220 .compile_region(fixed_region)
221 .map_err(|e| ObsError::InvalidObsSpec {
222 reason: format!("entry {i}: region compile failed: {e}"),
223 })?;
224
225 let ratio = region_plan.valid_ratio();
226 if ratio < COVERAGE_ERROR_THRESHOLD {
227 return Err(ObsError::InvalidComposition {
228 reason: format!(
229 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
230 ),
231 });
232 }
233 if ratio < COVERAGE_WARN_THRESHOLD {
234 eprintln!(
235 "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
236 );
237 }
238
239 let mut gather_ops = Vec::with_capacity(region_plan.coords.len());
240 for (coord_idx, coord) in region_plan.coords.iter().enumerate() {
241 let field_data_idx =
242 *coord_to_field_idx
243 .get(coord)
244 .ok_or_else(|| ObsError::InvalidObsSpec {
245 reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
246 })?;
247 let tensor_idx = region_plan.tensor_indices[coord_idx];
248 gather_ops.push(GatherOp {
249 field_data_idx,
250 tensor_idx,
251 });
252 }
253
254 let element_count = region_plan.bounding_shape.total_elements();
255 let shape = match ®ion_plan.bounding_shape {
256 murk_space::BoundingShape::Rect(dims) => dims.clone(),
257 };
258 entry_shapes.push(shape);
259
260 entries.push(CompiledEntry {
261 field_id: entry.field_id,
262 transform: entry.transform.clone(),
263 dtype: entry.dtype,
264 output_offset,
265 mask_offset,
266 element_count,
267 gather_ops,
268 valid_mask: region_plan.valid_mask,
269 valid_ratio: ratio,
270 });
271
272 output_offset += element_count;
273 mask_offset += element_count;
274 }
275
276 let plan = ObsPlan {
277 strategy: PlanStrategy::Simple(entries),
278 output_len: output_offset,
279 mask_len: mask_offset,
280 compiled_generation: None,
281 };
282
283 Ok(ObsPlanResult {
284 output_len: plan.output_len,
285 mask_len: plan.mask_len,
286 entry_shapes,
287 plan,
288 })
289 }
290
291 fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
296 let canonical = space.canonical_ordering();
297 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
298 .into_iter()
299 .enumerate()
300 .map(|(idx, coord)| (coord, idx))
301 .collect();
302
303 let geometry = GridGeometry::from_space(space);
304 let ndim = space.ndim();
305
306 let mut fixed_entries = Vec::new();
307 let mut agent_entries = Vec::new();
308 let mut output_offset = 0usize;
309 let mut mask_offset = 0usize;
310 let mut entry_shapes = Vec::new();
311
312 for (i, entry) in spec.entries.iter().enumerate() {
313 match &entry.region {
314 ObsRegion::Fixed(region_spec) => {
315 if entry.pool.is_some() {
316 return Err(ObsError::InvalidObsSpec {
317 reason: format!("entry {i}: pooling on Fixed regions not supported"),
318 });
319 }
320
321 let region_plan = space.compile_region(region_spec).map_err(|e| {
322 ObsError::InvalidObsSpec {
323 reason: format!("entry {i}: region compile failed: {e}"),
324 }
325 })?;
326
327 let ratio = region_plan.valid_ratio();
328 if ratio < COVERAGE_ERROR_THRESHOLD {
329 return Err(ObsError::InvalidComposition {
330 reason: format!(
331 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
332 ),
333 });
334 }
335
336 let mut gather_ops = Vec::with_capacity(region_plan.coords.len());
337 for (coord_idx, coord) in region_plan.coords.iter().enumerate() {
338 let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
339 ObsError::InvalidObsSpec {
340 reason: format!(
341 "entry {i}: coord {coord:?} not in canonical ordering"
342 ),
343 }
344 })?;
345 let tensor_idx = region_plan.tensor_indices[coord_idx];
346 gather_ops.push(GatherOp {
347 field_data_idx,
348 tensor_idx,
349 });
350 }
351
352 let element_count = region_plan.bounding_shape.total_elements();
353 let shape = match ®ion_plan.bounding_shape {
354 murk_space::BoundingShape::Rect(dims) => dims.clone(),
355 };
356 entry_shapes.push(shape);
357
358 fixed_entries.push(CompiledEntry {
359 field_id: entry.field_id,
360 transform: entry.transform.clone(),
361 dtype: entry.dtype,
362 output_offset,
363 mask_offset,
364 element_count,
365 gather_ops,
366 valid_mask: region_plan.valid_mask,
367 valid_ratio: ratio,
368 });
369
370 output_offset += element_count;
371 mask_offset += element_count;
372 }
373
374 ObsRegion::AgentDisk { radius } => {
375 let half_ext: smallvec::SmallVec<[u32; 4]> =
376 (0..ndim).map(|_| *radius).collect();
377 let (ae, shape) = Self::compile_agent_entry(
378 i,
379 entry,
380 &half_ext,
381 *radius,
382 &geometry,
383 Some(*radius),
384 output_offset,
385 mask_offset,
386 )?;
387 entry_shapes.push(shape);
388 output_offset += ae.element_count;
389 mask_offset += ae.element_count;
390 agent_entries.push(ae);
391 }
392
393 ObsRegion::AgentRect { half_extent } => {
394 let radius = *half_extent.iter().max().unwrap_or(&0);
395 let (ae, shape) = Self::compile_agent_entry(
396 i,
397 entry,
398 half_extent,
399 radius,
400 &geometry,
401 None,
402 output_offset,
403 mask_offset,
404 )?;
405 entry_shapes.push(shape);
406 output_offset += ae.element_count;
407 mask_offset += ae.element_count;
408 agent_entries.push(ae);
409 }
410 }
411 }
412
413 let plan = ObsPlan {
414 strategy: PlanStrategy::Standard(StandardPlanData {
415 fixed_entries,
416 agent_entries,
417 geometry,
418 }),
419 output_len: output_offset,
420 mask_len: mask_offset,
421 compiled_generation: None,
422 };
423
424 Ok(ObsPlanResult {
425 output_len: plan.output_len,
426 mask_len: plan.mask_len,
427 entry_shapes,
428 plan,
429 })
430 }
431
432 #[allow(clippy::too_many_arguments)]
437 fn compile_agent_entry(
438 entry_idx: usize,
439 entry: &crate::spec::ObsEntry,
440 half_extent: &[u32],
441 radius: u32,
442 geometry: &Option<GridGeometry>,
443 disk_radius: Option<u32>,
444 output_offset: usize,
445 mask_offset: usize,
446 ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
447 let pre_pool_shape: Vec<usize> =
448 half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
449 let pre_pool_element_count: usize = pre_pool_shape.iter().product();
450
451 let template_ops = generate_template_ops(half_extent, geometry, disk_radius);
452
453 let (element_count, output_shape) = if let Some(pool) = &entry.pool {
454 if pre_pool_shape.len() != 2 {
455 return Err(ObsError::InvalidObsSpec {
456 reason: format!(
457 "entry {entry_idx}: pooling requires 2D region, got {}D",
458 pre_pool_shape.len()
459 ),
460 });
461 }
462 let h = pre_pool_shape[0];
463 let w = pre_pool_shape[1];
464 let ks = pool.kernel_size;
465 let stride = pool.stride;
466 if ks == 0 || stride == 0 {
467 return Err(ObsError::InvalidObsSpec {
468 reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
469 });
470 }
471 let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
472 let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
473 if out_h == 0 || out_w == 0 {
474 return Err(ObsError::InvalidObsSpec {
475 reason: format!(
476 "entry {entry_idx}: pool produces empty output \
477 (region [{h},{w}], kernel_size {ks}, stride {stride})"
478 ),
479 });
480 }
481 (out_h * out_w, vec![out_h, out_w])
482 } else {
483 (pre_pool_element_count, pre_pool_shape.clone())
484 };
485
486 Ok((
487 AgentCompiledEntry {
488 field_id: entry.field_id,
489 pool: entry.pool.clone(),
490 transform: entry.transform.clone(),
491 dtype: entry.dtype,
492 output_offset,
493 mask_offset,
494 element_count,
495 pre_pool_element_count,
496 pre_pool_shape,
497 template_ops,
498 radius,
499 },
500 output_shape,
501 ))
502 }
503
504 pub fn compile_bound(
509 spec: &ObsSpec,
510 space: &dyn Space,
511 generation: WorldGenerationId,
512 ) -> Result<ObsPlanResult, ObsError> {
513 let mut result = Self::compile(spec, space)?;
514 result.plan.compiled_generation = Some(generation);
515 Ok(result)
516 }
517
518 pub fn output_len(&self) -> usize {
520 self.output_len
521 }
522
523 pub fn mask_len(&self) -> usize {
525 self.mask_len
526 }
527
528 pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
530 self.compiled_generation
531 }
532
533 pub fn execute(
552 &self,
553 snapshot: &dyn SnapshotAccess,
554 engine_tick: Option<TickId>,
555 output: &mut [f32],
556 mask: &mut [u8],
557 ) -> Result<ObsMetadata, ObsError> {
558 let entries = match &self.strategy {
559 PlanStrategy::Simple(entries) => entries,
560 PlanStrategy::Standard(_) => {
561 return Err(ObsError::ExecutionFailed {
562 reason: "Standard plan requires execute_agents(), not execute()".into(),
563 });
564 }
565 };
566
567 if output.len() < self.output_len {
568 return Err(ObsError::ExecutionFailed {
569 reason: format!(
570 "output buffer too small: {} < {}",
571 output.len(),
572 self.output_len
573 ),
574 });
575 }
576 if mask.len() < self.mask_len {
577 return Err(ObsError::ExecutionFailed {
578 reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
579 });
580 }
581
582 if let Some(compiled_gen) = self.compiled_generation {
584 let snapshot_gen = snapshot.world_generation_id();
585 if compiled_gen != snapshot_gen {
586 return Err(ObsError::PlanInvalidated {
587 reason: format!(
588 "plan compiled for generation {}, snapshot is generation {}",
589 compiled_gen.0, snapshot_gen.0
590 ),
591 });
592 }
593 }
594
595 let mut total_valid = 0usize;
596 let mut total_elements = 0usize;
597
598 for entry in entries {
599 let field_data =
600 snapshot
601 .read_field(entry.field_id)
602 .ok_or_else(|| ObsError::ExecutionFailed {
603 reason: format!("field {:?} not in snapshot", entry.field_id),
604 })?;
605
606 let out_slice =
607 &mut output[entry.output_offset..entry.output_offset + entry.element_count];
608 let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
609
610 out_slice.fill(0.0);
612 mask_slice.copy_from_slice(&entry.valid_mask);
613
614 for op in &entry.gather_ops {
616 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
617 ObsError::ExecutionFailed {
618 reason: format!(
619 "field {:?} has {} elements but gather requires index {}",
620 entry.field_id,
621 field_data.len(),
622 op.field_data_idx,
623 ),
624 }
625 })?;
626 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
627 }
628
629 total_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
630 total_elements += entry.element_count;
631 }
632
633 let coverage = if total_elements == 0 {
634 0.0
635 } else {
636 total_valid as f64 / total_elements as f64
637 };
638
639 let age_ticks = match engine_tick {
640 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
641 None => 0,
642 };
643
644 Ok(ObsMetadata {
645 tick_id: snapshot.tick_id(),
646 age_ticks,
647 coverage,
648 world_generation_id: snapshot.world_generation_id(),
649 parameter_version: snapshot.parameter_version(),
650 })
651 }
652
653 pub fn execute_batch(
661 &self,
662 snapshots: &[&dyn SnapshotAccess],
663 engine_tick: Option<TickId>,
664 output: &mut [f32],
665 mask: &mut [u8],
666 ) -> Result<Vec<ObsMetadata>, ObsError> {
667 if matches!(self.strategy, PlanStrategy::Standard(_)) {
669 return Err(ObsError::ExecutionFailed {
670 reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
671 });
672 }
673
674 let batch_size = snapshots.len();
675 let expected_out = batch_size * self.output_len;
676 let expected_mask = batch_size * self.mask_len;
677
678 if output.len() < expected_out {
679 return Err(ObsError::ExecutionFailed {
680 reason: format!(
681 "batch output buffer too small: {} < {}",
682 output.len(),
683 expected_out
684 ),
685 });
686 }
687 if mask.len() < expected_mask {
688 return Err(ObsError::ExecutionFailed {
689 reason: format!(
690 "batch mask buffer too small: {} < {}",
691 mask.len(),
692 expected_mask
693 ),
694 });
695 }
696
697 let mut metadata = Vec::with_capacity(batch_size);
698 for (i, snap) in snapshots.iter().enumerate() {
699 let out_start = i * self.output_len;
700 let mask_start = i * self.mask_len;
701 let out_slice = &mut output[out_start..out_start + self.output_len];
702 let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
703 let meta = self.execute(*snap, engine_tick, out_slice, mask_slice)?;
704 metadata.push(meta);
705 }
706 Ok(metadata)
707 }
708
709 pub fn execute_agents(
720 &self,
721 snapshot: &dyn SnapshotAccess,
722 space: &dyn Space,
723 agent_centers: &[Coord],
724 engine_tick: Option<TickId>,
725 output: &mut [f32],
726 mask: &mut [u8],
727 ) -> Result<Vec<ObsMetadata>, ObsError> {
728 let standard = match &self.strategy {
729 PlanStrategy::Standard(data) => data,
730 PlanStrategy::Simple(_) => {
731 return Err(ObsError::ExecutionFailed {
732 reason: "execute_agents requires a Standard plan \
733 (spec must contain agent-relative entries)"
734 .into(),
735 });
736 }
737 };
738
739 let n_agents = agent_centers.len();
740 let expected_out = n_agents * self.output_len;
741 let expected_mask = n_agents * self.mask_len;
742
743 if output.len() < expected_out {
744 return Err(ObsError::ExecutionFailed {
745 reason: format!(
746 "output buffer too small: {} < {}",
747 output.len(),
748 expected_out
749 ),
750 });
751 }
752 if mask.len() < expected_mask {
753 return Err(ObsError::ExecutionFailed {
754 reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
755 });
756 }
757
758 let expected_ndim = space.ndim();
760 for (i, center) in agent_centers.iter().enumerate() {
761 if center.len() != expected_ndim {
762 return Err(ObsError::ExecutionFailed {
763 reason: format!(
764 "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
765 center.len()
766 ),
767 });
768 }
769 }
770
771 if let Some(compiled_gen) = self.compiled_generation {
773 let snapshot_gen = snapshot.world_generation_id();
774 if compiled_gen != snapshot_gen {
775 return Err(ObsError::PlanInvalidated {
776 reason: format!(
777 "plan compiled for generation {}, snapshot is generation {}",
778 compiled_gen.0, snapshot_gen.0
779 ),
780 });
781 }
782 }
783
784 let mut field_data_map: IndexMap<FieldId, &[f32]> = IndexMap::new();
786 for entry in &standard.fixed_entries {
787 if !field_data_map.contains_key(&entry.field_id) {
788 let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
789 ObsError::ExecutionFailed {
790 reason: format!("field {:?} not in snapshot", entry.field_id),
791 }
792 })?;
793 field_data_map.insert(entry.field_id, data);
794 }
795 }
796 for entry in &standard.agent_entries {
797 if !field_data_map.contains_key(&entry.field_id) {
798 let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
799 ObsError::ExecutionFailed {
800 reason: format!("field {:?} not in snapshot", entry.field_id),
801 }
802 })?;
803 field_data_map.insert(entry.field_id, data);
804 }
805 }
806
807 let mut metadata = Vec::with_capacity(n_agents);
808
809 for (agent_i, center) in agent_centers.iter().enumerate() {
810 let out_start = agent_i * self.output_len;
811 let mask_start = agent_i * self.mask_len;
812 let agent_output = &mut output[out_start..out_start + self.output_len];
813 let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
814
815 agent_output.fill(0.0);
816 agent_mask.fill(0);
817
818 let mut total_valid = 0usize;
819 let mut total_elements = 0usize;
820
821 for entry in &standard.fixed_entries {
823 let field_data = field_data_map[&entry.field_id];
824 let out_slice = &mut agent_output
825 [entry.output_offset..entry.output_offset + entry.element_count];
826 let mask_slice =
827 &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
828
829 mask_slice.copy_from_slice(&entry.valid_mask);
830 for op in &entry.gather_ops {
831 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
832 ObsError::ExecutionFailed {
833 reason: format!(
834 "field {:?} has {} elements but gather requires index {}",
835 entry.field_id,
836 field_data.len(),
837 op.field_data_idx,
838 ),
839 }
840 })?;
841 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
842 }
843
844 total_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
845 total_elements += entry.element_count;
846 }
847
848 for entry in &standard.agent_entries {
850 let field_data = field_data_map[&entry.field_id];
851
852 let use_fast_path = standard
856 .geometry
857 .as_ref()
858 .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
859 .unwrap_or(false);
860
861 let valid = execute_agent_entry(
862 entry,
863 center,
864 field_data,
865 &standard.geometry,
866 space,
867 use_fast_path,
868 agent_output,
869 agent_mask,
870 );
871
872 total_valid += valid;
873 total_elements += entry.element_count;
874 }
875
876 let coverage = if total_elements == 0 {
877 0.0
878 } else {
879 total_valid as f64 / total_elements as f64
880 };
881
882 let age_ticks = match engine_tick {
883 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
884 None => 0,
885 };
886
887 metadata.push(ObsMetadata {
888 tick_id: snapshot.tick_id(),
889 age_ticks,
890 coverage,
891 world_generation_id: snapshot.world_generation_id(),
892 parameter_version: snapshot.parameter_version(),
893 });
894 }
895
896 Ok(metadata)
897 }
898
899 pub fn is_standard(&self) -> bool {
901 matches!(self.strategy, PlanStrategy::Standard(_))
902 }
903}
904
905#[allow(clippy::too_many_arguments)]
909fn execute_agent_entry(
910 entry: &AgentCompiledEntry,
911 center: &Coord,
912 field_data: &[f32],
913 geometry: &Option<GridGeometry>,
914 space: &dyn Space,
915 use_fast_path: bool,
916 agent_output: &mut [f32],
917 agent_mask: &mut [u8],
918) -> usize {
919 if entry.pool.is_some() {
920 execute_agent_entry_pooled(
921 entry,
922 center,
923 field_data,
924 geometry,
925 space,
926 use_fast_path,
927 agent_output,
928 agent_mask,
929 )
930 } else {
931 execute_agent_entry_direct(
932 entry,
933 center,
934 field_data,
935 geometry,
936 space,
937 use_fast_path,
938 agent_output,
939 agent_mask,
940 )
941 }
942}
943
944#[allow(clippy::too_many_arguments)]
946fn execute_agent_entry_direct(
947 entry: &AgentCompiledEntry,
948 center: &Coord,
949 field_data: &[f32],
950 geometry: &Option<GridGeometry>,
951 space: &dyn Space,
952 use_fast_path: bool,
953 agent_output: &mut [f32],
954 agent_mask: &mut [u8],
955) -> usize {
956 let out_slice =
957 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
958 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
959
960 if use_fast_path {
961 let geo = geometry.as_ref().unwrap();
963 let base_rank = geo.canonical_rank(center) as isize;
964 let mut valid = 0;
965 for op in &entry.template_ops {
966 if !op.in_disk {
967 continue;
968 }
969 let field_idx = (base_rank + op.stride_offset) as usize;
970 out_slice[op.tensor_idx] = apply_transform(field_data[field_idx], &entry.transform);
971 mask_slice[op.tensor_idx] = 1;
972 valid += 1;
973 }
974 valid
975 } else {
976 let mut valid = 0;
978 for op in &entry.template_ops {
979 if !op.in_disk {
980 continue;
981 }
982 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
983 if let Some(idx) = field_idx {
984 if idx < field_data.len() {
985 out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
986 mask_slice[op.tensor_idx] = 1;
987 valid += 1;
988 }
989 }
990 }
991 valid
992 }
993}
994
995#[allow(clippy::too_many_arguments)]
997fn execute_agent_entry_pooled(
998 entry: &AgentCompiledEntry,
999 center: &Coord,
1000 field_data: &[f32],
1001 geometry: &Option<GridGeometry>,
1002 space: &dyn Space,
1003 use_fast_path: bool,
1004 agent_output: &mut [f32],
1005 agent_mask: &mut [u8],
1006) -> usize {
1007 let mut scratch = vec![0.0f32; entry.pre_pool_element_count];
1008 let mut scratch_mask = vec![0u8; entry.pre_pool_element_count];
1009
1010 if use_fast_path {
1011 let geo = geometry.as_ref().unwrap();
1012 let base_rank = geo.canonical_rank(center) as isize;
1013 for op in &entry.template_ops {
1014 if !op.in_disk {
1015 continue;
1016 }
1017 let field_idx = (base_rank + op.stride_offset) as usize;
1018 scratch[op.tensor_idx] = field_data[field_idx];
1019 scratch_mask[op.tensor_idx] = 1;
1020 }
1021 } else {
1022 for op in &entry.template_ops {
1023 if !op.in_disk {
1024 continue;
1025 }
1026 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1027 if let Some(idx) = field_idx {
1028 if idx < field_data.len() {
1029 scratch[op.tensor_idx] = field_data[idx];
1030 scratch_mask[op.tensor_idx] = 1;
1031 }
1032 }
1033 }
1034 }
1035
1036 let pool_config = entry.pool.as_ref().unwrap();
1037 let (pooled, pooled_mask, _) =
1038 pool_2d(&scratch, &scratch_mask, &entry.pre_pool_shape, pool_config);
1039
1040 let out_slice =
1041 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1042 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1043
1044 let n = pooled.len().min(entry.element_count);
1045 for i in 0..n {
1046 out_slice[i] = apply_transform(pooled[i], &entry.transform);
1047 }
1048 mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1049
1050 pooled_mask[..n].iter().filter(|&&v| v == 1).count()
1051}
1052
1053fn generate_template_ops(
1065 half_extent: &[u32],
1066 geometry: &Option<GridGeometry>,
1067 disk_radius: Option<u32>,
1068) -> Vec<TemplateOp> {
1069 let ndim = half_extent.len();
1070 let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1071 let total: usize = shape.iter().product();
1072
1073 let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1074
1075 let mut ops = Vec::with_capacity(total);
1076
1077 for tensor_idx in 0..total {
1078 let mut relative = Coord::new();
1080 let mut remaining = tensor_idx;
1081 for d in (0..ndim).rev() {
1083 let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1084 relative.push(coord);
1085 remaining /= shape[d];
1086 }
1087 relative.reverse();
1088
1089 let stride_offset = strides
1090 .map(|s| {
1091 relative
1092 .iter()
1093 .zip(s.iter())
1094 .map(|(&r, &s)| r as isize * s as isize)
1095 .sum::<isize>()
1096 })
1097 .unwrap_or(0);
1098
1099 let in_disk = match disk_radius {
1100 Some(r) => match geometry {
1101 Some(geo) => geo.graph_distance(&relative) <= r,
1102 None => true, },
1104 None => true, };
1106
1107 ops.push(TemplateOp {
1108 relative,
1109 tensor_idx,
1110 stride_offset,
1111 in_disk,
1112 });
1113 }
1114
1115 ops
1116}
1117
1118fn resolve_field_index(
1125 center: &Coord,
1126 relative: &Coord,
1127 geometry: &Option<GridGeometry>,
1128 space: &dyn Space,
1129) -> Option<usize> {
1130 if let Some(geo) = geometry {
1131 if geo.all_wrap {
1132 let wrapped: Coord = center
1134 .iter()
1135 .zip(relative.iter())
1136 .zip(geo.coord_dims.iter())
1137 .map(|((&c, &r), &d)| {
1138 let d = d as i32;
1139 ((c + r) % d + d) % d
1140 })
1141 .collect();
1142 Some(geo.canonical_rank(&wrapped))
1143 } else {
1144 let abs_coord: Coord = center
1145 .iter()
1146 .zip(relative.iter())
1147 .map(|(&c, &r)| c + r)
1148 .collect();
1149 let abs_slice: &[i32] = &abs_coord;
1150 if geo.in_bounds(abs_slice) {
1151 Some(geo.canonical_rank(abs_slice))
1152 } else {
1153 None
1154 }
1155 }
1156 } else {
1157 let abs_coord: Coord = center
1158 .iter()
1159 .zip(relative.iter())
1160 .map(|(&c, &r)| c + r)
1161 .collect();
1162 space.canonical_rank(&abs_coord)
1163 }
1164}
1165
1166fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1168 match transform {
1169 ObsTransform::Identity => raw,
1170 ObsTransform::Normalize { min, max } => {
1171 let range = max - min;
1172 if range == 0.0 {
1173 0.0
1174 } else {
1175 let normalized = (raw as f64 - min) / range;
1176 normalized.clamp(0.0, 1.0) as f32
1177 }
1178 }
1179 }
1180}
1181
1182#[cfg(test)]
1183mod tests {
1184 use super::*;
1185 use crate::spec::{
1186 ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1187 };
1188 use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1189 use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1190 use murk_test_utils::MockSnapshot;
1191
1192 fn square4_space() -> Square4 {
1193 Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1194 }
1195
1196 fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1197 let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1198 snap.set_field(field, data);
1199 snap
1200 }
1201
1202 #[test]
1205 fn compile_empty_spec_errors() {
1206 let space = square4_space();
1207 let spec = ObsSpec { entries: vec![] };
1208 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1209 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1210 }
1211
1212 #[test]
1213 fn compile_all_region_square4() {
1214 let space = square4_space();
1215 let spec = ObsSpec {
1216 entries: vec![ObsEntry {
1217 field_id: FieldId(0),
1218 region: ObsRegion::Fixed(RegionSpec::All),
1219 pool: None,
1220 transform: ObsTransform::Identity,
1221 dtype: ObsDtype::F32,
1222 }],
1223 };
1224 let result = ObsPlan::compile(&spec, &space).unwrap();
1225 assert_eq!(result.output_len, 9); assert_eq!(result.mask_len, 9);
1227 assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1228 }
1229
1230 #[test]
1231 fn compile_rect_region() {
1232 let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1233 let spec = ObsSpec {
1234 entries: vec![ObsEntry {
1235 field_id: FieldId(0),
1236 region: ObsRegion::Fixed(RegionSpec::Rect {
1237 min: smallvec::smallvec![1, 1],
1238 max: smallvec::smallvec![2, 3],
1239 }),
1240 pool: None,
1241 transform: ObsTransform::Identity,
1242 dtype: ObsDtype::F32,
1243 }],
1244 };
1245 let result = ObsPlan::compile(&spec, &space).unwrap();
1246 assert_eq!(result.output_len, 6);
1248 assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1249 }
1250
1251 #[test]
1252 fn compile_two_entries_offsets() {
1253 let space = square4_space();
1254 let spec = ObsSpec {
1255 entries: vec![
1256 ObsEntry {
1257 field_id: FieldId(0),
1258 region: ObsRegion::Fixed(RegionSpec::All),
1259 pool: None,
1260 transform: ObsTransform::Identity,
1261 dtype: ObsDtype::F32,
1262 },
1263 ObsEntry {
1264 field_id: FieldId(1),
1265 region: ObsRegion::Fixed(RegionSpec::All),
1266 pool: None,
1267 transform: ObsTransform::Identity,
1268 dtype: ObsDtype::F32,
1269 },
1270 ],
1271 };
1272 let result = ObsPlan::compile(&spec, &space).unwrap();
1273 assert_eq!(result.output_len, 18); assert_eq!(result.mask_len, 18);
1275 }
1276
1277 #[test]
1278 fn compile_invalid_region_errors() {
1279 let space = square4_space();
1280 let spec = ObsSpec {
1281 entries: vec![ObsEntry {
1282 field_id: FieldId(0),
1283 region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1284 pool: None,
1285 transform: ObsTransform::Identity,
1286 dtype: ObsDtype::F32,
1287 }],
1288 };
1289 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1290 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1291 }
1292
1293 #[test]
1296 fn execute_identity_all_region() {
1297 let space = square4_space();
1298 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1301 let snap = snapshot_with_field(FieldId(0), data);
1302
1303 let spec = ObsSpec {
1304 entries: vec![ObsEntry {
1305 field_id: FieldId(0),
1306 region: ObsRegion::Fixed(RegionSpec::All),
1307 pool: None,
1308 transform: ObsTransform::Identity,
1309 dtype: ObsDtype::F32,
1310 }],
1311 };
1312 let result = ObsPlan::compile(&spec, &space).unwrap();
1313
1314 let mut output = vec![0.0f32; result.output_len];
1315 let mut mask = vec![0u8; result.mask_len];
1316 let meta = result
1317 .plan
1318 .execute(&snap, None, &mut output, &mut mask)
1319 .unwrap();
1320
1321 let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1323 assert_eq!(output, expected);
1324 assert_eq!(mask, vec![1u8; 9]);
1325 assert_eq!(meta.tick_id, TickId(5));
1326 assert_eq!(meta.coverage, 1.0);
1327 assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1328 assert_eq!(meta.parameter_version, ParameterVersion(0));
1329 assert_eq!(meta.age_ticks, 0);
1330 }
1331
1332 #[test]
1333 fn execute_normalize_transform() {
1334 let space = square4_space();
1335 let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1337 let snap = snapshot_with_field(FieldId(0), data);
1338
1339 let spec = ObsSpec {
1340 entries: vec![ObsEntry {
1341 field_id: FieldId(0),
1342 region: ObsRegion::Fixed(RegionSpec::All),
1343 pool: None,
1344 transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1345 dtype: ObsDtype::F32,
1346 }],
1347 };
1348 let result = ObsPlan::compile(&spec, &space).unwrap();
1349
1350 let mut output = vec![0.0f32; result.output_len];
1351 let mut mask = vec![0u8; result.mask_len];
1352 result
1353 .plan
1354 .execute(&snap, None, &mut output, &mut mask)
1355 .unwrap();
1356
1357 for (i, &v) in output.iter().enumerate() {
1359 let expected = i as f32 / 8.0;
1360 assert!(
1361 (v - expected).abs() < 1e-6,
1362 "output[{i}] = {v}, expected {expected}"
1363 );
1364 }
1365 }
1366
1367 #[test]
1368 fn execute_normalize_clamps_out_of_range() {
1369 let space = square4_space();
1370 let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1372 let snap = snapshot_with_field(FieldId(0), data);
1373
1374 let spec = ObsSpec {
1375 entries: vec![ObsEntry {
1376 field_id: FieldId(0),
1377 region: ObsRegion::Fixed(RegionSpec::All),
1378 pool: None,
1379 transform: ObsTransform::Normalize {
1380 min: 0.0,
1381 max: 10.0,
1382 },
1383 dtype: ObsDtype::F32,
1384 }],
1385 };
1386 let result = ObsPlan::compile(&spec, &space).unwrap();
1387
1388 let mut output = vec![0.0f32; result.output_len];
1389 let mut mask = vec![0u8; result.mask_len];
1390 result
1391 .plan
1392 .execute(&snap, None, &mut output, &mut mask)
1393 .unwrap();
1394
1395 for &v in &output {
1396 assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1397 }
1398 }
1399
1400 #[test]
1401 fn execute_normalize_zero_range() {
1402 let space = square4_space();
1403 let data = vec![5.0f32; 9];
1404 let snap = snapshot_with_field(FieldId(0), data);
1405
1406 let spec = ObsSpec {
1407 entries: vec![ObsEntry {
1408 field_id: FieldId(0),
1409 region: ObsRegion::Fixed(RegionSpec::All),
1410 pool: None,
1411 transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1412 dtype: ObsDtype::F32,
1413 }],
1414 };
1415 let result = ObsPlan::compile(&spec, &space).unwrap();
1416
1417 let mut output = vec![-1.0f32; result.output_len];
1418 let mut mask = vec![0u8; result.mask_len];
1419 result
1420 .plan
1421 .execute(&snap, None, &mut output, &mut mask)
1422 .unwrap();
1423
1424 assert!(output.iter().all(|&v| v == 0.0));
1426 }
1427
1428 #[test]
1429 fn execute_rect_subregion_correct_values() {
1430 let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1431 let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1433 let snap = snapshot_with_field(FieldId(0), data);
1434
1435 let spec = ObsSpec {
1436 entries: vec![ObsEntry {
1437 field_id: FieldId(0),
1438 region: ObsRegion::Fixed(RegionSpec::Rect {
1439 min: smallvec::smallvec![1, 1],
1440 max: smallvec::smallvec![2, 2],
1441 }),
1442 pool: None,
1443 transform: ObsTransform::Identity,
1444 dtype: ObsDtype::F32,
1445 }],
1446 };
1447 let result = ObsPlan::compile(&spec, &space).unwrap();
1448 assert_eq!(result.output_len, 4); let mut output = vec![0.0f32; result.output_len];
1451 let mut mask = vec![0u8; result.mask_len];
1452 result
1453 .plan
1454 .execute(&snap, None, &mut output, &mut mask)
1455 .unwrap();
1456
1457 assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1459 assert_eq!(mask, vec![1, 1, 1, 1]);
1460 }
1461
1462 #[test]
1463 fn execute_two_fields() {
1464 let space = square4_space();
1465 let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1466 let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1467 let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1468 snap.set_field(FieldId(0), data_a);
1469 snap.set_field(FieldId(1), data_b);
1470
1471 let spec = ObsSpec {
1472 entries: vec![
1473 ObsEntry {
1474 field_id: FieldId(0),
1475 region: ObsRegion::Fixed(RegionSpec::All),
1476 pool: None,
1477 transform: ObsTransform::Identity,
1478 dtype: ObsDtype::F32,
1479 },
1480 ObsEntry {
1481 field_id: FieldId(1),
1482 region: ObsRegion::Fixed(RegionSpec::All),
1483 pool: None,
1484 transform: ObsTransform::Identity,
1485 dtype: ObsDtype::F32,
1486 },
1487 ],
1488 };
1489 let result = ObsPlan::compile(&spec, &space).unwrap();
1490 assert_eq!(result.output_len, 18);
1491
1492 let mut output = vec![0.0f32; result.output_len];
1493 let mut mask = vec![0u8; result.mask_len];
1494 result
1495 .plan
1496 .execute(&snap, None, &mut output, &mut mask)
1497 .unwrap();
1498
1499 let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1501 let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1502 assert_eq!(&output[..9], &expected_a);
1503 assert_eq!(&output[9..], &expected_b);
1504 }
1505
1506 #[test]
1507 fn execute_missing_field_errors() {
1508 let space = square4_space();
1509 let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1510
1511 let spec = ObsSpec {
1512 entries: vec![ObsEntry {
1513 field_id: FieldId(0),
1514 region: ObsRegion::Fixed(RegionSpec::All),
1515 pool: None,
1516 transform: ObsTransform::Identity,
1517 dtype: ObsDtype::F32,
1518 }],
1519 };
1520 let result = ObsPlan::compile(&spec, &space).unwrap();
1521
1522 let mut output = vec![0.0f32; result.output_len];
1523 let mut mask = vec![0u8; result.mask_len];
1524 let err = result
1525 .plan
1526 .execute(&snap, None, &mut output, &mut mask)
1527 .unwrap_err();
1528 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1529 }
1530
1531 #[test]
1532 fn execute_buffer_too_small_errors() {
1533 let space = square4_space();
1534 let data: Vec<f32> = vec![0.0; 9];
1535 let snap = snapshot_with_field(FieldId(0), data);
1536
1537 let spec = ObsSpec {
1538 entries: vec![ObsEntry {
1539 field_id: FieldId(0),
1540 region: ObsRegion::Fixed(RegionSpec::All),
1541 pool: None,
1542 transform: ObsTransform::Identity,
1543 dtype: ObsDtype::F32,
1544 }],
1545 };
1546 let result = ObsPlan::compile(&spec, &space).unwrap();
1547
1548 let mut output = vec![0.0f32; 4]; let mut mask = vec![0u8; result.mask_len];
1550 let err = result
1551 .plan
1552 .execute(&snap, None, &mut output, &mut mask)
1553 .unwrap_err();
1554 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1555 }
1556
1557 #[test]
1560 fn valid_ratio_one_for_square_all() {
1561 let space = square4_space();
1562 let data: Vec<f32> = vec![1.0; 9];
1563 let snap = snapshot_with_field(FieldId(0), data);
1564
1565 let spec = ObsSpec {
1566 entries: vec![ObsEntry {
1567 field_id: FieldId(0),
1568 region: ObsRegion::Fixed(RegionSpec::All),
1569 pool: None,
1570 transform: ObsTransform::Identity,
1571 dtype: ObsDtype::F32,
1572 }],
1573 };
1574 let result = ObsPlan::compile(&spec, &space).unwrap();
1575
1576 let mut output = vec![0.0f32; result.output_len];
1577 let mut mask = vec![0u8; result.mask_len];
1578 let meta = result
1579 .plan
1580 .execute(&snap, None, &mut output, &mut mask)
1581 .unwrap();
1582
1583 assert_eq!(meta.coverage, 1.0);
1584 }
1585
1586 #[test]
1589 fn plan_invalidated_on_generation_mismatch() {
1590 let space = square4_space();
1591 let data: Vec<f32> = vec![1.0; 9];
1592 let snap = snapshot_with_field(FieldId(0), data);
1593
1594 let spec = ObsSpec {
1595 entries: vec![ObsEntry {
1596 field_id: FieldId(0),
1597 region: ObsRegion::Fixed(RegionSpec::All),
1598 pool: None,
1599 transform: ObsTransform::Identity,
1600 dtype: ObsDtype::F32,
1601 }],
1602 };
1603 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1605
1606 let mut output = vec![0.0f32; result.output_len];
1607 let mut mask = vec![0u8; result.mask_len];
1608 let err = result
1609 .plan
1610 .execute(&snap, None, &mut output, &mut mask)
1611 .unwrap_err();
1612 assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1613 }
1614
1615 #[test]
1616 fn generation_match_succeeds() {
1617 let space = square4_space();
1618 let data: Vec<f32> = vec![1.0; 9];
1619 let snap = snapshot_with_field(FieldId(0), data);
1620
1621 let spec = ObsSpec {
1622 entries: vec![ObsEntry {
1623 field_id: FieldId(0),
1624 region: ObsRegion::Fixed(RegionSpec::All),
1625 pool: None,
1626 transform: ObsTransform::Identity,
1627 dtype: ObsDtype::F32,
1628 }],
1629 };
1630 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1631
1632 let mut output = vec![0.0f32; result.output_len];
1633 let mut mask = vec![0u8; result.mask_len];
1634 result
1635 .plan
1636 .execute(&snap, None, &mut output, &mut mask)
1637 .unwrap();
1638 }
1639
1640 #[test]
1641 fn unbound_plan_ignores_generation() {
1642 let space = square4_space();
1643 let data: Vec<f32> = vec![1.0; 9];
1644 let snap = snapshot_with_field(FieldId(0), data);
1645
1646 let spec = ObsSpec {
1647 entries: vec![ObsEntry {
1648 field_id: FieldId(0),
1649 region: ObsRegion::Fixed(RegionSpec::All),
1650 pool: None,
1651 transform: ObsTransform::Identity,
1652 dtype: ObsDtype::F32,
1653 }],
1654 };
1655 let result = ObsPlan::compile(&spec, &space).unwrap();
1657
1658 let mut output = vec![0.0f32; result.output_len];
1659 let mut mask = vec![0u8; result.mask_len];
1660 result
1661 .plan
1662 .execute(&snap, None, &mut output, &mut mask)
1663 .unwrap();
1664 }
1665
1666 #[test]
1669 fn metadata_fields_populated() {
1670 let space = square4_space();
1671 let data: Vec<f32> = vec![1.0; 9];
1672 let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1673 snap.set_field(FieldId(0), data);
1674
1675 let spec = ObsSpec {
1676 entries: vec![ObsEntry {
1677 field_id: FieldId(0),
1678 region: ObsRegion::Fixed(RegionSpec::All),
1679 pool: None,
1680 transform: ObsTransform::Identity,
1681 dtype: ObsDtype::F32,
1682 }],
1683 };
1684 let result = ObsPlan::compile(&spec, &space).unwrap();
1685
1686 let mut output = vec![0.0f32; result.output_len];
1687 let mut mask = vec![0u8; result.mask_len];
1688 let meta = result
1689 .plan
1690 .execute(&snap, None, &mut output, &mut mask)
1691 .unwrap();
1692
1693 assert_eq!(meta.tick_id, TickId(42));
1694 assert_eq!(meta.age_ticks, 0);
1695 assert_eq!(meta.coverage, 1.0);
1696 assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1697 assert_eq!(meta.parameter_version, ParameterVersion(3));
1698 }
1699
1700 #[test]
1703 fn execute_batch_n1_matches_execute() {
1704 let space = square4_space();
1705 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1706 let snap = snapshot_with_field(FieldId(0), data.clone());
1707
1708 let spec = ObsSpec {
1709 entries: vec![ObsEntry {
1710 field_id: FieldId(0),
1711 region: ObsRegion::Fixed(RegionSpec::All),
1712 pool: None,
1713 transform: ObsTransform::Identity,
1714 dtype: ObsDtype::F32,
1715 }],
1716 };
1717 let result = ObsPlan::compile(&spec, &space).unwrap();
1718
1719 let mut out_single = vec![0.0f32; result.output_len];
1721 let mut mask_single = vec![0u8; result.mask_len];
1722 let meta_single = result
1723 .plan
1724 .execute(&snap, None, &mut out_single, &mut mask_single)
1725 .unwrap();
1726
1727 let mut out_batch = vec![0.0f32; result.output_len];
1729 let mut mask_batch = vec![0u8; result.mask_len];
1730 let snap_ref: &dyn SnapshotAccess = &snap;
1731 let meta_batch = result
1732 .plan
1733 .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1734 .unwrap();
1735
1736 assert_eq!(out_single, out_batch);
1737 assert_eq!(mask_single, mask_batch);
1738 assert_eq!(meta_single, meta_batch[0]);
1739 }
1740
1741 #[test]
1742 fn execute_batch_multiple_snapshots() {
1743 let space = square4_space();
1744 let spec = ObsSpec {
1745 entries: vec![ObsEntry {
1746 field_id: FieldId(0),
1747 region: ObsRegion::Fixed(RegionSpec::All),
1748 pool: None,
1749 transform: ObsTransform::Identity,
1750 dtype: ObsDtype::F32,
1751 }],
1752 };
1753 let result = ObsPlan::compile(&spec, &space).unwrap();
1754
1755 let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1756 let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1757
1758 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1759 let mut output = vec![0.0f32; result.output_len * 2];
1760 let mut mask = vec![0u8; result.mask_len * 2];
1761 let metas = result
1762 .plan
1763 .execute_batch(&snaps, None, &mut output, &mut mask)
1764 .unwrap();
1765
1766 assert_eq!(metas.len(), 2);
1767 assert!(output[..9].iter().all(|&v| v == 1.0));
1768 assert!(output[9..].iter().all(|&v| v == 2.0));
1769 }
1770
1771 #[test]
1772 fn execute_batch_buffer_too_small() {
1773 let space = square4_space();
1774 let spec = ObsSpec {
1775 entries: vec![ObsEntry {
1776 field_id: FieldId(0),
1777 region: ObsRegion::Fixed(RegionSpec::All),
1778 pool: None,
1779 transform: ObsTransform::Identity,
1780 dtype: ObsDtype::F32,
1781 }],
1782 };
1783 let result = ObsPlan::compile(&spec, &space).unwrap();
1784
1785 let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1786 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1787 let mut output = vec![0.0f32; 9]; let mut mask = vec![0u8; 18];
1789 let err = result
1790 .plan
1791 .execute_batch(&snaps, None, &mut output, &mut mask)
1792 .unwrap_err();
1793 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1794 }
1795
1796 #[test]
1799 fn short_field_buffer_returns_error_not_panic() {
1800 let space = square4_space(); let spec = ObsSpec {
1802 entries: vec![ObsEntry {
1803 field_id: FieldId(0),
1804 region: ObsRegion::Fixed(RegionSpec::All),
1805 pool: None,
1806 transform: ObsTransform::Identity,
1807 dtype: ObsDtype::F32,
1808 }],
1809 };
1810 let result = ObsPlan::compile(&spec, &space).unwrap();
1811
1812 let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
1814 let mut output = vec![0.0f32; result.output_len];
1815 let mut mask = vec![0u8; result.mask_len];
1816 let err = result
1817 .plan
1818 .execute(&snap, None, &mut output, &mut mask)
1819 .unwrap_err();
1820 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1821 }
1822
1823 #[test]
1826 fn standard_plan_detected_from_agent_region() {
1827 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1828 let spec = ObsSpec {
1829 entries: vec![ObsEntry {
1830 field_id: FieldId(0),
1831 region: ObsRegion::AgentRect {
1832 half_extent: smallvec::smallvec![2, 2],
1833 },
1834 pool: None,
1835 transform: ObsTransform::Identity,
1836 dtype: ObsDtype::F32,
1837 }],
1838 };
1839 let result = ObsPlan::compile(&spec, &space).unwrap();
1840 assert!(result.plan.is_standard());
1841 assert_eq!(result.output_len, 25);
1843 assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
1844 }
1845
1846 #[test]
1847 fn execute_on_standard_plan_errors() {
1848 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1849 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
1850 let snap = snapshot_with_field(FieldId(0), data);
1851
1852 let spec = ObsSpec {
1853 entries: vec![ObsEntry {
1854 field_id: FieldId(0),
1855 region: ObsRegion::AgentDisk { radius: 2 },
1856 pool: None,
1857 transform: ObsTransform::Identity,
1858 dtype: ObsDtype::F32,
1859 }],
1860 };
1861 let result = ObsPlan::compile(&spec, &space).unwrap();
1862
1863 let mut output = vec![0.0f32; result.output_len];
1864 let mut mask = vec![0u8; result.mask_len];
1865 let err = result
1866 .plan
1867 .execute(&snap, None, &mut output, &mut mask)
1868 .unwrap_err();
1869 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1870 }
1871
1872 #[test]
1873 fn interior_boundary_equivalence() {
1874 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
1877 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
1878 let snap = snapshot_with_field(FieldId(0), data);
1879
1880 let radius = 3u32;
1881 let center: Coord = smallvec::smallvec![10, 10]; let standard_spec = ObsSpec {
1885 entries: vec![ObsEntry {
1886 field_id: FieldId(0),
1887 region: ObsRegion::AgentRect {
1888 half_extent: smallvec::smallvec![radius, radius],
1889 },
1890 pool: None,
1891 transform: ObsTransform::Identity,
1892 dtype: ObsDtype::F32,
1893 }],
1894 };
1895 let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
1896 let mut std_output = vec![0.0f32; std_result.output_len];
1897 let mut std_mask = vec![0u8; std_result.mask_len];
1898 std_result
1899 .plan
1900 .execute_agents(
1901 &snap,
1902 &space,
1903 std::slice::from_ref(¢er),
1904 None,
1905 &mut std_output,
1906 &mut std_mask,
1907 )
1908 .unwrap();
1909
1910 let r = radius as i32;
1912 let simple_spec = ObsSpec {
1913 entries: vec![ObsEntry {
1914 field_id: FieldId(0),
1915 region: ObsRegion::Fixed(RegionSpec::Rect {
1916 min: smallvec::smallvec![10 - r, 10 - r],
1917 max: smallvec::smallvec![10 + r, 10 + r],
1918 }),
1919 pool: None,
1920 transform: ObsTransform::Identity,
1921 dtype: ObsDtype::F32,
1922 }],
1923 };
1924 let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
1925 let mut simple_output = vec![0.0f32; simple_result.output_len];
1926 let mut simple_mask = vec![0u8; simple_result.mask_len];
1927 simple_result
1928 .plan
1929 .execute(&snap, None, &mut simple_output, &mut simple_mask)
1930 .unwrap();
1931
1932 assert_eq!(std_result.output_len, simple_result.output_len);
1934 assert_eq!(std_output, simple_output);
1935 assert_eq!(std_mask, simple_mask);
1936 }
1937
1938 #[test]
1939 fn boundary_agent_gets_padding() {
1940 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1942 let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
1943 let snap = snapshot_with_field(FieldId(0), data);
1944
1945 let spec = ObsSpec {
1946 entries: vec![ObsEntry {
1947 field_id: FieldId(0),
1948 region: ObsRegion::AgentRect {
1949 half_extent: smallvec::smallvec![2, 2],
1950 },
1951 pool: None,
1952 transform: ObsTransform::Identity,
1953 dtype: ObsDtype::F32,
1954 }],
1955 };
1956 let result = ObsPlan::compile(&spec, &space).unwrap();
1957 let center: Coord = smallvec::smallvec![0, 0];
1958 let mut output = vec![0.0f32; result.output_len];
1959 let mut mask = vec![0u8; result.mask_len];
1960 let metas = result
1961 .plan
1962 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
1963 .unwrap();
1964
1965 let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
1968 assert_eq!(valid_count, 9);
1969
1970 assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
1972
1973 assert_eq!(mask[0], 0); assert_eq!(output[0], 0.0);
1977
1978 assert_eq!(mask[12], 1);
1981 assert_eq!(output[12], 1.0);
1982 }
1983
1984 #[test]
1985 fn hex_foveation_interior() {
1986 let space = Hex2D::new(20, 20).unwrap(); let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
1989 let snap = snapshot_with_field(FieldId(0), data);
1990
1991 let spec = ObsSpec {
1992 entries: vec![ObsEntry {
1993 field_id: FieldId(0),
1994 region: ObsRegion::AgentDisk { radius: 2 },
1995 pool: None,
1996 transform: ObsTransform::Identity,
1997 dtype: ObsDtype::F32,
1998 }],
1999 };
2000 let result = ObsPlan::compile(&spec, &space).unwrap();
2001 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2005 let mut output = vec![0.0f32; result.output_len];
2006 let mut mask = vec![0u8; result.mask_len];
2007 result
2008 .plan
2009 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2010 .unwrap();
2011
2012 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2016 assert_eq!(valid_count, 19);
2017
2018 for &idx in &[0, 1, 5, 19, 23, 24] {
2026 assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2027 assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2028 }
2029
2030 assert_eq!(output[12], 210.0);
2033
2034 assert_eq!(output[17], 211.0);
2038 }
2039
2040 #[test]
2041 fn wrap_space_all_interior() {
2042 let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2044 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2045 let snap = snapshot_with_field(FieldId(0), data);
2046
2047 let spec = ObsSpec {
2048 entries: vec![ObsEntry {
2049 field_id: FieldId(0),
2050 region: ObsRegion::AgentRect {
2051 half_extent: smallvec::smallvec![2, 2],
2052 },
2053 pool: None,
2054 transform: ObsTransform::Identity,
2055 dtype: ObsDtype::F32,
2056 }],
2057 };
2058 let result = ObsPlan::compile(&spec, &space).unwrap();
2059
2060 let center: Coord = smallvec::smallvec![0, 0];
2062 let mut output = vec![0.0f32; result.output_len];
2063 let mut mask = vec![0u8; result.mask_len];
2064 result
2065 .plan
2066 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2067 .unwrap();
2068
2069 assert!(mask.iter().all(|&v| v == 1));
2071 assert_eq!(output[12], 0.0); }
2073
2074 #[test]
2075 fn execute_agents_multiple() {
2076 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2077 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2078 let snap = snapshot_with_field(FieldId(0), data);
2079
2080 let spec = ObsSpec {
2081 entries: vec![ObsEntry {
2082 field_id: FieldId(0),
2083 region: ObsRegion::AgentRect {
2084 half_extent: smallvec::smallvec![1, 1],
2085 },
2086 pool: None,
2087 transform: ObsTransform::Identity,
2088 dtype: ObsDtype::F32,
2089 }],
2090 };
2091 let result = ObsPlan::compile(&spec, &space).unwrap();
2092 assert_eq!(result.output_len, 9); let centers = vec![
2096 smallvec::smallvec![5, 5], smallvec::smallvec![0, 5], ];
2099 let n = centers.len();
2100 let mut output = vec![0.0f32; result.output_len * n];
2101 let mut mask = vec![0u8; result.mask_len * n];
2102 let metas = result
2103 .plan
2104 .execute_agents(&snap, &space, ¢ers, None, &mut output, &mut mask)
2105 .unwrap();
2106
2107 assert_eq!(metas.len(), 2);
2108
2109 assert!(mask[..9].iter().all(|&v| v == 1));
2111 assert_eq!(output[4], 55.0); let agent1_mask = &mask[9..18];
2115 let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2116 assert_eq!(valid_count, 6); }
2118
2119 #[test]
2120 fn execute_agents_with_normalize() {
2121 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2122 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2123 let snap = snapshot_with_field(FieldId(0), data);
2124
2125 let spec = ObsSpec {
2126 entries: vec![ObsEntry {
2127 field_id: FieldId(0),
2128 region: ObsRegion::AgentRect {
2129 half_extent: smallvec::smallvec![1, 1],
2130 },
2131 pool: None,
2132 transform: ObsTransform::Normalize {
2133 min: 0.0,
2134 max: 99.0,
2135 },
2136 dtype: ObsDtype::F32,
2137 }],
2138 };
2139 let result = ObsPlan::compile(&spec, &space).unwrap();
2140
2141 let center: Coord = smallvec::smallvec![5, 5];
2142 let mut output = vec![0.0f32; result.output_len];
2143 let mut mask = vec![0u8; result.mask_len];
2144 result
2145 .plan
2146 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2147 .unwrap();
2148
2149 let expected = 55.0 / 99.0;
2151 assert!((output[4] - expected as f32).abs() < 1e-5);
2152 }
2153
2154 #[test]
2155 fn execute_agents_with_pooling() {
2156 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2157 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2158 let snap = snapshot_with_field(FieldId(0), data);
2159
2160 let spec = ObsSpec {
2163 entries: vec![ObsEntry {
2164 field_id: FieldId(0),
2165 region: ObsRegion::AgentRect {
2166 half_extent: smallvec::smallvec![3, 3],
2167 },
2168 pool: Some(PoolConfig {
2169 kernel: PoolKernel::Mean,
2170 kernel_size: 2,
2171 stride: 2,
2172 }),
2173 transform: ObsTransform::Identity,
2174 dtype: ObsDtype::F32,
2175 }],
2176 };
2177 let result = ObsPlan::compile(&spec, &space).unwrap();
2178 assert_eq!(result.output_len, 9); assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2180
2181 let center: Coord = smallvec::smallvec![10, 10];
2183 let mut output = vec![0.0f32; result.output_len];
2184 let mut mask = vec![0u8; result.mask_len];
2185 result
2186 .plan
2187 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2188 .unwrap();
2189
2190 assert!(mask.iter().all(|&v| v == 1));
2192
2193 assert!((output[0] - 157.5).abs() < 1e-4);
2198 }
2199
2200 #[test]
2201 fn mixed_fixed_and_agent_entries() {
2202 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2203 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2204 let snap = snapshot_with_field(FieldId(0), data);
2205
2206 let spec = ObsSpec {
2207 entries: vec![
2208 ObsEntry {
2210 field_id: FieldId(0),
2211 region: ObsRegion::Fixed(RegionSpec::All),
2212 pool: None,
2213 transform: ObsTransform::Identity,
2214 dtype: ObsDtype::F32,
2215 },
2216 ObsEntry {
2218 field_id: FieldId(0),
2219 region: ObsRegion::AgentRect {
2220 half_extent: smallvec::smallvec![1, 1],
2221 },
2222 pool: None,
2223 transform: ObsTransform::Identity,
2224 dtype: ObsDtype::F32,
2225 },
2226 ],
2227 };
2228 let result = ObsPlan::compile(&spec, &space).unwrap();
2229 assert!(result.plan.is_standard());
2230 assert_eq!(result.output_len, 109); let center: Coord = smallvec::smallvec![5, 5];
2233 let mut output = vec![0.0f32; result.output_len];
2234 let mut mask = vec![0u8; result.mask_len];
2235 result
2236 .plan
2237 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2238 .unwrap();
2239
2240 let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2242 assert_eq!(&output[..100], &expected[..]);
2243 assert!(mask[..100].iter().all(|&v| v == 1));
2244
2245 assert_eq!(output[100 + 4], 55.0);
2248 }
2249
2250 #[test]
2251 fn wrong_dimensionality_returns_error() {
2252 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2254 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2255 let snap = snapshot_with_field(FieldId(0), data);
2256
2257 let spec = ObsSpec {
2258 entries: vec![ObsEntry {
2259 field_id: FieldId(0),
2260 region: ObsRegion::AgentDisk { radius: 1 },
2261 pool: None,
2262 transform: ObsTransform::Identity,
2263 dtype: ObsDtype::F32,
2264 }],
2265 };
2266 let result = ObsPlan::compile(&spec, &space).unwrap();
2267
2268 let bad_center: Coord = smallvec::smallvec![5]; let mut output = vec![0.0f32; result.output_len];
2270 let mut mask = vec![0u8; result.mask_len];
2271 let err =
2272 result
2273 .plan
2274 .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2275 assert!(err.is_err());
2276 let msg = format!("{}", err.unwrap_err());
2277 assert!(
2278 msg.contains("dimensions"),
2279 "error should mention dimensions: {msg}"
2280 );
2281 }
2282
2283 #[test]
2284 fn agent_disk_square4_filters_corners() {
2285 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2288 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2289 let snap = snapshot_with_field(FieldId(0), data);
2290
2291 let spec = ObsSpec {
2292 entries: vec![ObsEntry {
2293 field_id: FieldId(0),
2294 region: ObsRegion::AgentDisk { radius: 2 },
2295 pool: None,
2296 transform: ObsTransform::Identity,
2297 dtype: ObsDtype::F32,
2298 }],
2299 };
2300 let result = ObsPlan::compile(&spec, &space).unwrap();
2301 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2305 let mut output = vec![0.0f32; 25];
2306 let mut mask = vec![0u8; 25];
2307 result
2308 .plan
2309 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2310 .unwrap();
2311
2312 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2320 assert_eq!(
2321 valid_count, 13,
2322 "Manhattan disk radius=2 should have 13 cells"
2323 );
2324
2325 for &idx in &[0, 4, 20, 24] {
2331 assert_eq!(
2332 mask[idx], 0,
2333 "corner tensor_idx {idx} should be outside disk"
2334 );
2335 }
2336
2337 assert_eq!(output[12], 210.0);
2339 assert_eq!(mask[12], 1);
2340 }
2341
2342 #[test]
2343 fn agent_rect_no_disk_filtering() {
2344 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2346 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2347 let snap = snapshot_with_field(FieldId(0), data);
2348
2349 let spec = ObsSpec {
2350 entries: vec![ObsEntry {
2351 field_id: FieldId(0),
2352 region: ObsRegion::AgentRect {
2353 half_extent: smallvec::smallvec![2, 2],
2354 },
2355 pool: None,
2356 transform: ObsTransform::Identity,
2357 dtype: ObsDtype::F32,
2358 }],
2359 };
2360 let result = ObsPlan::compile(&spec, &space).unwrap();
2361
2362 let center: Coord = smallvec::smallvec![10, 10];
2363 let mut output = vec![0.0f32; 25];
2364 let mut mask = vec![0u8; 25];
2365 result
2366 .plan
2367 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2368 .unwrap();
2369
2370 assert!(mask.iter().all(|&v| v == 1));
2372 }
2373
2374 #[test]
2375 fn agent_disk_square8_chebyshev() {
2376 let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2379 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2380 let snap = snapshot_with_field(FieldId(0), data);
2381
2382 let spec = ObsSpec {
2383 entries: vec![ObsEntry {
2384 field_id: FieldId(0),
2385 region: ObsRegion::AgentDisk { radius: 1 },
2386 pool: None,
2387 transform: ObsTransform::Identity,
2388 dtype: ObsDtype::F32,
2389 }],
2390 };
2391 let result = ObsPlan::compile(&spec, &space).unwrap();
2392 assert_eq!(result.output_len, 9);
2393
2394 let center: Coord = smallvec::smallvec![5, 5];
2395 let mut output = vec![0.0f32; 9];
2396 let mut mask = vec![0u8; 9];
2397 result
2398 .plan
2399 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2400 .unwrap();
2401
2402 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2404 assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2405 }
2406}