Skip to main content

murk_obs/
plan.rs

1//! Observation plan compilation and execution.
2//!
3//! [`ObsPlan`] is compiled from an [`ObsSpec`] + [`Space`], producing
4//! a reusable gather plan that can be executed against any
5//! [`SnapshotAccess`] implementor. The "Simple plan class" uses a
6//! branch-free flat gather: for each entry, iterate pre-computed
7//! `(field_data_index, tensor_index)` pairs, read the field value,
8//! optionally transform it, and write to the caller-allocated buffer.
9
10use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21/// Coverage threshold: warn if valid_ratio < this.
22const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24/// Coverage threshold: error if valid_ratio < this.
25const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27/// Result of compiling an [`ObsSpec`].
28#[derive(Debug)]
29pub struct ObsPlanResult {
30    /// The compiled plan, ready for execution.
31    pub plan: ObsPlan,
32    /// Total number of f32 elements in the output tensor.
33    pub output_len: usize,
34    /// Shape per entry (each entry's region bounding shape dimensions).
35    pub entry_shapes: Vec<Vec<usize>>,
36    /// Length of the validity mask in bytes.
37    pub mask_len: usize,
38}
39
40/// Compiled observation plan: either Simple or Standard class.
41///
42/// **Simple** (all `Fixed` regions): pre-computed gather indices, branch-free
43/// loop, zero spatial computation at runtime. Use [`execute`](Self::execute).
44///
45/// **Standard** (any agent-relative region): template-based gather with
46/// interior/boundary dispatch. Use [`execute_agents`](Self::execute_agents).
47#[derive(Debug)]
48pub struct ObsPlan {
49    strategy: PlanStrategy,
50    /// Total output elements across all entries (per agent for Standard).
51    output_len: usize,
52    /// Total mask bytes across all entries (per agent for Standard).
53    mask_len: usize,
54    /// Generation at compile time (for PLAN_INVALIDATED detection).
55    compiled_generation: Option<WorldGenerationId>,
56}
57
58/// Pre-computed gather instruction for a single cell.
59///
60/// At execution time, we read `field_data[field_data_idx]` and write
61/// the (transformed) value to `output[tensor_idx]`.
62#[derive(Debug, Clone)]
63struct GatherOp {
64    /// Index into the flat field data array (canonical ordering).
65    field_data_idx: usize,
66    /// Index into the output slice for this entry.
67    tensor_idx: usize,
68}
69
70/// A single compiled entry ready for gather execution.
71#[derive(Debug)]
72struct CompiledEntry {
73    field_id: FieldId,
74    transform: ObsTransform,
75    #[allow(dead_code)]
76    dtype: ObsDtype,
77    /// Offset into the output buffer where this entry starts.
78    output_offset: usize,
79    /// Offset into the validity mask where this entry starts.
80    mask_offset: usize,
81    /// Number of elements this entry contributes to the output.
82    element_count: usize,
83    /// Pre-computed gather operations (one per valid cell in the region).
84    gather_ops: Vec<GatherOp>,
85    /// Pre-computed validity mask for this entry's bounding box.
86    valid_mask: Vec<u8>,
87    /// Valid ratio for this entry's region.
88    #[allow(dead_code)]
89    valid_ratio: f64,
90}
91
92/// Relative offset from agent center for template-based gather.
93///
94/// At compile time, the bounding box of an agent-centered region is
95/// decomposed into `TemplateOp`s. At execute time, the agent center
96/// is resolved and each op is applied: `field_data[base_rank + stride_offset]`.
97#[derive(Debug, Clone)]
98struct TemplateOp {
99    /// Offset from center per coordinate axis.
100    relative: Coord,
101    /// Position in the gather bounding-box tensor (row-major).
102    tensor_idx: usize,
103    /// Precomputed `sum(relative[i] * strides[i])` for interior fast path.
104    /// Zero if no `GridGeometry` is available (fallback path only).
105    stride_offset: isize,
106    /// Whether this cell is within the disk region (always true for AgentRect).
107    /// For AgentDisk, cells outside the graph-distance radius are excluded.
108    in_disk: bool,
109}
110
111/// Compiled agent-relative entry for the Standard plan class.
112///
113/// Stores template data that is instantiated per-agent at execute time.
114/// The bounding box shape comes from the region (e.g., `[2r+1, 2r+1]` for
115/// `AgentDisk`/`AgentRect`), and may be reduced by pooling.
116#[derive(Debug)]
117struct AgentCompiledEntry {
118    field_id: FieldId,
119    pool: Option<PoolConfig>,
120    transform: ObsTransform,
121    #[allow(dead_code)]
122    dtype: ObsDtype,
123    /// Offset into the per-agent output buffer.
124    output_offset: usize,
125    /// Offset into the per-agent mask buffer.
126    mask_offset: usize,
127    /// Post-pool output elements (written to output).
128    element_count: usize,
129    /// Pre-pool bounding-box elements (gather buffer size).
130    pre_pool_element_count: usize,
131    /// Shape of the pre-pool bounding box (e.g., `[7, 7]`).
132    pre_pool_shape: Vec<usize>,
133    /// Template operations (one per cell in bounding box).
134    template_ops: Vec<TemplateOp>,
135    /// Radius for `is_interior` check.
136    radius: u32,
137}
138
139/// Data for the Standard plan class (agent-centered foveation + pooling).
140#[derive(Debug)]
141struct StandardPlanData {
142    /// Entries with `ObsRegion::Fixed` (same output for all agents).
143    fixed_entries: Vec<CompiledEntry>,
144    /// Entries with agent-relative regions (resolved per-agent).
145    agent_entries: Vec<AgentCompiledEntry>,
146    /// Grid geometry for interior/boundary dispatch (`None` → all slow path).
147    geometry: Option<GridGeometry>,
148}
149
150/// Internal plan strategy: Simple (all-fixed) or Standard (agent-centered).
151#[derive(Debug)]
152enum PlanStrategy {
153    /// All entries are `ObsRegion::Fixed`: pre-computed gather indices.
154    Simple(Vec<CompiledEntry>),
155    /// At least one entry is agent-relative: template-based gather.
156    Standard(StandardPlanData),
157}
158
159impl ObsPlan {
160    /// Compile an [`ObsSpec`] against a [`Space`].
161    ///
162    /// Detects whether the spec contains agent-relative regions and
163    /// dispatches to the appropriate plan class:
164    /// - All `Fixed` → **Simple** (pre-computed gather)
165    /// - Any `AgentDisk`/`AgentRect` → **Standard** (template-based)
166    pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
167        if spec.entries.is_empty() {
168            return Err(ObsError::InvalidObsSpec {
169                reason: "ObsSpec has no entries".into(),
170            });
171        }
172
173        // Validate transform parameters.
174        for (i, entry) in spec.entries.iter().enumerate() {
175            if let ObsTransform::Normalize { min, max } = &entry.transform {
176                if !min.is_finite() || !max.is_finite() {
177                    return Err(ObsError::InvalidObsSpec {
178                        reason: format!(
179                            "entry {i}: Normalize min/max must be finite, got min={min}, max={max}"
180                        ),
181                    });
182                }
183                if min > max {
184                    return Err(ObsError::InvalidObsSpec {
185                        reason: format!("entry {i}: Normalize min ({min}) must be <= max ({max})"),
186                    });
187                }
188            }
189        }
190
191        let has_agent = spec.entries.iter().any(|e| {
192            matches!(
193                e.region,
194                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
195            )
196        });
197
198        if has_agent {
199            Self::compile_standard(spec, space)
200        } else {
201            Self::compile_simple(spec, space)
202        }
203    }
204
205    /// Compile a Simple plan (all `Fixed` regions, no agent-relative entries).
206    fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
207        let canonical = space.canonical_ordering();
208        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
209            .into_iter()
210            .enumerate()
211            .map(|(idx, coord)| (coord, idx))
212            .collect();
213
214        let mut entries = Vec::with_capacity(spec.entries.len());
215        let mut output_offset = 0usize;
216        let mut mask_offset = 0usize;
217        let mut entry_shapes = Vec::with_capacity(spec.entries.len());
218
219        for (i, entry) in spec.entries.iter().enumerate() {
220            let fixed_region = match &entry.region {
221                ObsRegion::Fixed(spec) => spec,
222                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
223                    return Err(ObsError::InvalidObsSpec {
224                        reason: format!("entry {i}: agent-relative region in Simple plan"),
225                    });
226                }
227            };
228            if entry.pool.is_some() {
229                return Err(ObsError::InvalidObsSpec {
230                    reason: format!(
231                        "entry {i}: pooling requires a Standard plan (use agent-relative region)"
232                    ),
233                });
234            }
235
236            let mut region_plan =
237                space
238                    .compile_region(fixed_region)
239                    .map_err(|e| ObsError::InvalidObsSpec {
240                        reason: format!("entry {i}: region compile failed: {e}"),
241                    })?;
242
243            let ratio = region_plan.valid_ratio();
244            if ratio < COVERAGE_ERROR_THRESHOLD {
245                return Err(ObsError::InvalidComposition {
246                    reason: format!(
247                        "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
248                    ),
249                });
250            }
251            if ratio < COVERAGE_WARN_THRESHOLD {
252                eprintln!(
253                    "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
254                );
255            }
256
257            let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
258            for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
259                let field_data_idx =
260                    *coord_to_field_idx
261                        .get(coord)
262                        .ok_or_else(|| ObsError::InvalidObsSpec {
263                            reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
264                        })?;
265                let tensor_idx = region_plan.tensor_indices()[coord_idx];
266                gather_ops.push(GatherOp {
267                    field_data_idx,
268                    tensor_idx,
269                });
270            }
271
272            let element_count = region_plan.bounding_shape().total_elements();
273            let shape = match region_plan.bounding_shape() {
274                murk_space::BoundingShape::Rect(dims) => dims.clone(),
275            };
276            entry_shapes.push(shape);
277
278            entries.push(CompiledEntry {
279                field_id: entry.field_id,
280                transform: entry.transform.clone(),
281                dtype: entry.dtype,
282                output_offset,
283                mask_offset,
284                element_count,
285                gather_ops,
286                valid_mask: region_plan.take_valid_mask(),
287                valid_ratio: ratio,
288            });
289
290            output_offset += element_count;
291            mask_offset += element_count;
292        }
293
294        let plan = ObsPlan {
295            strategy: PlanStrategy::Simple(entries),
296            output_len: output_offset,
297            mask_len: mask_offset,
298            compiled_generation: None,
299        };
300
301        Ok(ObsPlanResult {
302            output_len: plan.output_len,
303            mask_len: plan.mask_len,
304            entry_shapes,
305            plan,
306        })
307    }
308
309    /// Compile a Standard plan (has agent-relative entries).
310    ///
311    /// Fixed entries are compiled with pre-computed gather (same for all agents).
312    /// Agent entries are compiled as templates (resolved per-agent at execute time).
313    fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
314        let canonical = space.canonical_ordering();
315        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
316            .into_iter()
317            .enumerate()
318            .map(|(idx, coord)| (coord, idx))
319            .collect();
320
321        let geometry = GridGeometry::from_space(space);
322        let ndim = space.ndim();
323
324        let mut fixed_entries = Vec::new();
325        let mut agent_entries = Vec::new();
326        let mut output_offset = 0usize;
327        let mut mask_offset = 0usize;
328        let mut entry_shapes = Vec::new();
329
330        for (i, entry) in spec.entries.iter().enumerate() {
331            match &entry.region {
332                ObsRegion::Fixed(region_spec) => {
333                    if entry.pool.is_some() {
334                        return Err(ObsError::InvalidObsSpec {
335                            reason: format!("entry {i}: pooling on Fixed regions not supported"),
336                        });
337                    }
338
339                    let mut region_plan = space.compile_region(region_spec).map_err(|e| {
340                        ObsError::InvalidObsSpec {
341                            reason: format!("entry {i}: region compile failed: {e}"),
342                        }
343                    })?;
344
345                    let ratio = region_plan.valid_ratio();
346                    if ratio < COVERAGE_ERROR_THRESHOLD {
347                        return Err(ObsError::InvalidComposition {
348                            reason: format!(
349                                "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
350                            ),
351                        });
352                    }
353
354                    let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
355                    for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
356                        let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
357                            ObsError::InvalidObsSpec {
358                                reason: format!(
359                                    "entry {i}: coord {coord:?} not in canonical ordering"
360                                ),
361                            }
362                        })?;
363                        let tensor_idx = region_plan.tensor_indices()[coord_idx];
364                        gather_ops.push(GatherOp {
365                            field_data_idx,
366                            tensor_idx,
367                        });
368                    }
369
370                    let element_count = region_plan.bounding_shape().total_elements();
371                    let shape = match region_plan.bounding_shape() {
372                        murk_space::BoundingShape::Rect(dims) => dims.clone(),
373                    };
374                    entry_shapes.push(shape);
375
376                    fixed_entries.push(CompiledEntry {
377                        field_id: entry.field_id,
378                        transform: entry.transform.clone(),
379                        dtype: entry.dtype,
380                        output_offset,
381                        mask_offset,
382                        element_count,
383                        gather_ops,
384                        valid_mask: region_plan.take_valid_mask(),
385                        valid_ratio: ratio,
386                    });
387
388                    output_offset += element_count;
389                    mask_offset += element_count;
390                }
391
392                ObsRegion::AgentDisk { radius } => {
393                    let half_ext: smallvec::SmallVec<[u32; 4]> =
394                        (0..ndim).map(|_| *radius).collect();
395                    let (ae, shape) = Self::compile_agent_entry(
396                        i,
397                        entry,
398                        &half_ext,
399                        *radius,
400                        &geometry,
401                        Some(*radius),
402                        output_offset,
403                        mask_offset,
404                    )?;
405                    entry_shapes.push(shape);
406                    output_offset += ae.element_count;
407                    mask_offset += ae.element_count;
408                    agent_entries.push(ae);
409                }
410
411                ObsRegion::AgentRect { half_extent } => {
412                    let radius = *half_extent.iter().max().unwrap_or(&0);
413                    let (ae, shape) = Self::compile_agent_entry(
414                        i,
415                        entry,
416                        half_extent,
417                        radius,
418                        &geometry,
419                        None,
420                        output_offset,
421                        mask_offset,
422                    )?;
423                    entry_shapes.push(shape);
424                    output_offset += ae.element_count;
425                    mask_offset += ae.element_count;
426                    agent_entries.push(ae);
427                }
428            }
429        }
430
431        let plan = ObsPlan {
432            strategy: PlanStrategy::Standard(StandardPlanData {
433                fixed_entries,
434                agent_entries,
435                geometry,
436            }),
437            output_len: output_offset,
438            mask_len: mask_offset,
439            compiled_generation: None,
440        };
441
442        Ok(ObsPlanResult {
443            output_len: plan.output_len,
444            mask_len: plan.mask_len,
445            entry_shapes,
446            plan,
447        })
448    }
449
450    /// Compile a single agent-relative entry into a template.
451    ///
452    /// `disk_radius`: if `Some(r)`, template ops outside graph-distance `r`
453    /// are marked `in_disk = false` (for `AgentDisk`). `None` for `AgentRect`.
454    #[allow(clippy::too_many_arguments)]
455    fn compile_agent_entry(
456        entry_idx: usize,
457        entry: &crate::spec::ObsEntry,
458        half_extent: &[u32],
459        radius: u32,
460        geometry: &Option<GridGeometry>,
461        disk_radius: Option<u32>,
462        output_offset: usize,
463        mask_offset: usize,
464    ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
465        let pre_pool_shape: Vec<usize> =
466            half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
467        let pre_pool_element_count: usize = pre_pool_shape.iter().product();
468
469        let template_ops = generate_template_ops(half_extent, geometry, disk_radius);
470
471        let (element_count, output_shape) = if let Some(pool) = &entry.pool {
472            if pre_pool_shape.len() != 2 {
473                return Err(ObsError::InvalidObsSpec {
474                    reason: format!(
475                        "entry {entry_idx}: pooling requires 2D region, got {}D",
476                        pre_pool_shape.len()
477                    ),
478                });
479            }
480            let h = pre_pool_shape[0];
481            let w = pre_pool_shape[1];
482            let ks = pool.kernel_size;
483            let stride = pool.stride;
484            if ks == 0 || stride == 0 {
485                return Err(ObsError::InvalidObsSpec {
486                    reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
487                });
488            }
489            let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
490            let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
491            if out_h == 0 || out_w == 0 {
492                return Err(ObsError::InvalidObsSpec {
493                    reason: format!(
494                        "entry {entry_idx}: pool produces empty output \
495                         (region [{h},{w}], kernel_size {ks}, stride {stride})"
496                    ),
497                });
498            }
499            (out_h * out_w, vec![out_h, out_w])
500        } else {
501            (pre_pool_element_count, pre_pool_shape.clone())
502        };
503
504        Ok((
505            AgentCompiledEntry {
506                field_id: entry.field_id,
507                pool: entry.pool.clone(),
508                transform: entry.transform.clone(),
509                dtype: entry.dtype,
510                output_offset,
511                mask_offset,
512                element_count,
513                pre_pool_element_count,
514                pre_pool_shape,
515                template_ops,
516                radius,
517            },
518            output_shape,
519        ))
520    }
521
522    /// Compile with generation binding for PLAN_INVALIDATED detection.
523    ///
524    /// Same as [`compile`](Self::compile) but records the snapshot's
525    /// `world_generation_id` for later validation in [`ObsPlan::execute`].
526    pub fn compile_bound(
527        spec: &ObsSpec,
528        space: &dyn Space,
529        generation: WorldGenerationId,
530    ) -> Result<ObsPlanResult, ObsError> {
531        let mut result = Self::compile(spec, space)?;
532        result.plan.compiled_generation = Some(generation);
533        Ok(result)
534    }
535
536    /// Total number of f32 elements in the output tensor.
537    pub fn output_len(&self) -> usize {
538        self.output_len
539    }
540
541    /// Total number of bytes in the validity mask.
542    pub fn mask_len(&self) -> usize {
543        self.mask_len
544    }
545
546    /// The generation this plan was compiled against, if bound.
547    pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
548        self.compiled_generation
549    }
550
551    /// Execute the observation plan against a snapshot.
552    ///
553    /// Fills `output` with gathered and transformed field values, and
554    /// `mask` with validity flags (1 = valid, 0 = padding). Both
555    /// buffers must be pre-allocated to [`output_len`](Self::output_len)
556    /// and [`mask_len`](Self::mask_len) respectively.
557    ///
558    /// `engine_tick` is the current engine tick for computing
559    /// [`ObsMetadata::age_ticks`]. Pass `None` in Lockstep mode
560    /// (age is always 0). In RealtimeAsync mode, pass the current
561    /// engine tick so age reflects snapshot staleness.
562    ///
563    /// Returns [`ObsMetadata`] on success.
564    ///
565    /// # Errors
566    ///
567    /// - [`ObsError::PlanInvalidated`] if bound and generation mismatches.
568    /// - [`ObsError::ExecutionFailed`] if a field is missing from the snapshot.
569    pub fn execute(
570        &self,
571        snapshot: &dyn SnapshotAccess,
572        engine_tick: Option<TickId>,
573        output: &mut [f32],
574        mask: &mut [u8],
575    ) -> Result<ObsMetadata, ObsError> {
576        let entries = match &self.strategy {
577            PlanStrategy::Simple(entries) => entries,
578            PlanStrategy::Standard(_) => {
579                return Err(ObsError::ExecutionFailed {
580                    reason: "Standard plan requires execute_agents(), not execute()".into(),
581                });
582            }
583        };
584
585        if output.len() < self.output_len {
586            return Err(ObsError::ExecutionFailed {
587                reason: format!(
588                    "output buffer too small: {} < {}",
589                    output.len(),
590                    self.output_len
591                ),
592            });
593        }
594        if mask.len() < self.mask_len {
595            return Err(ObsError::ExecutionFailed {
596                reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
597            });
598        }
599
600        // Generation check (PLAN_INVALIDATED).
601        if let Some(compiled_gen) = self.compiled_generation {
602            let snapshot_gen = snapshot.world_generation_id();
603            if compiled_gen != snapshot_gen {
604                return Err(ObsError::PlanInvalidated {
605                    reason: format!(
606                        "plan compiled for generation {}, snapshot is generation {}",
607                        compiled_gen.0, snapshot_gen.0
608                    ),
609                });
610            }
611        }
612
613        let mut total_valid = 0usize;
614        let mut total_elements = 0usize;
615
616        for entry in entries {
617            let field_data =
618                snapshot
619                    .read_field(entry.field_id)
620                    .ok_or_else(|| ObsError::ExecutionFailed {
621                        reason: format!("field {:?} not in snapshot", entry.field_id),
622                    })?;
623
624            let out_slice =
625                &mut output[entry.output_offset..entry.output_offset + entry.element_count];
626            let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
627
628            // Initialize to zero/padding.
629            out_slice.fill(0.0);
630            mask_slice.copy_from_slice(&entry.valid_mask);
631
632            // Branch-free gather: pre-computed (field_data_idx, tensor_idx) pairs.
633            for op in &entry.gather_ops {
634                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
635                    ObsError::ExecutionFailed {
636                        reason: format!(
637                            "field {:?} has {} elements but gather requires index {}",
638                            entry.field_id,
639                            field_data.len(),
640                            op.field_data_idx,
641                        ),
642                    }
643                })?;
644                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
645            }
646
647            total_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
648            total_elements += entry.element_count;
649        }
650
651        let coverage = if total_elements == 0 {
652            0.0
653        } else {
654            total_valid as f64 / total_elements as f64
655        };
656
657        let age_ticks = match engine_tick {
658            Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
659            None => 0,
660        };
661
662        Ok(ObsMetadata {
663            tick_id: snapshot.tick_id(),
664            age_ticks,
665            coverage,
666            world_generation_id: snapshot.world_generation_id(),
667            parameter_version: snapshot.parameter_version(),
668        })
669    }
670
671    /// Execute the plan for a batch of `N` identical environments.
672    ///
673    /// Each snapshot in the batch fills `output_len()` elements in the
674    /// output buffer, starting at `batch_idx * output_len()`. Same for
675    /// masks. This is the primary interface for vectorized RL training.
676    ///
677    /// Returns one [`ObsMetadata`] per snapshot.
678    pub fn execute_batch(
679        &self,
680        snapshots: &[&dyn SnapshotAccess],
681        engine_tick: Option<TickId>,
682        output: &mut [f32],
683        mask: &mut [u8],
684    ) -> Result<Vec<ObsMetadata>, ObsError> {
685        // execute_batch only works with Simple plans.
686        if matches!(self.strategy, PlanStrategy::Standard(_)) {
687            return Err(ObsError::ExecutionFailed {
688                reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
689            });
690        }
691
692        let batch_size = snapshots.len();
693        let expected_out = batch_size * self.output_len;
694        let expected_mask = batch_size * self.mask_len;
695
696        if output.len() < expected_out {
697            return Err(ObsError::ExecutionFailed {
698                reason: format!(
699                    "batch output buffer too small: {} < {}",
700                    output.len(),
701                    expected_out
702                ),
703            });
704        }
705        if mask.len() < expected_mask {
706            return Err(ObsError::ExecutionFailed {
707                reason: format!(
708                    "batch mask buffer too small: {} < {}",
709                    mask.len(),
710                    expected_mask
711                ),
712            });
713        }
714
715        let mut metadata = Vec::with_capacity(batch_size);
716        for (i, snap) in snapshots.iter().enumerate() {
717            let out_start = i * self.output_len;
718            let mask_start = i * self.mask_len;
719            let out_slice = &mut output[out_start..out_start + self.output_len];
720            let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
721            let meta = self.execute(*snap, engine_tick, out_slice, mask_slice)?;
722            metadata.push(meta);
723        }
724        Ok(metadata)
725    }
726
727    /// Execute the Standard plan for `N` agents in one environment.
728    ///
729    /// Each agent gets `output_len()` elements starting at
730    /// `agent_idx * output_len()`. Fixed entries produce the same
731    /// output for all agents; agent-relative entries are resolved
732    /// per-agent using interior/boundary dispatch.
733    ///
734    /// Interior agents (~49% for 20×20 grid, radius 3) use a branchless
735    /// fast path with stride arithmetic. Boundary agents fall back to
736    /// per-cell bounds checking.
737    pub fn execute_agents(
738        &self,
739        snapshot: &dyn SnapshotAccess,
740        space: &dyn Space,
741        agent_centers: &[Coord],
742        engine_tick: Option<TickId>,
743        output: &mut [f32],
744        mask: &mut [u8],
745    ) -> Result<Vec<ObsMetadata>, ObsError> {
746        let standard = match &self.strategy {
747            PlanStrategy::Standard(data) => data,
748            PlanStrategy::Simple(_) => {
749                return Err(ObsError::ExecutionFailed {
750                    reason: "execute_agents requires a Standard plan \
751                             (spec must contain agent-relative entries)"
752                        .into(),
753                });
754            }
755        };
756
757        let n_agents = agent_centers.len();
758        let expected_out = n_agents * self.output_len;
759        let expected_mask = n_agents * self.mask_len;
760
761        if output.len() < expected_out {
762            return Err(ObsError::ExecutionFailed {
763                reason: format!(
764                    "output buffer too small: {} < {}",
765                    output.len(),
766                    expected_out
767                ),
768            });
769        }
770        if mask.len() < expected_mask {
771            return Err(ObsError::ExecutionFailed {
772                reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
773            });
774        }
775
776        // Validate agent center dimensionality.
777        let expected_ndim = space.ndim();
778        for (i, center) in agent_centers.iter().enumerate() {
779            if center.len() != expected_ndim {
780                return Err(ObsError::ExecutionFailed {
781                    reason: format!(
782                        "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
783                        center.len()
784                    ),
785                });
786            }
787        }
788
789        // Generation check.
790        if let Some(compiled_gen) = self.compiled_generation {
791            let snapshot_gen = snapshot.world_generation_id();
792            if compiled_gen != snapshot_gen {
793                return Err(ObsError::PlanInvalidated {
794                    reason: format!(
795                        "plan compiled for generation {}, snapshot is generation {}",
796                        compiled_gen.0, snapshot_gen.0
797                    ),
798                });
799            }
800        }
801
802        // Pre-read all field data (shared borrows, valid for duration).
803        let mut field_data_map: IndexMap<FieldId, &[f32]> = IndexMap::new();
804        for entry in &standard.fixed_entries {
805            if !field_data_map.contains_key(&entry.field_id) {
806                let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
807                    ObsError::ExecutionFailed {
808                        reason: format!("field {:?} not in snapshot", entry.field_id),
809                    }
810                })?;
811                field_data_map.insert(entry.field_id, data);
812            }
813        }
814        for entry in &standard.agent_entries {
815            if !field_data_map.contains_key(&entry.field_id) {
816                let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
817                    ObsError::ExecutionFailed {
818                        reason: format!("field {:?} not in snapshot", entry.field_id),
819                    }
820                })?;
821                field_data_map.insert(entry.field_id, data);
822            }
823        }
824
825        // ── Compute fixed entries ONCE (identical for all agents) ──
826        // Allocate scratch buffers for the fixed-entry output and mask, gather
827        // into them, then memcpy per agent. This avoids redundant N*M gathers.
828        let mut fixed_out_scratch = vec![0.0f32; self.output_len];
829        let mut fixed_mask_scratch = vec![0u8; self.mask_len];
830        let mut fixed_valid = 0usize;
831        let mut fixed_elements = 0usize;
832
833        for entry in &standard.fixed_entries {
834            let field_data = field_data_map[&entry.field_id];
835            let out_slice = &mut fixed_out_scratch
836                [entry.output_offset..entry.output_offset + entry.element_count];
837            let mask_slice =
838                &mut fixed_mask_scratch[entry.mask_offset..entry.mask_offset + entry.element_count];
839
840            mask_slice.copy_from_slice(&entry.valid_mask);
841            for op in &entry.gather_ops {
842                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
843                    ObsError::ExecutionFailed {
844                        reason: format!(
845                            "field {:?} has {} elements but gather requires index {}",
846                            entry.field_id,
847                            field_data.len(),
848                            op.field_data_idx,
849                        ),
850                    }
851                })?;
852                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
853            }
854
855            fixed_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
856            fixed_elements += entry.element_count;
857        }
858
859        // ── Pre-allocate pooling scratch buffers ──────────────────
860        // Find the maximum pre_pool_element_count across all pooled entries
861        // so we can allocate once and reuse across agents (sequential processing).
862        let max_pool_scratch = standard
863            .agent_entries
864            .iter()
865            .filter(|e| e.pool.is_some())
866            .map(|e| e.pre_pool_element_count)
867            .max()
868            .unwrap_or(0);
869        let mut pool_scratch = vec![0.0f32; max_pool_scratch];
870        let mut pool_scratch_mask = vec![0u8; max_pool_scratch];
871
872        let mut metadata = Vec::with_capacity(n_agents);
873
874        for (agent_i, center) in agent_centers.iter().enumerate() {
875            let out_start = agent_i * self.output_len;
876            let mask_start = agent_i * self.mask_len;
877            let agent_output = &mut output[out_start..out_start + self.output_len];
878            let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
879
880            // Stamp fixed entries from pre-computed scratch (memcpy, not re-gather).
881            agent_output[..self.output_len].copy_from_slice(&fixed_out_scratch);
882            agent_mask[..self.mask_len].copy_from_slice(&fixed_mask_scratch);
883
884            let mut total_valid = fixed_valid;
885            let mut total_elements = fixed_elements;
886
887            // ── Agent-relative entries ───────────────────────────
888            for entry in &standard.agent_entries {
889                let field_data = field_data_map[&entry.field_id];
890
891                // Fast path: stride arithmetic works only for non-wrapping
892                // grids where all cells in the bounding box are in-bounds.
893                // Torus (all_wrap) requires modular arithmetic → slow path.
894                let use_fast_path = standard
895                    .geometry
896                    .as_ref()
897                    .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
898                    .unwrap_or(false);
899
900                // Zero the pooling scratch region for this entry before reuse.
901                if entry.pool.is_some() {
902                    pool_scratch[..entry.pre_pool_element_count].fill(0.0);
903                    pool_scratch_mask[..entry.pre_pool_element_count].fill(0);
904                }
905
906                let valid = execute_agent_entry(
907                    entry,
908                    center,
909                    field_data,
910                    &standard.geometry,
911                    space,
912                    use_fast_path,
913                    agent_output,
914                    agent_mask,
915                    &mut pool_scratch,
916                    &mut pool_scratch_mask,
917                );
918
919                total_valid += valid;
920                total_elements += entry.element_count;
921            }
922
923            let coverage = if total_elements == 0 {
924                0.0
925            } else {
926                total_valid as f64 / total_elements as f64
927            };
928
929            let age_ticks = match engine_tick {
930                Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
931                None => 0,
932            };
933
934            metadata.push(ObsMetadata {
935                tick_id: snapshot.tick_id(),
936                age_ticks,
937                coverage,
938                world_generation_id: snapshot.world_generation_id(),
939                parameter_version: snapshot.parameter_version(),
940            });
941        }
942
943        Ok(metadata)
944    }
945
946    /// Whether this plan requires `execute_agents` (Standard) or `execute` (Simple).
947    pub fn is_standard(&self) -> bool {
948        matches!(self.strategy, PlanStrategy::Standard(_))
949    }
950}
951
952/// Execute a single agent-relative entry for one agent.
953///
954/// For pooled entries, `pool_scratch` and `pool_scratch_mask` must be
955/// provided with sufficient capacity (zeroed by the caller). For
956/// non-pooled entries these are ignored.
957///
958/// Returns the number of valid cells written.
959#[allow(clippy::too_many_arguments)]
960fn execute_agent_entry(
961    entry: &AgentCompiledEntry,
962    center: &Coord,
963    field_data: &[f32],
964    geometry: &Option<GridGeometry>,
965    space: &dyn Space,
966    use_fast_path: bool,
967    agent_output: &mut [f32],
968    agent_mask: &mut [u8],
969    pool_scratch: &mut [f32],
970    pool_scratch_mask: &mut [u8],
971) -> usize {
972    if entry.pool.is_some() {
973        execute_agent_entry_pooled(
974            entry,
975            center,
976            field_data,
977            geometry,
978            space,
979            use_fast_path,
980            agent_output,
981            agent_mask,
982            &mut pool_scratch[..entry.pre_pool_element_count],
983            &mut pool_scratch_mask[..entry.pre_pool_element_count],
984        )
985    } else {
986        execute_agent_entry_direct(
987            entry,
988            center,
989            field_data,
990            geometry,
991            space,
992            use_fast_path,
993            agent_output,
994            agent_mask,
995        )
996    }
997}
998
999/// Direct gather (no pooling): gather + transform → output.
1000#[allow(clippy::too_many_arguments)]
1001fn execute_agent_entry_direct(
1002    entry: &AgentCompiledEntry,
1003    center: &Coord,
1004    field_data: &[f32],
1005    geometry: &Option<GridGeometry>,
1006    space: &dyn Space,
1007    use_fast_path: bool,
1008    agent_output: &mut [f32],
1009    agent_mask: &mut [u8],
1010) -> usize {
1011    let out_slice =
1012        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1013    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1014
1015    if use_fast_path {
1016        // FAST PATH: all cells in-bounds, branchless stride arithmetic.
1017        let geo = geometry.as_ref().unwrap();
1018        let base_rank = geo.canonical_rank(center) as isize;
1019        let mut valid = 0;
1020        for op in &entry.template_ops {
1021            if !op.in_disk {
1022                continue;
1023            }
1024            let field_idx = (base_rank + op.stride_offset) as usize;
1025            if let Some(&val) = field_data.get(field_idx) {
1026                out_slice[op.tensor_idx] = apply_transform(val, &entry.transform);
1027                mask_slice[op.tensor_idx] = 1;
1028                valid += 1;
1029            }
1030        }
1031        valid
1032    } else {
1033        // SLOW PATH: bounds-check each offset (or modular wrap for torus).
1034        let mut valid = 0;
1035        for op in &entry.template_ops {
1036            if !op.in_disk {
1037                continue;
1038            }
1039            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1040            if let Some(idx) = field_idx {
1041                if idx < field_data.len() {
1042                    out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
1043                    mask_slice[op.tensor_idx] = 1;
1044                    valid += 1;
1045                }
1046            }
1047        }
1048        valid
1049    }
1050}
1051
1052/// Pooled gather: gather → scratch → pool → transform → output.
1053///
1054/// `scratch` and `scratch_mask` are caller-provided buffers that must be
1055/// at least `entry.pre_pool_element_count` long. They are zeroed by the
1056/// caller before each invocation. This avoids per-agent heap allocation
1057/// when processing many agents sequentially (bug #83).
1058#[allow(clippy::too_many_arguments)]
1059fn execute_agent_entry_pooled(
1060    entry: &AgentCompiledEntry,
1061    center: &Coord,
1062    field_data: &[f32],
1063    geometry: &Option<GridGeometry>,
1064    space: &dyn Space,
1065    use_fast_path: bool,
1066    agent_output: &mut [f32],
1067    agent_mask: &mut [u8],
1068    scratch: &mut [f32],
1069    scratch_mask: &mut [u8],
1070) -> usize {
1071    if use_fast_path {
1072        let geo = geometry.as_ref().unwrap();
1073        let base_rank = geo.canonical_rank(center) as isize;
1074        for op in &entry.template_ops {
1075            if !op.in_disk {
1076                continue;
1077            }
1078            let field_idx = (base_rank + op.stride_offset) as usize;
1079            if let Some(&val) = field_data.get(field_idx) {
1080                scratch[op.tensor_idx] = val;
1081                scratch_mask[op.tensor_idx] = 1;
1082            }
1083        }
1084    } else {
1085        for op in &entry.template_ops {
1086            if !op.in_disk {
1087                continue;
1088            }
1089            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1090            if let Some(idx) = field_idx {
1091                if idx < field_data.len() {
1092                    scratch[op.tensor_idx] = field_data[idx];
1093                    scratch_mask[op.tensor_idx] = 1;
1094                }
1095            }
1096        }
1097    }
1098
1099    let pool_config = entry.pool.as_ref().unwrap();
1100    let (pooled, pooled_mask, _) =
1101        pool_2d(scratch, scratch_mask, &entry.pre_pool_shape, pool_config);
1102
1103    let out_slice =
1104        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1105    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1106
1107    let n = pooled.len().min(entry.element_count);
1108    for i in 0..n {
1109        out_slice[i] = apply_transform(pooled[i], &entry.transform);
1110    }
1111    mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1112
1113    pooled_mask[..n].iter().filter(|&&v| v == 1).count()
1114}
1115
1116/// Generate template operations for a rectangular bounding box.
1117///
1118/// `half_extent[d]` is the half-size per dimension. The bounding box is
1119/// `(2*he[0]+1) × (2*he[1]+1) × ...` in row-major order.
1120///
1121/// If `strides` is provided (from `GridGeometry`), each op gets a precomputed
1122/// `stride_offset` for the interior fast path.
1123///
1124/// If `disk_radius` is `Some(r)`, cells with graph distance > `r` are marked
1125/// `in_disk = false`. The `geometry` is required to compute graph distance.
1126/// When `geometry` is `None`, all cells are treated as in-disk (conservative).
1127fn generate_template_ops(
1128    half_extent: &[u32],
1129    geometry: &Option<GridGeometry>,
1130    disk_radius: Option<u32>,
1131) -> Vec<TemplateOp> {
1132    let ndim = half_extent.len();
1133    let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1134    let total: usize = shape.iter().product();
1135
1136    let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1137
1138    let mut ops = Vec::with_capacity(total);
1139
1140    for tensor_idx in 0..total {
1141        // Decompose tensor_idx into n-d relative coords (row-major).
1142        let mut relative = Coord::new();
1143        let mut remaining = tensor_idx;
1144        // Build in reverse order, then reverse.
1145        for d in (0..ndim).rev() {
1146            let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1147            relative.push(coord);
1148            remaining /= shape[d];
1149        }
1150        relative.reverse();
1151
1152        let stride_offset = strides
1153            .map(|s| {
1154                relative
1155                    .iter()
1156                    .zip(s.iter())
1157                    .map(|(&r, &s)| r as isize * s as isize)
1158                    .sum::<isize>()
1159            })
1160            .unwrap_or(0);
1161
1162        let in_disk = match disk_radius {
1163            Some(r) => match geometry {
1164                Some(geo) => geo.graph_distance(&relative) <= r,
1165                None => true, // no geometry → conservative (include all)
1166            },
1167            None => true, // AgentRect → all cells valid
1168        };
1169
1170        ops.push(TemplateOp {
1171            relative,
1172            tensor_idx,
1173            stride_offset,
1174            in_disk,
1175        });
1176    }
1177
1178    ops
1179}
1180
1181/// Resolve the field data index for an absolute coordinate.
1182///
1183/// Handles three cases:
1184/// 1. Torus (all_wrap): modular wrap, always in-bounds.
1185/// 2. Grid with geometry: bounds-check then stride arithmetic.
1186/// 3. No geometry: fall back to `space.canonical_rank()`.
1187fn resolve_field_index(
1188    center: &Coord,
1189    relative: &Coord,
1190    geometry: &Option<GridGeometry>,
1191    space: &dyn Space,
1192) -> Option<usize> {
1193    if let Some(geo) = geometry {
1194        if geo.all_wrap {
1195            // Torus: wrap coordinates with modular arithmetic.
1196            let wrapped: Coord = center
1197                .iter()
1198                .zip(relative.iter())
1199                .zip(geo.coord_dims.iter())
1200                .map(|((&c, &r), &d)| {
1201                    let d = d as i32;
1202                    ((c + r) % d + d) % d
1203                })
1204                .collect();
1205            Some(geo.canonical_rank(&wrapped))
1206        } else {
1207            let abs_coord: Coord = center
1208                .iter()
1209                .zip(relative.iter())
1210                .map(|(&c, &r)| c + r)
1211                .collect();
1212            let abs_slice: &[i32] = &abs_coord;
1213            if geo.in_bounds(abs_slice) {
1214                Some(geo.canonical_rank(abs_slice))
1215            } else {
1216                None
1217            }
1218        }
1219    } else {
1220        let abs_coord: Coord = center
1221            .iter()
1222            .zip(relative.iter())
1223            .map(|(&c, &r)| c + r)
1224            .collect();
1225        space.canonical_rank(&abs_coord)
1226    }
1227}
1228
1229/// Apply a transform to a raw field value.
1230fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1231    match transform {
1232        ObsTransform::Identity => raw,
1233        ObsTransform::Normalize { min, max } => {
1234            let range = max - min;
1235            if range == 0.0 {
1236                0.0
1237            } else {
1238                let normalized = (raw as f64 - min) / range;
1239                normalized.clamp(0.0, 1.0) as f32
1240            }
1241        }
1242    }
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247    use super::*;
1248    use crate::spec::{
1249        ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1250    };
1251    use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1252    use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1253    use murk_test_utils::MockSnapshot;
1254
1255    fn square4_space() -> Square4 {
1256        Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1257    }
1258
1259    fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1260        let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1261        snap.set_field(field, data);
1262        snap
1263    }
1264
1265    // ── Compilation tests ────────────────────────────────────
1266
1267    #[test]
1268    fn compile_empty_spec_errors() {
1269        let space = square4_space();
1270        let spec = ObsSpec { entries: vec![] };
1271        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1272        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1273    }
1274
1275    #[test]
1276    fn compile_all_region_square4() {
1277        let space = square4_space();
1278        let spec = ObsSpec {
1279            entries: vec![ObsEntry {
1280                field_id: FieldId(0),
1281                region: ObsRegion::Fixed(RegionSpec::All),
1282                pool: None,
1283                transform: ObsTransform::Identity,
1284                dtype: ObsDtype::F32,
1285            }],
1286        };
1287        let result = ObsPlan::compile(&spec, &space).unwrap();
1288        assert_eq!(result.output_len, 9); // 3x3
1289        assert_eq!(result.mask_len, 9);
1290        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1291    }
1292
1293    #[test]
1294    fn compile_rect_region() {
1295        let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1296        let spec = ObsSpec {
1297            entries: vec![ObsEntry {
1298                field_id: FieldId(0),
1299                region: ObsRegion::Fixed(RegionSpec::Rect {
1300                    min: smallvec::smallvec![1, 1],
1301                    max: smallvec::smallvec![2, 3],
1302                }),
1303                pool: None,
1304                transform: ObsTransform::Identity,
1305                dtype: ObsDtype::F32,
1306            }],
1307        };
1308        let result = ObsPlan::compile(&spec, &space).unwrap();
1309        // 2 rows x 3 cols = 6 cells
1310        assert_eq!(result.output_len, 6);
1311        assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1312    }
1313
1314    #[test]
1315    fn compile_two_entries_offsets() {
1316        let space = square4_space();
1317        let spec = ObsSpec {
1318            entries: vec![
1319                ObsEntry {
1320                    field_id: FieldId(0),
1321                    region: ObsRegion::Fixed(RegionSpec::All),
1322                    pool: None,
1323                    transform: ObsTransform::Identity,
1324                    dtype: ObsDtype::F32,
1325                },
1326                ObsEntry {
1327                    field_id: FieldId(1),
1328                    region: ObsRegion::Fixed(RegionSpec::All),
1329                    pool: None,
1330                    transform: ObsTransform::Identity,
1331                    dtype: ObsDtype::F32,
1332                },
1333            ],
1334        };
1335        let result = ObsPlan::compile(&spec, &space).unwrap();
1336        assert_eq!(result.output_len, 18); // 9 + 9
1337        assert_eq!(result.mask_len, 18);
1338    }
1339
1340    #[test]
1341    fn compile_invalid_region_errors() {
1342        let space = square4_space();
1343        let spec = ObsSpec {
1344            entries: vec![ObsEntry {
1345                field_id: FieldId(0),
1346                region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1347                pool: None,
1348                transform: ObsTransform::Identity,
1349                dtype: ObsDtype::F32,
1350            }],
1351        };
1352        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1353        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1354    }
1355
1356    // ── Execution tests ──────────────────────────────────────
1357
1358    #[test]
1359    fn execute_identity_all_region() {
1360        let space = square4_space();
1361        // Field data in canonical (row-major) order for 3x3:
1362        // (0,0)=1, (0,1)=2, (0,2)=3, (1,0)=4, ..., (2,2)=9
1363        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1364        let snap = snapshot_with_field(FieldId(0), data);
1365
1366        let spec = ObsSpec {
1367            entries: vec![ObsEntry {
1368                field_id: FieldId(0),
1369                region: ObsRegion::Fixed(RegionSpec::All),
1370                pool: None,
1371                transform: ObsTransform::Identity,
1372                dtype: ObsDtype::F32,
1373            }],
1374        };
1375        let result = ObsPlan::compile(&spec, &space).unwrap();
1376
1377        let mut output = vec![0.0f32; result.output_len];
1378        let mut mask = vec![0u8; result.mask_len];
1379        let meta = result
1380            .plan
1381            .execute(&snap, None, &mut output, &mut mask)
1382            .unwrap();
1383
1384        // Output should match field data in canonical order.
1385        let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1386        assert_eq!(output, expected);
1387        assert_eq!(mask, vec![1u8; 9]);
1388        assert_eq!(meta.tick_id, TickId(5));
1389        assert_eq!(meta.coverage, 1.0);
1390        assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1391        assert_eq!(meta.parameter_version, ParameterVersion(0));
1392        assert_eq!(meta.age_ticks, 0);
1393    }
1394
1395    #[test]
1396    fn execute_normalize_transform() {
1397        let space = square4_space();
1398        // Values 0..8 mapped to [0,1] with min=0, max=8.
1399        let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1400        let snap = snapshot_with_field(FieldId(0), data);
1401
1402        let spec = ObsSpec {
1403            entries: vec![ObsEntry {
1404                field_id: FieldId(0),
1405                region: ObsRegion::Fixed(RegionSpec::All),
1406                pool: None,
1407                transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1408                dtype: ObsDtype::F32,
1409            }],
1410        };
1411        let result = ObsPlan::compile(&spec, &space).unwrap();
1412
1413        let mut output = vec![0.0f32; result.output_len];
1414        let mut mask = vec![0u8; result.mask_len];
1415        result
1416            .plan
1417            .execute(&snap, None, &mut output, &mut mask)
1418            .unwrap();
1419
1420        // Each value x should be x/8.
1421        for (i, &v) in output.iter().enumerate() {
1422            let expected = i as f32 / 8.0;
1423            assert!(
1424                (v - expected).abs() < 1e-6,
1425                "output[{i}] = {v}, expected {expected}"
1426            );
1427        }
1428    }
1429
1430    #[test]
1431    fn execute_normalize_clamps_out_of_range() {
1432        let space = square4_space();
1433        // Values -5, 0, 5, 10, 15 etc.
1434        let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1435        let snap = snapshot_with_field(FieldId(0), data);
1436
1437        let spec = ObsSpec {
1438            entries: vec![ObsEntry {
1439                field_id: FieldId(0),
1440                region: ObsRegion::Fixed(RegionSpec::All),
1441                pool: None,
1442                transform: ObsTransform::Normalize {
1443                    min: 0.0,
1444                    max: 10.0,
1445                },
1446                dtype: ObsDtype::F32,
1447            }],
1448        };
1449        let result = ObsPlan::compile(&spec, &space).unwrap();
1450
1451        let mut output = vec![0.0f32; result.output_len];
1452        let mut mask = vec![0u8; result.mask_len];
1453        result
1454            .plan
1455            .execute(&snap, None, &mut output, &mut mask)
1456            .unwrap();
1457
1458        for &v in &output {
1459            assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1460        }
1461    }
1462
1463    #[test]
1464    fn execute_normalize_zero_range() {
1465        let space = square4_space();
1466        let data = vec![5.0f32; 9];
1467        let snap = snapshot_with_field(FieldId(0), data);
1468
1469        let spec = ObsSpec {
1470            entries: vec![ObsEntry {
1471                field_id: FieldId(0),
1472                region: ObsRegion::Fixed(RegionSpec::All),
1473                pool: None,
1474                transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1475                dtype: ObsDtype::F32,
1476            }],
1477        };
1478        let result = ObsPlan::compile(&spec, &space).unwrap();
1479
1480        let mut output = vec![-1.0f32; result.output_len];
1481        let mut mask = vec![0u8; result.mask_len];
1482        result
1483            .plan
1484            .execute(&snap, None, &mut output, &mut mask)
1485            .unwrap();
1486
1487        // Zero range → all outputs 0.0.
1488        assert!(output.iter().all(|&v| v == 0.0));
1489    }
1490
1491    #[test]
1492    fn execute_rect_subregion_correct_values() {
1493        let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1494        // 4x4 field: value = row * 4 + col + 1.
1495        let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1496        let snap = snapshot_with_field(FieldId(0), data);
1497
1498        let spec = ObsSpec {
1499            entries: vec![ObsEntry {
1500                field_id: FieldId(0),
1501                region: ObsRegion::Fixed(RegionSpec::Rect {
1502                    min: smallvec::smallvec![1, 1],
1503                    max: smallvec::smallvec![2, 2],
1504                }),
1505                pool: None,
1506                transform: ObsTransform::Identity,
1507                dtype: ObsDtype::F32,
1508            }],
1509        };
1510        let result = ObsPlan::compile(&spec, &space).unwrap();
1511        assert_eq!(result.output_len, 4); // 2x2
1512
1513        let mut output = vec![0.0f32; result.output_len];
1514        let mut mask = vec![0u8; result.mask_len];
1515        result
1516            .plan
1517            .execute(&snap, None, &mut output, &mut mask)
1518            .unwrap();
1519
1520        // Rect covers (1,1)=6, (1,2)=7, (2,1)=10, (2,2)=11
1521        assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1522        assert_eq!(mask, vec![1, 1, 1, 1]);
1523    }
1524
1525    #[test]
1526    fn execute_two_fields() {
1527        let space = square4_space();
1528        let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1529        let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1530        let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1531        snap.set_field(FieldId(0), data_a);
1532        snap.set_field(FieldId(1), data_b);
1533
1534        let spec = ObsSpec {
1535            entries: vec![
1536                ObsEntry {
1537                    field_id: FieldId(0),
1538                    region: ObsRegion::Fixed(RegionSpec::All),
1539                    pool: None,
1540                    transform: ObsTransform::Identity,
1541                    dtype: ObsDtype::F32,
1542                },
1543                ObsEntry {
1544                    field_id: FieldId(1),
1545                    region: ObsRegion::Fixed(RegionSpec::All),
1546                    pool: None,
1547                    transform: ObsTransform::Identity,
1548                    dtype: ObsDtype::F32,
1549                },
1550            ],
1551        };
1552        let result = ObsPlan::compile(&spec, &space).unwrap();
1553        assert_eq!(result.output_len, 18);
1554
1555        let mut output = vec![0.0f32; result.output_len];
1556        let mut mask = vec![0u8; result.mask_len];
1557        result
1558            .plan
1559            .execute(&snap, None, &mut output, &mut mask)
1560            .unwrap();
1561
1562        // First 9: field 0, next 9: field 1.
1563        let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1564        let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1565        assert_eq!(&output[..9], &expected_a);
1566        assert_eq!(&output[9..], &expected_b);
1567    }
1568
1569    #[test]
1570    fn execute_missing_field_errors() {
1571        let space = square4_space();
1572        let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1573
1574        let spec = ObsSpec {
1575            entries: vec![ObsEntry {
1576                field_id: FieldId(0),
1577                region: ObsRegion::Fixed(RegionSpec::All),
1578                pool: None,
1579                transform: ObsTransform::Identity,
1580                dtype: ObsDtype::F32,
1581            }],
1582        };
1583        let result = ObsPlan::compile(&spec, &space).unwrap();
1584
1585        let mut output = vec![0.0f32; result.output_len];
1586        let mut mask = vec![0u8; result.mask_len];
1587        let err = result
1588            .plan
1589            .execute(&snap, None, &mut output, &mut mask)
1590            .unwrap_err();
1591        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1592    }
1593
1594    #[test]
1595    fn execute_buffer_too_small_errors() {
1596        let space = square4_space();
1597        let data: Vec<f32> = vec![0.0; 9];
1598        let snap = snapshot_with_field(FieldId(0), data);
1599
1600        let spec = ObsSpec {
1601            entries: vec![ObsEntry {
1602                field_id: FieldId(0),
1603                region: ObsRegion::Fixed(RegionSpec::All),
1604                pool: None,
1605                transform: ObsTransform::Identity,
1606                dtype: ObsDtype::F32,
1607            }],
1608        };
1609        let result = ObsPlan::compile(&spec, &space).unwrap();
1610
1611        let mut output = vec![0.0f32; 4]; // too small
1612        let mut mask = vec![0u8; result.mask_len];
1613        let err = result
1614            .plan
1615            .execute(&snap, None, &mut output, &mut mask)
1616            .unwrap_err();
1617        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1618    }
1619
1620    // ── Validity / coverage tests ────────────────────────────
1621
1622    #[test]
1623    fn valid_ratio_one_for_square_all() {
1624        let space = square4_space();
1625        let data: Vec<f32> = vec![1.0; 9];
1626        let snap = snapshot_with_field(FieldId(0), data);
1627
1628        let spec = ObsSpec {
1629            entries: vec![ObsEntry {
1630                field_id: FieldId(0),
1631                region: ObsRegion::Fixed(RegionSpec::All),
1632                pool: None,
1633                transform: ObsTransform::Identity,
1634                dtype: ObsDtype::F32,
1635            }],
1636        };
1637        let result = ObsPlan::compile(&spec, &space).unwrap();
1638
1639        let mut output = vec![0.0f32; result.output_len];
1640        let mut mask = vec![0u8; result.mask_len];
1641        let meta = result
1642            .plan
1643            .execute(&snap, None, &mut output, &mut mask)
1644            .unwrap();
1645
1646        assert_eq!(meta.coverage, 1.0);
1647    }
1648
1649    // ── Generation binding tests ─────────────────────────────
1650
1651    #[test]
1652    fn plan_invalidated_on_generation_mismatch() {
1653        let space = square4_space();
1654        let data: Vec<f32> = vec![1.0; 9];
1655        let snap = snapshot_with_field(FieldId(0), data);
1656
1657        let spec = ObsSpec {
1658            entries: vec![ObsEntry {
1659                field_id: FieldId(0),
1660                region: ObsRegion::Fixed(RegionSpec::All),
1661                pool: None,
1662                transform: ObsTransform::Identity,
1663                dtype: ObsDtype::F32,
1664            }],
1665        };
1666        // Compile bound to generation 99, but snapshot is generation 1.
1667        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1668
1669        let mut output = vec![0.0f32; result.output_len];
1670        let mut mask = vec![0u8; result.mask_len];
1671        let err = result
1672            .plan
1673            .execute(&snap, None, &mut output, &mut mask)
1674            .unwrap_err();
1675        assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1676    }
1677
1678    #[test]
1679    fn generation_match_succeeds() {
1680        let space = square4_space();
1681        let data: Vec<f32> = vec![1.0; 9];
1682        let snap = snapshot_with_field(FieldId(0), data);
1683
1684        let spec = ObsSpec {
1685            entries: vec![ObsEntry {
1686                field_id: FieldId(0),
1687                region: ObsRegion::Fixed(RegionSpec::All),
1688                pool: None,
1689                transform: ObsTransform::Identity,
1690                dtype: ObsDtype::F32,
1691            }],
1692        };
1693        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1694
1695        let mut output = vec![0.0f32; result.output_len];
1696        let mut mask = vec![0u8; result.mask_len];
1697        result
1698            .plan
1699            .execute(&snap, None, &mut output, &mut mask)
1700            .unwrap();
1701    }
1702
1703    #[test]
1704    fn unbound_plan_ignores_generation() {
1705        let space = square4_space();
1706        let data: Vec<f32> = vec![1.0; 9];
1707        let snap = snapshot_with_field(FieldId(0), data);
1708
1709        let spec = ObsSpec {
1710            entries: vec![ObsEntry {
1711                field_id: FieldId(0),
1712                region: ObsRegion::Fixed(RegionSpec::All),
1713                pool: None,
1714                transform: ObsTransform::Identity,
1715                dtype: ObsDtype::F32,
1716            }],
1717        };
1718        // Unbound plan — no generation check.
1719        let result = ObsPlan::compile(&spec, &space).unwrap();
1720
1721        let mut output = vec![0.0f32; result.output_len];
1722        let mut mask = vec![0u8; result.mask_len];
1723        result
1724            .plan
1725            .execute(&snap, None, &mut output, &mut mask)
1726            .unwrap();
1727    }
1728
1729    // ── Metadata tests ───────────────────────────────────────
1730
1731    #[test]
1732    fn metadata_fields_populated() {
1733        let space = square4_space();
1734        let data: Vec<f32> = vec![1.0; 9];
1735        let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1736        snap.set_field(FieldId(0), data);
1737
1738        let spec = ObsSpec {
1739            entries: vec![ObsEntry {
1740                field_id: FieldId(0),
1741                region: ObsRegion::Fixed(RegionSpec::All),
1742                pool: None,
1743                transform: ObsTransform::Identity,
1744                dtype: ObsDtype::F32,
1745            }],
1746        };
1747        let result = ObsPlan::compile(&spec, &space).unwrap();
1748
1749        let mut output = vec![0.0f32; result.output_len];
1750        let mut mask = vec![0u8; result.mask_len];
1751        let meta = result
1752            .plan
1753            .execute(&snap, None, &mut output, &mut mask)
1754            .unwrap();
1755
1756        assert_eq!(meta.tick_id, TickId(42));
1757        assert_eq!(meta.age_ticks, 0);
1758        assert_eq!(meta.coverage, 1.0);
1759        assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1760        assert_eq!(meta.parameter_version, ParameterVersion(3));
1761    }
1762
1763    // ── Batch execution tests ────────────────────────────────
1764
1765    #[test]
1766    fn execute_batch_n1_matches_execute() {
1767        let space = square4_space();
1768        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1769        let snap = snapshot_with_field(FieldId(0), data.clone());
1770
1771        let spec = ObsSpec {
1772            entries: vec![ObsEntry {
1773                field_id: FieldId(0),
1774                region: ObsRegion::Fixed(RegionSpec::All),
1775                pool: None,
1776                transform: ObsTransform::Identity,
1777                dtype: ObsDtype::F32,
1778            }],
1779        };
1780        let result = ObsPlan::compile(&spec, &space).unwrap();
1781
1782        // Single execute.
1783        let mut out_single = vec![0.0f32; result.output_len];
1784        let mut mask_single = vec![0u8; result.mask_len];
1785        let meta_single = result
1786            .plan
1787            .execute(&snap, None, &mut out_single, &mut mask_single)
1788            .unwrap();
1789
1790        // Batch N=1.
1791        let mut out_batch = vec![0.0f32; result.output_len];
1792        let mut mask_batch = vec![0u8; result.mask_len];
1793        let snap_ref: &dyn SnapshotAccess = &snap;
1794        let meta_batch = result
1795            .plan
1796            .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1797            .unwrap();
1798
1799        assert_eq!(out_single, out_batch);
1800        assert_eq!(mask_single, mask_batch);
1801        assert_eq!(meta_single, meta_batch[0]);
1802    }
1803
1804    #[test]
1805    fn execute_batch_multiple_snapshots() {
1806        let space = square4_space();
1807        let spec = ObsSpec {
1808            entries: vec![ObsEntry {
1809                field_id: FieldId(0),
1810                region: ObsRegion::Fixed(RegionSpec::All),
1811                pool: None,
1812                transform: ObsTransform::Identity,
1813                dtype: ObsDtype::F32,
1814            }],
1815        };
1816        let result = ObsPlan::compile(&spec, &space).unwrap();
1817
1818        let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1819        let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1820
1821        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1822        let mut output = vec![0.0f32; result.output_len * 2];
1823        let mut mask = vec![0u8; result.mask_len * 2];
1824        let metas = result
1825            .plan
1826            .execute_batch(&snaps, None, &mut output, &mut mask)
1827            .unwrap();
1828
1829        assert_eq!(metas.len(), 2);
1830        assert!(output[..9].iter().all(|&v| v == 1.0));
1831        assert!(output[9..].iter().all(|&v| v == 2.0));
1832    }
1833
1834    #[test]
1835    fn execute_batch_buffer_too_small() {
1836        let space = square4_space();
1837        let spec = ObsSpec {
1838            entries: vec![ObsEntry {
1839                field_id: FieldId(0),
1840                region: ObsRegion::Fixed(RegionSpec::All),
1841                pool: None,
1842                transform: ObsTransform::Identity,
1843                dtype: ObsDtype::F32,
1844            }],
1845        };
1846        let result = ObsPlan::compile(&spec, &space).unwrap();
1847
1848        let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1849        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1850        let mut output = vec![0.0f32; 9]; // need 18
1851        let mut mask = vec![0u8; 18];
1852        let err = result
1853            .plan
1854            .execute_batch(&snaps, None, &mut output, &mut mask)
1855            .unwrap_err();
1856        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1857    }
1858
1859    // ── Field length mismatch tests ──────────────────────────
1860
1861    #[test]
1862    fn short_field_buffer_returns_error_not_panic() {
1863        let space = square4_space(); // 3x3 = 9 cells
1864        let spec = ObsSpec {
1865            entries: vec![ObsEntry {
1866                field_id: FieldId(0),
1867                region: ObsRegion::Fixed(RegionSpec::All),
1868                pool: None,
1869                transform: ObsTransform::Identity,
1870                dtype: ObsDtype::F32,
1871            }],
1872        };
1873        let result = ObsPlan::compile(&spec, &space).unwrap();
1874
1875        // Snapshot field has only 4 elements, but plan expects 9.
1876        let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
1877        let mut output = vec![0.0f32; result.output_len];
1878        let mut mask = vec![0u8; result.mask_len];
1879        let err = result
1880            .plan
1881            .execute(&snap, None, &mut output, &mut mask)
1882            .unwrap_err();
1883        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1884    }
1885
1886    // ── Standard plan (agent-centered) tests ─────────────────
1887
1888    #[test]
1889    fn standard_plan_detected_from_agent_region() {
1890        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1891        let spec = ObsSpec {
1892            entries: vec![ObsEntry {
1893                field_id: FieldId(0),
1894                region: ObsRegion::AgentRect {
1895                    half_extent: smallvec::smallvec![2, 2],
1896                },
1897                pool: None,
1898                transform: ObsTransform::Identity,
1899                dtype: ObsDtype::F32,
1900            }],
1901        };
1902        let result = ObsPlan::compile(&spec, &space).unwrap();
1903        assert!(result.plan.is_standard());
1904        // 5x5 = 25 elements
1905        assert_eq!(result.output_len, 25);
1906        assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
1907    }
1908
1909    #[test]
1910    fn execute_on_standard_plan_errors() {
1911        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1912        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
1913        let snap = snapshot_with_field(FieldId(0), data);
1914
1915        let spec = ObsSpec {
1916            entries: vec![ObsEntry {
1917                field_id: FieldId(0),
1918                region: ObsRegion::AgentDisk { radius: 2 },
1919                pool: None,
1920                transform: ObsTransform::Identity,
1921                dtype: ObsDtype::F32,
1922            }],
1923        };
1924        let result = ObsPlan::compile(&spec, &space).unwrap();
1925
1926        let mut output = vec![0.0f32; result.output_len];
1927        let mut mask = vec![0u8; result.mask_len];
1928        let err = result
1929            .plan
1930            .execute(&snap, None, &mut output, &mut mask)
1931            .unwrap_err();
1932        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1933    }
1934
1935    #[test]
1936    fn interior_boundary_equivalence() {
1937        // An INTERIOR agent using Standard plan should produce identical
1938        // output to a Simple plan with an explicit Rect at the same position.
1939        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
1940        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
1941        let snap = snapshot_with_field(FieldId(0), data);
1942
1943        let radius = 3u32;
1944        let center: Coord = smallvec::smallvec![10, 10]; // interior
1945
1946        // Standard plan: AgentRect centered on agent.
1947        let standard_spec = ObsSpec {
1948            entries: vec![ObsEntry {
1949                field_id: FieldId(0),
1950                region: ObsRegion::AgentRect {
1951                    half_extent: smallvec::smallvec![radius, radius],
1952                },
1953                pool: None,
1954                transform: ObsTransform::Identity,
1955                dtype: ObsDtype::F32,
1956            }],
1957        };
1958        let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
1959        let mut std_output = vec![0.0f32; std_result.output_len];
1960        let mut std_mask = vec![0u8; std_result.mask_len];
1961        std_result
1962            .plan
1963            .execute_agents(
1964                &snap,
1965                &space,
1966                std::slice::from_ref(&center),
1967                None,
1968                &mut std_output,
1969                &mut std_mask,
1970            )
1971            .unwrap();
1972
1973        // Simple plan: explicit Rect covering the same area.
1974        let r = radius as i32;
1975        let simple_spec = ObsSpec {
1976            entries: vec![ObsEntry {
1977                field_id: FieldId(0),
1978                region: ObsRegion::Fixed(RegionSpec::Rect {
1979                    min: smallvec::smallvec![10 - r, 10 - r],
1980                    max: smallvec::smallvec![10 + r, 10 + r],
1981                }),
1982                pool: None,
1983                transform: ObsTransform::Identity,
1984                dtype: ObsDtype::F32,
1985            }],
1986        };
1987        let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
1988        let mut simple_output = vec![0.0f32; simple_result.output_len];
1989        let mut simple_mask = vec![0u8; simple_result.mask_len];
1990        simple_result
1991            .plan
1992            .execute(&snap, None, &mut simple_output, &mut simple_mask)
1993            .unwrap();
1994
1995        // Same shape, same values.
1996        assert_eq!(std_result.output_len, simple_result.output_len);
1997        assert_eq!(std_output, simple_output);
1998        assert_eq!(std_mask, simple_mask);
1999    }
2000
2001    #[test]
2002    fn boundary_agent_gets_padding() {
2003        // Agent at corner (0,0) with radius 2: many cells out-of-bounds.
2004        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2005        let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
2006        let snap = snapshot_with_field(FieldId(0), data);
2007
2008        let spec = ObsSpec {
2009            entries: vec![ObsEntry {
2010                field_id: FieldId(0),
2011                region: ObsRegion::AgentRect {
2012                    half_extent: smallvec::smallvec![2, 2],
2013                },
2014                pool: None,
2015                transform: ObsTransform::Identity,
2016                dtype: ObsDtype::F32,
2017            }],
2018        };
2019        let result = ObsPlan::compile(&spec, &space).unwrap();
2020        let center: Coord = smallvec::smallvec![0, 0];
2021        let mut output = vec![0.0f32; result.output_len];
2022        let mut mask = vec![0u8; result.mask_len];
2023        let metas = result
2024            .plan
2025            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2026            .unwrap();
2027
2028        // 5x5 = 25 cells total. Agent at (0,0) with radius 2:
2029        // Only cells with row in [0,2] and col in [0,2] are valid (3x3 = 9).
2030        let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
2031        assert_eq!(valid_count, 9);
2032
2033        // Coverage should be 9/25
2034        assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
2035
2036        // Check that the top-left corner of the bounding box (relative [-2,-2])
2037        // is padding (mask=0, value=0).
2038        assert_eq!(mask[0], 0); // relative (-2,-2) → absolute (-2,-2) → out of bounds
2039        assert_eq!(output[0], 0.0);
2040
2041        // The cell at relative (0,0) is at tensor_idx = 2*5+2 = 12
2042        // Absolute (0,0) → field value = 1.0
2043        assert_eq!(mask[12], 1);
2044        assert_eq!(output[12], 1.0);
2045    }
2046
2047    #[test]
2048    fn hex_foveation_interior() {
2049        // Test agent-centered observation on Hex2D grid.
2050        let space = Hex2D::new(20, 20).unwrap(); // 20 rows, 20 cols
2051        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2052        let snap = snapshot_with_field(FieldId(0), data);
2053
2054        let spec = ObsSpec {
2055            entries: vec![ObsEntry {
2056                field_id: FieldId(0),
2057                region: ObsRegion::AgentDisk { radius: 2 },
2058                pool: None,
2059                transform: ObsTransform::Identity,
2060                dtype: ObsDtype::F32,
2061            }],
2062        };
2063        let result = ObsPlan::compile(&spec, &space).unwrap();
2064        assert_eq!(result.output_len, 25); // 5x5 bounding box (tensor shape)
2065
2066        // Interior agent: q=10, r=10
2067        let center: Coord = smallvec::smallvec![10, 10];
2068        let mut output = vec![0.0f32; result.output_len];
2069        let mut mask = vec![0u8; result.mask_len];
2070        result
2071            .plan
2072            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2073            .unwrap();
2074
2075        // Hex disk of radius 2: 19 of 25 cells are within hex distance.
2076        // The 6 corners of the 5x5 bounding box exceed hex distance 2.
2077        // Hex distance = max(|dq|, |dr|, |dq+dr|) for axial coordinates.
2078        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2079        assert_eq!(valid_count, 19);
2080
2081        // Corners that should be masked out (distance > 2):
2082        // tensor_idx 0: dq=-2,dr=-2 → max(2,2,4)=4
2083        // tensor_idx 1: dq=-2,dr=-1 → max(2,1,3)=3
2084        // tensor_idx 5: dq=-1,dr=-2 → max(1,2,3)=3
2085        // tensor_idx 19: dq=+1,dr=+2 → max(1,2,3)=3
2086        // tensor_idx 23: dq=+2,dr=+1 → max(2,1,3)=3
2087        // tensor_idx 24: dq=+2,dr=+2 → max(2,2,4)=4
2088        for &idx in &[0, 1, 5, 19, 23, 24] {
2089            assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2090            assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2091        }
2092
2093        // Center cell is at tensor_idx = 2*5+2 = 12 (relative [0,0]).
2094        // Hex2D canonical_rank([q,r]) = r*cols + q = 10*20 + 10 = 210
2095        assert_eq!(output[12], 210.0);
2096
2097        // Cell at relative [1, 0] (dq=+1, dr=0) → absolute [11, 10]
2098        // rank = 10*20 + 11 = 211
2099        // In row-major bounding box: dim0_idx=3, dim1_idx=2 → tensor_idx = 3*5+2 = 17
2100        assert_eq!(output[17], 211.0);
2101    }
2102
2103    #[test]
2104    fn wrap_space_all_interior() {
2105        // Wrapped (torus) space: all agents are interior.
2106        let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2107        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2108        let snap = snapshot_with_field(FieldId(0), data);
2109
2110        let spec = ObsSpec {
2111            entries: vec![ObsEntry {
2112                field_id: FieldId(0),
2113                region: ObsRegion::AgentRect {
2114                    half_extent: smallvec::smallvec![2, 2],
2115                },
2116                pool: None,
2117                transform: ObsTransform::Identity,
2118                dtype: ObsDtype::F32,
2119            }],
2120        };
2121        let result = ObsPlan::compile(&spec, &space).unwrap();
2122
2123        // Agent at corner (0,0) — still interior on torus.
2124        let center: Coord = smallvec::smallvec![0, 0];
2125        let mut output = vec![0.0f32; result.output_len];
2126        let mut mask = vec![0u8; result.mask_len];
2127        result
2128            .plan
2129            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2130            .unwrap();
2131
2132        // All 25 cells valid (torus wraps).
2133        assert!(mask.iter().all(|&v| v == 1));
2134        assert_eq!(output[12], 0.0); // center (0,0) → rank 0
2135    }
2136
2137    #[test]
2138    fn execute_agents_multiple() {
2139        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2140        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2141        let snap = snapshot_with_field(FieldId(0), data);
2142
2143        let spec = ObsSpec {
2144            entries: vec![ObsEntry {
2145                field_id: FieldId(0),
2146                region: ObsRegion::AgentRect {
2147                    half_extent: smallvec::smallvec![1, 1],
2148                },
2149                pool: None,
2150                transform: ObsTransform::Identity,
2151                dtype: ObsDtype::F32,
2152            }],
2153        };
2154        let result = ObsPlan::compile(&spec, &space).unwrap();
2155        assert_eq!(result.output_len, 9); // 3x3
2156
2157        // Two agents: one interior, one at edge.
2158        let centers = vec![
2159            smallvec::smallvec![5, 5], // interior
2160            smallvec::smallvec![0, 5], // top edge
2161        ];
2162        let n = centers.len();
2163        let mut output = vec![0.0f32; result.output_len * n];
2164        let mut mask = vec![0u8; result.mask_len * n];
2165        let metas = result
2166            .plan
2167            .execute_agents(&snap, &space, &centers, None, &mut output, &mut mask)
2168            .unwrap();
2169
2170        assert_eq!(metas.len(), 2);
2171
2172        // Agent 0 (interior): all 9 cells valid, center = (5,5) → rank 55
2173        assert!(mask[..9].iter().all(|&v| v == 1));
2174        assert_eq!(output[4], 55.0); // center at tensor_idx = 1*3+1 = 4
2175
2176        // Agent 1 (top edge): row -1 is out of bounds → 3 cells masked
2177        let agent1_mask = &mask[9..18];
2178        let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2179        assert_eq!(valid_count, 6); // 2 rows in-bounds × 3 cols
2180    }
2181
2182    #[test]
2183    fn execute_agents_with_normalize() {
2184        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2185        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2186        let snap = snapshot_with_field(FieldId(0), data);
2187
2188        let spec = ObsSpec {
2189            entries: vec![ObsEntry {
2190                field_id: FieldId(0),
2191                region: ObsRegion::AgentRect {
2192                    half_extent: smallvec::smallvec![1, 1],
2193                },
2194                pool: None,
2195                transform: ObsTransform::Normalize {
2196                    min: 0.0,
2197                    max: 99.0,
2198                },
2199                dtype: ObsDtype::F32,
2200            }],
2201        };
2202        let result = ObsPlan::compile(&spec, &space).unwrap();
2203
2204        let center: Coord = smallvec::smallvec![5, 5];
2205        let mut output = vec![0.0f32; result.output_len];
2206        let mut mask = vec![0u8; result.mask_len];
2207        result
2208            .plan
2209            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2210            .unwrap();
2211
2212        // Center (5,5) rank=55, normalized = 55/99 ≈ 0.5556
2213        let expected = 55.0 / 99.0;
2214        assert!((output[4] - expected as f32).abs() < 1e-5);
2215    }
2216
2217    #[test]
2218    fn execute_agents_with_pooling() {
2219        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2220        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2221        let snap = snapshot_with_field(FieldId(0), data);
2222
2223        // AgentRect with half_extent=3 → 7x7 bounding box.
2224        // Mean pool 2x2 stride 2 → floor((7-2)/2)+1 = 3 per dim → 3x3 = 9 output.
2225        let spec = ObsSpec {
2226            entries: vec![ObsEntry {
2227                field_id: FieldId(0),
2228                region: ObsRegion::AgentRect {
2229                    half_extent: smallvec::smallvec![3, 3],
2230                },
2231                pool: Some(PoolConfig {
2232                    kernel: PoolKernel::Mean,
2233                    kernel_size: 2,
2234                    stride: 2,
2235                }),
2236                transform: ObsTransform::Identity,
2237                dtype: ObsDtype::F32,
2238            }],
2239        };
2240        let result = ObsPlan::compile(&spec, &space).unwrap();
2241        assert_eq!(result.output_len, 9); // 3x3
2242        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2243
2244        // Interior agent at (10, 10): all cells valid.
2245        let center: Coord = smallvec::smallvec![10, 10];
2246        let mut output = vec![0.0f32; result.output_len];
2247        let mut mask = vec![0u8; result.mask_len];
2248        result
2249            .plan
2250            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2251            .unwrap();
2252
2253        // All pooled cells should be valid.
2254        assert!(mask.iter().all(|&v| v == 1));
2255
2256        // Verify first pooled cell: mean of top-left 2x2 of the 7x7 gather.
2257        // Gather bounding box starts at (10-3, 10-3) = (7, 7).
2258        // Top-left 2x2: (7,7)=147, (7,8)=148, (8,7)=167, (8,8)=168
2259        // Mean = (147+148+167+168)/4 = 157.5
2260        assert!((output[0] - 157.5).abs() < 1e-4);
2261    }
2262
2263    #[test]
2264    fn mixed_fixed_and_agent_entries() {
2265        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2266        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2267        let snap = snapshot_with_field(FieldId(0), data);
2268
2269        let spec = ObsSpec {
2270            entries: vec![
2271                // Fixed entry: full grid (100 elements).
2272                ObsEntry {
2273                    field_id: FieldId(0),
2274                    region: ObsRegion::Fixed(RegionSpec::All),
2275                    pool: None,
2276                    transform: ObsTransform::Identity,
2277                    dtype: ObsDtype::F32,
2278                },
2279                // Agent entry: 3x3 rect around agent.
2280                ObsEntry {
2281                    field_id: FieldId(0),
2282                    region: ObsRegion::AgentRect {
2283                        half_extent: smallvec::smallvec![1, 1],
2284                    },
2285                    pool: None,
2286                    transform: ObsTransform::Identity,
2287                    dtype: ObsDtype::F32,
2288                },
2289            ],
2290        };
2291        let result = ObsPlan::compile(&spec, &space).unwrap();
2292        assert!(result.plan.is_standard());
2293        assert_eq!(result.output_len, 109); // 100 + 9
2294
2295        let center: Coord = smallvec::smallvec![5, 5];
2296        let mut output = vec![0.0f32; result.output_len];
2297        let mut mask = vec![0u8; result.mask_len];
2298        result
2299            .plan
2300            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2301            .unwrap();
2302
2303        // Fixed entry: first 100 elements match field data.
2304        let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2305        assert_eq!(&output[..100], &expected[..]);
2306        assert!(mask[..100].iter().all(|&v| v == 1));
2307
2308        // Agent entry: 3x3 centered on (5,5). Center at tensor_idx = 1*3+1 = 4.
2309        // rank(5,5) = 55
2310        assert_eq!(output[100 + 4], 55.0);
2311    }
2312
2313    #[test]
2314    fn wrong_dimensionality_returns_error() {
2315        // 2D space but 1D agent center → should error, not panic.
2316        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2317        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2318        let snap = snapshot_with_field(FieldId(0), data);
2319
2320        let spec = ObsSpec {
2321            entries: vec![ObsEntry {
2322                field_id: FieldId(0),
2323                region: ObsRegion::AgentDisk { radius: 1 },
2324                pool: None,
2325                transform: ObsTransform::Identity,
2326                dtype: ObsDtype::F32,
2327            }],
2328        };
2329        let result = ObsPlan::compile(&spec, &space).unwrap();
2330
2331        let bad_center: Coord = smallvec::smallvec![5]; // 1D, not 2D
2332        let mut output = vec![0.0f32; result.output_len];
2333        let mut mask = vec![0u8; result.mask_len];
2334        let err =
2335            result
2336                .plan
2337                .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2338        assert!(err.is_err());
2339        let msg = format!("{}", err.unwrap_err());
2340        assert!(
2341            msg.contains("dimensions"),
2342            "error should mention dimensions: {msg}"
2343        );
2344    }
2345
2346    #[test]
2347    fn agent_disk_square4_filters_corners() {
2348        // On a 4-connected grid, AgentDisk radius=2 should use Manhattan distance.
2349        // Bounding box is 5x5 = 25, but Manhattan disk has 13 cells (diamond shape).
2350        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2351        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2352        let snap = snapshot_with_field(FieldId(0), data);
2353
2354        let spec = ObsSpec {
2355            entries: vec![ObsEntry {
2356                field_id: FieldId(0),
2357                region: ObsRegion::AgentDisk { radius: 2 },
2358                pool: None,
2359                transform: ObsTransform::Identity,
2360                dtype: ObsDtype::F32,
2361            }],
2362        };
2363        let result = ObsPlan::compile(&spec, &space).unwrap();
2364        assert_eq!(result.output_len, 25); // tensor shape is still 5x5
2365
2366        // Interior agent at (10, 10).
2367        let center: Coord = smallvec::smallvec![10, 10];
2368        let mut output = vec![0.0f32; 25];
2369        let mut mask = vec![0u8; 25];
2370        result
2371            .plan
2372            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2373            .unwrap();
2374
2375        // Manhattan distance disk of radius 2 on a 5x5 bounding box:
2376        //   . . X . .    (row -2: only center col)
2377        //   . X X X .    (row -1: 3 cells)
2378        //   X X X X X    (row  0: 5 cells)
2379        //   . X X X .    (row +1: 3 cells)
2380        //   . . X . .    (row +2: only center col)
2381        // Total: 1 + 3 + 5 + 3 + 1 = 13 cells
2382        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2383        assert_eq!(
2384            valid_count, 13,
2385            "Manhattan disk radius=2 should have 13 cells"
2386        );
2387
2388        // Corners should be masked out: (dr,dc) where |dr|+|dc| > 2
2389        // tensor_idx 0: dr=-2,dc=-2 → dist=4 → OUT
2390        // tensor_idx 4: dr=-2,dc=+2 → dist=4 → OUT
2391        // tensor_idx 20: dr=+2,dc=-2 → dist=4 → OUT
2392        // tensor_idx 24: dr=+2,dc=+2 → dist=4 → OUT
2393        for &idx in &[0, 4, 20, 24] {
2394            assert_eq!(
2395                mask[idx], 0,
2396                "corner tensor_idx {idx} should be outside disk"
2397            );
2398        }
2399
2400        // Center cell: tensor_idx = 2*5+2 = 12, absolute = row 10 * 20 + col 10 = 210
2401        assert_eq!(output[12], 210.0);
2402        assert_eq!(mask[12], 1);
2403    }
2404
2405    #[test]
2406    fn agent_rect_no_disk_filtering() {
2407        // AgentRect should NOT filter any cells — full rectangle is valid.
2408        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2409        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2410        let snap = snapshot_with_field(FieldId(0), data);
2411
2412        let spec = ObsSpec {
2413            entries: vec![ObsEntry {
2414                field_id: FieldId(0),
2415                region: ObsRegion::AgentRect {
2416                    half_extent: smallvec::smallvec![2, 2],
2417                },
2418                pool: None,
2419                transform: ObsTransform::Identity,
2420                dtype: ObsDtype::F32,
2421            }],
2422        };
2423        let result = ObsPlan::compile(&spec, &space).unwrap();
2424
2425        let center: Coord = smallvec::smallvec![10, 10];
2426        let mut output = vec![0.0f32; 25];
2427        let mut mask = vec![0u8; 25];
2428        result
2429            .plan
2430            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2431            .unwrap();
2432
2433        // All 25 cells should be valid for AgentRect (no disk filtering).
2434        assert!(mask.iter().all(|&v| v == 1));
2435    }
2436
2437    #[test]
2438    fn agent_disk_square8_chebyshev() {
2439        // On an 8-connected grid, AgentDisk radius=1 uses Chebyshev distance.
2440        // Bounding box is 3x3 = 9, Chebyshev disk radius=1 = full 3x3 → 9 cells.
2441        let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2442        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2443        let snap = snapshot_with_field(FieldId(0), data);
2444
2445        let spec = ObsSpec {
2446            entries: vec![ObsEntry {
2447                field_id: FieldId(0),
2448                region: ObsRegion::AgentDisk { radius: 1 },
2449                pool: None,
2450                transform: ObsTransform::Identity,
2451                dtype: ObsDtype::F32,
2452            }],
2453        };
2454        let result = ObsPlan::compile(&spec, &space).unwrap();
2455        assert_eq!(result.output_len, 9);
2456
2457        let center: Coord = smallvec::smallvec![5, 5];
2458        let mut output = vec![0.0f32; 9];
2459        let mut mask = vec![0u8; 9];
2460        result
2461            .plan
2462            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2463            .unwrap();
2464
2465        // Chebyshev distance <= 1 covers full 3x3 = 9 cells (all corners included).
2466        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2467        assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2468    }
2469
2470    #[test]
2471    fn compile_rejects_inverted_normalize_range() {
2472        let space = square4_space();
2473        let spec = ObsSpec {
2474            entries: vec![ObsEntry {
2475                field_id: FieldId(0),
2476                region: ObsRegion::Fixed(RegionSpec::All),
2477                pool: None,
2478                transform: ObsTransform::Normalize {
2479                    min: 10.0,
2480                    max: 5.0,
2481                },
2482                dtype: ObsDtype::F32,
2483            }],
2484        };
2485        let err = ObsPlan::compile(&spec, &space).unwrap_err();
2486        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
2487    }
2488
2489    #[test]
2490    fn compile_rejects_nan_normalize() {
2491        let space = square4_space();
2492        let spec = ObsSpec {
2493            entries: vec![ObsEntry {
2494                field_id: FieldId(0),
2495                region: ObsRegion::Fixed(RegionSpec::All),
2496                pool: None,
2497                transform: ObsTransform::Normalize {
2498                    min: f64::NAN,
2499                    max: 1.0,
2500                },
2501                dtype: ObsDtype::F32,
2502            }],
2503        };
2504        assert!(ObsPlan::compile(&spec, &space).is_err());
2505    }
2506}