1use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d_into;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27#[derive(Debug)]
29pub struct ObsPlanResult {
30 pub plan: ObsPlan,
32 pub output_len: usize,
34 pub entry_shapes: Vec<Vec<usize>>,
36 pub mask_len: usize,
38}
39
40#[derive(Debug)]
48pub struct ObsPlan {
49 strategy: PlanStrategy,
50 output_len: usize,
52 mask_len: usize,
54 compiled_generation: Option<WorldGenerationId>,
56}
57
58#[derive(Debug, Clone)]
63struct GatherOp {
64 field_data_idx: usize,
66 tensor_idx: usize,
68}
69
70#[derive(Debug)]
72struct CompiledEntry {
73 field_id: FieldId,
74 transform: ObsTransform,
75 #[allow(dead_code)]
76 dtype: ObsDtype,
77 output_offset: usize,
79 mask_offset: usize,
81 element_count: usize,
83 gather_ops: Vec<GatherOp>,
85 valid_mask: Vec<u8>,
87 valid_count: usize,
89 #[allow(dead_code)]
91 valid_ratio: f64,
92}
93
94#[derive(Debug, Clone)]
100struct TemplateOp {
101 relative: Coord,
103 tensor_idx: usize,
105 stride_offset: isize,
108 in_disk: bool,
111}
112
113#[derive(Debug)]
119struct AgentCompiledEntry {
120 field_id: FieldId,
121 pool: Option<PoolConfig>,
122 transform: ObsTransform,
123 #[allow(dead_code)]
124 dtype: ObsDtype,
125 output_offset: usize,
127 mask_offset: usize,
129 element_count: usize,
131 pre_pool_element_count: usize,
133 pre_pool_shape: Vec<usize>,
135 active_ops: Vec<TemplateOp>,
137 radius: u32,
139}
140
141#[derive(Debug)]
143struct StandardPlanData {
144 fixed_entries: Vec<CompiledEntry>,
146 agent_entries: Vec<AgentCompiledEntry>,
148 geometry: Option<GridGeometry>,
150}
151
152#[derive(Debug)]
154struct SimplePlanData {
155 entries: Vec<CompiledEntry>,
156 total_valid: usize,
157 total_elements: usize,
158}
159
160#[derive(Debug)]
162enum PlanStrategy {
163 Simple(SimplePlanData),
165 Standard(StandardPlanData),
167}
168
169impl ObsPlan {
170 pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
177 if spec.entries.is_empty() {
178 return Err(ObsError::InvalidObsSpec {
179 reason: "ObsSpec has no entries".into(),
180 });
181 }
182
183 for (i, entry) in spec.entries.iter().enumerate() {
185 if let ObsTransform::Normalize { min, max } = &entry.transform {
186 if !min.is_finite() || !max.is_finite() {
187 return Err(ObsError::InvalidObsSpec {
188 reason: format!(
189 "entry {i}: Normalize min/max must be finite, got min={min}, max={max}"
190 ),
191 });
192 }
193 if min > max {
194 return Err(ObsError::InvalidObsSpec {
195 reason: format!("entry {i}: Normalize min ({min}) must be <= max ({max})"),
196 });
197 }
198 }
199 }
200
201 let has_agent = spec.entries.iter().any(|e| {
202 matches!(
203 e.region,
204 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
205 )
206 });
207
208 if has_agent {
209 Self::compile_standard(spec, space)
210 } else {
211 Self::compile_simple(spec, space)
212 }
213 }
214
215 fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
217 let canonical = space.canonical_ordering();
218 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
219 .into_iter()
220 .enumerate()
221 .map(|(idx, coord)| (coord, idx))
222 .collect();
223
224 let mut entries = Vec::with_capacity(spec.entries.len());
225 let mut total_valid = 0usize;
226 let mut total_elements = 0usize;
227 let mut output_offset = 0usize;
228 let mut mask_offset = 0usize;
229 let mut entry_shapes = Vec::with_capacity(spec.entries.len());
230
231 for (i, entry) in spec.entries.iter().enumerate() {
232 let fixed_region = match &entry.region {
233 ObsRegion::Fixed(spec) => spec,
234 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
235 return Err(ObsError::InvalidObsSpec {
236 reason: format!("entry {i}: agent-relative region in Simple plan"),
237 });
238 }
239 };
240 if entry.pool.is_some() {
241 return Err(ObsError::InvalidObsSpec {
242 reason: format!(
243 "entry {i}: pooling requires a Standard plan (use agent-relative region)"
244 ),
245 });
246 }
247
248 let mut region_plan =
249 space
250 .compile_region(fixed_region)
251 .map_err(|e| ObsError::InvalidObsSpec {
252 reason: format!("entry {i}: region compile failed: {e}"),
253 })?;
254
255 let ratio = region_plan.valid_ratio();
256 if ratio < COVERAGE_ERROR_THRESHOLD {
257 return Err(ObsError::InvalidComposition {
258 reason: format!(
259 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
260 ),
261 });
262 }
263 if ratio < COVERAGE_WARN_THRESHOLD {
264 eprintln!(
265 "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
266 );
267 }
268
269 let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
270 for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
271 let field_data_idx =
272 *coord_to_field_idx
273 .get(coord)
274 .ok_or_else(|| ObsError::InvalidObsSpec {
275 reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
276 })?;
277 let tensor_idx = region_plan.tensor_indices()[coord_idx];
278 gather_ops.push(GatherOp {
279 field_data_idx,
280 tensor_idx,
281 });
282 }
283
284 let element_count = region_plan.bounding_shape().total_elements();
285 let shape = match region_plan.bounding_shape() {
286 murk_space::BoundingShape::Rect(dims) => dims.clone(),
287 };
288 entry_shapes.push(shape);
289
290 let valid_mask = region_plan.take_valid_mask();
291 let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
292 entries.push(CompiledEntry {
293 field_id: entry.field_id,
294 transform: entry.transform.clone(),
295 dtype: entry.dtype,
296 output_offset,
297 mask_offset,
298 element_count,
299 gather_ops,
300 valid_mask,
301 valid_count,
302 valid_ratio: ratio,
303 });
304 total_valid += valid_count;
305 total_elements += element_count;
306
307 output_offset += element_count;
308 mask_offset += element_count;
309 }
310
311 let plan = ObsPlan {
312 strategy: PlanStrategy::Simple(SimplePlanData {
313 entries,
314 total_valid,
315 total_elements,
316 }),
317 output_len: output_offset,
318 mask_len: mask_offset,
319 compiled_generation: None,
320 };
321
322 Ok(ObsPlanResult {
323 output_len: plan.output_len,
324 mask_len: plan.mask_len,
325 entry_shapes,
326 plan,
327 })
328 }
329
330 fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
335 let canonical = space.canonical_ordering();
336 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
337 .into_iter()
338 .enumerate()
339 .map(|(idx, coord)| (coord, idx))
340 .collect();
341
342 let geometry = GridGeometry::from_space(space);
343 let ndim = space.ndim();
344
345 let mut fixed_entries = Vec::new();
346 let mut agent_entries = Vec::new();
347 let mut output_offset = 0usize;
348 let mut mask_offset = 0usize;
349 let mut entry_shapes = Vec::new();
350
351 for (i, entry) in spec.entries.iter().enumerate() {
352 match &entry.region {
353 ObsRegion::Fixed(region_spec) => {
354 if entry.pool.is_some() {
355 return Err(ObsError::InvalidObsSpec {
356 reason: format!("entry {i}: pooling on Fixed regions not supported"),
357 });
358 }
359
360 let mut region_plan = space.compile_region(region_spec).map_err(|e| {
361 ObsError::InvalidObsSpec {
362 reason: format!("entry {i}: region compile failed: {e}"),
363 }
364 })?;
365
366 let ratio = region_plan.valid_ratio();
367 if ratio < COVERAGE_ERROR_THRESHOLD {
368 return Err(ObsError::InvalidComposition {
369 reason: format!(
370 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
371 ),
372 });
373 }
374
375 let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
376 for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
377 let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
378 ObsError::InvalidObsSpec {
379 reason: format!(
380 "entry {i}: coord {coord:?} not in canonical ordering"
381 ),
382 }
383 })?;
384 let tensor_idx = region_plan.tensor_indices()[coord_idx];
385 gather_ops.push(GatherOp {
386 field_data_idx,
387 tensor_idx,
388 });
389 }
390
391 let element_count = region_plan.bounding_shape().total_elements();
392 let shape = match region_plan.bounding_shape() {
393 murk_space::BoundingShape::Rect(dims) => dims.clone(),
394 };
395 entry_shapes.push(shape);
396
397 let valid_mask = region_plan.take_valid_mask();
398 let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
399 fixed_entries.push(CompiledEntry {
400 field_id: entry.field_id,
401 transform: entry.transform.clone(),
402 dtype: entry.dtype,
403 output_offset,
404 mask_offset,
405 element_count,
406 gather_ops,
407 valid_mask,
408 valid_count,
409 valid_ratio: ratio,
410 });
411
412 output_offset += element_count;
413 mask_offset += element_count;
414 }
415
416 ObsRegion::AgentDisk { radius } => {
417 let half_ext: smallvec::SmallVec<[u32; 4]> =
418 (0..ndim).map(|_| *radius).collect();
419 let (ae, shape) = Self::compile_agent_entry(
420 i,
421 entry,
422 &half_ext,
423 *radius,
424 &geometry,
425 Some(*radius),
426 output_offset,
427 mask_offset,
428 )?;
429 entry_shapes.push(shape);
430 output_offset += ae.element_count;
431 mask_offset += ae.element_count;
432 agent_entries.push(ae);
433 }
434
435 ObsRegion::AgentRect { half_extent } => {
436 let radius = *half_extent.iter().max().unwrap_or(&0);
437 let (ae, shape) = Self::compile_agent_entry(
438 i,
439 entry,
440 half_extent,
441 radius,
442 &geometry,
443 None,
444 output_offset,
445 mask_offset,
446 )?;
447 entry_shapes.push(shape);
448 output_offset += ae.element_count;
449 mask_offset += ae.element_count;
450 agent_entries.push(ae);
451 }
452 }
453 }
454
455 let plan = ObsPlan {
456 strategy: PlanStrategy::Standard(StandardPlanData {
457 fixed_entries,
458 agent_entries,
459 geometry,
460 }),
461 output_len: output_offset,
462 mask_len: mask_offset,
463 compiled_generation: None,
464 };
465
466 Ok(ObsPlanResult {
467 output_len: plan.output_len,
468 mask_len: plan.mask_len,
469 entry_shapes,
470 plan,
471 })
472 }
473
474 #[allow(clippy::too_many_arguments)]
479 fn compile_agent_entry(
480 entry_idx: usize,
481 entry: &crate::spec::ObsEntry,
482 half_extent: &[u32],
483 radius: u32,
484 geometry: &Option<GridGeometry>,
485 disk_radius: Option<u32>,
486 output_offset: usize,
487 mask_offset: usize,
488 ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
489 let pre_pool_shape: Vec<usize> =
490 half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
491 let pre_pool_element_count: usize = pre_pool_shape.iter().product();
492
493 let template_ops = generate_template_ops(half_extent, geometry, disk_radius);
494 let active_ops = template_ops
495 .iter()
496 .filter(|op| op.in_disk)
497 .cloned()
498 .collect();
499
500 let (element_count, output_shape) = if let Some(pool) = &entry.pool {
501 if pre_pool_shape.len() != 2 {
502 return Err(ObsError::InvalidObsSpec {
503 reason: format!(
504 "entry {entry_idx}: pooling requires 2D region, got {}D",
505 pre_pool_shape.len()
506 ),
507 });
508 }
509 let h = pre_pool_shape[0];
510 let w = pre_pool_shape[1];
511 let ks = pool.kernel_size;
512 let stride = pool.stride;
513 if ks == 0 || stride == 0 {
514 return Err(ObsError::InvalidObsSpec {
515 reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
516 });
517 }
518 let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
519 let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
520 if out_h == 0 || out_w == 0 {
521 return Err(ObsError::InvalidObsSpec {
522 reason: format!(
523 "entry {entry_idx}: pool produces empty output \
524 (region [{h},{w}], kernel_size {ks}, stride {stride})"
525 ),
526 });
527 }
528 (out_h * out_w, vec![out_h, out_w])
529 } else {
530 (pre_pool_element_count, pre_pool_shape.clone())
531 };
532
533 Ok((
534 AgentCompiledEntry {
535 field_id: entry.field_id,
536 pool: entry.pool.clone(),
537 transform: entry.transform.clone(),
538 dtype: entry.dtype,
539 output_offset,
540 mask_offset,
541 element_count,
542 pre_pool_element_count,
543 pre_pool_shape,
544 active_ops,
545 radius,
546 },
547 output_shape,
548 ))
549 }
550
551 pub fn compile_bound(
556 spec: &ObsSpec,
557 space: &dyn Space,
558 generation: WorldGenerationId,
559 ) -> Result<ObsPlanResult, ObsError> {
560 let mut result = Self::compile(spec, space)?;
561 result.plan.compiled_generation = Some(generation);
562 Ok(result)
563 }
564
565 pub fn output_len(&self) -> usize {
567 self.output_len
568 }
569
570 pub fn mask_len(&self) -> usize {
572 self.mask_len
573 }
574
575 pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
577 self.compiled_generation
578 }
579
580 pub fn execute(
599 &self,
600 snapshot: &dyn SnapshotAccess,
601 engine_tick: Option<TickId>,
602 output: &mut [f32],
603 mask: &mut [u8],
604 ) -> Result<ObsMetadata, ObsError> {
605 let simple = match &self.strategy {
606 PlanStrategy::Simple(data) => data,
607 PlanStrategy::Standard(_) => {
608 return Err(ObsError::ExecutionFailed {
609 reason: "Standard plan requires execute_agents(), not execute()".into(),
610 });
611 }
612 };
613
614 if output.len() < self.output_len {
615 return Err(ObsError::ExecutionFailed {
616 reason: format!(
617 "output buffer too small: {} < {}",
618 output.len(),
619 self.output_len
620 ),
621 });
622 }
623 if mask.len() < self.mask_len {
624 return Err(ObsError::ExecutionFailed {
625 reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
626 });
627 }
628
629 if let Some(compiled_gen) = self.compiled_generation {
631 let snapshot_gen = snapshot.world_generation_id();
632 if compiled_gen != snapshot_gen {
633 return Err(ObsError::PlanInvalidated {
634 reason: format!(
635 "plan compiled for generation {}, snapshot is generation {}",
636 compiled_gen.0, snapshot_gen.0
637 ),
638 });
639 }
640 }
641
642 Self::execute_simple_entries(&simple.entries, snapshot, output, mask)?;
643
644 let coverage = if simple.total_elements == 0 {
645 0.0
646 } else {
647 simple.total_valid as f64 / simple.total_elements as f64
648 };
649
650 let age_ticks = match engine_tick {
651 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
652 None => 0,
653 };
654
655 Ok(ObsMetadata {
656 tick_id: snapshot.tick_id(),
657 age_ticks,
658 coverage,
659 world_generation_id: snapshot.world_generation_id(),
660 parameter_version: snapshot.parameter_version(),
661 })
662 }
663
664 pub fn execute_batch(
672 &self,
673 snapshots: &[&dyn SnapshotAccess],
674 engine_tick: Option<TickId>,
675 output: &mut [f32],
676 mask: &mut [u8],
677 ) -> Result<Vec<ObsMetadata>, ObsError> {
678 let simple = match &self.strategy {
679 PlanStrategy::Simple(data) => data,
680 PlanStrategy::Standard(_) => {
681 return Err(ObsError::ExecutionFailed {
682 reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
683 });
684 }
685 };
686
687 let batch_size = snapshots.len();
688 let expected_out = batch_size * self.output_len;
689 let expected_mask = batch_size * self.mask_len;
690
691 if output.len() < expected_out {
692 return Err(ObsError::ExecutionFailed {
693 reason: format!(
694 "batch output buffer too small: {} < {}",
695 output.len(),
696 expected_out
697 ),
698 });
699 }
700 if mask.len() < expected_mask {
701 return Err(ObsError::ExecutionFailed {
702 reason: format!(
703 "batch mask buffer too small: {} < {}",
704 mask.len(),
705 expected_mask
706 ),
707 });
708 }
709
710 let coverage = if simple.total_elements == 0 {
711 0.0
712 } else {
713 simple.total_valid as f64 / simple.total_elements as f64
714 };
715
716 let mut metadata = Vec::with_capacity(batch_size);
717 for (i, snap) in snapshots.iter().enumerate() {
718 if let Some(compiled_gen) = self.compiled_generation {
719 let snapshot_gen = snap.world_generation_id();
720 if compiled_gen != snapshot_gen {
721 return Err(ObsError::PlanInvalidated {
722 reason: format!(
723 "plan compiled for generation {}, snapshot is generation {}",
724 compiled_gen.0, snapshot_gen.0
725 ),
726 });
727 }
728 }
729
730 let out_start = i * self.output_len;
731 let mask_start = i * self.mask_len;
732 let out_slice = &mut output[out_start..out_start + self.output_len];
733 let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
734 Self::execute_simple_entries(&simple.entries, *snap, out_slice, mask_slice)?;
735
736 let age_ticks = match engine_tick {
737 Some(tick) => tick.0.saturating_sub(snap.tick_id().0),
738 None => 0,
739 };
740 metadata.push(ObsMetadata {
741 tick_id: snap.tick_id(),
742 age_ticks,
743 coverage,
744 world_generation_id: snap.world_generation_id(),
745 parameter_version: snap.parameter_version(),
746 });
747 }
748 Ok(metadata)
749 }
750
751 pub fn execute_agents(
762 &self,
763 snapshot: &dyn SnapshotAccess,
764 space: &dyn Space,
765 agent_centers: &[Coord],
766 engine_tick: Option<TickId>,
767 output: &mut [f32],
768 mask: &mut [u8],
769 ) -> Result<Vec<ObsMetadata>, ObsError> {
770 let standard = match &self.strategy {
771 PlanStrategy::Standard(data) => data,
772 PlanStrategy::Simple(_) => {
773 return Err(ObsError::ExecutionFailed {
774 reason: "execute_agents requires a Standard plan \
775 (spec must contain agent-relative entries)"
776 .into(),
777 });
778 }
779 };
780
781 let n_agents = agent_centers.len();
782 let expected_out = n_agents * self.output_len;
783 let expected_mask = n_agents * self.mask_len;
784
785 if output.len() < expected_out {
786 return Err(ObsError::ExecutionFailed {
787 reason: format!(
788 "output buffer too small: {} < {}",
789 output.len(),
790 expected_out
791 ),
792 });
793 }
794 if mask.len() < expected_mask {
795 return Err(ObsError::ExecutionFailed {
796 reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
797 });
798 }
799
800 let expected_ndim = space.ndim();
802 for (i, center) in agent_centers.iter().enumerate() {
803 if center.len() != expected_ndim {
804 return Err(ObsError::ExecutionFailed {
805 reason: format!(
806 "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
807 center.len()
808 ),
809 });
810 }
811 }
812
813 if let Some(compiled_gen) = self.compiled_generation {
815 let snapshot_gen = snapshot.world_generation_id();
816 if compiled_gen != snapshot_gen {
817 return Err(ObsError::PlanInvalidated {
818 reason: format!(
819 "plan compiled for generation {}, snapshot is generation {}",
820 compiled_gen.0, snapshot_gen.0
821 ),
822 });
823 }
824 }
825
826 let mut fixed_field_data = Vec::with_capacity(standard.fixed_entries.len());
830 for entry in &standard.fixed_entries {
831 let data =
832 snapshot
833 .read_field(entry.field_id)
834 .ok_or_else(|| ObsError::ExecutionFailed {
835 reason: format!("field {:?} not in snapshot", entry.field_id),
836 })?;
837 fixed_field_data.push(data);
838 }
839 let mut agent_field_data = Vec::with_capacity(standard.agent_entries.len());
840 for entry in &standard.agent_entries {
841 let data =
842 snapshot
843 .read_field(entry.field_id)
844 .ok_or_else(|| ObsError::ExecutionFailed {
845 reason: format!("field {:?} not in snapshot", entry.field_id),
846 })?;
847 agent_field_data.push(data);
848 }
849
850 let has_fixed = !standard.fixed_entries.is_empty();
854 let mut fixed_out_scratch = if has_fixed {
855 vec![0.0f32; self.output_len]
856 } else {
857 Vec::new()
858 };
859 let mut fixed_mask_scratch = if has_fixed {
860 vec![0u8; self.mask_len]
861 } else {
862 Vec::new()
863 };
864 let mut fixed_valid = 0usize;
865 let mut fixed_elements = 0usize;
866
867 for (entry, field_data) in standard
868 .fixed_entries
869 .iter()
870 .zip(fixed_field_data.iter().copied())
871 {
872 let out_slice = &mut fixed_out_scratch
873 [entry.output_offset..entry.output_offset + entry.element_count];
874 let mask_slice =
875 &mut fixed_mask_scratch[entry.mask_offset..entry.mask_offset + entry.element_count];
876
877 mask_slice.copy_from_slice(&entry.valid_mask);
878 for op in &entry.gather_ops {
879 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
880 ObsError::ExecutionFailed {
881 reason: format!(
882 "field {:?} has {} elements but gather requires index {}",
883 entry.field_id,
884 field_data.len(),
885 op.field_data_idx,
886 ),
887 }
888 })?;
889 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
890 }
891
892 fixed_valid += entry.valid_count;
893 fixed_elements += entry.element_count;
894 }
895
896 let max_pool_scratch = standard
900 .agent_entries
901 .iter()
902 .filter(|e| e.pool.is_some())
903 .map(|e| e.pre_pool_element_count)
904 .max()
905 .unwrap_or(0);
906 let max_pool_output = standard
907 .agent_entries
908 .iter()
909 .filter(|e| e.pool.is_some())
910 .map(|e| e.element_count)
911 .max()
912 .unwrap_or(0);
913 let mut pool_scratch = vec![0.0f32; max_pool_scratch];
914 let mut pool_scratch_mask = vec![0u8; max_pool_scratch];
915 let mut pooled_scratch = vec![0.0f32; max_pool_output];
916 let mut pooled_scratch_mask = vec![0u8; max_pool_output];
917
918 let mut metadata = Vec::with_capacity(n_agents);
919
920 for (agent_i, center) in agent_centers.iter().enumerate() {
921 let out_start = agent_i * self.output_len;
922 let mask_start = agent_i * self.mask_len;
923 let agent_output = &mut output[out_start..out_start + self.output_len];
924 let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
925
926 agent_output.fill(0.0);
929 agent_mask.fill(0);
930 if has_fixed {
931 for entry in &standard.fixed_entries {
932 let out_range = entry.output_offset..entry.output_offset + entry.element_count;
933 let mask_range = entry.mask_offset..entry.mask_offset + entry.element_count;
934 agent_output[out_range.clone()].copy_from_slice(&fixed_out_scratch[out_range]);
935 agent_mask[mask_range.clone()].copy_from_slice(&fixed_mask_scratch[mask_range]);
936 }
937 }
938
939 let mut total_valid = fixed_valid;
940 let mut total_elements = fixed_elements;
941
942 for (entry, field_data) in standard
944 .agent_entries
945 .iter()
946 .zip(agent_field_data.iter().copied())
947 {
948 let use_fast_path = standard
952 .geometry
953 .as_ref()
954 .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
955 .unwrap_or(false);
956
957 if entry.pool.is_some() {
959 pool_scratch[..entry.pre_pool_element_count].fill(0.0);
960 pool_scratch_mask[..entry.pre_pool_element_count].fill(0);
961 }
962
963 let valid = execute_agent_entry(
964 entry,
965 center,
966 field_data,
967 &standard.geometry,
968 space,
969 use_fast_path,
970 agent_output,
971 agent_mask,
972 &mut pool_scratch,
973 &mut pool_scratch_mask,
974 &mut pooled_scratch,
975 &mut pooled_scratch_mask,
976 );
977
978 total_valid += valid;
979 total_elements += entry.element_count;
980 }
981
982 let coverage = if total_elements == 0 {
983 0.0
984 } else {
985 total_valid as f64 / total_elements as f64
986 };
987
988 let age_ticks = match engine_tick {
989 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
990 None => 0,
991 };
992
993 metadata.push(ObsMetadata {
994 tick_id: snapshot.tick_id(),
995 age_ticks,
996 coverage,
997 world_generation_id: snapshot.world_generation_id(),
998 parameter_version: snapshot.parameter_version(),
999 });
1000 }
1001
1002 Ok(metadata)
1003 }
1004
1005 pub fn is_standard(&self) -> bool {
1007 matches!(self.strategy, PlanStrategy::Standard(_))
1008 }
1009
1010 fn execute_simple_entries(
1012 entries: &[CompiledEntry],
1013 snapshot: &dyn SnapshotAccess,
1014 output: &mut [f32],
1015 mask: &mut [u8],
1016 ) -> Result<(), ObsError> {
1017 for entry in entries {
1018 let field_data =
1019 snapshot
1020 .read_field(entry.field_id)
1021 .ok_or_else(|| ObsError::ExecutionFailed {
1022 reason: format!("field {:?} not in snapshot", entry.field_id),
1023 })?;
1024
1025 let out_slice =
1026 &mut output[entry.output_offset..entry.output_offset + entry.element_count];
1027 let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1028
1029 out_slice.fill(0.0);
1031 mask_slice.copy_from_slice(&entry.valid_mask);
1032
1033 for op in &entry.gather_ops {
1035 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
1036 ObsError::ExecutionFailed {
1037 reason: format!(
1038 "field {:?} has {} elements but gather requires index {}",
1039 entry.field_id,
1040 field_data.len(),
1041 op.field_data_idx,
1042 ),
1043 }
1044 })?;
1045 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
1046 }
1047 }
1048 Ok(())
1049 }
1050}
1051
1052#[allow(clippy::too_many_arguments)]
1060fn execute_agent_entry(
1061 entry: &AgentCompiledEntry,
1062 center: &Coord,
1063 field_data: &[f32],
1064 geometry: &Option<GridGeometry>,
1065 space: &dyn Space,
1066 use_fast_path: bool,
1067 agent_output: &mut [f32],
1068 agent_mask: &mut [u8],
1069 pool_scratch: &mut [f32],
1070 pool_scratch_mask: &mut [u8],
1071 pooled_scratch: &mut [f32],
1072 pooled_scratch_mask: &mut [u8],
1073) -> usize {
1074 if entry.pool.is_some() {
1075 execute_agent_entry_pooled(
1076 entry,
1077 center,
1078 field_data,
1079 geometry,
1080 space,
1081 use_fast_path,
1082 agent_output,
1083 agent_mask,
1084 &mut pool_scratch[..entry.pre_pool_element_count],
1085 &mut pool_scratch_mask[..entry.pre_pool_element_count],
1086 &mut pooled_scratch[..entry.element_count],
1087 &mut pooled_scratch_mask[..entry.element_count],
1088 )
1089 } else {
1090 execute_agent_entry_direct(
1091 entry,
1092 center,
1093 field_data,
1094 geometry,
1095 space,
1096 use_fast_path,
1097 agent_output,
1098 agent_mask,
1099 )
1100 }
1101}
1102
1103#[allow(clippy::too_many_arguments)]
1105fn execute_agent_entry_direct(
1106 entry: &AgentCompiledEntry,
1107 center: &Coord,
1108 field_data: &[f32],
1109 geometry: &Option<GridGeometry>,
1110 space: &dyn Space,
1111 use_fast_path: bool,
1112 agent_output: &mut [f32],
1113 agent_mask: &mut [u8],
1114) -> usize {
1115 let out_slice =
1116 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1117 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1118
1119 if use_fast_path {
1120 let geo = geometry.as_ref().unwrap();
1122 let base_rank = geo.canonical_rank(center) as isize;
1123 let mut valid = 0;
1124 for op in &entry.active_ops {
1125 let field_idx = (base_rank + op.stride_offset) as usize;
1126 if let Some(&val) = field_data.get(field_idx) {
1127 out_slice[op.tensor_idx] = apply_transform(val, &entry.transform);
1128 mask_slice[op.tensor_idx] = 1;
1129 valid += 1;
1130 }
1131 }
1132 valid
1133 } else {
1134 let mut valid = 0;
1136 for op in &entry.active_ops {
1137 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1138 if let Some(idx) = field_idx {
1139 if idx < field_data.len() {
1140 out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
1141 mask_slice[op.tensor_idx] = 1;
1142 valid += 1;
1143 }
1144 }
1145 }
1146 valid
1147 }
1148}
1149
1150#[allow(clippy::too_many_arguments)]
1157fn execute_agent_entry_pooled(
1158 entry: &AgentCompiledEntry,
1159 center: &Coord,
1160 field_data: &[f32],
1161 geometry: &Option<GridGeometry>,
1162 space: &dyn Space,
1163 use_fast_path: bool,
1164 agent_output: &mut [f32],
1165 agent_mask: &mut [u8],
1166 scratch: &mut [f32],
1167 scratch_mask: &mut [u8],
1168 pooled: &mut [f32],
1169 pooled_mask: &mut [u8],
1170) -> usize {
1171 if use_fast_path {
1172 let geo = geometry.as_ref().unwrap();
1173 let base_rank = geo.canonical_rank(center) as isize;
1174 for op in &entry.active_ops {
1175 let field_idx = (base_rank + op.stride_offset) as usize;
1176 if let Some(&val) = field_data.get(field_idx) {
1177 scratch[op.tensor_idx] = val;
1178 scratch_mask[op.tensor_idx] = 1;
1179 }
1180 }
1181 } else {
1182 for op in &entry.active_ops {
1183 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1184 if let Some(idx) = field_idx {
1185 if idx < field_data.len() {
1186 scratch[op.tensor_idx] = field_data[idx];
1187 scratch_mask[op.tensor_idx] = 1;
1188 }
1189 }
1190 }
1191 }
1192
1193 let pool_config = entry.pool.as_ref().unwrap();
1194 let (out_h, out_w) = pool_2d_into(
1195 scratch,
1196 scratch_mask,
1197 &entry.pre_pool_shape,
1198 pool_config,
1199 pooled,
1200 pooled_mask,
1201 );
1202
1203 let out_slice =
1204 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1205 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1206
1207 let n = (out_h * out_w).min(entry.element_count);
1208 for i in 0..n {
1209 out_slice[i] = apply_transform(pooled[i], &entry.transform);
1210 }
1211 mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1212
1213 pooled_mask[..n].iter().filter(|&&v| v == 1).count()
1214}
1215
1216fn generate_template_ops(
1228 half_extent: &[u32],
1229 geometry: &Option<GridGeometry>,
1230 disk_radius: Option<u32>,
1231) -> Vec<TemplateOp> {
1232 let ndim = half_extent.len();
1233 let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1234 let total: usize = shape.iter().product();
1235
1236 let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1237
1238 let mut ops = Vec::with_capacity(total);
1239
1240 for tensor_idx in 0..total {
1241 let mut relative = Coord::new();
1243 let mut remaining = tensor_idx;
1244 for d in (0..ndim).rev() {
1246 let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1247 relative.push(coord);
1248 remaining /= shape[d];
1249 }
1250 relative.reverse();
1251
1252 let stride_offset = strides
1253 .map(|s| {
1254 relative
1255 .iter()
1256 .zip(s.iter())
1257 .map(|(&r, &s)| r as isize * s as isize)
1258 .sum::<isize>()
1259 })
1260 .unwrap_or(0);
1261
1262 let in_disk = match disk_radius {
1263 Some(r) => match geometry {
1264 Some(geo) => geo.graph_distance(&relative) <= r,
1265 None => true, },
1267 None => true, };
1269
1270 ops.push(TemplateOp {
1271 relative,
1272 tensor_idx,
1273 stride_offset,
1274 in_disk,
1275 });
1276 }
1277
1278 ops
1279}
1280
1281fn resolve_field_index(
1288 center: &Coord,
1289 relative: &Coord,
1290 geometry: &Option<GridGeometry>,
1291 space: &dyn Space,
1292) -> Option<usize> {
1293 if let Some(geo) = geometry {
1294 if geo.all_wrap {
1295 let wrapped: Coord = center
1297 .iter()
1298 .zip(relative.iter())
1299 .zip(geo.coord_dims.iter())
1300 .map(|((&c, &r), &d)| {
1301 let d = d as i32;
1302 ((c + r) % d + d) % d
1303 })
1304 .collect();
1305 Some(geo.canonical_rank(&wrapped))
1306 } else {
1307 let abs_coord: Coord = center
1308 .iter()
1309 .zip(relative.iter())
1310 .map(|(&c, &r)| c + r)
1311 .collect();
1312 let abs_slice: &[i32] = &abs_coord;
1313 if geo.in_bounds(abs_slice) {
1314 Some(geo.canonical_rank(abs_slice))
1315 } else {
1316 None
1317 }
1318 }
1319 } else {
1320 let abs_coord: Coord = center
1321 .iter()
1322 .zip(relative.iter())
1323 .map(|(&c, &r)| c + r)
1324 .collect();
1325 space.canonical_rank(&abs_coord)
1326 }
1327}
1328
1329fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1331 match transform {
1332 ObsTransform::Identity => raw,
1333 ObsTransform::Normalize { min, max } => {
1334 let range = max - min;
1335 if range == 0.0 {
1336 0.0
1337 } else {
1338 let normalized = (raw as f64 - min) / range;
1339 normalized.clamp(0.0, 1.0) as f32
1340 }
1341 }
1342 }
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347 use super::*;
1348 use crate::spec::{
1349 ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1350 };
1351 use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1352 use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1353 use murk_test_utils::MockSnapshot;
1354
1355 fn square4_space() -> Square4 {
1356 Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1357 }
1358
1359 fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1360 let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1361 snap.set_field(field, data);
1362 snap
1363 }
1364
1365 #[test]
1368 fn compile_empty_spec_errors() {
1369 let space = square4_space();
1370 let spec = ObsSpec { entries: vec![] };
1371 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1372 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1373 }
1374
1375 #[test]
1376 fn compile_all_region_square4() {
1377 let space = square4_space();
1378 let spec = ObsSpec {
1379 entries: vec![ObsEntry {
1380 field_id: FieldId(0),
1381 region: ObsRegion::Fixed(RegionSpec::All),
1382 pool: None,
1383 transform: ObsTransform::Identity,
1384 dtype: ObsDtype::F32,
1385 }],
1386 };
1387 let result = ObsPlan::compile(&spec, &space).unwrap();
1388 assert_eq!(result.output_len, 9); assert_eq!(result.mask_len, 9);
1390 assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1391 }
1392
1393 #[test]
1394 fn compile_rect_region() {
1395 let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1396 let spec = ObsSpec {
1397 entries: vec![ObsEntry {
1398 field_id: FieldId(0),
1399 region: ObsRegion::Fixed(RegionSpec::Rect {
1400 min: smallvec::smallvec![1, 1],
1401 max: smallvec::smallvec![2, 3],
1402 }),
1403 pool: None,
1404 transform: ObsTransform::Identity,
1405 dtype: ObsDtype::F32,
1406 }],
1407 };
1408 let result = ObsPlan::compile(&spec, &space).unwrap();
1409 assert_eq!(result.output_len, 6);
1411 assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1412 }
1413
1414 #[test]
1415 fn compile_two_entries_offsets() {
1416 let space = square4_space();
1417 let spec = ObsSpec {
1418 entries: vec![
1419 ObsEntry {
1420 field_id: FieldId(0),
1421 region: ObsRegion::Fixed(RegionSpec::All),
1422 pool: None,
1423 transform: ObsTransform::Identity,
1424 dtype: ObsDtype::F32,
1425 },
1426 ObsEntry {
1427 field_id: FieldId(1),
1428 region: ObsRegion::Fixed(RegionSpec::All),
1429 pool: None,
1430 transform: ObsTransform::Identity,
1431 dtype: ObsDtype::F32,
1432 },
1433 ],
1434 };
1435 let result = ObsPlan::compile(&spec, &space).unwrap();
1436 assert_eq!(result.output_len, 18); assert_eq!(result.mask_len, 18);
1438 }
1439
1440 #[test]
1441 fn compile_invalid_region_errors() {
1442 let space = square4_space();
1443 let spec = ObsSpec {
1444 entries: vec![ObsEntry {
1445 field_id: FieldId(0),
1446 region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1447 pool: None,
1448 transform: ObsTransform::Identity,
1449 dtype: ObsDtype::F32,
1450 }],
1451 };
1452 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1453 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1454 }
1455
1456 #[test]
1459 fn execute_identity_all_region() {
1460 let space = square4_space();
1461 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1464 let snap = snapshot_with_field(FieldId(0), data);
1465
1466 let spec = ObsSpec {
1467 entries: vec![ObsEntry {
1468 field_id: FieldId(0),
1469 region: ObsRegion::Fixed(RegionSpec::All),
1470 pool: None,
1471 transform: ObsTransform::Identity,
1472 dtype: ObsDtype::F32,
1473 }],
1474 };
1475 let result = ObsPlan::compile(&spec, &space).unwrap();
1476
1477 let mut output = vec![0.0f32; result.output_len];
1478 let mut mask = vec![0u8; result.mask_len];
1479 let meta = result
1480 .plan
1481 .execute(&snap, None, &mut output, &mut mask)
1482 .unwrap();
1483
1484 let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1486 assert_eq!(output, expected);
1487 assert_eq!(mask, vec![1u8; 9]);
1488 assert_eq!(meta.tick_id, TickId(5));
1489 assert_eq!(meta.coverage, 1.0);
1490 assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1491 assert_eq!(meta.parameter_version, ParameterVersion(0));
1492 assert_eq!(meta.age_ticks, 0);
1493 }
1494
1495 #[test]
1496 fn execute_normalize_transform() {
1497 let space = square4_space();
1498 let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1500 let snap = snapshot_with_field(FieldId(0), data);
1501
1502 let spec = ObsSpec {
1503 entries: vec![ObsEntry {
1504 field_id: FieldId(0),
1505 region: ObsRegion::Fixed(RegionSpec::All),
1506 pool: None,
1507 transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1508 dtype: ObsDtype::F32,
1509 }],
1510 };
1511 let result = ObsPlan::compile(&spec, &space).unwrap();
1512
1513 let mut output = vec![0.0f32; result.output_len];
1514 let mut mask = vec![0u8; result.mask_len];
1515 result
1516 .plan
1517 .execute(&snap, None, &mut output, &mut mask)
1518 .unwrap();
1519
1520 for (i, &v) in output.iter().enumerate() {
1522 let expected = i as f32 / 8.0;
1523 assert!(
1524 (v - expected).abs() < 1e-6,
1525 "output[{i}] = {v}, expected {expected}"
1526 );
1527 }
1528 }
1529
1530 #[test]
1531 fn execute_normalize_clamps_out_of_range() {
1532 let space = square4_space();
1533 let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1535 let snap = snapshot_with_field(FieldId(0), data);
1536
1537 let spec = ObsSpec {
1538 entries: vec![ObsEntry {
1539 field_id: FieldId(0),
1540 region: ObsRegion::Fixed(RegionSpec::All),
1541 pool: None,
1542 transform: ObsTransform::Normalize {
1543 min: 0.0,
1544 max: 10.0,
1545 },
1546 dtype: ObsDtype::F32,
1547 }],
1548 };
1549 let result = ObsPlan::compile(&spec, &space).unwrap();
1550
1551 let mut output = vec![0.0f32; result.output_len];
1552 let mut mask = vec![0u8; result.mask_len];
1553 result
1554 .plan
1555 .execute(&snap, None, &mut output, &mut mask)
1556 .unwrap();
1557
1558 for &v in &output {
1559 assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1560 }
1561 }
1562
1563 #[test]
1564 fn execute_normalize_zero_range() {
1565 let space = square4_space();
1566 let data = vec![5.0f32; 9];
1567 let snap = snapshot_with_field(FieldId(0), data);
1568
1569 let spec = ObsSpec {
1570 entries: vec![ObsEntry {
1571 field_id: FieldId(0),
1572 region: ObsRegion::Fixed(RegionSpec::All),
1573 pool: None,
1574 transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1575 dtype: ObsDtype::F32,
1576 }],
1577 };
1578 let result = ObsPlan::compile(&spec, &space).unwrap();
1579
1580 let mut output = vec![-1.0f32; result.output_len];
1581 let mut mask = vec![0u8; result.mask_len];
1582 result
1583 .plan
1584 .execute(&snap, None, &mut output, &mut mask)
1585 .unwrap();
1586
1587 assert!(output.iter().all(|&v| v == 0.0));
1589 }
1590
1591 #[test]
1592 fn execute_rect_subregion_correct_values() {
1593 let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1594 let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1596 let snap = snapshot_with_field(FieldId(0), data);
1597
1598 let spec = ObsSpec {
1599 entries: vec![ObsEntry {
1600 field_id: FieldId(0),
1601 region: ObsRegion::Fixed(RegionSpec::Rect {
1602 min: smallvec::smallvec![1, 1],
1603 max: smallvec::smallvec![2, 2],
1604 }),
1605 pool: None,
1606 transform: ObsTransform::Identity,
1607 dtype: ObsDtype::F32,
1608 }],
1609 };
1610 let result = ObsPlan::compile(&spec, &space).unwrap();
1611 assert_eq!(result.output_len, 4); let mut output = vec![0.0f32; result.output_len];
1614 let mut mask = vec![0u8; result.mask_len];
1615 result
1616 .plan
1617 .execute(&snap, None, &mut output, &mut mask)
1618 .unwrap();
1619
1620 assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1622 assert_eq!(mask, vec![1, 1, 1, 1]);
1623 }
1624
1625 #[test]
1626 fn execute_two_fields() {
1627 let space = square4_space();
1628 let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1629 let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1630 let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1631 snap.set_field(FieldId(0), data_a);
1632 snap.set_field(FieldId(1), data_b);
1633
1634 let spec = ObsSpec {
1635 entries: vec![
1636 ObsEntry {
1637 field_id: FieldId(0),
1638 region: ObsRegion::Fixed(RegionSpec::All),
1639 pool: None,
1640 transform: ObsTransform::Identity,
1641 dtype: ObsDtype::F32,
1642 },
1643 ObsEntry {
1644 field_id: FieldId(1),
1645 region: ObsRegion::Fixed(RegionSpec::All),
1646 pool: None,
1647 transform: ObsTransform::Identity,
1648 dtype: ObsDtype::F32,
1649 },
1650 ],
1651 };
1652 let result = ObsPlan::compile(&spec, &space).unwrap();
1653 assert_eq!(result.output_len, 18);
1654
1655 let mut output = vec![0.0f32; result.output_len];
1656 let mut mask = vec![0u8; result.mask_len];
1657 result
1658 .plan
1659 .execute(&snap, None, &mut output, &mut mask)
1660 .unwrap();
1661
1662 let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1664 let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1665 assert_eq!(&output[..9], &expected_a);
1666 assert_eq!(&output[9..], &expected_b);
1667 }
1668
1669 #[test]
1670 fn execute_missing_field_errors() {
1671 let space = square4_space();
1672 let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1673
1674 let spec = ObsSpec {
1675 entries: vec![ObsEntry {
1676 field_id: FieldId(0),
1677 region: ObsRegion::Fixed(RegionSpec::All),
1678 pool: None,
1679 transform: ObsTransform::Identity,
1680 dtype: ObsDtype::F32,
1681 }],
1682 };
1683 let result = ObsPlan::compile(&spec, &space).unwrap();
1684
1685 let mut output = vec![0.0f32; result.output_len];
1686 let mut mask = vec![0u8; result.mask_len];
1687 let err = result
1688 .plan
1689 .execute(&snap, None, &mut output, &mut mask)
1690 .unwrap_err();
1691 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1692 }
1693
1694 #[test]
1695 fn execute_buffer_too_small_errors() {
1696 let space = square4_space();
1697 let data: Vec<f32> = vec![0.0; 9];
1698 let snap = snapshot_with_field(FieldId(0), data);
1699
1700 let spec = ObsSpec {
1701 entries: vec![ObsEntry {
1702 field_id: FieldId(0),
1703 region: ObsRegion::Fixed(RegionSpec::All),
1704 pool: None,
1705 transform: ObsTransform::Identity,
1706 dtype: ObsDtype::F32,
1707 }],
1708 };
1709 let result = ObsPlan::compile(&spec, &space).unwrap();
1710
1711 let mut output = vec![0.0f32; 4]; let mut mask = vec![0u8; result.mask_len];
1713 let err = result
1714 .plan
1715 .execute(&snap, None, &mut output, &mut mask)
1716 .unwrap_err();
1717 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1718 }
1719
1720 #[test]
1723 fn valid_ratio_one_for_square_all() {
1724 let space = square4_space();
1725 let data: Vec<f32> = vec![1.0; 9];
1726 let snap = snapshot_with_field(FieldId(0), data);
1727
1728 let spec = ObsSpec {
1729 entries: vec![ObsEntry {
1730 field_id: FieldId(0),
1731 region: ObsRegion::Fixed(RegionSpec::All),
1732 pool: None,
1733 transform: ObsTransform::Identity,
1734 dtype: ObsDtype::F32,
1735 }],
1736 };
1737 let result = ObsPlan::compile(&spec, &space).unwrap();
1738
1739 let mut output = vec![0.0f32; result.output_len];
1740 let mut mask = vec![0u8; result.mask_len];
1741 let meta = result
1742 .plan
1743 .execute(&snap, None, &mut output, &mut mask)
1744 .unwrap();
1745
1746 assert_eq!(meta.coverage, 1.0);
1747 }
1748
1749 #[test]
1752 fn plan_invalidated_on_generation_mismatch() {
1753 let space = square4_space();
1754 let data: Vec<f32> = vec![1.0; 9];
1755 let snap = snapshot_with_field(FieldId(0), data);
1756
1757 let spec = ObsSpec {
1758 entries: vec![ObsEntry {
1759 field_id: FieldId(0),
1760 region: ObsRegion::Fixed(RegionSpec::All),
1761 pool: None,
1762 transform: ObsTransform::Identity,
1763 dtype: ObsDtype::F32,
1764 }],
1765 };
1766 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1768
1769 let mut output = vec![0.0f32; result.output_len];
1770 let mut mask = vec![0u8; result.mask_len];
1771 let err = result
1772 .plan
1773 .execute(&snap, None, &mut output, &mut mask)
1774 .unwrap_err();
1775 assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1776 }
1777
1778 #[test]
1779 fn generation_match_succeeds() {
1780 let space = square4_space();
1781 let data: Vec<f32> = vec![1.0; 9];
1782 let snap = snapshot_with_field(FieldId(0), data);
1783
1784 let spec = ObsSpec {
1785 entries: vec![ObsEntry {
1786 field_id: FieldId(0),
1787 region: ObsRegion::Fixed(RegionSpec::All),
1788 pool: None,
1789 transform: ObsTransform::Identity,
1790 dtype: ObsDtype::F32,
1791 }],
1792 };
1793 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1794
1795 let mut output = vec![0.0f32; result.output_len];
1796 let mut mask = vec![0u8; result.mask_len];
1797 result
1798 .plan
1799 .execute(&snap, None, &mut output, &mut mask)
1800 .unwrap();
1801 }
1802
1803 #[test]
1804 fn unbound_plan_ignores_generation() {
1805 let space = square4_space();
1806 let data: Vec<f32> = vec![1.0; 9];
1807 let snap = snapshot_with_field(FieldId(0), data);
1808
1809 let spec = ObsSpec {
1810 entries: vec![ObsEntry {
1811 field_id: FieldId(0),
1812 region: ObsRegion::Fixed(RegionSpec::All),
1813 pool: None,
1814 transform: ObsTransform::Identity,
1815 dtype: ObsDtype::F32,
1816 }],
1817 };
1818 let result = ObsPlan::compile(&spec, &space).unwrap();
1820
1821 let mut output = vec![0.0f32; result.output_len];
1822 let mut mask = vec![0u8; result.mask_len];
1823 result
1824 .plan
1825 .execute(&snap, None, &mut output, &mut mask)
1826 .unwrap();
1827 }
1828
1829 #[test]
1832 fn metadata_fields_populated() {
1833 let space = square4_space();
1834 let data: Vec<f32> = vec![1.0; 9];
1835 let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1836 snap.set_field(FieldId(0), data);
1837
1838 let spec = ObsSpec {
1839 entries: vec![ObsEntry {
1840 field_id: FieldId(0),
1841 region: ObsRegion::Fixed(RegionSpec::All),
1842 pool: None,
1843 transform: ObsTransform::Identity,
1844 dtype: ObsDtype::F32,
1845 }],
1846 };
1847 let result = ObsPlan::compile(&spec, &space).unwrap();
1848
1849 let mut output = vec![0.0f32; result.output_len];
1850 let mut mask = vec![0u8; result.mask_len];
1851 let meta = result
1852 .plan
1853 .execute(&snap, None, &mut output, &mut mask)
1854 .unwrap();
1855
1856 assert_eq!(meta.tick_id, TickId(42));
1857 assert_eq!(meta.age_ticks, 0);
1858 assert_eq!(meta.coverage, 1.0);
1859 assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1860 assert_eq!(meta.parameter_version, ParameterVersion(3));
1861 }
1862
1863 #[test]
1866 fn execute_batch_n1_matches_execute() {
1867 let space = square4_space();
1868 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1869 let snap = snapshot_with_field(FieldId(0), data.clone());
1870
1871 let spec = ObsSpec {
1872 entries: vec![ObsEntry {
1873 field_id: FieldId(0),
1874 region: ObsRegion::Fixed(RegionSpec::All),
1875 pool: None,
1876 transform: ObsTransform::Identity,
1877 dtype: ObsDtype::F32,
1878 }],
1879 };
1880 let result = ObsPlan::compile(&spec, &space).unwrap();
1881
1882 let mut out_single = vec![0.0f32; result.output_len];
1884 let mut mask_single = vec![0u8; result.mask_len];
1885 let meta_single = result
1886 .plan
1887 .execute(&snap, None, &mut out_single, &mut mask_single)
1888 .unwrap();
1889
1890 let mut out_batch = vec![0.0f32; result.output_len];
1892 let mut mask_batch = vec![0u8; result.mask_len];
1893 let snap_ref: &dyn SnapshotAccess = &snap;
1894 let meta_batch = result
1895 .plan
1896 .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1897 .unwrap();
1898
1899 assert_eq!(out_single, out_batch);
1900 assert_eq!(mask_single, mask_batch);
1901 assert_eq!(meta_single, meta_batch[0]);
1902 }
1903
1904 #[test]
1905 fn execute_batch_multiple_snapshots() {
1906 let space = square4_space();
1907 let spec = ObsSpec {
1908 entries: vec![ObsEntry {
1909 field_id: FieldId(0),
1910 region: ObsRegion::Fixed(RegionSpec::All),
1911 pool: None,
1912 transform: ObsTransform::Identity,
1913 dtype: ObsDtype::F32,
1914 }],
1915 };
1916 let result = ObsPlan::compile(&spec, &space).unwrap();
1917
1918 let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1919 let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1920
1921 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1922 let mut output = vec![0.0f32; result.output_len * 2];
1923 let mut mask = vec![0u8; result.mask_len * 2];
1924 let metas = result
1925 .plan
1926 .execute_batch(&snaps, None, &mut output, &mut mask)
1927 .unwrap();
1928
1929 assert_eq!(metas.len(), 2);
1930 assert!(output[..9].iter().all(|&v| v == 1.0));
1931 assert!(output[9..].iter().all(|&v| v == 2.0));
1932 }
1933
1934 #[test]
1935 fn execute_batch_buffer_too_small() {
1936 let space = square4_space();
1937 let spec = ObsSpec {
1938 entries: vec![ObsEntry {
1939 field_id: FieldId(0),
1940 region: ObsRegion::Fixed(RegionSpec::All),
1941 pool: None,
1942 transform: ObsTransform::Identity,
1943 dtype: ObsDtype::F32,
1944 }],
1945 };
1946 let result = ObsPlan::compile(&spec, &space).unwrap();
1947
1948 let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1949 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1950 let mut output = vec![0.0f32; 9]; let mut mask = vec![0u8; 18];
1952 let err = result
1953 .plan
1954 .execute_batch(&snaps, None, &mut output, &mut mask)
1955 .unwrap_err();
1956 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1957 }
1958
1959 #[test]
1962 fn short_field_buffer_returns_error_not_panic() {
1963 let space = square4_space(); let spec = ObsSpec {
1965 entries: vec![ObsEntry {
1966 field_id: FieldId(0),
1967 region: ObsRegion::Fixed(RegionSpec::All),
1968 pool: None,
1969 transform: ObsTransform::Identity,
1970 dtype: ObsDtype::F32,
1971 }],
1972 };
1973 let result = ObsPlan::compile(&spec, &space).unwrap();
1974
1975 let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
1977 let mut output = vec![0.0f32; result.output_len];
1978 let mut mask = vec![0u8; result.mask_len];
1979 let err = result
1980 .plan
1981 .execute(&snap, None, &mut output, &mut mask)
1982 .unwrap_err();
1983 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1984 }
1985
1986 #[test]
1989 fn standard_plan_detected_from_agent_region() {
1990 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1991 let spec = ObsSpec {
1992 entries: vec![ObsEntry {
1993 field_id: FieldId(0),
1994 region: ObsRegion::AgentRect {
1995 half_extent: smallvec::smallvec![2, 2],
1996 },
1997 pool: None,
1998 transform: ObsTransform::Identity,
1999 dtype: ObsDtype::F32,
2000 }],
2001 };
2002 let result = ObsPlan::compile(&spec, &space).unwrap();
2003 assert!(result.plan.is_standard());
2004 assert_eq!(result.output_len, 25);
2006 assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
2007 }
2008
2009 #[test]
2010 fn execute_on_standard_plan_errors() {
2011 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2012 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2013 let snap = snapshot_with_field(FieldId(0), data);
2014
2015 let spec = ObsSpec {
2016 entries: vec![ObsEntry {
2017 field_id: FieldId(0),
2018 region: ObsRegion::AgentDisk { radius: 2 },
2019 pool: None,
2020 transform: ObsTransform::Identity,
2021 dtype: ObsDtype::F32,
2022 }],
2023 };
2024 let result = ObsPlan::compile(&spec, &space).unwrap();
2025
2026 let mut output = vec![0.0f32; result.output_len];
2027 let mut mask = vec![0u8; result.mask_len];
2028 let err = result
2029 .plan
2030 .execute(&snap, None, &mut output, &mut mask)
2031 .unwrap_err();
2032 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2033 }
2034
2035 #[test]
2036 fn interior_boundary_equivalence() {
2037 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2040 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2041 let snap = snapshot_with_field(FieldId(0), data);
2042
2043 let radius = 3u32;
2044 let center: Coord = smallvec::smallvec![10, 10]; let standard_spec = ObsSpec {
2048 entries: vec![ObsEntry {
2049 field_id: FieldId(0),
2050 region: ObsRegion::AgentRect {
2051 half_extent: smallvec::smallvec![radius, radius],
2052 },
2053 pool: None,
2054 transform: ObsTransform::Identity,
2055 dtype: ObsDtype::F32,
2056 }],
2057 };
2058 let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
2059 let mut std_output = vec![0.0f32; std_result.output_len];
2060 let mut std_mask = vec![0u8; std_result.mask_len];
2061 std_result
2062 .plan
2063 .execute_agents(
2064 &snap,
2065 &space,
2066 std::slice::from_ref(¢er),
2067 None,
2068 &mut std_output,
2069 &mut std_mask,
2070 )
2071 .unwrap();
2072
2073 let r = radius as i32;
2075 let simple_spec = ObsSpec {
2076 entries: vec![ObsEntry {
2077 field_id: FieldId(0),
2078 region: ObsRegion::Fixed(RegionSpec::Rect {
2079 min: smallvec::smallvec![10 - r, 10 - r],
2080 max: smallvec::smallvec![10 + r, 10 + r],
2081 }),
2082 pool: None,
2083 transform: ObsTransform::Identity,
2084 dtype: ObsDtype::F32,
2085 }],
2086 };
2087 let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
2088 let mut simple_output = vec![0.0f32; simple_result.output_len];
2089 let mut simple_mask = vec![0u8; simple_result.mask_len];
2090 simple_result
2091 .plan
2092 .execute(&snap, None, &mut simple_output, &mut simple_mask)
2093 .unwrap();
2094
2095 assert_eq!(std_result.output_len, simple_result.output_len);
2097 assert_eq!(std_output, simple_output);
2098 assert_eq!(std_mask, simple_mask);
2099 }
2100
2101 #[test]
2102 fn boundary_agent_gets_padding() {
2103 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2105 let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
2106 let snap = snapshot_with_field(FieldId(0), data);
2107
2108 let spec = ObsSpec {
2109 entries: vec![ObsEntry {
2110 field_id: FieldId(0),
2111 region: ObsRegion::AgentRect {
2112 half_extent: smallvec::smallvec![2, 2],
2113 },
2114 pool: None,
2115 transform: ObsTransform::Identity,
2116 dtype: ObsDtype::F32,
2117 }],
2118 };
2119 let result = ObsPlan::compile(&spec, &space).unwrap();
2120 let center: Coord = smallvec::smallvec![0, 0];
2121 let mut output = vec![0.0f32; result.output_len];
2122 let mut mask = vec![0u8; result.mask_len];
2123 let metas = result
2124 .plan
2125 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2126 .unwrap();
2127
2128 let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
2131 assert_eq!(valid_count, 9);
2132
2133 assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
2135
2136 assert_eq!(mask[0], 0); assert_eq!(output[0], 0.0);
2140
2141 assert_eq!(mask[12], 1);
2144 assert_eq!(output[12], 1.0);
2145 }
2146
2147 #[test]
2148 fn hex_foveation_interior() {
2149 let space = Hex2D::new(20, 20).unwrap(); let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2152 let snap = snapshot_with_field(FieldId(0), data);
2153
2154 let spec = ObsSpec {
2155 entries: vec![ObsEntry {
2156 field_id: FieldId(0),
2157 region: ObsRegion::AgentDisk { radius: 2 },
2158 pool: None,
2159 transform: ObsTransform::Identity,
2160 dtype: ObsDtype::F32,
2161 }],
2162 };
2163 let result = ObsPlan::compile(&spec, &space).unwrap();
2164 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2168 let mut output = vec![0.0f32; result.output_len];
2169 let mut mask = vec![0u8; result.mask_len];
2170 result
2171 .plan
2172 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2173 .unwrap();
2174
2175 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2179 assert_eq!(valid_count, 19);
2180
2181 for &idx in &[0, 1, 5, 19, 23, 24] {
2189 assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2190 assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2191 }
2192
2193 assert_eq!(output[12], 210.0);
2196
2197 assert_eq!(output[17], 211.0);
2201 }
2202
2203 #[test]
2204 fn wrap_space_all_interior() {
2205 let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2207 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2208 let snap = snapshot_with_field(FieldId(0), data);
2209
2210 let spec = ObsSpec {
2211 entries: vec![ObsEntry {
2212 field_id: FieldId(0),
2213 region: ObsRegion::AgentRect {
2214 half_extent: smallvec::smallvec![2, 2],
2215 },
2216 pool: None,
2217 transform: ObsTransform::Identity,
2218 dtype: ObsDtype::F32,
2219 }],
2220 };
2221 let result = ObsPlan::compile(&spec, &space).unwrap();
2222
2223 let center: Coord = smallvec::smallvec![0, 0];
2225 let mut output = vec![0.0f32; result.output_len];
2226 let mut mask = vec![0u8; result.mask_len];
2227 result
2228 .plan
2229 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2230 .unwrap();
2231
2232 assert!(mask.iter().all(|&v| v == 1));
2234 assert_eq!(output[12], 0.0); }
2236
2237 #[test]
2238 fn execute_agents_multiple() {
2239 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2240 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2241 let snap = snapshot_with_field(FieldId(0), data);
2242
2243 let spec = ObsSpec {
2244 entries: vec![ObsEntry {
2245 field_id: FieldId(0),
2246 region: ObsRegion::AgentRect {
2247 half_extent: smallvec::smallvec![1, 1],
2248 },
2249 pool: None,
2250 transform: ObsTransform::Identity,
2251 dtype: ObsDtype::F32,
2252 }],
2253 };
2254 let result = ObsPlan::compile(&spec, &space).unwrap();
2255 assert_eq!(result.output_len, 9); let centers = vec![
2259 smallvec::smallvec![5, 5], smallvec::smallvec![0, 5], ];
2262 let n = centers.len();
2263 let mut output = vec![0.0f32; result.output_len * n];
2264 let mut mask = vec![0u8; result.mask_len * n];
2265 let metas = result
2266 .plan
2267 .execute_agents(&snap, &space, ¢ers, None, &mut output, &mut mask)
2268 .unwrap();
2269
2270 assert_eq!(metas.len(), 2);
2271
2272 assert!(mask[..9].iter().all(|&v| v == 1));
2274 assert_eq!(output[4], 55.0); let agent1_mask = &mask[9..18];
2278 let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2279 assert_eq!(valid_count, 6); }
2281
2282 #[test]
2283 fn execute_agents_with_normalize() {
2284 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2285 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2286 let snap = snapshot_with_field(FieldId(0), data);
2287
2288 let spec = ObsSpec {
2289 entries: vec![ObsEntry {
2290 field_id: FieldId(0),
2291 region: ObsRegion::AgentRect {
2292 half_extent: smallvec::smallvec![1, 1],
2293 },
2294 pool: None,
2295 transform: ObsTransform::Normalize {
2296 min: 0.0,
2297 max: 99.0,
2298 },
2299 dtype: ObsDtype::F32,
2300 }],
2301 };
2302 let result = ObsPlan::compile(&spec, &space).unwrap();
2303
2304 let center: Coord = smallvec::smallvec![5, 5];
2305 let mut output = vec![0.0f32; result.output_len];
2306 let mut mask = vec![0u8; result.mask_len];
2307 result
2308 .plan
2309 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2310 .unwrap();
2311
2312 let expected = 55.0 / 99.0;
2314 assert!((output[4] - expected as f32).abs() < 1e-5);
2315 }
2316
2317 #[test]
2318 fn execute_agents_with_pooling() {
2319 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2320 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2321 let snap = snapshot_with_field(FieldId(0), data);
2322
2323 let spec = ObsSpec {
2326 entries: vec![ObsEntry {
2327 field_id: FieldId(0),
2328 region: ObsRegion::AgentRect {
2329 half_extent: smallvec::smallvec![3, 3],
2330 },
2331 pool: Some(PoolConfig {
2332 kernel: PoolKernel::Mean,
2333 kernel_size: 2,
2334 stride: 2,
2335 }),
2336 transform: ObsTransform::Identity,
2337 dtype: ObsDtype::F32,
2338 }],
2339 };
2340 let result = ObsPlan::compile(&spec, &space).unwrap();
2341 assert_eq!(result.output_len, 9); assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2343
2344 let center: Coord = smallvec::smallvec![10, 10];
2346 let mut output = vec![0.0f32; result.output_len];
2347 let mut mask = vec![0u8; result.mask_len];
2348 result
2349 .plan
2350 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2351 .unwrap();
2352
2353 assert!(mask.iter().all(|&v| v == 1));
2355
2356 assert!((output[0] - 157.5).abs() < 1e-4);
2361 }
2362
2363 #[test]
2364 fn mixed_fixed_and_agent_entries() {
2365 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2366 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2367 let snap = snapshot_with_field(FieldId(0), data);
2368
2369 let spec = ObsSpec {
2370 entries: vec![
2371 ObsEntry {
2373 field_id: FieldId(0),
2374 region: ObsRegion::Fixed(RegionSpec::All),
2375 pool: None,
2376 transform: ObsTransform::Identity,
2377 dtype: ObsDtype::F32,
2378 },
2379 ObsEntry {
2381 field_id: FieldId(0),
2382 region: ObsRegion::AgentRect {
2383 half_extent: smallvec::smallvec![1, 1],
2384 },
2385 pool: None,
2386 transform: ObsTransform::Identity,
2387 dtype: ObsDtype::F32,
2388 },
2389 ],
2390 };
2391 let result = ObsPlan::compile(&spec, &space).unwrap();
2392 assert!(result.plan.is_standard());
2393 assert_eq!(result.output_len, 109); let center: Coord = smallvec::smallvec![5, 5];
2396 let mut output = vec![0.0f32; result.output_len];
2397 let mut mask = vec![0u8; result.mask_len];
2398 result
2399 .plan
2400 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2401 .unwrap();
2402
2403 let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2405 assert_eq!(&output[..100], &expected[..]);
2406 assert!(mask[..100].iter().all(|&v| v == 1));
2407
2408 assert_eq!(output[100 + 4], 55.0);
2411 }
2412
2413 #[test]
2414 fn wrong_dimensionality_returns_error() {
2415 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2417 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2418 let snap = snapshot_with_field(FieldId(0), data);
2419
2420 let spec = ObsSpec {
2421 entries: vec![ObsEntry {
2422 field_id: FieldId(0),
2423 region: ObsRegion::AgentDisk { radius: 1 },
2424 pool: None,
2425 transform: ObsTransform::Identity,
2426 dtype: ObsDtype::F32,
2427 }],
2428 };
2429 let result = ObsPlan::compile(&spec, &space).unwrap();
2430
2431 let bad_center: Coord = smallvec::smallvec![5]; let mut output = vec![0.0f32; result.output_len];
2433 let mut mask = vec![0u8; result.mask_len];
2434 let err =
2435 result
2436 .plan
2437 .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2438 assert!(err.is_err());
2439 let msg = format!("{}", err.unwrap_err());
2440 assert!(
2441 msg.contains("dimensions"),
2442 "error should mention dimensions: {msg}"
2443 );
2444 }
2445
2446 #[test]
2447 fn agent_disk_square4_filters_corners() {
2448 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2451 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2452 let snap = snapshot_with_field(FieldId(0), data);
2453
2454 let spec = ObsSpec {
2455 entries: vec![ObsEntry {
2456 field_id: FieldId(0),
2457 region: ObsRegion::AgentDisk { radius: 2 },
2458 pool: None,
2459 transform: ObsTransform::Identity,
2460 dtype: ObsDtype::F32,
2461 }],
2462 };
2463 let result = ObsPlan::compile(&spec, &space).unwrap();
2464 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2468 let mut output = vec![0.0f32; 25];
2469 let mut mask = vec![0u8; 25];
2470 result
2471 .plan
2472 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2473 .unwrap();
2474
2475 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2483 assert_eq!(
2484 valid_count, 13,
2485 "Manhattan disk radius=2 should have 13 cells"
2486 );
2487
2488 for &idx in &[0, 4, 20, 24] {
2494 assert_eq!(
2495 mask[idx], 0,
2496 "corner tensor_idx {idx} should be outside disk"
2497 );
2498 }
2499
2500 assert_eq!(output[12], 210.0);
2502 assert_eq!(mask[12], 1);
2503 }
2504
2505 #[test]
2506 fn agent_rect_no_disk_filtering() {
2507 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2509 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2510 let snap = snapshot_with_field(FieldId(0), data);
2511
2512 let spec = ObsSpec {
2513 entries: vec![ObsEntry {
2514 field_id: FieldId(0),
2515 region: ObsRegion::AgentRect {
2516 half_extent: smallvec::smallvec![2, 2],
2517 },
2518 pool: None,
2519 transform: ObsTransform::Identity,
2520 dtype: ObsDtype::F32,
2521 }],
2522 };
2523 let result = ObsPlan::compile(&spec, &space).unwrap();
2524
2525 let center: Coord = smallvec::smallvec![10, 10];
2526 let mut output = vec![0.0f32; 25];
2527 let mut mask = vec![0u8; 25];
2528 result
2529 .plan
2530 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2531 .unwrap();
2532
2533 assert!(mask.iter().all(|&v| v == 1));
2535 }
2536
2537 #[test]
2538 fn agent_disk_square8_chebyshev() {
2539 let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2542 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2543 let snap = snapshot_with_field(FieldId(0), data);
2544
2545 let spec = ObsSpec {
2546 entries: vec![ObsEntry {
2547 field_id: FieldId(0),
2548 region: ObsRegion::AgentDisk { radius: 1 },
2549 pool: None,
2550 transform: ObsTransform::Identity,
2551 dtype: ObsDtype::F32,
2552 }],
2553 };
2554 let result = ObsPlan::compile(&spec, &space).unwrap();
2555 assert_eq!(result.output_len, 9);
2556
2557 let center: Coord = smallvec::smallvec![5, 5];
2558 let mut output = vec![0.0f32; 9];
2559 let mut mask = vec![0u8; 9];
2560 result
2561 .plan
2562 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2563 .unwrap();
2564
2565 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2567 assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2568 }
2569
2570 #[test]
2571 fn compile_rejects_inverted_normalize_range() {
2572 let space = square4_space();
2573 let spec = ObsSpec {
2574 entries: vec![ObsEntry {
2575 field_id: FieldId(0),
2576 region: ObsRegion::Fixed(RegionSpec::All),
2577 pool: None,
2578 transform: ObsTransform::Normalize {
2579 min: 10.0,
2580 max: 5.0,
2581 },
2582 dtype: ObsDtype::F32,
2583 }],
2584 };
2585 let err = ObsPlan::compile(&spec, &space).unwrap_err();
2586 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
2587 }
2588
2589 #[test]
2590 fn compile_rejects_nan_normalize() {
2591 let space = square4_space();
2592 let spec = ObsSpec {
2593 entries: vec![ObsEntry {
2594 field_id: FieldId(0),
2595 region: ObsRegion::Fixed(RegionSpec::All),
2596 pool: None,
2597 transform: ObsTransform::Normalize {
2598 min: f64::NAN,
2599 max: 1.0,
2600 },
2601 dtype: ObsDtype::F32,
2602 }],
2603 };
2604 assert!(ObsPlan::compile(&spec, &space).is_err());
2605 }
2606}