1use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27#[derive(Debug)]
29pub struct ObsPlanResult {
30 pub plan: ObsPlan,
32 pub output_len: usize,
34 pub entry_shapes: Vec<Vec<usize>>,
36 pub mask_len: usize,
38}
39
40#[derive(Debug)]
48pub struct ObsPlan {
49 strategy: PlanStrategy,
50 output_len: usize,
52 mask_len: usize,
54 compiled_generation: Option<WorldGenerationId>,
56}
57
58#[derive(Debug, Clone)]
63struct GatherOp {
64 field_data_idx: usize,
66 tensor_idx: usize,
68}
69
70#[derive(Debug)]
72struct CompiledEntry {
73 field_id: FieldId,
74 transform: ObsTransform,
75 #[allow(dead_code)]
76 dtype: ObsDtype,
77 output_offset: usize,
79 mask_offset: usize,
81 element_count: usize,
83 gather_ops: Vec<GatherOp>,
85 valid_mask: Vec<u8>,
87 #[allow(dead_code)]
89 valid_ratio: f64,
90}
91
92#[derive(Debug, Clone)]
98struct TemplateOp {
99 relative: Coord,
101 tensor_idx: usize,
103 stride_offset: isize,
106 in_disk: bool,
109}
110
111#[derive(Debug)]
117struct AgentCompiledEntry {
118 field_id: FieldId,
119 pool: Option<PoolConfig>,
120 transform: ObsTransform,
121 #[allow(dead_code)]
122 dtype: ObsDtype,
123 output_offset: usize,
125 mask_offset: usize,
127 element_count: usize,
129 pre_pool_element_count: usize,
131 pre_pool_shape: Vec<usize>,
133 template_ops: Vec<TemplateOp>,
135 radius: u32,
137}
138
139#[derive(Debug)]
141struct StandardPlanData {
142 fixed_entries: Vec<CompiledEntry>,
144 agent_entries: Vec<AgentCompiledEntry>,
146 geometry: Option<GridGeometry>,
148}
149
150#[derive(Debug)]
152enum PlanStrategy {
153 Simple(Vec<CompiledEntry>),
155 Standard(StandardPlanData),
157}
158
159impl ObsPlan {
160 pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
167 if spec.entries.is_empty() {
168 return Err(ObsError::InvalidObsSpec {
169 reason: "ObsSpec has no entries".into(),
170 });
171 }
172
173 for (i, entry) in spec.entries.iter().enumerate() {
175 if let ObsTransform::Normalize { min, max } = &entry.transform {
176 if !min.is_finite() || !max.is_finite() {
177 return Err(ObsError::InvalidObsSpec {
178 reason: format!(
179 "entry {i}: Normalize min/max must be finite, got min={min}, max={max}"
180 ),
181 });
182 }
183 if min > max {
184 return Err(ObsError::InvalidObsSpec {
185 reason: format!("entry {i}: Normalize min ({min}) must be <= max ({max})"),
186 });
187 }
188 }
189 }
190
191 let has_agent = spec.entries.iter().any(|e| {
192 matches!(
193 e.region,
194 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
195 )
196 });
197
198 if has_agent {
199 Self::compile_standard(spec, space)
200 } else {
201 Self::compile_simple(spec, space)
202 }
203 }
204
205 fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
207 let canonical = space.canonical_ordering();
208 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
209 .into_iter()
210 .enumerate()
211 .map(|(idx, coord)| (coord, idx))
212 .collect();
213
214 let mut entries = Vec::with_capacity(spec.entries.len());
215 let mut output_offset = 0usize;
216 let mut mask_offset = 0usize;
217 let mut entry_shapes = Vec::with_capacity(spec.entries.len());
218
219 for (i, entry) in spec.entries.iter().enumerate() {
220 let fixed_region = match &entry.region {
221 ObsRegion::Fixed(spec) => spec,
222 ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
223 return Err(ObsError::InvalidObsSpec {
224 reason: format!("entry {i}: agent-relative region in Simple plan"),
225 });
226 }
227 };
228 if entry.pool.is_some() {
229 return Err(ObsError::InvalidObsSpec {
230 reason: format!(
231 "entry {i}: pooling requires a Standard plan (use agent-relative region)"
232 ),
233 });
234 }
235
236 let mut region_plan =
237 space
238 .compile_region(fixed_region)
239 .map_err(|e| ObsError::InvalidObsSpec {
240 reason: format!("entry {i}: region compile failed: {e}"),
241 })?;
242
243 let ratio = region_plan.valid_ratio();
244 if ratio < COVERAGE_ERROR_THRESHOLD {
245 return Err(ObsError::InvalidComposition {
246 reason: format!(
247 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
248 ),
249 });
250 }
251 if ratio < COVERAGE_WARN_THRESHOLD {
252 eprintln!(
253 "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
254 );
255 }
256
257 let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
258 for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
259 let field_data_idx =
260 *coord_to_field_idx
261 .get(coord)
262 .ok_or_else(|| ObsError::InvalidObsSpec {
263 reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
264 })?;
265 let tensor_idx = region_plan.tensor_indices()[coord_idx];
266 gather_ops.push(GatherOp {
267 field_data_idx,
268 tensor_idx,
269 });
270 }
271
272 let element_count = region_plan.bounding_shape().total_elements();
273 let shape = match region_plan.bounding_shape() {
274 murk_space::BoundingShape::Rect(dims) => dims.clone(),
275 };
276 entry_shapes.push(shape);
277
278 entries.push(CompiledEntry {
279 field_id: entry.field_id,
280 transform: entry.transform.clone(),
281 dtype: entry.dtype,
282 output_offset,
283 mask_offset,
284 element_count,
285 gather_ops,
286 valid_mask: region_plan.take_valid_mask(),
287 valid_ratio: ratio,
288 });
289
290 output_offset += element_count;
291 mask_offset += element_count;
292 }
293
294 let plan = ObsPlan {
295 strategy: PlanStrategy::Simple(entries),
296 output_len: output_offset,
297 mask_len: mask_offset,
298 compiled_generation: None,
299 };
300
301 Ok(ObsPlanResult {
302 output_len: plan.output_len,
303 mask_len: plan.mask_len,
304 entry_shapes,
305 plan,
306 })
307 }
308
309 fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
314 let canonical = space.canonical_ordering();
315 let coord_to_field_idx: IndexMap<Coord, usize> = canonical
316 .into_iter()
317 .enumerate()
318 .map(|(idx, coord)| (coord, idx))
319 .collect();
320
321 let geometry = GridGeometry::from_space(space);
322 let ndim = space.ndim();
323
324 let mut fixed_entries = Vec::new();
325 let mut agent_entries = Vec::new();
326 let mut output_offset = 0usize;
327 let mut mask_offset = 0usize;
328 let mut entry_shapes = Vec::new();
329
330 for (i, entry) in spec.entries.iter().enumerate() {
331 match &entry.region {
332 ObsRegion::Fixed(region_spec) => {
333 if entry.pool.is_some() {
334 return Err(ObsError::InvalidObsSpec {
335 reason: format!("entry {i}: pooling on Fixed regions not supported"),
336 });
337 }
338
339 let mut region_plan = space.compile_region(region_spec).map_err(|e| {
340 ObsError::InvalidObsSpec {
341 reason: format!("entry {i}: region compile failed: {e}"),
342 }
343 })?;
344
345 let ratio = region_plan.valid_ratio();
346 if ratio < COVERAGE_ERROR_THRESHOLD {
347 return Err(ObsError::InvalidComposition {
348 reason: format!(
349 "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
350 ),
351 });
352 }
353
354 let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
355 for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
356 let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
357 ObsError::InvalidObsSpec {
358 reason: format!(
359 "entry {i}: coord {coord:?} not in canonical ordering"
360 ),
361 }
362 })?;
363 let tensor_idx = region_plan.tensor_indices()[coord_idx];
364 gather_ops.push(GatherOp {
365 field_data_idx,
366 tensor_idx,
367 });
368 }
369
370 let element_count = region_plan.bounding_shape().total_elements();
371 let shape = match region_plan.bounding_shape() {
372 murk_space::BoundingShape::Rect(dims) => dims.clone(),
373 };
374 entry_shapes.push(shape);
375
376 fixed_entries.push(CompiledEntry {
377 field_id: entry.field_id,
378 transform: entry.transform.clone(),
379 dtype: entry.dtype,
380 output_offset,
381 mask_offset,
382 element_count,
383 gather_ops,
384 valid_mask: region_plan.take_valid_mask(),
385 valid_ratio: ratio,
386 });
387
388 output_offset += element_count;
389 mask_offset += element_count;
390 }
391
392 ObsRegion::AgentDisk { radius } => {
393 let half_ext: smallvec::SmallVec<[u32; 4]> =
394 (0..ndim).map(|_| *radius).collect();
395 let (ae, shape) = Self::compile_agent_entry(
396 i,
397 entry,
398 &half_ext,
399 *radius,
400 &geometry,
401 Some(*radius),
402 output_offset,
403 mask_offset,
404 )?;
405 entry_shapes.push(shape);
406 output_offset += ae.element_count;
407 mask_offset += ae.element_count;
408 agent_entries.push(ae);
409 }
410
411 ObsRegion::AgentRect { half_extent } => {
412 let radius = *half_extent.iter().max().unwrap_or(&0);
413 let (ae, shape) = Self::compile_agent_entry(
414 i,
415 entry,
416 half_extent,
417 radius,
418 &geometry,
419 None,
420 output_offset,
421 mask_offset,
422 )?;
423 entry_shapes.push(shape);
424 output_offset += ae.element_count;
425 mask_offset += ae.element_count;
426 agent_entries.push(ae);
427 }
428 }
429 }
430
431 let plan = ObsPlan {
432 strategy: PlanStrategy::Standard(StandardPlanData {
433 fixed_entries,
434 agent_entries,
435 geometry,
436 }),
437 output_len: output_offset,
438 mask_len: mask_offset,
439 compiled_generation: None,
440 };
441
442 Ok(ObsPlanResult {
443 output_len: plan.output_len,
444 mask_len: plan.mask_len,
445 entry_shapes,
446 plan,
447 })
448 }
449
450 #[allow(clippy::too_many_arguments)]
455 fn compile_agent_entry(
456 entry_idx: usize,
457 entry: &crate::spec::ObsEntry,
458 half_extent: &[u32],
459 radius: u32,
460 geometry: &Option<GridGeometry>,
461 disk_radius: Option<u32>,
462 output_offset: usize,
463 mask_offset: usize,
464 ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
465 let pre_pool_shape: Vec<usize> =
466 half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
467 let pre_pool_element_count: usize = pre_pool_shape.iter().product();
468
469 let template_ops = generate_template_ops(half_extent, geometry, disk_radius);
470
471 let (element_count, output_shape) = if let Some(pool) = &entry.pool {
472 if pre_pool_shape.len() != 2 {
473 return Err(ObsError::InvalidObsSpec {
474 reason: format!(
475 "entry {entry_idx}: pooling requires 2D region, got {}D",
476 pre_pool_shape.len()
477 ),
478 });
479 }
480 let h = pre_pool_shape[0];
481 let w = pre_pool_shape[1];
482 let ks = pool.kernel_size;
483 let stride = pool.stride;
484 if ks == 0 || stride == 0 {
485 return Err(ObsError::InvalidObsSpec {
486 reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
487 });
488 }
489 let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
490 let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
491 if out_h == 0 || out_w == 0 {
492 return Err(ObsError::InvalidObsSpec {
493 reason: format!(
494 "entry {entry_idx}: pool produces empty output \
495 (region [{h},{w}], kernel_size {ks}, stride {stride})"
496 ),
497 });
498 }
499 (out_h * out_w, vec![out_h, out_w])
500 } else {
501 (pre_pool_element_count, pre_pool_shape.clone())
502 };
503
504 Ok((
505 AgentCompiledEntry {
506 field_id: entry.field_id,
507 pool: entry.pool.clone(),
508 transform: entry.transform.clone(),
509 dtype: entry.dtype,
510 output_offset,
511 mask_offset,
512 element_count,
513 pre_pool_element_count,
514 pre_pool_shape,
515 template_ops,
516 radius,
517 },
518 output_shape,
519 ))
520 }
521
522 pub fn compile_bound(
527 spec: &ObsSpec,
528 space: &dyn Space,
529 generation: WorldGenerationId,
530 ) -> Result<ObsPlanResult, ObsError> {
531 let mut result = Self::compile(spec, space)?;
532 result.plan.compiled_generation = Some(generation);
533 Ok(result)
534 }
535
536 pub fn output_len(&self) -> usize {
538 self.output_len
539 }
540
541 pub fn mask_len(&self) -> usize {
543 self.mask_len
544 }
545
546 pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
548 self.compiled_generation
549 }
550
551 pub fn execute(
570 &self,
571 snapshot: &dyn SnapshotAccess,
572 engine_tick: Option<TickId>,
573 output: &mut [f32],
574 mask: &mut [u8],
575 ) -> Result<ObsMetadata, ObsError> {
576 let entries = match &self.strategy {
577 PlanStrategy::Simple(entries) => entries,
578 PlanStrategy::Standard(_) => {
579 return Err(ObsError::ExecutionFailed {
580 reason: "Standard plan requires execute_agents(), not execute()".into(),
581 });
582 }
583 };
584
585 if output.len() < self.output_len {
586 return Err(ObsError::ExecutionFailed {
587 reason: format!(
588 "output buffer too small: {} < {}",
589 output.len(),
590 self.output_len
591 ),
592 });
593 }
594 if mask.len() < self.mask_len {
595 return Err(ObsError::ExecutionFailed {
596 reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
597 });
598 }
599
600 if let Some(compiled_gen) = self.compiled_generation {
602 let snapshot_gen = snapshot.world_generation_id();
603 if compiled_gen != snapshot_gen {
604 return Err(ObsError::PlanInvalidated {
605 reason: format!(
606 "plan compiled for generation {}, snapshot is generation {}",
607 compiled_gen.0, snapshot_gen.0
608 ),
609 });
610 }
611 }
612
613 let mut total_valid = 0usize;
614 let mut total_elements = 0usize;
615
616 for entry in entries {
617 let field_data =
618 snapshot
619 .read_field(entry.field_id)
620 .ok_or_else(|| ObsError::ExecutionFailed {
621 reason: format!("field {:?} not in snapshot", entry.field_id),
622 })?;
623
624 let out_slice =
625 &mut output[entry.output_offset..entry.output_offset + entry.element_count];
626 let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
627
628 out_slice.fill(0.0);
630 mask_slice.copy_from_slice(&entry.valid_mask);
631
632 for op in &entry.gather_ops {
634 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
635 ObsError::ExecutionFailed {
636 reason: format!(
637 "field {:?} has {} elements but gather requires index {}",
638 entry.field_id,
639 field_data.len(),
640 op.field_data_idx,
641 ),
642 }
643 })?;
644 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
645 }
646
647 total_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
648 total_elements += entry.element_count;
649 }
650
651 let coverage = if total_elements == 0 {
652 0.0
653 } else {
654 total_valid as f64 / total_elements as f64
655 };
656
657 let age_ticks = match engine_tick {
658 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
659 None => 0,
660 };
661
662 Ok(ObsMetadata {
663 tick_id: snapshot.tick_id(),
664 age_ticks,
665 coverage,
666 world_generation_id: snapshot.world_generation_id(),
667 parameter_version: snapshot.parameter_version(),
668 })
669 }
670
671 pub fn execute_batch(
679 &self,
680 snapshots: &[&dyn SnapshotAccess],
681 engine_tick: Option<TickId>,
682 output: &mut [f32],
683 mask: &mut [u8],
684 ) -> Result<Vec<ObsMetadata>, ObsError> {
685 if matches!(self.strategy, PlanStrategy::Standard(_)) {
687 return Err(ObsError::ExecutionFailed {
688 reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
689 });
690 }
691
692 let batch_size = snapshots.len();
693 let expected_out = batch_size * self.output_len;
694 let expected_mask = batch_size * self.mask_len;
695
696 if output.len() < expected_out {
697 return Err(ObsError::ExecutionFailed {
698 reason: format!(
699 "batch output buffer too small: {} < {}",
700 output.len(),
701 expected_out
702 ),
703 });
704 }
705 if mask.len() < expected_mask {
706 return Err(ObsError::ExecutionFailed {
707 reason: format!(
708 "batch mask buffer too small: {} < {}",
709 mask.len(),
710 expected_mask
711 ),
712 });
713 }
714
715 let mut metadata = Vec::with_capacity(batch_size);
716 for (i, snap) in snapshots.iter().enumerate() {
717 let out_start = i * self.output_len;
718 let mask_start = i * self.mask_len;
719 let out_slice = &mut output[out_start..out_start + self.output_len];
720 let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
721 let meta = self.execute(*snap, engine_tick, out_slice, mask_slice)?;
722 metadata.push(meta);
723 }
724 Ok(metadata)
725 }
726
727 pub fn execute_agents(
738 &self,
739 snapshot: &dyn SnapshotAccess,
740 space: &dyn Space,
741 agent_centers: &[Coord],
742 engine_tick: Option<TickId>,
743 output: &mut [f32],
744 mask: &mut [u8],
745 ) -> Result<Vec<ObsMetadata>, ObsError> {
746 let standard = match &self.strategy {
747 PlanStrategy::Standard(data) => data,
748 PlanStrategy::Simple(_) => {
749 return Err(ObsError::ExecutionFailed {
750 reason: "execute_agents requires a Standard plan \
751 (spec must contain agent-relative entries)"
752 .into(),
753 });
754 }
755 };
756
757 let n_agents = agent_centers.len();
758 let expected_out = n_agents * self.output_len;
759 let expected_mask = n_agents * self.mask_len;
760
761 if output.len() < expected_out {
762 return Err(ObsError::ExecutionFailed {
763 reason: format!(
764 "output buffer too small: {} < {}",
765 output.len(),
766 expected_out
767 ),
768 });
769 }
770 if mask.len() < expected_mask {
771 return Err(ObsError::ExecutionFailed {
772 reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
773 });
774 }
775
776 let expected_ndim = space.ndim();
778 for (i, center) in agent_centers.iter().enumerate() {
779 if center.len() != expected_ndim {
780 return Err(ObsError::ExecutionFailed {
781 reason: format!(
782 "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
783 center.len()
784 ),
785 });
786 }
787 }
788
789 if let Some(compiled_gen) = self.compiled_generation {
791 let snapshot_gen = snapshot.world_generation_id();
792 if compiled_gen != snapshot_gen {
793 return Err(ObsError::PlanInvalidated {
794 reason: format!(
795 "plan compiled for generation {}, snapshot is generation {}",
796 compiled_gen.0, snapshot_gen.0
797 ),
798 });
799 }
800 }
801
802 let mut field_data_map: IndexMap<FieldId, &[f32]> = IndexMap::new();
804 for entry in &standard.fixed_entries {
805 if !field_data_map.contains_key(&entry.field_id) {
806 let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
807 ObsError::ExecutionFailed {
808 reason: format!("field {:?} not in snapshot", entry.field_id),
809 }
810 })?;
811 field_data_map.insert(entry.field_id, data);
812 }
813 }
814 for entry in &standard.agent_entries {
815 if !field_data_map.contains_key(&entry.field_id) {
816 let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
817 ObsError::ExecutionFailed {
818 reason: format!("field {:?} not in snapshot", entry.field_id),
819 }
820 })?;
821 field_data_map.insert(entry.field_id, data);
822 }
823 }
824
825 let mut fixed_out_scratch = vec![0.0f32; self.output_len];
829 let mut fixed_mask_scratch = vec![0u8; self.mask_len];
830 let mut fixed_valid = 0usize;
831 let mut fixed_elements = 0usize;
832
833 for entry in &standard.fixed_entries {
834 let field_data = field_data_map[&entry.field_id];
835 let out_slice = &mut fixed_out_scratch
836 [entry.output_offset..entry.output_offset + entry.element_count];
837 let mask_slice =
838 &mut fixed_mask_scratch[entry.mask_offset..entry.mask_offset + entry.element_count];
839
840 mask_slice.copy_from_slice(&entry.valid_mask);
841 for op in &entry.gather_ops {
842 let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
843 ObsError::ExecutionFailed {
844 reason: format!(
845 "field {:?} has {} elements but gather requires index {}",
846 entry.field_id,
847 field_data.len(),
848 op.field_data_idx,
849 ),
850 }
851 })?;
852 out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
853 }
854
855 fixed_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
856 fixed_elements += entry.element_count;
857 }
858
859 let max_pool_scratch = standard
863 .agent_entries
864 .iter()
865 .filter(|e| e.pool.is_some())
866 .map(|e| e.pre_pool_element_count)
867 .max()
868 .unwrap_or(0);
869 let mut pool_scratch = vec![0.0f32; max_pool_scratch];
870 let mut pool_scratch_mask = vec![0u8; max_pool_scratch];
871
872 let mut metadata = Vec::with_capacity(n_agents);
873
874 for (agent_i, center) in agent_centers.iter().enumerate() {
875 let out_start = agent_i * self.output_len;
876 let mask_start = agent_i * self.mask_len;
877 let agent_output = &mut output[out_start..out_start + self.output_len];
878 let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
879
880 agent_output[..self.output_len].copy_from_slice(&fixed_out_scratch);
882 agent_mask[..self.mask_len].copy_from_slice(&fixed_mask_scratch);
883
884 let mut total_valid = fixed_valid;
885 let mut total_elements = fixed_elements;
886
887 for entry in &standard.agent_entries {
889 let field_data = field_data_map[&entry.field_id];
890
891 let use_fast_path = standard
895 .geometry
896 .as_ref()
897 .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
898 .unwrap_or(false);
899
900 if entry.pool.is_some() {
902 pool_scratch[..entry.pre_pool_element_count].fill(0.0);
903 pool_scratch_mask[..entry.pre_pool_element_count].fill(0);
904 }
905
906 let valid = execute_agent_entry(
907 entry,
908 center,
909 field_data,
910 &standard.geometry,
911 space,
912 use_fast_path,
913 agent_output,
914 agent_mask,
915 &mut pool_scratch,
916 &mut pool_scratch_mask,
917 );
918
919 total_valid += valid;
920 total_elements += entry.element_count;
921 }
922
923 let coverage = if total_elements == 0 {
924 0.0
925 } else {
926 total_valid as f64 / total_elements as f64
927 };
928
929 let age_ticks = match engine_tick {
930 Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
931 None => 0,
932 };
933
934 metadata.push(ObsMetadata {
935 tick_id: snapshot.tick_id(),
936 age_ticks,
937 coverage,
938 world_generation_id: snapshot.world_generation_id(),
939 parameter_version: snapshot.parameter_version(),
940 });
941 }
942
943 Ok(metadata)
944 }
945
946 pub fn is_standard(&self) -> bool {
948 matches!(self.strategy, PlanStrategy::Standard(_))
949 }
950}
951
952#[allow(clippy::too_many_arguments)]
960fn execute_agent_entry(
961 entry: &AgentCompiledEntry,
962 center: &Coord,
963 field_data: &[f32],
964 geometry: &Option<GridGeometry>,
965 space: &dyn Space,
966 use_fast_path: bool,
967 agent_output: &mut [f32],
968 agent_mask: &mut [u8],
969 pool_scratch: &mut [f32],
970 pool_scratch_mask: &mut [u8],
971) -> usize {
972 if entry.pool.is_some() {
973 execute_agent_entry_pooled(
974 entry,
975 center,
976 field_data,
977 geometry,
978 space,
979 use_fast_path,
980 agent_output,
981 agent_mask,
982 &mut pool_scratch[..entry.pre_pool_element_count],
983 &mut pool_scratch_mask[..entry.pre_pool_element_count],
984 )
985 } else {
986 execute_agent_entry_direct(
987 entry,
988 center,
989 field_data,
990 geometry,
991 space,
992 use_fast_path,
993 agent_output,
994 agent_mask,
995 )
996 }
997}
998
999#[allow(clippy::too_many_arguments)]
1001fn execute_agent_entry_direct(
1002 entry: &AgentCompiledEntry,
1003 center: &Coord,
1004 field_data: &[f32],
1005 geometry: &Option<GridGeometry>,
1006 space: &dyn Space,
1007 use_fast_path: bool,
1008 agent_output: &mut [f32],
1009 agent_mask: &mut [u8],
1010) -> usize {
1011 let out_slice =
1012 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1013 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1014
1015 if use_fast_path {
1016 let geo = geometry.as_ref().unwrap();
1018 let base_rank = geo.canonical_rank(center) as isize;
1019 let mut valid = 0;
1020 for op in &entry.template_ops {
1021 if !op.in_disk {
1022 continue;
1023 }
1024 let field_idx = (base_rank + op.stride_offset) as usize;
1025 if let Some(&val) = field_data.get(field_idx) {
1026 out_slice[op.tensor_idx] = apply_transform(val, &entry.transform);
1027 mask_slice[op.tensor_idx] = 1;
1028 valid += 1;
1029 }
1030 }
1031 valid
1032 } else {
1033 let mut valid = 0;
1035 for op in &entry.template_ops {
1036 if !op.in_disk {
1037 continue;
1038 }
1039 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1040 if let Some(idx) = field_idx {
1041 if idx < field_data.len() {
1042 out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
1043 mask_slice[op.tensor_idx] = 1;
1044 valid += 1;
1045 }
1046 }
1047 }
1048 valid
1049 }
1050}
1051
1052#[allow(clippy::too_many_arguments)]
1059fn execute_agent_entry_pooled(
1060 entry: &AgentCompiledEntry,
1061 center: &Coord,
1062 field_data: &[f32],
1063 geometry: &Option<GridGeometry>,
1064 space: &dyn Space,
1065 use_fast_path: bool,
1066 agent_output: &mut [f32],
1067 agent_mask: &mut [u8],
1068 scratch: &mut [f32],
1069 scratch_mask: &mut [u8],
1070) -> usize {
1071 if use_fast_path {
1072 let geo = geometry.as_ref().unwrap();
1073 let base_rank = geo.canonical_rank(center) as isize;
1074 for op in &entry.template_ops {
1075 if !op.in_disk {
1076 continue;
1077 }
1078 let field_idx = (base_rank + op.stride_offset) as usize;
1079 if let Some(&val) = field_data.get(field_idx) {
1080 scratch[op.tensor_idx] = val;
1081 scratch_mask[op.tensor_idx] = 1;
1082 }
1083 }
1084 } else {
1085 for op in &entry.template_ops {
1086 if !op.in_disk {
1087 continue;
1088 }
1089 let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1090 if let Some(idx) = field_idx {
1091 if idx < field_data.len() {
1092 scratch[op.tensor_idx] = field_data[idx];
1093 scratch_mask[op.tensor_idx] = 1;
1094 }
1095 }
1096 }
1097 }
1098
1099 let pool_config = entry.pool.as_ref().unwrap();
1100 let (pooled, pooled_mask, _) =
1101 pool_2d(scratch, scratch_mask, &entry.pre_pool_shape, pool_config);
1102
1103 let out_slice =
1104 &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1105 let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1106
1107 let n = pooled.len().min(entry.element_count);
1108 for i in 0..n {
1109 out_slice[i] = apply_transform(pooled[i], &entry.transform);
1110 }
1111 mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1112
1113 pooled_mask[..n].iter().filter(|&&v| v == 1).count()
1114}
1115
1116fn generate_template_ops(
1128 half_extent: &[u32],
1129 geometry: &Option<GridGeometry>,
1130 disk_radius: Option<u32>,
1131) -> Vec<TemplateOp> {
1132 let ndim = half_extent.len();
1133 let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1134 let total: usize = shape.iter().product();
1135
1136 let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1137
1138 let mut ops = Vec::with_capacity(total);
1139
1140 for tensor_idx in 0..total {
1141 let mut relative = Coord::new();
1143 let mut remaining = tensor_idx;
1144 for d in (0..ndim).rev() {
1146 let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1147 relative.push(coord);
1148 remaining /= shape[d];
1149 }
1150 relative.reverse();
1151
1152 let stride_offset = strides
1153 .map(|s| {
1154 relative
1155 .iter()
1156 .zip(s.iter())
1157 .map(|(&r, &s)| r as isize * s as isize)
1158 .sum::<isize>()
1159 })
1160 .unwrap_or(0);
1161
1162 let in_disk = match disk_radius {
1163 Some(r) => match geometry {
1164 Some(geo) => geo.graph_distance(&relative) <= r,
1165 None => true, },
1167 None => true, };
1169
1170 ops.push(TemplateOp {
1171 relative,
1172 tensor_idx,
1173 stride_offset,
1174 in_disk,
1175 });
1176 }
1177
1178 ops
1179}
1180
1181fn resolve_field_index(
1188 center: &Coord,
1189 relative: &Coord,
1190 geometry: &Option<GridGeometry>,
1191 space: &dyn Space,
1192) -> Option<usize> {
1193 if let Some(geo) = geometry {
1194 if geo.all_wrap {
1195 let wrapped: Coord = center
1197 .iter()
1198 .zip(relative.iter())
1199 .zip(geo.coord_dims.iter())
1200 .map(|((&c, &r), &d)| {
1201 let d = d as i32;
1202 ((c + r) % d + d) % d
1203 })
1204 .collect();
1205 Some(geo.canonical_rank(&wrapped))
1206 } else {
1207 let abs_coord: Coord = center
1208 .iter()
1209 .zip(relative.iter())
1210 .map(|(&c, &r)| c + r)
1211 .collect();
1212 let abs_slice: &[i32] = &abs_coord;
1213 if geo.in_bounds(abs_slice) {
1214 Some(geo.canonical_rank(abs_slice))
1215 } else {
1216 None
1217 }
1218 }
1219 } else {
1220 let abs_coord: Coord = center
1221 .iter()
1222 .zip(relative.iter())
1223 .map(|(&c, &r)| c + r)
1224 .collect();
1225 space.canonical_rank(&abs_coord)
1226 }
1227}
1228
1229fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1231 match transform {
1232 ObsTransform::Identity => raw,
1233 ObsTransform::Normalize { min, max } => {
1234 let range = max - min;
1235 if range == 0.0 {
1236 0.0
1237 } else {
1238 let normalized = (raw as f64 - min) / range;
1239 normalized.clamp(0.0, 1.0) as f32
1240 }
1241 }
1242 }
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247 use super::*;
1248 use crate::spec::{
1249 ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1250 };
1251 use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1252 use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1253 use murk_test_utils::MockSnapshot;
1254
1255 fn square4_space() -> Square4 {
1256 Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1257 }
1258
1259 fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1260 let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1261 snap.set_field(field, data);
1262 snap
1263 }
1264
1265 #[test]
1268 fn compile_empty_spec_errors() {
1269 let space = square4_space();
1270 let spec = ObsSpec { entries: vec![] };
1271 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1272 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1273 }
1274
1275 #[test]
1276 fn compile_all_region_square4() {
1277 let space = square4_space();
1278 let spec = ObsSpec {
1279 entries: vec![ObsEntry {
1280 field_id: FieldId(0),
1281 region: ObsRegion::Fixed(RegionSpec::All),
1282 pool: None,
1283 transform: ObsTransform::Identity,
1284 dtype: ObsDtype::F32,
1285 }],
1286 };
1287 let result = ObsPlan::compile(&spec, &space).unwrap();
1288 assert_eq!(result.output_len, 9); assert_eq!(result.mask_len, 9);
1290 assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1291 }
1292
1293 #[test]
1294 fn compile_rect_region() {
1295 let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1296 let spec = ObsSpec {
1297 entries: vec![ObsEntry {
1298 field_id: FieldId(0),
1299 region: ObsRegion::Fixed(RegionSpec::Rect {
1300 min: smallvec::smallvec![1, 1],
1301 max: smallvec::smallvec![2, 3],
1302 }),
1303 pool: None,
1304 transform: ObsTransform::Identity,
1305 dtype: ObsDtype::F32,
1306 }],
1307 };
1308 let result = ObsPlan::compile(&spec, &space).unwrap();
1309 assert_eq!(result.output_len, 6);
1311 assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1312 }
1313
1314 #[test]
1315 fn compile_two_entries_offsets() {
1316 let space = square4_space();
1317 let spec = ObsSpec {
1318 entries: vec![
1319 ObsEntry {
1320 field_id: FieldId(0),
1321 region: ObsRegion::Fixed(RegionSpec::All),
1322 pool: None,
1323 transform: ObsTransform::Identity,
1324 dtype: ObsDtype::F32,
1325 },
1326 ObsEntry {
1327 field_id: FieldId(1),
1328 region: ObsRegion::Fixed(RegionSpec::All),
1329 pool: None,
1330 transform: ObsTransform::Identity,
1331 dtype: ObsDtype::F32,
1332 },
1333 ],
1334 };
1335 let result = ObsPlan::compile(&spec, &space).unwrap();
1336 assert_eq!(result.output_len, 18); assert_eq!(result.mask_len, 18);
1338 }
1339
1340 #[test]
1341 fn compile_invalid_region_errors() {
1342 let space = square4_space();
1343 let spec = ObsSpec {
1344 entries: vec![ObsEntry {
1345 field_id: FieldId(0),
1346 region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1347 pool: None,
1348 transform: ObsTransform::Identity,
1349 dtype: ObsDtype::F32,
1350 }],
1351 };
1352 let err = ObsPlan::compile(&spec, &space).unwrap_err();
1353 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1354 }
1355
1356 #[test]
1359 fn execute_identity_all_region() {
1360 let space = square4_space();
1361 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1364 let snap = snapshot_with_field(FieldId(0), data);
1365
1366 let spec = ObsSpec {
1367 entries: vec![ObsEntry {
1368 field_id: FieldId(0),
1369 region: ObsRegion::Fixed(RegionSpec::All),
1370 pool: None,
1371 transform: ObsTransform::Identity,
1372 dtype: ObsDtype::F32,
1373 }],
1374 };
1375 let result = ObsPlan::compile(&spec, &space).unwrap();
1376
1377 let mut output = vec![0.0f32; result.output_len];
1378 let mut mask = vec![0u8; result.mask_len];
1379 let meta = result
1380 .plan
1381 .execute(&snap, None, &mut output, &mut mask)
1382 .unwrap();
1383
1384 let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1386 assert_eq!(output, expected);
1387 assert_eq!(mask, vec![1u8; 9]);
1388 assert_eq!(meta.tick_id, TickId(5));
1389 assert_eq!(meta.coverage, 1.0);
1390 assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1391 assert_eq!(meta.parameter_version, ParameterVersion(0));
1392 assert_eq!(meta.age_ticks, 0);
1393 }
1394
1395 #[test]
1396 fn execute_normalize_transform() {
1397 let space = square4_space();
1398 let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1400 let snap = snapshot_with_field(FieldId(0), data);
1401
1402 let spec = ObsSpec {
1403 entries: vec![ObsEntry {
1404 field_id: FieldId(0),
1405 region: ObsRegion::Fixed(RegionSpec::All),
1406 pool: None,
1407 transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1408 dtype: ObsDtype::F32,
1409 }],
1410 };
1411 let result = ObsPlan::compile(&spec, &space).unwrap();
1412
1413 let mut output = vec![0.0f32; result.output_len];
1414 let mut mask = vec![0u8; result.mask_len];
1415 result
1416 .plan
1417 .execute(&snap, None, &mut output, &mut mask)
1418 .unwrap();
1419
1420 for (i, &v) in output.iter().enumerate() {
1422 let expected = i as f32 / 8.0;
1423 assert!(
1424 (v - expected).abs() < 1e-6,
1425 "output[{i}] = {v}, expected {expected}"
1426 );
1427 }
1428 }
1429
1430 #[test]
1431 fn execute_normalize_clamps_out_of_range() {
1432 let space = square4_space();
1433 let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1435 let snap = snapshot_with_field(FieldId(0), data);
1436
1437 let spec = ObsSpec {
1438 entries: vec![ObsEntry {
1439 field_id: FieldId(0),
1440 region: ObsRegion::Fixed(RegionSpec::All),
1441 pool: None,
1442 transform: ObsTransform::Normalize {
1443 min: 0.0,
1444 max: 10.0,
1445 },
1446 dtype: ObsDtype::F32,
1447 }],
1448 };
1449 let result = ObsPlan::compile(&spec, &space).unwrap();
1450
1451 let mut output = vec![0.0f32; result.output_len];
1452 let mut mask = vec![0u8; result.mask_len];
1453 result
1454 .plan
1455 .execute(&snap, None, &mut output, &mut mask)
1456 .unwrap();
1457
1458 for &v in &output {
1459 assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1460 }
1461 }
1462
1463 #[test]
1464 fn execute_normalize_zero_range() {
1465 let space = square4_space();
1466 let data = vec![5.0f32; 9];
1467 let snap = snapshot_with_field(FieldId(0), data);
1468
1469 let spec = ObsSpec {
1470 entries: vec![ObsEntry {
1471 field_id: FieldId(0),
1472 region: ObsRegion::Fixed(RegionSpec::All),
1473 pool: None,
1474 transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1475 dtype: ObsDtype::F32,
1476 }],
1477 };
1478 let result = ObsPlan::compile(&spec, &space).unwrap();
1479
1480 let mut output = vec![-1.0f32; result.output_len];
1481 let mut mask = vec![0u8; result.mask_len];
1482 result
1483 .plan
1484 .execute(&snap, None, &mut output, &mut mask)
1485 .unwrap();
1486
1487 assert!(output.iter().all(|&v| v == 0.0));
1489 }
1490
1491 #[test]
1492 fn execute_rect_subregion_correct_values() {
1493 let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1494 let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1496 let snap = snapshot_with_field(FieldId(0), data);
1497
1498 let spec = ObsSpec {
1499 entries: vec![ObsEntry {
1500 field_id: FieldId(0),
1501 region: ObsRegion::Fixed(RegionSpec::Rect {
1502 min: smallvec::smallvec![1, 1],
1503 max: smallvec::smallvec![2, 2],
1504 }),
1505 pool: None,
1506 transform: ObsTransform::Identity,
1507 dtype: ObsDtype::F32,
1508 }],
1509 };
1510 let result = ObsPlan::compile(&spec, &space).unwrap();
1511 assert_eq!(result.output_len, 4); let mut output = vec![0.0f32; result.output_len];
1514 let mut mask = vec![0u8; result.mask_len];
1515 result
1516 .plan
1517 .execute(&snap, None, &mut output, &mut mask)
1518 .unwrap();
1519
1520 assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1522 assert_eq!(mask, vec![1, 1, 1, 1]);
1523 }
1524
1525 #[test]
1526 fn execute_two_fields() {
1527 let space = square4_space();
1528 let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1529 let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1530 let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1531 snap.set_field(FieldId(0), data_a);
1532 snap.set_field(FieldId(1), data_b);
1533
1534 let spec = ObsSpec {
1535 entries: vec![
1536 ObsEntry {
1537 field_id: FieldId(0),
1538 region: ObsRegion::Fixed(RegionSpec::All),
1539 pool: None,
1540 transform: ObsTransform::Identity,
1541 dtype: ObsDtype::F32,
1542 },
1543 ObsEntry {
1544 field_id: FieldId(1),
1545 region: ObsRegion::Fixed(RegionSpec::All),
1546 pool: None,
1547 transform: ObsTransform::Identity,
1548 dtype: ObsDtype::F32,
1549 },
1550 ],
1551 };
1552 let result = ObsPlan::compile(&spec, &space).unwrap();
1553 assert_eq!(result.output_len, 18);
1554
1555 let mut output = vec![0.0f32; result.output_len];
1556 let mut mask = vec![0u8; result.mask_len];
1557 result
1558 .plan
1559 .execute(&snap, None, &mut output, &mut mask)
1560 .unwrap();
1561
1562 let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1564 let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1565 assert_eq!(&output[..9], &expected_a);
1566 assert_eq!(&output[9..], &expected_b);
1567 }
1568
1569 #[test]
1570 fn execute_missing_field_errors() {
1571 let space = square4_space();
1572 let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1573
1574 let spec = ObsSpec {
1575 entries: vec![ObsEntry {
1576 field_id: FieldId(0),
1577 region: ObsRegion::Fixed(RegionSpec::All),
1578 pool: None,
1579 transform: ObsTransform::Identity,
1580 dtype: ObsDtype::F32,
1581 }],
1582 };
1583 let result = ObsPlan::compile(&spec, &space).unwrap();
1584
1585 let mut output = vec![0.0f32; result.output_len];
1586 let mut mask = vec![0u8; result.mask_len];
1587 let err = result
1588 .plan
1589 .execute(&snap, None, &mut output, &mut mask)
1590 .unwrap_err();
1591 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1592 }
1593
1594 #[test]
1595 fn execute_buffer_too_small_errors() {
1596 let space = square4_space();
1597 let data: Vec<f32> = vec![0.0; 9];
1598 let snap = snapshot_with_field(FieldId(0), data);
1599
1600 let spec = ObsSpec {
1601 entries: vec![ObsEntry {
1602 field_id: FieldId(0),
1603 region: ObsRegion::Fixed(RegionSpec::All),
1604 pool: None,
1605 transform: ObsTransform::Identity,
1606 dtype: ObsDtype::F32,
1607 }],
1608 };
1609 let result = ObsPlan::compile(&spec, &space).unwrap();
1610
1611 let mut output = vec![0.0f32; 4]; let mut mask = vec![0u8; result.mask_len];
1613 let err = result
1614 .plan
1615 .execute(&snap, None, &mut output, &mut mask)
1616 .unwrap_err();
1617 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1618 }
1619
1620 #[test]
1623 fn valid_ratio_one_for_square_all() {
1624 let space = square4_space();
1625 let data: Vec<f32> = vec![1.0; 9];
1626 let snap = snapshot_with_field(FieldId(0), data);
1627
1628 let spec = ObsSpec {
1629 entries: vec![ObsEntry {
1630 field_id: FieldId(0),
1631 region: ObsRegion::Fixed(RegionSpec::All),
1632 pool: None,
1633 transform: ObsTransform::Identity,
1634 dtype: ObsDtype::F32,
1635 }],
1636 };
1637 let result = ObsPlan::compile(&spec, &space).unwrap();
1638
1639 let mut output = vec![0.0f32; result.output_len];
1640 let mut mask = vec![0u8; result.mask_len];
1641 let meta = result
1642 .plan
1643 .execute(&snap, None, &mut output, &mut mask)
1644 .unwrap();
1645
1646 assert_eq!(meta.coverage, 1.0);
1647 }
1648
1649 #[test]
1652 fn plan_invalidated_on_generation_mismatch() {
1653 let space = square4_space();
1654 let data: Vec<f32> = vec![1.0; 9];
1655 let snap = snapshot_with_field(FieldId(0), data);
1656
1657 let spec = ObsSpec {
1658 entries: vec![ObsEntry {
1659 field_id: FieldId(0),
1660 region: ObsRegion::Fixed(RegionSpec::All),
1661 pool: None,
1662 transform: ObsTransform::Identity,
1663 dtype: ObsDtype::F32,
1664 }],
1665 };
1666 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1668
1669 let mut output = vec![0.0f32; result.output_len];
1670 let mut mask = vec![0u8; result.mask_len];
1671 let err = result
1672 .plan
1673 .execute(&snap, None, &mut output, &mut mask)
1674 .unwrap_err();
1675 assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1676 }
1677
1678 #[test]
1679 fn generation_match_succeeds() {
1680 let space = square4_space();
1681 let data: Vec<f32> = vec![1.0; 9];
1682 let snap = snapshot_with_field(FieldId(0), data);
1683
1684 let spec = ObsSpec {
1685 entries: vec![ObsEntry {
1686 field_id: FieldId(0),
1687 region: ObsRegion::Fixed(RegionSpec::All),
1688 pool: None,
1689 transform: ObsTransform::Identity,
1690 dtype: ObsDtype::F32,
1691 }],
1692 };
1693 let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1694
1695 let mut output = vec![0.0f32; result.output_len];
1696 let mut mask = vec![0u8; result.mask_len];
1697 result
1698 .plan
1699 .execute(&snap, None, &mut output, &mut mask)
1700 .unwrap();
1701 }
1702
1703 #[test]
1704 fn unbound_plan_ignores_generation() {
1705 let space = square4_space();
1706 let data: Vec<f32> = vec![1.0; 9];
1707 let snap = snapshot_with_field(FieldId(0), data);
1708
1709 let spec = ObsSpec {
1710 entries: vec![ObsEntry {
1711 field_id: FieldId(0),
1712 region: ObsRegion::Fixed(RegionSpec::All),
1713 pool: None,
1714 transform: ObsTransform::Identity,
1715 dtype: ObsDtype::F32,
1716 }],
1717 };
1718 let result = ObsPlan::compile(&spec, &space).unwrap();
1720
1721 let mut output = vec![0.0f32; result.output_len];
1722 let mut mask = vec![0u8; result.mask_len];
1723 result
1724 .plan
1725 .execute(&snap, None, &mut output, &mut mask)
1726 .unwrap();
1727 }
1728
1729 #[test]
1732 fn metadata_fields_populated() {
1733 let space = square4_space();
1734 let data: Vec<f32> = vec![1.0; 9];
1735 let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1736 snap.set_field(FieldId(0), data);
1737
1738 let spec = ObsSpec {
1739 entries: vec![ObsEntry {
1740 field_id: FieldId(0),
1741 region: ObsRegion::Fixed(RegionSpec::All),
1742 pool: None,
1743 transform: ObsTransform::Identity,
1744 dtype: ObsDtype::F32,
1745 }],
1746 };
1747 let result = ObsPlan::compile(&spec, &space).unwrap();
1748
1749 let mut output = vec![0.0f32; result.output_len];
1750 let mut mask = vec![0u8; result.mask_len];
1751 let meta = result
1752 .plan
1753 .execute(&snap, None, &mut output, &mut mask)
1754 .unwrap();
1755
1756 assert_eq!(meta.tick_id, TickId(42));
1757 assert_eq!(meta.age_ticks, 0);
1758 assert_eq!(meta.coverage, 1.0);
1759 assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1760 assert_eq!(meta.parameter_version, ParameterVersion(3));
1761 }
1762
1763 #[test]
1766 fn execute_batch_n1_matches_execute() {
1767 let space = square4_space();
1768 let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1769 let snap = snapshot_with_field(FieldId(0), data.clone());
1770
1771 let spec = ObsSpec {
1772 entries: vec![ObsEntry {
1773 field_id: FieldId(0),
1774 region: ObsRegion::Fixed(RegionSpec::All),
1775 pool: None,
1776 transform: ObsTransform::Identity,
1777 dtype: ObsDtype::F32,
1778 }],
1779 };
1780 let result = ObsPlan::compile(&spec, &space).unwrap();
1781
1782 let mut out_single = vec![0.0f32; result.output_len];
1784 let mut mask_single = vec![0u8; result.mask_len];
1785 let meta_single = result
1786 .plan
1787 .execute(&snap, None, &mut out_single, &mut mask_single)
1788 .unwrap();
1789
1790 let mut out_batch = vec![0.0f32; result.output_len];
1792 let mut mask_batch = vec![0u8; result.mask_len];
1793 let snap_ref: &dyn SnapshotAccess = &snap;
1794 let meta_batch = result
1795 .plan
1796 .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1797 .unwrap();
1798
1799 assert_eq!(out_single, out_batch);
1800 assert_eq!(mask_single, mask_batch);
1801 assert_eq!(meta_single, meta_batch[0]);
1802 }
1803
1804 #[test]
1805 fn execute_batch_multiple_snapshots() {
1806 let space = square4_space();
1807 let spec = ObsSpec {
1808 entries: vec![ObsEntry {
1809 field_id: FieldId(0),
1810 region: ObsRegion::Fixed(RegionSpec::All),
1811 pool: None,
1812 transform: ObsTransform::Identity,
1813 dtype: ObsDtype::F32,
1814 }],
1815 };
1816 let result = ObsPlan::compile(&spec, &space).unwrap();
1817
1818 let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1819 let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1820
1821 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1822 let mut output = vec![0.0f32; result.output_len * 2];
1823 let mut mask = vec![0u8; result.mask_len * 2];
1824 let metas = result
1825 .plan
1826 .execute_batch(&snaps, None, &mut output, &mut mask)
1827 .unwrap();
1828
1829 assert_eq!(metas.len(), 2);
1830 assert!(output[..9].iter().all(|&v| v == 1.0));
1831 assert!(output[9..].iter().all(|&v| v == 2.0));
1832 }
1833
1834 #[test]
1835 fn execute_batch_buffer_too_small() {
1836 let space = square4_space();
1837 let spec = ObsSpec {
1838 entries: vec![ObsEntry {
1839 field_id: FieldId(0),
1840 region: ObsRegion::Fixed(RegionSpec::All),
1841 pool: None,
1842 transform: ObsTransform::Identity,
1843 dtype: ObsDtype::F32,
1844 }],
1845 };
1846 let result = ObsPlan::compile(&spec, &space).unwrap();
1847
1848 let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1849 let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1850 let mut output = vec![0.0f32; 9]; let mut mask = vec![0u8; 18];
1852 let err = result
1853 .plan
1854 .execute_batch(&snaps, None, &mut output, &mut mask)
1855 .unwrap_err();
1856 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1857 }
1858
1859 #[test]
1862 fn short_field_buffer_returns_error_not_panic() {
1863 let space = square4_space(); let spec = ObsSpec {
1865 entries: vec![ObsEntry {
1866 field_id: FieldId(0),
1867 region: ObsRegion::Fixed(RegionSpec::All),
1868 pool: None,
1869 transform: ObsTransform::Identity,
1870 dtype: ObsDtype::F32,
1871 }],
1872 };
1873 let result = ObsPlan::compile(&spec, &space).unwrap();
1874
1875 let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
1877 let mut output = vec![0.0f32; result.output_len];
1878 let mut mask = vec![0u8; result.mask_len];
1879 let err = result
1880 .plan
1881 .execute(&snap, None, &mut output, &mut mask)
1882 .unwrap_err();
1883 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1884 }
1885
1886 #[test]
1889 fn standard_plan_detected_from_agent_region() {
1890 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1891 let spec = ObsSpec {
1892 entries: vec![ObsEntry {
1893 field_id: FieldId(0),
1894 region: ObsRegion::AgentRect {
1895 half_extent: smallvec::smallvec![2, 2],
1896 },
1897 pool: None,
1898 transform: ObsTransform::Identity,
1899 dtype: ObsDtype::F32,
1900 }],
1901 };
1902 let result = ObsPlan::compile(&spec, &space).unwrap();
1903 assert!(result.plan.is_standard());
1904 assert_eq!(result.output_len, 25);
1906 assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
1907 }
1908
1909 #[test]
1910 fn execute_on_standard_plan_errors() {
1911 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1912 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
1913 let snap = snapshot_with_field(FieldId(0), data);
1914
1915 let spec = ObsSpec {
1916 entries: vec![ObsEntry {
1917 field_id: FieldId(0),
1918 region: ObsRegion::AgentDisk { radius: 2 },
1919 pool: None,
1920 transform: ObsTransform::Identity,
1921 dtype: ObsDtype::F32,
1922 }],
1923 };
1924 let result = ObsPlan::compile(&spec, &space).unwrap();
1925
1926 let mut output = vec![0.0f32; result.output_len];
1927 let mut mask = vec![0u8; result.mask_len];
1928 let err = result
1929 .plan
1930 .execute(&snap, None, &mut output, &mut mask)
1931 .unwrap_err();
1932 assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1933 }
1934
1935 #[test]
1936 fn interior_boundary_equivalence() {
1937 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
1940 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
1941 let snap = snapshot_with_field(FieldId(0), data);
1942
1943 let radius = 3u32;
1944 let center: Coord = smallvec::smallvec![10, 10]; let standard_spec = ObsSpec {
1948 entries: vec![ObsEntry {
1949 field_id: FieldId(0),
1950 region: ObsRegion::AgentRect {
1951 half_extent: smallvec::smallvec![radius, radius],
1952 },
1953 pool: None,
1954 transform: ObsTransform::Identity,
1955 dtype: ObsDtype::F32,
1956 }],
1957 };
1958 let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
1959 let mut std_output = vec![0.0f32; std_result.output_len];
1960 let mut std_mask = vec![0u8; std_result.mask_len];
1961 std_result
1962 .plan
1963 .execute_agents(
1964 &snap,
1965 &space,
1966 std::slice::from_ref(¢er),
1967 None,
1968 &mut std_output,
1969 &mut std_mask,
1970 )
1971 .unwrap();
1972
1973 let r = radius as i32;
1975 let simple_spec = ObsSpec {
1976 entries: vec![ObsEntry {
1977 field_id: FieldId(0),
1978 region: ObsRegion::Fixed(RegionSpec::Rect {
1979 min: smallvec::smallvec![10 - r, 10 - r],
1980 max: smallvec::smallvec![10 + r, 10 + r],
1981 }),
1982 pool: None,
1983 transform: ObsTransform::Identity,
1984 dtype: ObsDtype::F32,
1985 }],
1986 };
1987 let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
1988 let mut simple_output = vec![0.0f32; simple_result.output_len];
1989 let mut simple_mask = vec![0u8; simple_result.mask_len];
1990 simple_result
1991 .plan
1992 .execute(&snap, None, &mut simple_output, &mut simple_mask)
1993 .unwrap();
1994
1995 assert_eq!(std_result.output_len, simple_result.output_len);
1997 assert_eq!(std_output, simple_output);
1998 assert_eq!(std_mask, simple_mask);
1999 }
2000
2001 #[test]
2002 fn boundary_agent_gets_padding() {
2003 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2005 let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
2006 let snap = snapshot_with_field(FieldId(0), data);
2007
2008 let spec = ObsSpec {
2009 entries: vec![ObsEntry {
2010 field_id: FieldId(0),
2011 region: ObsRegion::AgentRect {
2012 half_extent: smallvec::smallvec![2, 2],
2013 },
2014 pool: None,
2015 transform: ObsTransform::Identity,
2016 dtype: ObsDtype::F32,
2017 }],
2018 };
2019 let result = ObsPlan::compile(&spec, &space).unwrap();
2020 let center: Coord = smallvec::smallvec![0, 0];
2021 let mut output = vec![0.0f32; result.output_len];
2022 let mut mask = vec![0u8; result.mask_len];
2023 let metas = result
2024 .plan
2025 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2026 .unwrap();
2027
2028 let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
2031 assert_eq!(valid_count, 9);
2032
2033 assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
2035
2036 assert_eq!(mask[0], 0); assert_eq!(output[0], 0.0);
2040
2041 assert_eq!(mask[12], 1);
2044 assert_eq!(output[12], 1.0);
2045 }
2046
2047 #[test]
2048 fn hex_foveation_interior() {
2049 let space = Hex2D::new(20, 20).unwrap(); let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2052 let snap = snapshot_with_field(FieldId(0), data);
2053
2054 let spec = ObsSpec {
2055 entries: vec![ObsEntry {
2056 field_id: FieldId(0),
2057 region: ObsRegion::AgentDisk { radius: 2 },
2058 pool: None,
2059 transform: ObsTransform::Identity,
2060 dtype: ObsDtype::F32,
2061 }],
2062 };
2063 let result = ObsPlan::compile(&spec, &space).unwrap();
2064 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2068 let mut output = vec![0.0f32; result.output_len];
2069 let mut mask = vec![0u8; result.mask_len];
2070 result
2071 .plan
2072 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2073 .unwrap();
2074
2075 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2079 assert_eq!(valid_count, 19);
2080
2081 for &idx in &[0, 1, 5, 19, 23, 24] {
2089 assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2090 assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2091 }
2092
2093 assert_eq!(output[12], 210.0);
2096
2097 assert_eq!(output[17], 211.0);
2101 }
2102
2103 #[test]
2104 fn wrap_space_all_interior() {
2105 let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2107 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2108 let snap = snapshot_with_field(FieldId(0), data);
2109
2110 let spec = ObsSpec {
2111 entries: vec![ObsEntry {
2112 field_id: FieldId(0),
2113 region: ObsRegion::AgentRect {
2114 half_extent: smallvec::smallvec![2, 2],
2115 },
2116 pool: None,
2117 transform: ObsTransform::Identity,
2118 dtype: ObsDtype::F32,
2119 }],
2120 };
2121 let result = ObsPlan::compile(&spec, &space).unwrap();
2122
2123 let center: Coord = smallvec::smallvec![0, 0];
2125 let mut output = vec![0.0f32; result.output_len];
2126 let mut mask = vec![0u8; result.mask_len];
2127 result
2128 .plan
2129 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2130 .unwrap();
2131
2132 assert!(mask.iter().all(|&v| v == 1));
2134 assert_eq!(output[12], 0.0); }
2136
2137 #[test]
2138 fn execute_agents_multiple() {
2139 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2140 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2141 let snap = snapshot_with_field(FieldId(0), data);
2142
2143 let spec = ObsSpec {
2144 entries: vec![ObsEntry {
2145 field_id: FieldId(0),
2146 region: ObsRegion::AgentRect {
2147 half_extent: smallvec::smallvec![1, 1],
2148 },
2149 pool: None,
2150 transform: ObsTransform::Identity,
2151 dtype: ObsDtype::F32,
2152 }],
2153 };
2154 let result = ObsPlan::compile(&spec, &space).unwrap();
2155 assert_eq!(result.output_len, 9); let centers = vec![
2159 smallvec::smallvec![5, 5], smallvec::smallvec![0, 5], ];
2162 let n = centers.len();
2163 let mut output = vec![0.0f32; result.output_len * n];
2164 let mut mask = vec![0u8; result.mask_len * n];
2165 let metas = result
2166 .plan
2167 .execute_agents(&snap, &space, ¢ers, None, &mut output, &mut mask)
2168 .unwrap();
2169
2170 assert_eq!(metas.len(), 2);
2171
2172 assert!(mask[..9].iter().all(|&v| v == 1));
2174 assert_eq!(output[4], 55.0); let agent1_mask = &mask[9..18];
2178 let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2179 assert_eq!(valid_count, 6); }
2181
2182 #[test]
2183 fn execute_agents_with_normalize() {
2184 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2185 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2186 let snap = snapshot_with_field(FieldId(0), data);
2187
2188 let spec = ObsSpec {
2189 entries: vec![ObsEntry {
2190 field_id: FieldId(0),
2191 region: ObsRegion::AgentRect {
2192 half_extent: smallvec::smallvec![1, 1],
2193 },
2194 pool: None,
2195 transform: ObsTransform::Normalize {
2196 min: 0.0,
2197 max: 99.0,
2198 },
2199 dtype: ObsDtype::F32,
2200 }],
2201 };
2202 let result = ObsPlan::compile(&spec, &space).unwrap();
2203
2204 let center: Coord = smallvec::smallvec![5, 5];
2205 let mut output = vec![0.0f32; result.output_len];
2206 let mut mask = vec![0u8; result.mask_len];
2207 result
2208 .plan
2209 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2210 .unwrap();
2211
2212 let expected = 55.0 / 99.0;
2214 assert!((output[4] - expected as f32).abs() < 1e-5);
2215 }
2216
2217 #[test]
2218 fn execute_agents_with_pooling() {
2219 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2220 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2221 let snap = snapshot_with_field(FieldId(0), data);
2222
2223 let spec = ObsSpec {
2226 entries: vec![ObsEntry {
2227 field_id: FieldId(0),
2228 region: ObsRegion::AgentRect {
2229 half_extent: smallvec::smallvec![3, 3],
2230 },
2231 pool: Some(PoolConfig {
2232 kernel: PoolKernel::Mean,
2233 kernel_size: 2,
2234 stride: 2,
2235 }),
2236 transform: ObsTransform::Identity,
2237 dtype: ObsDtype::F32,
2238 }],
2239 };
2240 let result = ObsPlan::compile(&spec, &space).unwrap();
2241 assert_eq!(result.output_len, 9); assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2243
2244 let center: Coord = smallvec::smallvec![10, 10];
2246 let mut output = vec![0.0f32; result.output_len];
2247 let mut mask = vec![0u8; result.mask_len];
2248 result
2249 .plan
2250 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2251 .unwrap();
2252
2253 assert!(mask.iter().all(|&v| v == 1));
2255
2256 assert!((output[0] - 157.5).abs() < 1e-4);
2261 }
2262
2263 #[test]
2264 fn mixed_fixed_and_agent_entries() {
2265 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2266 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2267 let snap = snapshot_with_field(FieldId(0), data);
2268
2269 let spec = ObsSpec {
2270 entries: vec![
2271 ObsEntry {
2273 field_id: FieldId(0),
2274 region: ObsRegion::Fixed(RegionSpec::All),
2275 pool: None,
2276 transform: ObsTransform::Identity,
2277 dtype: ObsDtype::F32,
2278 },
2279 ObsEntry {
2281 field_id: FieldId(0),
2282 region: ObsRegion::AgentRect {
2283 half_extent: smallvec::smallvec![1, 1],
2284 },
2285 pool: None,
2286 transform: ObsTransform::Identity,
2287 dtype: ObsDtype::F32,
2288 },
2289 ],
2290 };
2291 let result = ObsPlan::compile(&spec, &space).unwrap();
2292 assert!(result.plan.is_standard());
2293 assert_eq!(result.output_len, 109); let center: Coord = smallvec::smallvec![5, 5];
2296 let mut output = vec![0.0f32; result.output_len];
2297 let mut mask = vec![0u8; result.mask_len];
2298 result
2299 .plan
2300 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2301 .unwrap();
2302
2303 let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2305 assert_eq!(&output[..100], &expected[..]);
2306 assert!(mask[..100].iter().all(|&v| v == 1));
2307
2308 assert_eq!(output[100 + 4], 55.0);
2311 }
2312
2313 #[test]
2314 fn wrong_dimensionality_returns_error() {
2315 let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2317 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2318 let snap = snapshot_with_field(FieldId(0), data);
2319
2320 let spec = ObsSpec {
2321 entries: vec![ObsEntry {
2322 field_id: FieldId(0),
2323 region: ObsRegion::AgentDisk { radius: 1 },
2324 pool: None,
2325 transform: ObsTransform::Identity,
2326 dtype: ObsDtype::F32,
2327 }],
2328 };
2329 let result = ObsPlan::compile(&spec, &space).unwrap();
2330
2331 let bad_center: Coord = smallvec::smallvec![5]; let mut output = vec![0.0f32; result.output_len];
2333 let mut mask = vec![0u8; result.mask_len];
2334 let err =
2335 result
2336 .plan
2337 .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2338 assert!(err.is_err());
2339 let msg = format!("{}", err.unwrap_err());
2340 assert!(
2341 msg.contains("dimensions"),
2342 "error should mention dimensions: {msg}"
2343 );
2344 }
2345
2346 #[test]
2347 fn agent_disk_square4_filters_corners() {
2348 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2351 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2352 let snap = snapshot_with_field(FieldId(0), data);
2353
2354 let spec = ObsSpec {
2355 entries: vec![ObsEntry {
2356 field_id: FieldId(0),
2357 region: ObsRegion::AgentDisk { radius: 2 },
2358 pool: None,
2359 transform: ObsTransform::Identity,
2360 dtype: ObsDtype::F32,
2361 }],
2362 };
2363 let result = ObsPlan::compile(&spec, &space).unwrap();
2364 assert_eq!(result.output_len, 25); let center: Coord = smallvec::smallvec![10, 10];
2368 let mut output = vec![0.0f32; 25];
2369 let mut mask = vec![0u8; 25];
2370 result
2371 .plan
2372 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2373 .unwrap();
2374
2375 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2383 assert_eq!(
2384 valid_count, 13,
2385 "Manhattan disk radius=2 should have 13 cells"
2386 );
2387
2388 for &idx in &[0, 4, 20, 24] {
2394 assert_eq!(
2395 mask[idx], 0,
2396 "corner tensor_idx {idx} should be outside disk"
2397 );
2398 }
2399
2400 assert_eq!(output[12], 210.0);
2402 assert_eq!(mask[12], 1);
2403 }
2404
2405 #[test]
2406 fn agent_rect_no_disk_filtering() {
2407 let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2409 let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2410 let snap = snapshot_with_field(FieldId(0), data);
2411
2412 let spec = ObsSpec {
2413 entries: vec![ObsEntry {
2414 field_id: FieldId(0),
2415 region: ObsRegion::AgentRect {
2416 half_extent: smallvec::smallvec![2, 2],
2417 },
2418 pool: None,
2419 transform: ObsTransform::Identity,
2420 dtype: ObsDtype::F32,
2421 }],
2422 };
2423 let result = ObsPlan::compile(&spec, &space).unwrap();
2424
2425 let center: Coord = smallvec::smallvec![10, 10];
2426 let mut output = vec![0.0f32; 25];
2427 let mut mask = vec![0u8; 25];
2428 result
2429 .plan
2430 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2431 .unwrap();
2432
2433 assert!(mask.iter().all(|&v| v == 1));
2435 }
2436
2437 #[test]
2438 fn agent_disk_square8_chebyshev() {
2439 let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2442 let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2443 let snap = snapshot_with_field(FieldId(0), data);
2444
2445 let spec = ObsSpec {
2446 entries: vec![ObsEntry {
2447 field_id: FieldId(0),
2448 region: ObsRegion::AgentDisk { radius: 1 },
2449 pool: None,
2450 transform: ObsTransform::Identity,
2451 dtype: ObsDtype::F32,
2452 }],
2453 };
2454 let result = ObsPlan::compile(&spec, &space).unwrap();
2455 assert_eq!(result.output_len, 9);
2456
2457 let center: Coord = smallvec::smallvec![5, 5];
2458 let mut output = vec![0.0f32; 9];
2459 let mut mask = vec![0u8; 9];
2460 result
2461 .plan
2462 .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2463 .unwrap();
2464
2465 let valid_count = mask.iter().filter(|&&v| v == 1).count();
2467 assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2468 }
2469
2470 #[test]
2471 fn compile_rejects_inverted_normalize_range() {
2472 let space = square4_space();
2473 let spec = ObsSpec {
2474 entries: vec![ObsEntry {
2475 field_id: FieldId(0),
2476 region: ObsRegion::Fixed(RegionSpec::All),
2477 pool: None,
2478 transform: ObsTransform::Normalize {
2479 min: 10.0,
2480 max: 5.0,
2481 },
2482 dtype: ObsDtype::F32,
2483 }],
2484 };
2485 let err = ObsPlan::compile(&spec, &space).unwrap_err();
2486 assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
2487 }
2488
2489 #[test]
2490 fn compile_rejects_nan_normalize() {
2491 let space = square4_space();
2492 let spec = ObsSpec {
2493 entries: vec![ObsEntry {
2494 field_id: FieldId(0),
2495 region: ObsRegion::Fixed(RegionSpec::All),
2496 pool: None,
2497 transform: ObsTransform::Normalize {
2498 min: f64::NAN,
2499 max: 1.0,
2500 },
2501 dtype: ObsDtype::F32,
2502 }],
2503 };
2504 assert!(ObsPlan::compile(&spec, &space).is_err());
2505 }
2506}