Skip to main content

murk_obs/
plan.rs

1//! Observation plan compilation and execution.
2//!
3//! [`ObsPlan`] is compiled from an [`ObsSpec`] + [`Space`], producing
4//! a reusable gather plan that can be executed against any
5//! [`SnapshotAccess`] implementor. The "Simple plan class" uses a
6//! branch-free flat gather: for each entry, iterate pre-computed
7//! `(field_data_index, tensor_index)` pairs, read the field value,
8//! optionally transform it, and write to the caller-allocated buffer.
9
10use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d_into;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21/// Coverage threshold: warn if valid_ratio < this.
22const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24/// Coverage threshold: error if valid_ratio < this.
25const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27/// Result of compiling an [`ObsSpec`].
28#[derive(Debug)]
29pub struct ObsPlanResult {
30    /// The compiled plan, ready for execution.
31    pub plan: ObsPlan,
32    /// Total number of f32 elements in the output tensor.
33    pub output_len: usize,
34    /// Shape per entry (each entry's region bounding shape dimensions).
35    pub entry_shapes: Vec<Vec<usize>>,
36    /// Length of the validity mask in bytes.
37    pub mask_len: usize,
38}
39
40/// Compiled observation plan: either Simple or Standard class.
41///
42/// **Simple** (all `Fixed` regions): pre-computed gather indices, branch-free
43/// loop, zero spatial computation at runtime. Use [`execute`](Self::execute).
44///
45/// **Standard** (any agent-relative region): template-based gather with
46/// interior/boundary dispatch. Use [`execute_agents`](Self::execute_agents).
47#[derive(Debug)]
48pub struct ObsPlan {
49    strategy: PlanStrategy,
50    /// Total output elements across all entries (per agent for Standard).
51    output_len: usize,
52    /// Total mask bytes across all entries (per agent for Standard).
53    mask_len: usize,
54    /// Generation at compile time (for PLAN_INVALIDATED detection).
55    compiled_generation: Option<WorldGenerationId>,
56}
57
58/// Pre-computed gather instruction for a single cell.
59///
60/// At execution time, we read `field_data[field_data_idx]` and write
61/// the (transformed) value to `output[tensor_idx]`.
62#[derive(Debug, Clone)]
63struct GatherOp {
64    /// Index into the flat field data array (canonical ordering).
65    field_data_idx: usize,
66    /// Index into the output slice for this entry.
67    tensor_idx: usize,
68}
69
70/// A single compiled entry ready for gather execution.
71#[derive(Debug)]
72struct CompiledEntry {
73    field_id: FieldId,
74    transform: ObsTransform,
75    #[allow(dead_code)]
76    dtype: ObsDtype,
77    /// Offset into the output buffer where this entry starts.
78    output_offset: usize,
79    /// Offset into the validity mask where this entry starts.
80    mask_offset: usize,
81    /// Number of elements this entry contributes to the output.
82    element_count: usize,
83    /// Pre-computed gather operations (one per valid cell in the region).
84    gather_ops: Vec<GatherOp>,
85    /// Pre-computed validity mask for this entry's bounding box.
86    valid_mask: Vec<u8>,
87    /// Number of valid cells in `valid_mask`.
88    valid_count: usize,
89    /// Valid ratio for this entry's region.
90    #[allow(dead_code)]
91    valid_ratio: f64,
92}
93
94/// Relative offset from agent center for template-based gather.
95///
96/// At compile time, the bounding box of an agent-centered region is
97/// decomposed into `TemplateOp`s. At execute time, the agent center
98/// is resolved and each op is applied: `field_data[base_rank + stride_offset]`.
99#[derive(Debug, Clone)]
100struct TemplateOp {
101    /// Offset from center per coordinate axis.
102    relative: Coord,
103    /// Position in the gather bounding-box tensor (row-major).
104    tensor_idx: usize,
105    /// Precomputed `sum(relative[i] * strides[i])` for interior fast path.
106    /// Zero if no `GridGeometry` is available (fallback path only).
107    stride_offset: isize,
108    /// Whether this cell is within the disk region (always true for AgentRect).
109    /// For AgentDisk, cells outside the graph-distance radius are excluded.
110    in_disk: bool,
111}
112
113/// Compiled agent-relative entry for the Standard plan class.
114///
115/// Stores template data that is instantiated per-agent at execute time.
116/// The bounding box shape comes from the region (e.g., `[2r+1, 2r+1]` for
117/// `AgentDisk`/`AgentRect`), and may be reduced by pooling.
118#[derive(Debug)]
119struct AgentCompiledEntry {
120    field_id: FieldId,
121    pool: Option<PoolConfig>,
122    transform: ObsTransform,
123    #[allow(dead_code)]
124    dtype: ObsDtype,
125    /// Offset into the per-agent output buffer.
126    output_offset: usize,
127    /// Offset into the per-agent mask buffer.
128    mask_offset: usize,
129    /// Post-pool output elements (written to output).
130    element_count: usize,
131    /// Pre-pool bounding-box elements (gather buffer size).
132    pre_pool_element_count: usize,
133    /// Shape of the pre-pool bounding box (e.g., `[7, 7]`).
134    pre_pool_shape: Vec<usize>,
135    /// Active template operations (in-disk only) for runtime gather.
136    active_ops: Vec<TemplateOp>,
137    /// Radius for `is_interior` check.
138    radius: u32,
139}
140
141/// Data for the Standard plan class (agent-centered foveation + pooling).
142#[derive(Debug)]
143struct StandardPlanData {
144    /// Entries with `ObsRegion::Fixed` (same output for all agents).
145    fixed_entries: Vec<CompiledEntry>,
146    /// Entries with agent-relative regions (resolved per-agent).
147    agent_entries: Vec<AgentCompiledEntry>,
148    /// Grid geometry for interior/boundary dispatch (`None` → all slow path).
149    geometry: Option<GridGeometry>,
150}
151
152/// Data for the Simple plan class (all entries fixed, branch-free gather).
153#[derive(Debug)]
154struct SimplePlanData {
155    entries: Vec<CompiledEntry>,
156    total_valid: usize,
157    total_elements: usize,
158}
159
160/// Internal plan strategy: Simple (all-fixed) or Standard (agent-centered).
161#[derive(Debug)]
162enum PlanStrategy {
163    /// All entries are `ObsRegion::Fixed`: pre-computed gather indices.
164    Simple(SimplePlanData),
165    /// At least one entry is agent-relative: template-based gather.
166    Standard(StandardPlanData),
167}
168
169impl ObsPlan {
170    /// Compile an [`ObsSpec`] against a [`Space`].
171    ///
172    /// Detects whether the spec contains agent-relative regions and
173    /// dispatches to the appropriate plan class:
174    /// - All `Fixed` → **Simple** (pre-computed gather)
175    /// - Any `AgentDisk`/`AgentRect` → **Standard** (template-based)
176    pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
177        if spec.entries.is_empty() {
178            return Err(ObsError::InvalidObsSpec {
179                reason: "ObsSpec has no entries".into(),
180            });
181        }
182
183        // Validate transform parameters.
184        for (i, entry) in spec.entries.iter().enumerate() {
185            if let ObsTransform::Normalize { min, max } = &entry.transform {
186                if !min.is_finite() || !max.is_finite() {
187                    return Err(ObsError::InvalidObsSpec {
188                        reason: format!(
189                            "entry {i}: Normalize min/max must be finite, got min={min}, max={max}"
190                        ),
191                    });
192                }
193                if min > max {
194                    return Err(ObsError::InvalidObsSpec {
195                        reason: format!("entry {i}: Normalize min ({min}) must be <= max ({max})"),
196                    });
197                }
198            }
199        }
200
201        let has_agent = spec.entries.iter().any(|e| {
202            matches!(
203                e.region,
204                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
205            )
206        });
207
208        if has_agent {
209            Self::compile_standard(spec, space)
210        } else {
211            Self::compile_simple(spec, space)
212        }
213    }
214
215    /// Compile a Simple plan (all `Fixed` regions, no agent-relative entries).
216    fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
217        let canonical = space.canonical_ordering();
218        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
219            .into_iter()
220            .enumerate()
221            .map(|(idx, coord)| (coord, idx))
222            .collect();
223
224        let mut entries = Vec::with_capacity(spec.entries.len());
225        let mut total_valid = 0usize;
226        let mut total_elements = 0usize;
227        let mut output_offset = 0usize;
228        let mut mask_offset = 0usize;
229        let mut entry_shapes = Vec::with_capacity(spec.entries.len());
230
231        for (i, entry) in spec.entries.iter().enumerate() {
232            let fixed_region = match &entry.region {
233                ObsRegion::Fixed(spec) => spec,
234                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
235                    return Err(ObsError::InvalidObsSpec {
236                        reason: format!("entry {i}: agent-relative region in Simple plan"),
237                    });
238                }
239            };
240            if entry.pool.is_some() {
241                return Err(ObsError::InvalidObsSpec {
242                    reason: format!(
243                        "entry {i}: pooling requires a Standard plan (use agent-relative region)"
244                    ),
245                });
246            }
247
248            let mut region_plan =
249                space
250                    .compile_region(fixed_region)
251                    .map_err(|e| ObsError::InvalidObsSpec {
252                        reason: format!("entry {i}: region compile failed: {e}"),
253                    })?;
254
255            let ratio = region_plan.valid_ratio();
256            if ratio < COVERAGE_ERROR_THRESHOLD {
257                return Err(ObsError::InvalidComposition {
258                    reason: format!(
259                        "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
260                    ),
261                });
262            }
263            if ratio < COVERAGE_WARN_THRESHOLD {
264                eprintln!(
265                    "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
266                );
267            }
268
269            let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
270            for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
271                let field_data_idx =
272                    *coord_to_field_idx
273                        .get(coord)
274                        .ok_or_else(|| ObsError::InvalidObsSpec {
275                            reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
276                        })?;
277                let tensor_idx = region_plan.tensor_indices()[coord_idx];
278                gather_ops.push(GatherOp {
279                    field_data_idx,
280                    tensor_idx,
281                });
282            }
283
284            let element_count = region_plan.bounding_shape().total_elements();
285            let shape = match region_plan.bounding_shape() {
286                murk_space::BoundingShape::Rect(dims) => dims.clone(),
287            };
288            entry_shapes.push(shape);
289
290            let valid_mask = region_plan.take_valid_mask();
291            let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
292            entries.push(CompiledEntry {
293                field_id: entry.field_id,
294                transform: entry.transform.clone(),
295                dtype: entry.dtype,
296                output_offset,
297                mask_offset,
298                element_count,
299                gather_ops,
300                valid_mask,
301                valid_count,
302                valid_ratio: ratio,
303            });
304            total_valid += valid_count;
305            total_elements += element_count;
306
307            output_offset += element_count;
308            mask_offset += element_count;
309        }
310
311        let plan = ObsPlan {
312            strategy: PlanStrategy::Simple(SimplePlanData {
313                entries,
314                total_valid,
315                total_elements,
316            }),
317            output_len: output_offset,
318            mask_len: mask_offset,
319            compiled_generation: None,
320        };
321
322        Ok(ObsPlanResult {
323            output_len: plan.output_len,
324            mask_len: plan.mask_len,
325            entry_shapes,
326            plan,
327        })
328    }
329
330    /// Compile a Standard plan (has agent-relative entries).
331    ///
332    /// Fixed entries are compiled with pre-computed gather (same for all agents).
333    /// Agent entries are compiled as templates (resolved per-agent at execute time).
334    fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
335        let canonical = space.canonical_ordering();
336        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
337            .into_iter()
338            .enumerate()
339            .map(|(idx, coord)| (coord, idx))
340            .collect();
341
342        let geometry = GridGeometry::from_space(space);
343        let ndim = space.ndim();
344
345        let mut fixed_entries = Vec::new();
346        let mut agent_entries = Vec::new();
347        let mut output_offset = 0usize;
348        let mut mask_offset = 0usize;
349        let mut entry_shapes = Vec::new();
350
351        for (i, entry) in spec.entries.iter().enumerate() {
352            match &entry.region {
353                ObsRegion::Fixed(region_spec) => {
354                    if entry.pool.is_some() {
355                        return Err(ObsError::InvalidObsSpec {
356                            reason: format!("entry {i}: pooling on Fixed regions not supported"),
357                        });
358                    }
359
360                    let mut region_plan = space.compile_region(region_spec).map_err(|e| {
361                        ObsError::InvalidObsSpec {
362                            reason: format!("entry {i}: region compile failed: {e}"),
363                        }
364                    })?;
365
366                    let ratio = region_plan.valid_ratio();
367                    if ratio < COVERAGE_ERROR_THRESHOLD {
368                        return Err(ObsError::InvalidComposition {
369                            reason: format!(
370                                "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
371                            ),
372                        });
373                    }
374
375                    let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
376                    for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
377                        let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
378                            ObsError::InvalidObsSpec {
379                                reason: format!(
380                                    "entry {i}: coord {coord:?} not in canonical ordering"
381                                ),
382                            }
383                        })?;
384                        let tensor_idx = region_plan.tensor_indices()[coord_idx];
385                        gather_ops.push(GatherOp {
386                            field_data_idx,
387                            tensor_idx,
388                        });
389                    }
390
391                    let element_count = region_plan.bounding_shape().total_elements();
392                    let shape = match region_plan.bounding_shape() {
393                        murk_space::BoundingShape::Rect(dims) => dims.clone(),
394                    };
395                    entry_shapes.push(shape);
396
397                    let valid_mask = region_plan.take_valid_mask();
398                    let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
399                    fixed_entries.push(CompiledEntry {
400                        field_id: entry.field_id,
401                        transform: entry.transform.clone(),
402                        dtype: entry.dtype,
403                        output_offset,
404                        mask_offset,
405                        element_count,
406                        gather_ops,
407                        valid_mask,
408                        valid_count,
409                        valid_ratio: ratio,
410                    });
411
412                    output_offset += element_count;
413                    mask_offset += element_count;
414                }
415
416                ObsRegion::AgentDisk { radius } => {
417                    let half_ext: smallvec::SmallVec<[u32; 4]> =
418                        (0..ndim).map(|_| *radius).collect();
419                    let (ae, shape) = Self::compile_agent_entry(
420                        i,
421                        entry,
422                        &half_ext,
423                        *radius,
424                        &geometry,
425                        Some(*radius),
426                        output_offset,
427                        mask_offset,
428                    )?;
429                    entry_shapes.push(shape);
430                    output_offset += ae.element_count;
431                    mask_offset += ae.element_count;
432                    agent_entries.push(ae);
433                }
434
435                ObsRegion::AgentRect { half_extent } => {
436                    let radius = *half_extent.iter().max().unwrap_or(&0);
437                    let (ae, shape) = Self::compile_agent_entry(
438                        i,
439                        entry,
440                        half_extent,
441                        radius,
442                        &geometry,
443                        None,
444                        output_offset,
445                        mask_offset,
446                    )?;
447                    entry_shapes.push(shape);
448                    output_offset += ae.element_count;
449                    mask_offset += ae.element_count;
450                    agent_entries.push(ae);
451                }
452            }
453        }
454
455        let plan = ObsPlan {
456            strategy: PlanStrategy::Standard(StandardPlanData {
457                fixed_entries,
458                agent_entries,
459                geometry,
460            }),
461            output_len: output_offset,
462            mask_len: mask_offset,
463            compiled_generation: None,
464        };
465
466        Ok(ObsPlanResult {
467            output_len: plan.output_len,
468            mask_len: plan.mask_len,
469            entry_shapes,
470            plan,
471        })
472    }
473
474    /// Compile a single agent-relative entry into a template.
475    ///
476    /// `disk_radius`: if `Some(r)`, template ops outside graph-distance `r`
477    /// are marked `in_disk = false` (for `AgentDisk`). `None` for `AgentRect`.
478    #[allow(clippy::too_many_arguments)]
479    fn compile_agent_entry(
480        entry_idx: usize,
481        entry: &crate::spec::ObsEntry,
482        half_extent: &[u32],
483        radius: u32,
484        geometry: &Option<GridGeometry>,
485        disk_radius: Option<u32>,
486        output_offset: usize,
487        mask_offset: usize,
488    ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
489        let pre_pool_shape: Vec<usize> =
490            half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
491        let pre_pool_element_count: usize = pre_pool_shape.iter().product();
492
493        let template_ops = generate_template_ops(half_extent, geometry, disk_radius);
494        let active_ops = template_ops
495            .iter()
496            .filter(|op| op.in_disk)
497            .cloned()
498            .collect();
499
500        let (element_count, output_shape) = if let Some(pool) = &entry.pool {
501            if pre_pool_shape.len() != 2 {
502                return Err(ObsError::InvalidObsSpec {
503                    reason: format!(
504                        "entry {entry_idx}: pooling requires 2D region, got {}D",
505                        pre_pool_shape.len()
506                    ),
507                });
508            }
509            let h = pre_pool_shape[0];
510            let w = pre_pool_shape[1];
511            let ks = pool.kernel_size;
512            let stride = pool.stride;
513            if ks == 0 || stride == 0 {
514                return Err(ObsError::InvalidObsSpec {
515                    reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
516                });
517            }
518            let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
519            let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
520            if out_h == 0 || out_w == 0 {
521                return Err(ObsError::InvalidObsSpec {
522                    reason: format!(
523                        "entry {entry_idx}: pool produces empty output \
524                         (region [{h},{w}], kernel_size {ks}, stride {stride})"
525                    ),
526                });
527            }
528            (out_h * out_w, vec![out_h, out_w])
529        } else {
530            (pre_pool_element_count, pre_pool_shape.clone())
531        };
532
533        Ok((
534            AgentCompiledEntry {
535                field_id: entry.field_id,
536                pool: entry.pool.clone(),
537                transform: entry.transform.clone(),
538                dtype: entry.dtype,
539                output_offset,
540                mask_offset,
541                element_count,
542                pre_pool_element_count,
543                pre_pool_shape,
544                active_ops,
545                radius,
546            },
547            output_shape,
548        ))
549    }
550
551    /// Compile with generation binding for PLAN_INVALIDATED detection.
552    ///
553    /// Same as [`compile`](Self::compile) but records the snapshot's
554    /// `world_generation_id` for later validation in [`ObsPlan::execute`].
555    pub fn compile_bound(
556        spec: &ObsSpec,
557        space: &dyn Space,
558        generation: WorldGenerationId,
559    ) -> Result<ObsPlanResult, ObsError> {
560        let mut result = Self::compile(spec, space)?;
561        result.plan.compiled_generation = Some(generation);
562        Ok(result)
563    }
564
565    /// Total number of f32 elements in the output tensor.
566    pub fn output_len(&self) -> usize {
567        self.output_len
568    }
569
570    /// Total number of bytes in the validity mask.
571    pub fn mask_len(&self) -> usize {
572        self.mask_len
573    }
574
575    /// The generation this plan was compiled against, if bound.
576    pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
577        self.compiled_generation
578    }
579
580    /// Execute the observation plan against a snapshot.
581    ///
582    /// Fills `output` with gathered and transformed field values, and
583    /// `mask` with validity flags (1 = valid, 0 = padding). Both
584    /// buffers must be pre-allocated to [`output_len`](Self::output_len)
585    /// and [`mask_len`](Self::mask_len) respectively.
586    ///
587    /// `engine_tick` is the current engine tick for computing
588    /// [`ObsMetadata::age_ticks`]. Pass `None` in Lockstep mode
589    /// (age is always 0). In RealtimeAsync mode, pass the current
590    /// engine tick so age reflects snapshot staleness.
591    ///
592    /// Returns [`ObsMetadata`] on success.
593    ///
594    /// # Errors
595    ///
596    /// - [`ObsError::PlanInvalidated`] if bound and generation mismatches.
597    /// - [`ObsError::ExecutionFailed`] if a field is missing from the snapshot.
598    pub fn execute(
599        &self,
600        snapshot: &dyn SnapshotAccess,
601        engine_tick: Option<TickId>,
602        output: &mut [f32],
603        mask: &mut [u8],
604    ) -> Result<ObsMetadata, ObsError> {
605        let simple = match &self.strategy {
606            PlanStrategy::Simple(data) => data,
607            PlanStrategy::Standard(_) => {
608                return Err(ObsError::ExecutionFailed {
609                    reason: "Standard plan requires execute_agents(), not execute()".into(),
610                });
611            }
612        };
613
614        if output.len() < self.output_len {
615            return Err(ObsError::ExecutionFailed {
616                reason: format!(
617                    "output buffer too small: {} < {}",
618                    output.len(),
619                    self.output_len
620                ),
621            });
622        }
623        if mask.len() < self.mask_len {
624            return Err(ObsError::ExecutionFailed {
625                reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
626            });
627        }
628
629        // Generation check (PLAN_INVALIDATED).
630        if let Some(compiled_gen) = self.compiled_generation {
631            let snapshot_gen = snapshot.world_generation_id();
632            if compiled_gen != snapshot_gen {
633                return Err(ObsError::PlanInvalidated {
634                    reason: format!(
635                        "plan compiled for generation {}, snapshot is generation {}",
636                        compiled_gen.0, snapshot_gen.0
637                    ),
638                });
639            }
640        }
641
642        Self::execute_simple_entries(&simple.entries, snapshot, output, mask)?;
643
644        let coverage = if simple.total_elements == 0 {
645            0.0
646        } else {
647            simple.total_valid as f64 / simple.total_elements as f64
648        };
649
650        let age_ticks = match engine_tick {
651            Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
652            None => 0,
653        };
654
655        Ok(ObsMetadata {
656            tick_id: snapshot.tick_id(),
657            age_ticks,
658            coverage,
659            world_generation_id: snapshot.world_generation_id(),
660            parameter_version: snapshot.parameter_version(),
661        })
662    }
663
664    /// Execute the plan for a batch of `N` identical environments.
665    ///
666    /// Each snapshot in the batch fills `output_len()` elements in the
667    /// output buffer, starting at `batch_idx * output_len()`. Same for
668    /// masks. This is the primary interface for vectorized RL training.
669    ///
670    /// Returns one [`ObsMetadata`] per snapshot.
671    pub fn execute_batch(
672        &self,
673        snapshots: &[&dyn SnapshotAccess],
674        engine_tick: Option<TickId>,
675        output: &mut [f32],
676        mask: &mut [u8],
677    ) -> Result<Vec<ObsMetadata>, ObsError> {
678        let simple = match &self.strategy {
679            PlanStrategy::Simple(data) => data,
680            PlanStrategy::Standard(_) => {
681                return Err(ObsError::ExecutionFailed {
682                    reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
683                });
684            }
685        };
686
687        let batch_size = snapshots.len();
688        let expected_out = batch_size * self.output_len;
689        let expected_mask = batch_size * self.mask_len;
690
691        if output.len() < expected_out {
692            return Err(ObsError::ExecutionFailed {
693                reason: format!(
694                    "batch output buffer too small: {} < {}",
695                    output.len(),
696                    expected_out
697                ),
698            });
699        }
700        if mask.len() < expected_mask {
701            return Err(ObsError::ExecutionFailed {
702                reason: format!(
703                    "batch mask buffer too small: {} < {}",
704                    mask.len(),
705                    expected_mask
706                ),
707            });
708        }
709
710        let coverage = if simple.total_elements == 0 {
711            0.0
712        } else {
713            simple.total_valid as f64 / simple.total_elements as f64
714        };
715
716        let mut metadata = Vec::with_capacity(batch_size);
717        for (i, snap) in snapshots.iter().enumerate() {
718            if let Some(compiled_gen) = self.compiled_generation {
719                let snapshot_gen = snap.world_generation_id();
720                if compiled_gen != snapshot_gen {
721                    return Err(ObsError::PlanInvalidated {
722                        reason: format!(
723                            "plan compiled for generation {}, snapshot is generation {}",
724                            compiled_gen.0, snapshot_gen.0
725                        ),
726                    });
727                }
728            }
729
730            let out_start = i * self.output_len;
731            let mask_start = i * self.mask_len;
732            let out_slice = &mut output[out_start..out_start + self.output_len];
733            let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
734            Self::execute_simple_entries(&simple.entries, *snap, out_slice, mask_slice)?;
735
736            let age_ticks = match engine_tick {
737                Some(tick) => tick.0.saturating_sub(snap.tick_id().0),
738                None => 0,
739            };
740            metadata.push(ObsMetadata {
741                tick_id: snap.tick_id(),
742                age_ticks,
743                coverage,
744                world_generation_id: snap.world_generation_id(),
745                parameter_version: snap.parameter_version(),
746            });
747        }
748        Ok(metadata)
749    }
750
751    /// Execute the Standard plan for `N` agents in one environment.
752    ///
753    /// Each agent gets `output_len()` elements starting at
754    /// `agent_idx * output_len()`. Fixed entries produce the same
755    /// output for all agents; agent-relative entries are resolved
756    /// per-agent using interior/boundary dispatch.
757    ///
758    /// Interior agents (~49% for 20×20 grid, radius 3) use a branchless
759    /// fast path with stride arithmetic. Boundary agents fall back to
760    /// per-cell bounds checking.
761    pub fn execute_agents(
762        &self,
763        snapshot: &dyn SnapshotAccess,
764        space: &dyn Space,
765        agent_centers: &[Coord],
766        engine_tick: Option<TickId>,
767        output: &mut [f32],
768        mask: &mut [u8],
769    ) -> Result<Vec<ObsMetadata>, ObsError> {
770        let standard = match &self.strategy {
771            PlanStrategy::Standard(data) => data,
772            PlanStrategy::Simple(_) => {
773                return Err(ObsError::ExecutionFailed {
774                    reason: "execute_agents requires a Standard plan \
775                             (spec must contain agent-relative entries)"
776                        .into(),
777                });
778            }
779        };
780
781        let n_agents = agent_centers.len();
782        let expected_out = n_agents * self.output_len;
783        let expected_mask = n_agents * self.mask_len;
784
785        if output.len() < expected_out {
786            return Err(ObsError::ExecutionFailed {
787                reason: format!(
788                    "output buffer too small: {} < {}",
789                    output.len(),
790                    expected_out
791                ),
792            });
793        }
794        if mask.len() < expected_mask {
795            return Err(ObsError::ExecutionFailed {
796                reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
797            });
798        }
799
800        // Validate agent center dimensionality.
801        let expected_ndim = space.ndim();
802        for (i, center) in agent_centers.iter().enumerate() {
803            if center.len() != expected_ndim {
804                return Err(ObsError::ExecutionFailed {
805                    reason: format!(
806                        "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
807                        center.len()
808                    ),
809                });
810            }
811        }
812
813        // Generation check.
814        if let Some(compiled_gen) = self.compiled_generation {
815            let snapshot_gen = snapshot.world_generation_id();
816            if compiled_gen != snapshot_gen {
817                return Err(ObsError::PlanInvalidated {
818                    reason: format!(
819                        "plan compiled for generation {}, snapshot is generation {}",
820                        compiled_gen.0, snapshot_gen.0
821                    ),
822                });
823            }
824        }
825
826        // Pre-read all field data (shared borrows, valid for duration).
827        // Keep vectors aligned with compiled entry vectors to avoid
828        // map lookups on the hot path.
829        let mut fixed_field_data = Vec::with_capacity(standard.fixed_entries.len());
830        for entry in &standard.fixed_entries {
831            let data =
832                snapshot
833                    .read_field(entry.field_id)
834                    .ok_or_else(|| ObsError::ExecutionFailed {
835                        reason: format!("field {:?} not in snapshot", entry.field_id),
836                    })?;
837            fixed_field_data.push(data);
838        }
839        let mut agent_field_data = Vec::with_capacity(standard.agent_entries.len());
840        for entry in &standard.agent_entries {
841            let data =
842                snapshot
843                    .read_field(entry.field_id)
844                    .ok_or_else(|| ObsError::ExecutionFailed {
845                        reason: format!("field {:?} not in snapshot", entry.field_id),
846                    })?;
847            agent_field_data.push(data);
848        }
849
850        // ── Compute fixed entries ONCE (identical for all agents) ──
851        // Allocate scratch buffers for the fixed-entry output and mask, gather
852        // into them, then memcpy per agent. This avoids redundant N*M gathers.
853        let has_fixed = !standard.fixed_entries.is_empty();
854        let mut fixed_out_scratch = if has_fixed {
855            vec![0.0f32; self.output_len]
856        } else {
857            Vec::new()
858        };
859        let mut fixed_mask_scratch = if has_fixed {
860            vec![0u8; self.mask_len]
861        } else {
862            Vec::new()
863        };
864        let mut fixed_valid = 0usize;
865        let mut fixed_elements = 0usize;
866
867        for (entry, field_data) in standard
868            .fixed_entries
869            .iter()
870            .zip(fixed_field_data.iter().copied())
871        {
872            let out_slice = &mut fixed_out_scratch
873                [entry.output_offset..entry.output_offset + entry.element_count];
874            let mask_slice =
875                &mut fixed_mask_scratch[entry.mask_offset..entry.mask_offset + entry.element_count];
876
877            mask_slice.copy_from_slice(&entry.valid_mask);
878            for op in &entry.gather_ops {
879                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
880                    ObsError::ExecutionFailed {
881                        reason: format!(
882                            "field {:?} has {} elements but gather requires index {}",
883                            entry.field_id,
884                            field_data.len(),
885                            op.field_data_idx,
886                        ),
887                    }
888                })?;
889                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
890            }
891
892            fixed_valid += entry.valid_count;
893            fixed_elements += entry.element_count;
894        }
895
896        // ── Pre-allocate pooling scratch buffers ──────────────────
897        // Find the maximum pre_pool_element_count across all pooled entries
898        // so we can allocate once and reuse across agents (sequential processing).
899        let max_pool_scratch = standard
900            .agent_entries
901            .iter()
902            .filter(|e| e.pool.is_some())
903            .map(|e| e.pre_pool_element_count)
904            .max()
905            .unwrap_or(0);
906        let max_pool_output = standard
907            .agent_entries
908            .iter()
909            .filter(|e| e.pool.is_some())
910            .map(|e| e.element_count)
911            .max()
912            .unwrap_or(0);
913        let mut pool_scratch = vec![0.0f32; max_pool_scratch];
914        let mut pool_scratch_mask = vec![0u8; max_pool_scratch];
915        let mut pooled_scratch = vec![0.0f32; max_pool_output];
916        let mut pooled_scratch_mask = vec![0u8; max_pool_output];
917
918        let mut metadata = Vec::with_capacity(n_agents);
919
920        for (agent_i, center) in agent_centers.iter().enumerate() {
921            let out_start = agent_i * self.output_len;
922            let mask_start = agent_i * self.mask_len;
923            let agent_output = &mut output[out_start..out_start + self.output_len];
924            let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
925
926            // Initialize per-agent slices, then stamp fixed entries from
927            // pre-computed scratch (no re-gather).
928            agent_output.fill(0.0);
929            agent_mask.fill(0);
930            if has_fixed {
931                for entry in &standard.fixed_entries {
932                    let out_range = entry.output_offset..entry.output_offset + entry.element_count;
933                    let mask_range = entry.mask_offset..entry.mask_offset + entry.element_count;
934                    agent_output[out_range.clone()].copy_from_slice(&fixed_out_scratch[out_range]);
935                    agent_mask[mask_range.clone()].copy_from_slice(&fixed_mask_scratch[mask_range]);
936                }
937            }
938
939            let mut total_valid = fixed_valid;
940            let mut total_elements = fixed_elements;
941
942            // ── Agent-relative entries ───────────────────────────
943            for (entry, field_data) in standard
944                .agent_entries
945                .iter()
946                .zip(agent_field_data.iter().copied())
947            {
948                // Fast path: stride arithmetic works only for non-wrapping
949                // grids where all cells in the bounding box are in-bounds.
950                // Torus (all_wrap) requires modular arithmetic → slow path.
951                let use_fast_path = standard
952                    .geometry
953                    .as_ref()
954                    .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
955                    .unwrap_or(false);
956
957                // Zero the pooling scratch region for this entry before reuse.
958                if entry.pool.is_some() {
959                    pool_scratch[..entry.pre_pool_element_count].fill(0.0);
960                    pool_scratch_mask[..entry.pre_pool_element_count].fill(0);
961                }
962
963                let valid = execute_agent_entry(
964                    entry,
965                    center,
966                    field_data,
967                    &standard.geometry,
968                    space,
969                    use_fast_path,
970                    agent_output,
971                    agent_mask,
972                    &mut pool_scratch,
973                    &mut pool_scratch_mask,
974                    &mut pooled_scratch,
975                    &mut pooled_scratch_mask,
976                );
977
978                total_valid += valid;
979                total_elements += entry.element_count;
980            }
981
982            let coverage = if total_elements == 0 {
983                0.0
984            } else {
985                total_valid as f64 / total_elements as f64
986            };
987
988            let age_ticks = match engine_tick {
989                Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
990                None => 0,
991            };
992
993            metadata.push(ObsMetadata {
994                tick_id: snapshot.tick_id(),
995                age_ticks,
996                coverage,
997                world_generation_id: snapshot.world_generation_id(),
998                parameter_version: snapshot.parameter_version(),
999            });
1000        }
1001
1002        Ok(metadata)
1003    }
1004
1005    /// Whether this plan requires `execute_agents` (Standard) or `execute` (Simple).
1006    pub fn is_standard(&self) -> bool {
1007        matches!(self.strategy, PlanStrategy::Standard(_))
1008    }
1009
1010    /// Execute pre-compiled Simple plan entries into caller-provided buffers.
1011    fn execute_simple_entries(
1012        entries: &[CompiledEntry],
1013        snapshot: &dyn SnapshotAccess,
1014        output: &mut [f32],
1015        mask: &mut [u8],
1016    ) -> Result<(), ObsError> {
1017        for entry in entries {
1018            let field_data =
1019                snapshot
1020                    .read_field(entry.field_id)
1021                    .ok_or_else(|| ObsError::ExecutionFailed {
1022                        reason: format!("field {:?} not in snapshot", entry.field_id),
1023                    })?;
1024
1025            let out_slice =
1026                &mut output[entry.output_offset..entry.output_offset + entry.element_count];
1027            let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1028
1029            // Initialize to zero/padding.
1030            out_slice.fill(0.0);
1031            mask_slice.copy_from_slice(&entry.valid_mask);
1032
1033            // Branch-free gather: pre-computed (field_data_idx, tensor_idx) pairs.
1034            for op in &entry.gather_ops {
1035                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
1036                    ObsError::ExecutionFailed {
1037                        reason: format!(
1038                            "field {:?} has {} elements but gather requires index {}",
1039                            entry.field_id,
1040                            field_data.len(),
1041                            op.field_data_idx,
1042                        ),
1043                    }
1044                })?;
1045                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
1046            }
1047        }
1048        Ok(())
1049    }
1050}
1051
1052/// Execute a single agent-relative entry for one agent.
1053///
1054/// For pooled entries, `pool_scratch` and `pool_scratch_mask` must be
1055/// provided with sufficient capacity (zeroed by the caller). For
1056/// non-pooled entries these are ignored.
1057///
1058/// Returns the number of valid cells written.
1059#[allow(clippy::too_many_arguments)]
1060fn execute_agent_entry(
1061    entry: &AgentCompiledEntry,
1062    center: &Coord,
1063    field_data: &[f32],
1064    geometry: &Option<GridGeometry>,
1065    space: &dyn Space,
1066    use_fast_path: bool,
1067    agent_output: &mut [f32],
1068    agent_mask: &mut [u8],
1069    pool_scratch: &mut [f32],
1070    pool_scratch_mask: &mut [u8],
1071    pooled_scratch: &mut [f32],
1072    pooled_scratch_mask: &mut [u8],
1073) -> usize {
1074    if entry.pool.is_some() {
1075        execute_agent_entry_pooled(
1076            entry,
1077            center,
1078            field_data,
1079            geometry,
1080            space,
1081            use_fast_path,
1082            agent_output,
1083            agent_mask,
1084            &mut pool_scratch[..entry.pre_pool_element_count],
1085            &mut pool_scratch_mask[..entry.pre_pool_element_count],
1086            &mut pooled_scratch[..entry.element_count],
1087            &mut pooled_scratch_mask[..entry.element_count],
1088        )
1089    } else {
1090        execute_agent_entry_direct(
1091            entry,
1092            center,
1093            field_data,
1094            geometry,
1095            space,
1096            use_fast_path,
1097            agent_output,
1098            agent_mask,
1099        )
1100    }
1101}
1102
1103/// Direct gather (no pooling): gather + transform → output.
1104#[allow(clippy::too_many_arguments)]
1105fn execute_agent_entry_direct(
1106    entry: &AgentCompiledEntry,
1107    center: &Coord,
1108    field_data: &[f32],
1109    geometry: &Option<GridGeometry>,
1110    space: &dyn Space,
1111    use_fast_path: bool,
1112    agent_output: &mut [f32],
1113    agent_mask: &mut [u8],
1114) -> usize {
1115    let out_slice =
1116        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1117    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1118
1119    if use_fast_path {
1120        // FAST PATH: all cells in-bounds, branchless stride arithmetic.
1121        let geo = geometry.as_ref().unwrap();
1122        let base_rank = geo.canonical_rank(center) as isize;
1123        let mut valid = 0;
1124        for op in &entry.active_ops {
1125            let field_idx = (base_rank + op.stride_offset) as usize;
1126            if let Some(&val) = field_data.get(field_idx) {
1127                out_slice[op.tensor_idx] = apply_transform(val, &entry.transform);
1128                mask_slice[op.tensor_idx] = 1;
1129                valid += 1;
1130            }
1131        }
1132        valid
1133    } else {
1134        // SLOW PATH: bounds-check each offset (or modular wrap for torus).
1135        let mut valid = 0;
1136        for op in &entry.active_ops {
1137            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1138            if let Some(idx) = field_idx {
1139                if idx < field_data.len() {
1140                    out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
1141                    mask_slice[op.tensor_idx] = 1;
1142                    valid += 1;
1143                }
1144            }
1145        }
1146        valid
1147    }
1148}
1149
1150/// Pooled gather: gather → scratch → pool → transform → output.
1151///
1152/// `scratch` and `scratch_mask` are caller-provided buffers that must be
1153/// at least `entry.pre_pool_element_count` long. They are zeroed by the
1154/// caller before each invocation. This avoids per-agent heap allocation
1155/// when processing many agents sequentially (bug #83).
1156#[allow(clippy::too_many_arguments)]
1157fn execute_agent_entry_pooled(
1158    entry: &AgentCompiledEntry,
1159    center: &Coord,
1160    field_data: &[f32],
1161    geometry: &Option<GridGeometry>,
1162    space: &dyn Space,
1163    use_fast_path: bool,
1164    agent_output: &mut [f32],
1165    agent_mask: &mut [u8],
1166    scratch: &mut [f32],
1167    scratch_mask: &mut [u8],
1168    pooled: &mut [f32],
1169    pooled_mask: &mut [u8],
1170) -> usize {
1171    if use_fast_path {
1172        let geo = geometry.as_ref().unwrap();
1173        let base_rank = geo.canonical_rank(center) as isize;
1174        for op in &entry.active_ops {
1175            let field_idx = (base_rank + op.stride_offset) as usize;
1176            if let Some(&val) = field_data.get(field_idx) {
1177                scratch[op.tensor_idx] = val;
1178                scratch_mask[op.tensor_idx] = 1;
1179            }
1180        }
1181    } else {
1182        for op in &entry.active_ops {
1183            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1184            if let Some(idx) = field_idx {
1185                if idx < field_data.len() {
1186                    scratch[op.tensor_idx] = field_data[idx];
1187                    scratch_mask[op.tensor_idx] = 1;
1188                }
1189            }
1190        }
1191    }
1192
1193    let pool_config = entry.pool.as_ref().unwrap();
1194    let (out_h, out_w) = pool_2d_into(
1195        scratch,
1196        scratch_mask,
1197        &entry.pre_pool_shape,
1198        pool_config,
1199        pooled,
1200        pooled_mask,
1201    );
1202
1203    let out_slice =
1204        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1205    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1206
1207    let n = (out_h * out_w).min(entry.element_count);
1208    for i in 0..n {
1209        out_slice[i] = apply_transform(pooled[i], &entry.transform);
1210    }
1211    mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1212
1213    pooled_mask[..n].iter().filter(|&&v| v == 1).count()
1214}
1215
1216/// Generate template operations for a rectangular bounding box.
1217///
1218/// `half_extent[d]` is the half-size per dimension. The bounding box is
1219/// `(2*he[0]+1) × (2*he[1]+1) × ...` in row-major order.
1220///
1221/// If `strides` is provided (from `GridGeometry`), each op gets a precomputed
1222/// `stride_offset` for the interior fast path.
1223///
1224/// If `disk_radius` is `Some(r)`, cells with graph distance > `r` are marked
1225/// `in_disk = false`. The `geometry` is required to compute graph distance.
1226/// When `geometry` is `None`, all cells are treated as in-disk (conservative).
1227fn generate_template_ops(
1228    half_extent: &[u32],
1229    geometry: &Option<GridGeometry>,
1230    disk_radius: Option<u32>,
1231) -> Vec<TemplateOp> {
1232    let ndim = half_extent.len();
1233    let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1234    let total: usize = shape.iter().product();
1235
1236    let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1237
1238    let mut ops = Vec::with_capacity(total);
1239
1240    for tensor_idx in 0..total {
1241        // Decompose tensor_idx into n-d relative coords (row-major).
1242        let mut relative = Coord::new();
1243        let mut remaining = tensor_idx;
1244        // Build in reverse order, then reverse.
1245        for d in (0..ndim).rev() {
1246            let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1247            relative.push(coord);
1248            remaining /= shape[d];
1249        }
1250        relative.reverse();
1251
1252        let stride_offset = strides
1253            .map(|s| {
1254                relative
1255                    .iter()
1256                    .zip(s.iter())
1257                    .map(|(&r, &s)| r as isize * s as isize)
1258                    .sum::<isize>()
1259            })
1260            .unwrap_or(0);
1261
1262        let in_disk = match disk_radius {
1263            Some(r) => match geometry {
1264                Some(geo) => geo.graph_distance(&relative) <= r,
1265                None => true, // no geometry → conservative (include all)
1266            },
1267            None => true, // AgentRect → all cells valid
1268        };
1269
1270        ops.push(TemplateOp {
1271            relative,
1272            tensor_idx,
1273            stride_offset,
1274            in_disk,
1275        });
1276    }
1277
1278    ops
1279}
1280
1281/// Resolve the field data index for an absolute coordinate.
1282///
1283/// Handles three cases:
1284/// 1. Torus (all_wrap): modular wrap, always in-bounds.
1285/// 2. Grid with geometry: bounds-check then stride arithmetic.
1286/// 3. No geometry: fall back to `space.canonical_rank()`.
1287fn resolve_field_index(
1288    center: &Coord,
1289    relative: &Coord,
1290    geometry: &Option<GridGeometry>,
1291    space: &dyn Space,
1292) -> Option<usize> {
1293    if let Some(geo) = geometry {
1294        if geo.all_wrap {
1295            // Torus: wrap coordinates with modular arithmetic.
1296            let wrapped: Coord = center
1297                .iter()
1298                .zip(relative.iter())
1299                .zip(geo.coord_dims.iter())
1300                .map(|((&c, &r), &d)| {
1301                    let d = d as i32;
1302                    ((c + r) % d + d) % d
1303                })
1304                .collect();
1305            Some(geo.canonical_rank(&wrapped))
1306        } else {
1307            let abs_coord: Coord = center
1308                .iter()
1309                .zip(relative.iter())
1310                .map(|(&c, &r)| c + r)
1311                .collect();
1312            let abs_slice: &[i32] = &abs_coord;
1313            if geo.in_bounds(abs_slice) {
1314                Some(geo.canonical_rank(abs_slice))
1315            } else {
1316                None
1317            }
1318        }
1319    } else {
1320        let abs_coord: Coord = center
1321            .iter()
1322            .zip(relative.iter())
1323            .map(|(&c, &r)| c + r)
1324            .collect();
1325        space.canonical_rank(&abs_coord)
1326    }
1327}
1328
1329/// Apply a transform to a raw field value.
1330fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1331    match transform {
1332        ObsTransform::Identity => raw,
1333        ObsTransform::Normalize { min, max } => {
1334            let range = max - min;
1335            if range == 0.0 {
1336                0.0
1337            } else {
1338                let normalized = (raw as f64 - min) / range;
1339                normalized.clamp(0.0, 1.0) as f32
1340            }
1341        }
1342    }
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347    use super::*;
1348    use crate::spec::{
1349        ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1350    };
1351    use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1352    use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1353    use murk_test_utils::MockSnapshot;
1354
1355    fn square4_space() -> Square4 {
1356        Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1357    }
1358
1359    fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1360        let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1361        snap.set_field(field, data);
1362        snap
1363    }
1364
1365    // ── Compilation tests ────────────────────────────────────
1366
1367    #[test]
1368    fn compile_empty_spec_errors() {
1369        let space = square4_space();
1370        let spec = ObsSpec { entries: vec![] };
1371        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1372        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1373    }
1374
1375    #[test]
1376    fn compile_all_region_square4() {
1377        let space = square4_space();
1378        let spec = ObsSpec {
1379            entries: vec![ObsEntry {
1380                field_id: FieldId(0),
1381                region: ObsRegion::Fixed(RegionSpec::All),
1382                pool: None,
1383                transform: ObsTransform::Identity,
1384                dtype: ObsDtype::F32,
1385            }],
1386        };
1387        let result = ObsPlan::compile(&spec, &space).unwrap();
1388        assert_eq!(result.output_len, 9); // 3x3
1389        assert_eq!(result.mask_len, 9);
1390        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1391    }
1392
1393    #[test]
1394    fn compile_rect_region() {
1395        let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1396        let spec = ObsSpec {
1397            entries: vec![ObsEntry {
1398                field_id: FieldId(0),
1399                region: ObsRegion::Fixed(RegionSpec::Rect {
1400                    min: smallvec::smallvec![1, 1],
1401                    max: smallvec::smallvec![2, 3],
1402                }),
1403                pool: None,
1404                transform: ObsTransform::Identity,
1405                dtype: ObsDtype::F32,
1406            }],
1407        };
1408        let result = ObsPlan::compile(&spec, &space).unwrap();
1409        // 2 rows x 3 cols = 6 cells
1410        assert_eq!(result.output_len, 6);
1411        assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1412    }
1413
1414    #[test]
1415    fn compile_two_entries_offsets() {
1416        let space = square4_space();
1417        let spec = ObsSpec {
1418            entries: vec![
1419                ObsEntry {
1420                    field_id: FieldId(0),
1421                    region: ObsRegion::Fixed(RegionSpec::All),
1422                    pool: None,
1423                    transform: ObsTransform::Identity,
1424                    dtype: ObsDtype::F32,
1425                },
1426                ObsEntry {
1427                    field_id: FieldId(1),
1428                    region: ObsRegion::Fixed(RegionSpec::All),
1429                    pool: None,
1430                    transform: ObsTransform::Identity,
1431                    dtype: ObsDtype::F32,
1432                },
1433            ],
1434        };
1435        let result = ObsPlan::compile(&spec, &space).unwrap();
1436        assert_eq!(result.output_len, 18); // 9 + 9
1437        assert_eq!(result.mask_len, 18);
1438    }
1439
1440    #[test]
1441    fn compile_invalid_region_errors() {
1442        let space = square4_space();
1443        let spec = ObsSpec {
1444            entries: vec![ObsEntry {
1445                field_id: FieldId(0),
1446                region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1447                pool: None,
1448                transform: ObsTransform::Identity,
1449                dtype: ObsDtype::F32,
1450            }],
1451        };
1452        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1453        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1454    }
1455
1456    // ── Execution tests ──────────────────────────────────────
1457
1458    #[test]
1459    fn execute_identity_all_region() {
1460        let space = square4_space();
1461        // Field data in canonical (row-major) order for 3x3:
1462        // (0,0)=1, (0,1)=2, (0,2)=3, (1,0)=4, ..., (2,2)=9
1463        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1464        let snap = snapshot_with_field(FieldId(0), data);
1465
1466        let spec = ObsSpec {
1467            entries: vec![ObsEntry {
1468                field_id: FieldId(0),
1469                region: ObsRegion::Fixed(RegionSpec::All),
1470                pool: None,
1471                transform: ObsTransform::Identity,
1472                dtype: ObsDtype::F32,
1473            }],
1474        };
1475        let result = ObsPlan::compile(&spec, &space).unwrap();
1476
1477        let mut output = vec![0.0f32; result.output_len];
1478        let mut mask = vec![0u8; result.mask_len];
1479        let meta = result
1480            .plan
1481            .execute(&snap, None, &mut output, &mut mask)
1482            .unwrap();
1483
1484        // Output should match field data in canonical order.
1485        let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1486        assert_eq!(output, expected);
1487        assert_eq!(mask, vec![1u8; 9]);
1488        assert_eq!(meta.tick_id, TickId(5));
1489        assert_eq!(meta.coverage, 1.0);
1490        assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1491        assert_eq!(meta.parameter_version, ParameterVersion(0));
1492        assert_eq!(meta.age_ticks, 0);
1493    }
1494
1495    #[test]
1496    fn execute_normalize_transform() {
1497        let space = square4_space();
1498        // Values 0..8 mapped to [0,1] with min=0, max=8.
1499        let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1500        let snap = snapshot_with_field(FieldId(0), data);
1501
1502        let spec = ObsSpec {
1503            entries: vec![ObsEntry {
1504                field_id: FieldId(0),
1505                region: ObsRegion::Fixed(RegionSpec::All),
1506                pool: None,
1507                transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1508                dtype: ObsDtype::F32,
1509            }],
1510        };
1511        let result = ObsPlan::compile(&spec, &space).unwrap();
1512
1513        let mut output = vec![0.0f32; result.output_len];
1514        let mut mask = vec![0u8; result.mask_len];
1515        result
1516            .plan
1517            .execute(&snap, None, &mut output, &mut mask)
1518            .unwrap();
1519
1520        // Each value x should be x/8.
1521        for (i, &v) in output.iter().enumerate() {
1522            let expected = i as f32 / 8.0;
1523            assert!(
1524                (v - expected).abs() < 1e-6,
1525                "output[{i}] = {v}, expected {expected}"
1526            );
1527        }
1528    }
1529
1530    #[test]
1531    fn execute_normalize_clamps_out_of_range() {
1532        let space = square4_space();
1533        // Values -5, 0, 5, 10, 15 etc.
1534        let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1535        let snap = snapshot_with_field(FieldId(0), data);
1536
1537        let spec = ObsSpec {
1538            entries: vec![ObsEntry {
1539                field_id: FieldId(0),
1540                region: ObsRegion::Fixed(RegionSpec::All),
1541                pool: None,
1542                transform: ObsTransform::Normalize {
1543                    min: 0.0,
1544                    max: 10.0,
1545                },
1546                dtype: ObsDtype::F32,
1547            }],
1548        };
1549        let result = ObsPlan::compile(&spec, &space).unwrap();
1550
1551        let mut output = vec![0.0f32; result.output_len];
1552        let mut mask = vec![0u8; result.mask_len];
1553        result
1554            .plan
1555            .execute(&snap, None, &mut output, &mut mask)
1556            .unwrap();
1557
1558        for &v in &output {
1559            assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1560        }
1561    }
1562
1563    #[test]
1564    fn execute_normalize_zero_range() {
1565        let space = square4_space();
1566        let data = vec![5.0f32; 9];
1567        let snap = snapshot_with_field(FieldId(0), data);
1568
1569        let spec = ObsSpec {
1570            entries: vec![ObsEntry {
1571                field_id: FieldId(0),
1572                region: ObsRegion::Fixed(RegionSpec::All),
1573                pool: None,
1574                transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1575                dtype: ObsDtype::F32,
1576            }],
1577        };
1578        let result = ObsPlan::compile(&spec, &space).unwrap();
1579
1580        let mut output = vec![-1.0f32; result.output_len];
1581        let mut mask = vec![0u8; result.mask_len];
1582        result
1583            .plan
1584            .execute(&snap, None, &mut output, &mut mask)
1585            .unwrap();
1586
1587        // Zero range → all outputs 0.0.
1588        assert!(output.iter().all(|&v| v == 0.0));
1589    }
1590
1591    #[test]
1592    fn execute_rect_subregion_correct_values() {
1593        let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1594        // 4x4 field: value = row * 4 + col + 1.
1595        let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1596        let snap = snapshot_with_field(FieldId(0), data);
1597
1598        let spec = ObsSpec {
1599            entries: vec![ObsEntry {
1600                field_id: FieldId(0),
1601                region: ObsRegion::Fixed(RegionSpec::Rect {
1602                    min: smallvec::smallvec![1, 1],
1603                    max: smallvec::smallvec![2, 2],
1604                }),
1605                pool: None,
1606                transform: ObsTransform::Identity,
1607                dtype: ObsDtype::F32,
1608            }],
1609        };
1610        let result = ObsPlan::compile(&spec, &space).unwrap();
1611        assert_eq!(result.output_len, 4); // 2x2
1612
1613        let mut output = vec![0.0f32; result.output_len];
1614        let mut mask = vec![0u8; result.mask_len];
1615        result
1616            .plan
1617            .execute(&snap, None, &mut output, &mut mask)
1618            .unwrap();
1619
1620        // Rect covers (1,1)=6, (1,2)=7, (2,1)=10, (2,2)=11
1621        assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1622        assert_eq!(mask, vec![1, 1, 1, 1]);
1623    }
1624
1625    #[test]
1626    fn execute_two_fields() {
1627        let space = square4_space();
1628        let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1629        let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1630        let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1631        snap.set_field(FieldId(0), data_a);
1632        snap.set_field(FieldId(1), data_b);
1633
1634        let spec = ObsSpec {
1635            entries: vec![
1636                ObsEntry {
1637                    field_id: FieldId(0),
1638                    region: ObsRegion::Fixed(RegionSpec::All),
1639                    pool: None,
1640                    transform: ObsTransform::Identity,
1641                    dtype: ObsDtype::F32,
1642                },
1643                ObsEntry {
1644                    field_id: FieldId(1),
1645                    region: ObsRegion::Fixed(RegionSpec::All),
1646                    pool: None,
1647                    transform: ObsTransform::Identity,
1648                    dtype: ObsDtype::F32,
1649                },
1650            ],
1651        };
1652        let result = ObsPlan::compile(&spec, &space).unwrap();
1653        assert_eq!(result.output_len, 18);
1654
1655        let mut output = vec![0.0f32; result.output_len];
1656        let mut mask = vec![0u8; result.mask_len];
1657        result
1658            .plan
1659            .execute(&snap, None, &mut output, &mut mask)
1660            .unwrap();
1661
1662        // First 9: field 0, next 9: field 1.
1663        let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1664        let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1665        assert_eq!(&output[..9], &expected_a);
1666        assert_eq!(&output[9..], &expected_b);
1667    }
1668
1669    #[test]
1670    fn execute_missing_field_errors() {
1671        let space = square4_space();
1672        let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1673
1674        let spec = ObsSpec {
1675            entries: vec![ObsEntry {
1676                field_id: FieldId(0),
1677                region: ObsRegion::Fixed(RegionSpec::All),
1678                pool: None,
1679                transform: ObsTransform::Identity,
1680                dtype: ObsDtype::F32,
1681            }],
1682        };
1683        let result = ObsPlan::compile(&spec, &space).unwrap();
1684
1685        let mut output = vec![0.0f32; result.output_len];
1686        let mut mask = vec![0u8; result.mask_len];
1687        let err = result
1688            .plan
1689            .execute(&snap, None, &mut output, &mut mask)
1690            .unwrap_err();
1691        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1692    }
1693
1694    #[test]
1695    fn execute_buffer_too_small_errors() {
1696        let space = square4_space();
1697        let data: Vec<f32> = vec![0.0; 9];
1698        let snap = snapshot_with_field(FieldId(0), data);
1699
1700        let spec = ObsSpec {
1701            entries: vec![ObsEntry {
1702                field_id: FieldId(0),
1703                region: ObsRegion::Fixed(RegionSpec::All),
1704                pool: None,
1705                transform: ObsTransform::Identity,
1706                dtype: ObsDtype::F32,
1707            }],
1708        };
1709        let result = ObsPlan::compile(&spec, &space).unwrap();
1710
1711        let mut output = vec![0.0f32; 4]; // too small
1712        let mut mask = vec![0u8; result.mask_len];
1713        let err = result
1714            .plan
1715            .execute(&snap, None, &mut output, &mut mask)
1716            .unwrap_err();
1717        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1718    }
1719
1720    // ── Validity / coverage tests ────────────────────────────
1721
1722    #[test]
1723    fn valid_ratio_one_for_square_all() {
1724        let space = square4_space();
1725        let data: Vec<f32> = vec![1.0; 9];
1726        let snap = snapshot_with_field(FieldId(0), data);
1727
1728        let spec = ObsSpec {
1729            entries: vec![ObsEntry {
1730                field_id: FieldId(0),
1731                region: ObsRegion::Fixed(RegionSpec::All),
1732                pool: None,
1733                transform: ObsTransform::Identity,
1734                dtype: ObsDtype::F32,
1735            }],
1736        };
1737        let result = ObsPlan::compile(&spec, &space).unwrap();
1738
1739        let mut output = vec![0.0f32; result.output_len];
1740        let mut mask = vec![0u8; result.mask_len];
1741        let meta = result
1742            .plan
1743            .execute(&snap, None, &mut output, &mut mask)
1744            .unwrap();
1745
1746        assert_eq!(meta.coverage, 1.0);
1747    }
1748
1749    // ── Generation binding tests ─────────────────────────────
1750
1751    #[test]
1752    fn plan_invalidated_on_generation_mismatch() {
1753        let space = square4_space();
1754        let data: Vec<f32> = vec![1.0; 9];
1755        let snap = snapshot_with_field(FieldId(0), data);
1756
1757        let spec = ObsSpec {
1758            entries: vec![ObsEntry {
1759                field_id: FieldId(0),
1760                region: ObsRegion::Fixed(RegionSpec::All),
1761                pool: None,
1762                transform: ObsTransform::Identity,
1763                dtype: ObsDtype::F32,
1764            }],
1765        };
1766        // Compile bound to generation 99, but snapshot is generation 1.
1767        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1768
1769        let mut output = vec![0.0f32; result.output_len];
1770        let mut mask = vec![0u8; result.mask_len];
1771        let err = result
1772            .plan
1773            .execute(&snap, None, &mut output, &mut mask)
1774            .unwrap_err();
1775        assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1776    }
1777
1778    #[test]
1779    fn generation_match_succeeds() {
1780        let space = square4_space();
1781        let data: Vec<f32> = vec![1.0; 9];
1782        let snap = snapshot_with_field(FieldId(0), data);
1783
1784        let spec = ObsSpec {
1785            entries: vec![ObsEntry {
1786                field_id: FieldId(0),
1787                region: ObsRegion::Fixed(RegionSpec::All),
1788                pool: None,
1789                transform: ObsTransform::Identity,
1790                dtype: ObsDtype::F32,
1791            }],
1792        };
1793        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1794
1795        let mut output = vec![0.0f32; result.output_len];
1796        let mut mask = vec![0u8; result.mask_len];
1797        result
1798            .plan
1799            .execute(&snap, None, &mut output, &mut mask)
1800            .unwrap();
1801    }
1802
1803    #[test]
1804    fn unbound_plan_ignores_generation() {
1805        let space = square4_space();
1806        let data: Vec<f32> = vec![1.0; 9];
1807        let snap = snapshot_with_field(FieldId(0), data);
1808
1809        let spec = ObsSpec {
1810            entries: vec![ObsEntry {
1811                field_id: FieldId(0),
1812                region: ObsRegion::Fixed(RegionSpec::All),
1813                pool: None,
1814                transform: ObsTransform::Identity,
1815                dtype: ObsDtype::F32,
1816            }],
1817        };
1818        // Unbound plan — no generation check.
1819        let result = ObsPlan::compile(&spec, &space).unwrap();
1820
1821        let mut output = vec![0.0f32; result.output_len];
1822        let mut mask = vec![0u8; result.mask_len];
1823        result
1824            .plan
1825            .execute(&snap, None, &mut output, &mut mask)
1826            .unwrap();
1827    }
1828
1829    // ── Metadata tests ───────────────────────────────────────
1830
1831    #[test]
1832    fn metadata_fields_populated() {
1833        let space = square4_space();
1834        let data: Vec<f32> = vec![1.0; 9];
1835        let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1836        snap.set_field(FieldId(0), data);
1837
1838        let spec = ObsSpec {
1839            entries: vec![ObsEntry {
1840                field_id: FieldId(0),
1841                region: ObsRegion::Fixed(RegionSpec::All),
1842                pool: None,
1843                transform: ObsTransform::Identity,
1844                dtype: ObsDtype::F32,
1845            }],
1846        };
1847        let result = ObsPlan::compile(&spec, &space).unwrap();
1848
1849        let mut output = vec![0.0f32; result.output_len];
1850        let mut mask = vec![0u8; result.mask_len];
1851        let meta = result
1852            .plan
1853            .execute(&snap, None, &mut output, &mut mask)
1854            .unwrap();
1855
1856        assert_eq!(meta.tick_id, TickId(42));
1857        assert_eq!(meta.age_ticks, 0);
1858        assert_eq!(meta.coverage, 1.0);
1859        assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1860        assert_eq!(meta.parameter_version, ParameterVersion(3));
1861    }
1862
1863    // ── Batch execution tests ────────────────────────────────
1864
1865    #[test]
1866    fn execute_batch_n1_matches_execute() {
1867        let space = square4_space();
1868        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1869        let snap = snapshot_with_field(FieldId(0), data.clone());
1870
1871        let spec = ObsSpec {
1872            entries: vec![ObsEntry {
1873                field_id: FieldId(0),
1874                region: ObsRegion::Fixed(RegionSpec::All),
1875                pool: None,
1876                transform: ObsTransform::Identity,
1877                dtype: ObsDtype::F32,
1878            }],
1879        };
1880        let result = ObsPlan::compile(&spec, &space).unwrap();
1881
1882        // Single execute.
1883        let mut out_single = vec![0.0f32; result.output_len];
1884        let mut mask_single = vec![0u8; result.mask_len];
1885        let meta_single = result
1886            .plan
1887            .execute(&snap, None, &mut out_single, &mut mask_single)
1888            .unwrap();
1889
1890        // Batch N=1.
1891        let mut out_batch = vec![0.0f32; result.output_len];
1892        let mut mask_batch = vec![0u8; result.mask_len];
1893        let snap_ref: &dyn SnapshotAccess = &snap;
1894        let meta_batch = result
1895            .plan
1896            .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1897            .unwrap();
1898
1899        assert_eq!(out_single, out_batch);
1900        assert_eq!(mask_single, mask_batch);
1901        assert_eq!(meta_single, meta_batch[0]);
1902    }
1903
1904    #[test]
1905    fn execute_batch_multiple_snapshots() {
1906        let space = square4_space();
1907        let spec = ObsSpec {
1908            entries: vec![ObsEntry {
1909                field_id: FieldId(0),
1910                region: ObsRegion::Fixed(RegionSpec::All),
1911                pool: None,
1912                transform: ObsTransform::Identity,
1913                dtype: ObsDtype::F32,
1914            }],
1915        };
1916        let result = ObsPlan::compile(&spec, &space).unwrap();
1917
1918        let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1919        let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1920
1921        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1922        let mut output = vec![0.0f32; result.output_len * 2];
1923        let mut mask = vec![0u8; result.mask_len * 2];
1924        let metas = result
1925            .plan
1926            .execute_batch(&snaps, None, &mut output, &mut mask)
1927            .unwrap();
1928
1929        assert_eq!(metas.len(), 2);
1930        assert!(output[..9].iter().all(|&v| v == 1.0));
1931        assert!(output[9..].iter().all(|&v| v == 2.0));
1932    }
1933
1934    #[test]
1935    fn execute_batch_buffer_too_small() {
1936        let space = square4_space();
1937        let spec = ObsSpec {
1938            entries: vec![ObsEntry {
1939                field_id: FieldId(0),
1940                region: ObsRegion::Fixed(RegionSpec::All),
1941                pool: None,
1942                transform: ObsTransform::Identity,
1943                dtype: ObsDtype::F32,
1944            }],
1945        };
1946        let result = ObsPlan::compile(&spec, &space).unwrap();
1947
1948        let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1949        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1950        let mut output = vec![0.0f32; 9]; // need 18
1951        let mut mask = vec![0u8; 18];
1952        let err = result
1953            .plan
1954            .execute_batch(&snaps, None, &mut output, &mut mask)
1955            .unwrap_err();
1956        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1957    }
1958
1959    // ── Field length mismatch tests ──────────────────────────
1960
1961    #[test]
1962    fn short_field_buffer_returns_error_not_panic() {
1963        let space = square4_space(); // 3x3 = 9 cells
1964        let spec = ObsSpec {
1965            entries: vec![ObsEntry {
1966                field_id: FieldId(0),
1967                region: ObsRegion::Fixed(RegionSpec::All),
1968                pool: None,
1969                transform: ObsTransform::Identity,
1970                dtype: ObsDtype::F32,
1971            }],
1972        };
1973        let result = ObsPlan::compile(&spec, &space).unwrap();
1974
1975        // Snapshot field has only 4 elements, but plan expects 9.
1976        let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
1977        let mut output = vec![0.0f32; result.output_len];
1978        let mut mask = vec![0u8; result.mask_len];
1979        let err = result
1980            .plan
1981            .execute(&snap, None, &mut output, &mut mask)
1982            .unwrap_err();
1983        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1984    }
1985
1986    // ── Standard plan (agent-centered) tests ─────────────────
1987
1988    #[test]
1989    fn standard_plan_detected_from_agent_region() {
1990        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1991        let spec = ObsSpec {
1992            entries: vec![ObsEntry {
1993                field_id: FieldId(0),
1994                region: ObsRegion::AgentRect {
1995                    half_extent: smallvec::smallvec![2, 2],
1996                },
1997                pool: None,
1998                transform: ObsTransform::Identity,
1999                dtype: ObsDtype::F32,
2000            }],
2001        };
2002        let result = ObsPlan::compile(&spec, &space).unwrap();
2003        assert!(result.plan.is_standard());
2004        // 5x5 = 25 elements
2005        assert_eq!(result.output_len, 25);
2006        assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
2007    }
2008
2009    #[test]
2010    fn execute_on_standard_plan_errors() {
2011        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2012        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2013        let snap = snapshot_with_field(FieldId(0), data);
2014
2015        let spec = ObsSpec {
2016            entries: vec![ObsEntry {
2017                field_id: FieldId(0),
2018                region: ObsRegion::AgentDisk { radius: 2 },
2019                pool: None,
2020                transform: ObsTransform::Identity,
2021                dtype: ObsDtype::F32,
2022            }],
2023        };
2024        let result = ObsPlan::compile(&spec, &space).unwrap();
2025
2026        let mut output = vec![0.0f32; result.output_len];
2027        let mut mask = vec![0u8; result.mask_len];
2028        let err = result
2029            .plan
2030            .execute(&snap, None, &mut output, &mut mask)
2031            .unwrap_err();
2032        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2033    }
2034
2035    #[test]
2036    fn interior_boundary_equivalence() {
2037        // An INTERIOR agent using Standard plan should produce identical
2038        // output to a Simple plan with an explicit Rect at the same position.
2039        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2040        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2041        let snap = snapshot_with_field(FieldId(0), data);
2042
2043        let radius = 3u32;
2044        let center: Coord = smallvec::smallvec![10, 10]; // interior
2045
2046        // Standard plan: AgentRect centered on agent.
2047        let standard_spec = ObsSpec {
2048            entries: vec![ObsEntry {
2049                field_id: FieldId(0),
2050                region: ObsRegion::AgentRect {
2051                    half_extent: smallvec::smallvec![radius, radius],
2052                },
2053                pool: None,
2054                transform: ObsTransform::Identity,
2055                dtype: ObsDtype::F32,
2056            }],
2057        };
2058        let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
2059        let mut std_output = vec![0.0f32; std_result.output_len];
2060        let mut std_mask = vec![0u8; std_result.mask_len];
2061        std_result
2062            .plan
2063            .execute_agents(
2064                &snap,
2065                &space,
2066                std::slice::from_ref(&center),
2067                None,
2068                &mut std_output,
2069                &mut std_mask,
2070            )
2071            .unwrap();
2072
2073        // Simple plan: explicit Rect covering the same area.
2074        let r = radius as i32;
2075        let simple_spec = ObsSpec {
2076            entries: vec![ObsEntry {
2077                field_id: FieldId(0),
2078                region: ObsRegion::Fixed(RegionSpec::Rect {
2079                    min: smallvec::smallvec![10 - r, 10 - r],
2080                    max: smallvec::smallvec![10 + r, 10 + r],
2081                }),
2082                pool: None,
2083                transform: ObsTransform::Identity,
2084                dtype: ObsDtype::F32,
2085            }],
2086        };
2087        let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
2088        let mut simple_output = vec![0.0f32; simple_result.output_len];
2089        let mut simple_mask = vec![0u8; simple_result.mask_len];
2090        simple_result
2091            .plan
2092            .execute(&snap, None, &mut simple_output, &mut simple_mask)
2093            .unwrap();
2094
2095        // Same shape, same values.
2096        assert_eq!(std_result.output_len, simple_result.output_len);
2097        assert_eq!(std_output, simple_output);
2098        assert_eq!(std_mask, simple_mask);
2099    }
2100
2101    #[test]
2102    fn boundary_agent_gets_padding() {
2103        // Agent at corner (0,0) with radius 2: many cells out-of-bounds.
2104        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2105        let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
2106        let snap = snapshot_with_field(FieldId(0), data);
2107
2108        let spec = ObsSpec {
2109            entries: vec![ObsEntry {
2110                field_id: FieldId(0),
2111                region: ObsRegion::AgentRect {
2112                    half_extent: smallvec::smallvec![2, 2],
2113                },
2114                pool: None,
2115                transform: ObsTransform::Identity,
2116                dtype: ObsDtype::F32,
2117            }],
2118        };
2119        let result = ObsPlan::compile(&spec, &space).unwrap();
2120        let center: Coord = smallvec::smallvec![0, 0];
2121        let mut output = vec![0.0f32; result.output_len];
2122        let mut mask = vec![0u8; result.mask_len];
2123        let metas = result
2124            .plan
2125            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2126            .unwrap();
2127
2128        // 5x5 = 25 cells total. Agent at (0,0) with radius 2:
2129        // Only cells with row in [0,2] and col in [0,2] are valid (3x3 = 9).
2130        let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
2131        assert_eq!(valid_count, 9);
2132
2133        // Coverage should be 9/25
2134        assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
2135
2136        // Check that the top-left corner of the bounding box (relative [-2,-2])
2137        // is padding (mask=0, value=0).
2138        assert_eq!(mask[0], 0); // relative (-2,-2) → absolute (-2,-2) → out of bounds
2139        assert_eq!(output[0], 0.0);
2140
2141        // The cell at relative (0,0) is at tensor_idx = 2*5+2 = 12
2142        // Absolute (0,0) → field value = 1.0
2143        assert_eq!(mask[12], 1);
2144        assert_eq!(output[12], 1.0);
2145    }
2146
2147    #[test]
2148    fn hex_foveation_interior() {
2149        // Test agent-centered observation on Hex2D grid.
2150        let space = Hex2D::new(20, 20).unwrap(); // 20 rows, 20 cols
2151        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2152        let snap = snapshot_with_field(FieldId(0), data);
2153
2154        let spec = ObsSpec {
2155            entries: vec![ObsEntry {
2156                field_id: FieldId(0),
2157                region: ObsRegion::AgentDisk { radius: 2 },
2158                pool: None,
2159                transform: ObsTransform::Identity,
2160                dtype: ObsDtype::F32,
2161            }],
2162        };
2163        let result = ObsPlan::compile(&spec, &space).unwrap();
2164        assert_eq!(result.output_len, 25); // 5x5 bounding box (tensor shape)
2165
2166        // Interior agent: q=10, r=10
2167        let center: Coord = smallvec::smallvec![10, 10];
2168        let mut output = vec![0.0f32; result.output_len];
2169        let mut mask = vec![0u8; result.mask_len];
2170        result
2171            .plan
2172            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2173            .unwrap();
2174
2175        // Hex disk of radius 2: 19 of 25 cells are within hex distance.
2176        // The 6 corners of the 5x5 bounding box exceed hex distance 2.
2177        // Hex distance = max(|dq|, |dr|, |dq+dr|) for axial coordinates.
2178        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2179        assert_eq!(valid_count, 19);
2180
2181        // Corners that should be masked out (distance > 2):
2182        // tensor_idx 0: dq=-2,dr=-2 → max(2,2,4)=4
2183        // tensor_idx 1: dq=-2,dr=-1 → max(2,1,3)=3
2184        // tensor_idx 5: dq=-1,dr=-2 → max(1,2,3)=3
2185        // tensor_idx 19: dq=+1,dr=+2 → max(1,2,3)=3
2186        // tensor_idx 23: dq=+2,dr=+1 → max(2,1,3)=3
2187        // tensor_idx 24: dq=+2,dr=+2 → max(2,2,4)=4
2188        for &idx in &[0, 1, 5, 19, 23, 24] {
2189            assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2190            assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2191        }
2192
2193        // Center cell is at tensor_idx = 2*5+2 = 12 (relative [0,0]).
2194        // Hex2D canonical_rank([q,r]) = r*cols + q = 10*20 + 10 = 210
2195        assert_eq!(output[12], 210.0);
2196
2197        // Cell at relative [1, 0] (dq=+1, dr=0) → absolute [11, 10]
2198        // rank = 10*20 + 11 = 211
2199        // In row-major bounding box: dim0_idx=3, dim1_idx=2 → tensor_idx = 3*5+2 = 17
2200        assert_eq!(output[17], 211.0);
2201    }
2202
2203    #[test]
2204    fn wrap_space_all_interior() {
2205        // Wrapped (torus) space: all agents are interior.
2206        let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2207        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2208        let snap = snapshot_with_field(FieldId(0), data);
2209
2210        let spec = ObsSpec {
2211            entries: vec![ObsEntry {
2212                field_id: FieldId(0),
2213                region: ObsRegion::AgentRect {
2214                    half_extent: smallvec::smallvec![2, 2],
2215                },
2216                pool: None,
2217                transform: ObsTransform::Identity,
2218                dtype: ObsDtype::F32,
2219            }],
2220        };
2221        let result = ObsPlan::compile(&spec, &space).unwrap();
2222
2223        // Agent at corner (0,0) — still interior on torus.
2224        let center: Coord = smallvec::smallvec![0, 0];
2225        let mut output = vec![0.0f32; result.output_len];
2226        let mut mask = vec![0u8; result.mask_len];
2227        result
2228            .plan
2229            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2230            .unwrap();
2231
2232        // All 25 cells valid (torus wraps).
2233        assert!(mask.iter().all(|&v| v == 1));
2234        assert_eq!(output[12], 0.0); // center (0,0) → rank 0
2235    }
2236
2237    #[test]
2238    fn execute_agents_multiple() {
2239        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2240        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2241        let snap = snapshot_with_field(FieldId(0), data);
2242
2243        let spec = ObsSpec {
2244            entries: vec![ObsEntry {
2245                field_id: FieldId(0),
2246                region: ObsRegion::AgentRect {
2247                    half_extent: smallvec::smallvec![1, 1],
2248                },
2249                pool: None,
2250                transform: ObsTransform::Identity,
2251                dtype: ObsDtype::F32,
2252            }],
2253        };
2254        let result = ObsPlan::compile(&spec, &space).unwrap();
2255        assert_eq!(result.output_len, 9); // 3x3
2256
2257        // Two agents: one interior, one at edge.
2258        let centers = vec![
2259            smallvec::smallvec![5, 5], // interior
2260            smallvec::smallvec![0, 5], // top edge
2261        ];
2262        let n = centers.len();
2263        let mut output = vec![0.0f32; result.output_len * n];
2264        let mut mask = vec![0u8; result.mask_len * n];
2265        let metas = result
2266            .plan
2267            .execute_agents(&snap, &space, &centers, None, &mut output, &mut mask)
2268            .unwrap();
2269
2270        assert_eq!(metas.len(), 2);
2271
2272        // Agent 0 (interior): all 9 cells valid, center = (5,5) → rank 55
2273        assert!(mask[..9].iter().all(|&v| v == 1));
2274        assert_eq!(output[4], 55.0); // center at tensor_idx = 1*3+1 = 4
2275
2276        // Agent 1 (top edge): row -1 is out of bounds → 3 cells masked
2277        let agent1_mask = &mask[9..18];
2278        let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2279        assert_eq!(valid_count, 6); // 2 rows in-bounds × 3 cols
2280    }
2281
2282    #[test]
2283    fn execute_agents_with_normalize() {
2284        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2285        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2286        let snap = snapshot_with_field(FieldId(0), data);
2287
2288        let spec = ObsSpec {
2289            entries: vec![ObsEntry {
2290                field_id: FieldId(0),
2291                region: ObsRegion::AgentRect {
2292                    half_extent: smallvec::smallvec![1, 1],
2293                },
2294                pool: None,
2295                transform: ObsTransform::Normalize {
2296                    min: 0.0,
2297                    max: 99.0,
2298                },
2299                dtype: ObsDtype::F32,
2300            }],
2301        };
2302        let result = ObsPlan::compile(&spec, &space).unwrap();
2303
2304        let center: Coord = smallvec::smallvec![5, 5];
2305        let mut output = vec![0.0f32; result.output_len];
2306        let mut mask = vec![0u8; result.mask_len];
2307        result
2308            .plan
2309            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2310            .unwrap();
2311
2312        // Center (5,5) rank=55, normalized = 55/99 ≈ 0.5556
2313        let expected = 55.0 / 99.0;
2314        assert!((output[4] - expected as f32).abs() < 1e-5);
2315    }
2316
2317    #[test]
2318    fn execute_agents_with_pooling() {
2319        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2320        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2321        let snap = snapshot_with_field(FieldId(0), data);
2322
2323        // AgentRect with half_extent=3 → 7x7 bounding box.
2324        // Mean pool 2x2 stride 2 → floor((7-2)/2)+1 = 3 per dim → 3x3 = 9 output.
2325        let spec = ObsSpec {
2326            entries: vec![ObsEntry {
2327                field_id: FieldId(0),
2328                region: ObsRegion::AgentRect {
2329                    half_extent: smallvec::smallvec![3, 3],
2330                },
2331                pool: Some(PoolConfig {
2332                    kernel: PoolKernel::Mean,
2333                    kernel_size: 2,
2334                    stride: 2,
2335                }),
2336                transform: ObsTransform::Identity,
2337                dtype: ObsDtype::F32,
2338            }],
2339        };
2340        let result = ObsPlan::compile(&spec, &space).unwrap();
2341        assert_eq!(result.output_len, 9); // 3x3
2342        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2343
2344        // Interior agent at (10, 10): all cells valid.
2345        let center: Coord = smallvec::smallvec![10, 10];
2346        let mut output = vec![0.0f32; result.output_len];
2347        let mut mask = vec![0u8; result.mask_len];
2348        result
2349            .plan
2350            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2351            .unwrap();
2352
2353        // All pooled cells should be valid.
2354        assert!(mask.iter().all(|&v| v == 1));
2355
2356        // Verify first pooled cell: mean of top-left 2x2 of the 7x7 gather.
2357        // Gather bounding box starts at (10-3, 10-3) = (7, 7).
2358        // Top-left 2x2: (7,7)=147, (7,8)=148, (8,7)=167, (8,8)=168
2359        // Mean = (147+148+167+168)/4 = 157.5
2360        assert!((output[0] - 157.5).abs() < 1e-4);
2361    }
2362
2363    #[test]
2364    fn mixed_fixed_and_agent_entries() {
2365        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2366        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2367        let snap = snapshot_with_field(FieldId(0), data);
2368
2369        let spec = ObsSpec {
2370            entries: vec![
2371                // Fixed entry: full grid (100 elements).
2372                ObsEntry {
2373                    field_id: FieldId(0),
2374                    region: ObsRegion::Fixed(RegionSpec::All),
2375                    pool: None,
2376                    transform: ObsTransform::Identity,
2377                    dtype: ObsDtype::F32,
2378                },
2379                // Agent entry: 3x3 rect around agent.
2380                ObsEntry {
2381                    field_id: FieldId(0),
2382                    region: ObsRegion::AgentRect {
2383                        half_extent: smallvec::smallvec![1, 1],
2384                    },
2385                    pool: None,
2386                    transform: ObsTransform::Identity,
2387                    dtype: ObsDtype::F32,
2388                },
2389            ],
2390        };
2391        let result = ObsPlan::compile(&spec, &space).unwrap();
2392        assert!(result.plan.is_standard());
2393        assert_eq!(result.output_len, 109); // 100 + 9
2394
2395        let center: Coord = smallvec::smallvec![5, 5];
2396        let mut output = vec![0.0f32; result.output_len];
2397        let mut mask = vec![0u8; result.mask_len];
2398        result
2399            .plan
2400            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2401            .unwrap();
2402
2403        // Fixed entry: first 100 elements match field data.
2404        let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2405        assert_eq!(&output[..100], &expected[..]);
2406        assert!(mask[..100].iter().all(|&v| v == 1));
2407
2408        // Agent entry: 3x3 centered on (5,5). Center at tensor_idx = 1*3+1 = 4.
2409        // rank(5,5) = 55
2410        assert_eq!(output[100 + 4], 55.0);
2411    }
2412
2413    #[test]
2414    fn wrong_dimensionality_returns_error() {
2415        // 2D space but 1D agent center → should error, not panic.
2416        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2417        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2418        let snap = snapshot_with_field(FieldId(0), data);
2419
2420        let spec = ObsSpec {
2421            entries: vec![ObsEntry {
2422                field_id: FieldId(0),
2423                region: ObsRegion::AgentDisk { radius: 1 },
2424                pool: None,
2425                transform: ObsTransform::Identity,
2426                dtype: ObsDtype::F32,
2427            }],
2428        };
2429        let result = ObsPlan::compile(&spec, &space).unwrap();
2430
2431        let bad_center: Coord = smallvec::smallvec![5]; // 1D, not 2D
2432        let mut output = vec![0.0f32; result.output_len];
2433        let mut mask = vec![0u8; result.mask_len];
2434        let err =
2435            result
2436                .plan
2437                .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2438        assert!(err.is_err());
2439        let msg = format!("{}", err.unwrap_err());
2440        assert!(
2441            msg.contains("dimensions"),
2442            "error should mention dimensions: {msg}"
2443        );
2444    }
2445
2446    #[test]
2447    fn agent_disk_square4_filters_corners() {
2448        // On a 4-connected grid, AgentDisk radius=2 should use Manhattan distance.
2449        // Bounding box is 5x5 = 25, but Manhattan disk has 13 cells (diamond shape).
2450        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2451        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2452        let snap = snapshot_with_field(FieldId(0), data);
2453
2454        let spec = ObsSpec {
2455            entries: vec![ObsEntry {
2456                field_id: FieldId(0),
2457                region: ObsRegion::AgentDisk { radius: 2 },
2458                pool: None,
2459                transform: ObsTransform::Identity,
2460                dtype: ObsDtype::F32,
2461            }],
2462        };
2463        let result = ObsPlan::compile(&spec, &space).unwrap();
2464        assert_eq!(result.output_len, 25); // tensor shape is still 5x5
2465
2466        // Interior agent at (10, 10).
2467        let center: Coord = smallvec::smallvec![10, 10];
2468        let mut output = vec![0.0f32; 25];
2469        let mut mask = vec![0u8; 25];
2470        result
2471            .plan
2472            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2473            .unwrap();
2474
2475        // Manhattan distance disk of radius 2 on a 5x5 bounding box:
2476        //   . . X . .    (row -2: only center col)
2477        //   . X X X .    (row -1: 3 cells)
2478        //   X X X X X    (row  0: 5 cells)
2479        //   . X X X .    (row +1: 3 cells)
2480        //   . . X . .    (row +2: only center col)
2481        // Total: 1 + 3 + 5 + 3 + 1 = 13 cells
2482        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2483        assert_eq!(
2484            valid_count, 13,
2485            "Manhattan disk radius=2 should have 13 cells"
2486        );
2487
2488        // Corners should be masked out: (dr,dc) where |dr|+|dc| > 2
2489        // tensor_idx 0: dr=-2,dc=-2 → dist=4 → OUT
2490        // tensor_idx 4: dr=-2,dc=+2 → dist=4 → OUT
2491        // tensor_idx 20: dr=+2,dc=-2 → dist=4 → OUT
2492        // tensor_idx 24: dr=+2,dc=+2 → dist=4 → OUT
2493        for &idx in &[0, 4, 20, 24] {
2494            assert_eq!(
2495                mask[idx], 0,
2496                "corner tensor_idx {idx} should be outside disk"
2497            );
2498        }
2499
2500        // Center cell: tensor_idx = 2*5+2 = 12, absolute = row 10 * 20 + col 10 = 210
2501        assert_eq!(output[12], 210.0);
2502        assert_eq!(mask[12], 1);
2503    }
2504
2505    #[test]
2506    fn agent_rect_no_disk_filtering() {
2507        // AgentRect should NOT filter any cells — full rectangle is valid.
2508        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2509        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2510        let snap = snapshot_with_field(FieldId(0), data);
2511
2512        let spec = ObsSpec {
2513            entries: vec![ObsEntry {
2514                field_id: FieldId(0),
2515                region: ObsRegion::AgentRect {
2516                    half_extent: smallvec::smallvec![2, 2],
2517                },
2518                pool: None,
2519                transform: ObsTransform::Identity,
2520                dtype: ObsDtype::F32,
2521            }],
2522        };
2523        let result = ObsPlan::compile(&spec, &space).unwrap();
2524
2525        let center: Coord = smallvec::smallvec![10, 10];
2526        let mut output = vec![0.0f32; 25];
2527        let mut mask = vec![0u8; 25];
2528        result
2529            .plan
2530            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2531            .unwrap();
2532
2533        // All 25 cells should be valid for AgentRect (no disk filtering).
2534        assert!(mask.iter().all(|&v| v == 1));
2535    }
2536
2537    #[test]
2538    fn agent_disk_square8_chebyshev() {
2539        // On an 8-connected grid, AgentDisk radius=1 uses Chebyshev distance.
2540        // Bounding box is 3x3 = 9, Chebyshev disk radius=1 = full 3x3 → 9 cells.
2541        let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2542        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2543        let snap = snapshot_with_field(FieldId(0), data);
2544
2545        let spec = ObsSpec {
2546            entries: vec![ObsEntry {
2547                field_id: FieldId(0),
2548                region: ObsRegion::AgentDisk { radius: 1 },
2549                pool: None,
2550                transform: ObsTransform::Identity,
2551                dtype: ObsDtype::F32,
2552            }],
2553        };
2554        let result = ObsPlan::compile(&spec, &space).unwrap();
2555        assert_eq!(result.output_len, 9);
2556
2557        let center: Coord = smallvec::smallvec![5, 5];
2558        let mut output = vec![0.0f32; 9];
2559        let mut mask = vec![0u8; 9];
2560        result
2561            .plan
2562            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2563            .unwrap();
2564
2565        // Chebyshev distance <= 1 covers full 3x3 = 9 cells (all corners included).
2566        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2567        assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2568    }
2569
2570    #[test]
2571    fn compile_rejects_inverted_normalize_range() {
2572        let space = square4_space();
2573        let spec = ObsSpec {
2574            entries: vec![ObsEntry {
2575                field_id: FieldId(0),
2576                region: ObsRegion::Fixed(RegionSpec::All),
2577                pool: None,
2578                transform: ObsTransform::Normalize {
2579                    min: 10.0,
2580                    max: 5.0,
2581                },
2582                dtype: ObsDtype::F32,
2583            }],
2584        };
2585        let err = ObsPlan::compile(&spec, &space).unwrap_err();
2586        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
2587    }
2588
2589    #[test]
2590    fn compile_rejects_nan_normalize() {
2591        let space = square4_space();
2592        let spec = ObsSpec {
2593            entries: vec![ObsEntry {
2594                field_id: FieldId(0),
2595                region: ObsRegion::Fixed(RegionSpec::All),
2596                pool: None,
2597                transform: ObsTransform::Normalize {
2598                    min: f64::NAN,
2599                    max: 1.0,
2600                },
2601                dtype: ObsDtype::F32,
2602            }],
2603        };
2604        assert!(ObsPlan::compile(&spec, &space).is_err());
2605    }
2606}