Skip to main content

murk_obs/
plan.rs

1//! Observation plan compilation and execution.
2//!
3//! [`ObsPlan`] is compiled from an [`ObsSpec`] + [`Space`], producing
4//! a reusable gather plan that can be executed against any
5//! [`SnapshotAccess`] implementor. The "Simple plan class" uses a
6//! branch-free flat gather: for each entry, iterate pre-computed
7//! `(field_data_index, tensor_index)` pairs, read the field value,
8//! optionally transform it, and write to the caller-allocated buffer.
9
10use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d_into;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21/// Coverage threshold: warn if valid_ratio < this.
22const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24/// Coverage threshold: error if valid_ratio < this.
25const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27/// Result of compiling an [`ObsSpec`].
28#[derive(Debug)]
29pub struct ObsPlanResult {
30    /// The compiled plan, ready for execution.
31    pub plan: ObsPlan,
32    /// Total number of f32 elements in the output tensor.
33    pub output_len: usize,
34    /// Shape per entry (each entry's region bounding shape dimensions).
35    pub entry_shapes: Vec<Vec<usize>>,
36    /// Length of the validity mask in bytes.
37    pub mask_len: usize,
38}
39
40/// Compiled observation plan: either Simple or Standard class.
41///
42/// **Simple** (all `Fixed` regions): pre-computed gather indices, branch-free
43/// loop, zero spatial computation at runtime. Use [`execute`](Self::execute).
44///
45/// **Standard** (any agent-relative region): template-based gather with
46/// interior/boundary dispatch. Use [`execute_agents`](Self::execute_agents).
47#[derive(Debug)]
48pub struct ObsPlan {
49    strategy: PlanStrategy,
50    /// Total output elements across all entries (per agent for Standard).
51    output_len: usize,
52    /// Total mask bytes across all entries (per agent for Standard).
53    mask_len: usize,
54    /// Generation at compile time (for PLAN_INVALIDATED detection).
55    compiled_generation: Option<WorldGenerationId>,
56}
57
58/// Pre-computed gather instruction for a single cell.
59///
60/// At execution time, we read `field_data[field_data_idx]` and write
61/// the (transformed) value to `output[tensor_idx]`.
62#[derive(Debug, Clone)]
63struct GatherOp {
64    /// Index into the flat field data array (canonical ordering).
65    field_data_idx: usize,
66    /// Index into the output slice for this entry.
67    tensor_idx: usize,
68}
69
70/// A single compiled entry ready for gather execution.
71#[derive(Debug)]
72struct CompiledEntry {
73    field_id: FieldId,
74    transform: ObsTransform,
75    #[allow(dead_code)]
76    dtype: ObsDtype,
77    /// Offset into the output buffer where this entry starts.
78    output_offset: usize,
79    /// Offset into the validity mask where this entry starts.
80    mask_offset: usize,
81    /// Number of elements this entry contributes to the output.
82    element_count: usize,
83    /// Pre-computed gather operations (one per valid cell in the region).
84    gather_ops: Vec<GatherOp>,
85    /// Pre-computed validity mask for this entry's bounding box.
86    valid_mask: Vec<u8>,
87    /// Number of valid cells in `valid_mask`.
88    valid_count: usize,
89    /// Valid ratio for this entry's region.
90    #[allow(dead_code)]
91    valid_ratio: f64,
92}
93
94/// Relative offset from agent center for template-based gather.
95///
96/// At compile time, the bounding box of an agent-centered region is
97/// decomposed into `TemplateOp`s. At execute time, the agent center
98/// is resolved and each op is applied: `field_data[base_rank + stride_offset]`.
99#[derive(Debug, Clone)]
100struct TemplateOp {
101    /// Offset from center per coordinate axis.
102    relative: Coord,
103    /// Position in the gather bounding-box tensor (row-major).
104    tensor_idx: usize,
105    /// Precomputed `sum(relative[i] * strides[i])` for interior fast path.
106    /// Zero if no `GridGeometry` is available (fallback path only).
107    stride_offset: isize,
108    /// Whether this cell is within the disk region (always true for AgentRect).
109    /// For AgentDisk, cells outside the graph-distance radius are excluded.
110    in_disk: bool,
111}
112
113/// Compiled agent-relative entry for the Standard plan class.
114///
115/// Stores template data that is instantiated per-agent at execute time.
116/// The bounding box shape comes from the region (e.g., `[2r+1, 2r+1]` for
117/// `AgentDisk`/`AgentRect`), and may be reduced by pooling.
118#[derive(Debug)]
119struct AgentCompiledEntry {
120    field_id: FieldId,
121    pool: Option<PoolConfig>,
122    transform: ObsTransform,
123    #[allow(dead_code)]
124    dtype: ObsDtype,
125    /// Offset into the per-agent output buffer.
126    output_offset: usize,
127    /// Offset into the per-agent mask buffer.
128    mask_offset: usize,
129    /// Post-pool output elements (written to output).
130    element_count: usize,
131    /// Pre-pool bounding-box elements (gather buffer size).
132    pre_pool_element_count: usize,
133    /// Shape of the pre-pool bounding box (e.g., `[7, 7]`).
134    pre_pool_shape: Vec<usize>,
135    /// Active template operations (in-disk only) for runtime gather.
136    active_ops: Vec<TemplateOp>,
137    /// Radius for `is_interior` check.
138    radius: u32,
139}
140
141/// Data for the Standard plan class (agent-centered foveation + pooling).
142#[derive(Debug)]
143struct StandardPlanData {
144    /// Entries with `ObsRegion::Fixed` (same output for all agents).
145    fixed_entries: Vec<CompiledEntry>,
146    /// Entries with agent-relative regions (resolved per-agent).
147    agent_entries: Vec<AgentCompiledEntry>,
148    /// Grid geometry for interior/boundary dispatch (`None` → all slow path).
149    geometry: Option<GridGeometry>,
150}
151
152/// Data for the Simple plan class (all entries fixed, branch-free gather).
153#[derive(Debug)]
154struct SimplePlanData {
155    entries: Vec<CompiledEntry>,
156    total_valid: usize,
157    total_elements: usize,
158}
159
160/// Internal plan strategy: Simple (all-fixed) or Standard (agent-centered).
161#[derive(Debug)]
162enum PlanStrategy {
163    /// All entries are `ObsRegion::Fixed`: pre-computed gather indices.
164    Simple(SimplePlanData),
165    /// At least one entry is agent-relative: template-based gather.
166    Standard(StandardPlanData),
167}
168
169impl ObsPlan {
170    /// Compile an [`ObsSpec`] against a [`Space`].
171    ///
172    /// Detects whether the spec contains agent-relative regions and
173    /// dispatches to the appropriate plan class:
174    /// - All `Fixed` → **Simple** (pre-computed gather)
175    /// - Any `AgentDisk`/`AgentRect` → **Standard** (template-based)
176    pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
177        if spec.entries.is_empty() {
178            return Err(ObsError::InvalidObsSpec {
179                reason: "ObsSpec has no entries".into(),
180            });
181        }
182
183        // Validate transform parameters.
184        for (i, entry) in spec.entries.iter().enumerate() {
185            if let ObsTransform::Normalize { min, max } = &entry.transform {
186                if !min.is_finite() || !max.is_finite() {
187                    return Err(ObsError::InvalidObsSpec {
188                        reason: format!(
189                            "entry {i}: Normalize min/max must be finite, got min={min}, max={max}"
190                        ),
191                    });
192                }
193                if min > max {
194                    return Err(ObsError::InvalidObsSpec {
195                        reason: format!("entry {i}: Normalize min ({min}) must be <= max ({max})"),
196                    });
197                }
198            }
199        }
200
201        let has_agent = spec.entries.iter().any(|e| {
202            matches!(
203                e.region,
204                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
205            )
206        });
207
208        if has_agent {
209            Self::compile_standard(spec, space)
210        } else {
211            Self::compile_simple(spec, space)
212        }
213    }
214
215    /// Compile a Simple plan (all `Fixed` regions, no agent-relative entries).
216    fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
217        let canonical = space.canonical_ordering();
218        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
219            .into_iter()
220            .enumerate()
221            .map(|(idx, coord)| (coord, idx))
222            .collect();
223
224        let mut entries = Vec::with_capacity(spec.entries.len());
225        let mut total_valid = 0usize;
226        let mut total_elements = 0usize;
227        let mut output_offset = 0usize;
228        let mut mask_offset = 0usize;
229        let mut entry_shapes = Vec::with_capacity(spec.entries.len());
230
231        for (i, entry) in spec.entries.iter().enumerate() {
232            let fixed_region = match &entry.region {
233                ObsRegion::Fixed(spec) => spec,
234                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
235                    return Err(ObsError::InvalidObsSpec {
236                        reason: format!("entry {i}: agent-relative region in Simple plan"),
237                    });
238                }
239            };
240            if entry.pool.is_some() {
241                return Err(ObsError::InvalidObsSpec {
242                    reason: format!(
243                        "entry {i}: pooling requires a Standard plan (use agent-relative region)"
244                    ),
245                });
246            }
247
248            let mut region_plan =
249                space
250                    .compile_region(fixed_region)
251                    .map_err(|e| ObsError::InvalidObsSpec {
252                        reason: format!("entry {i}: region compile failed: {e}"),
253                    })?;
254
255            let ratio = region_plan.valid_ratio();
256            if ratio < COVERAGE_ERROR_THRESHOLD {
257                return Err(ObsError::InvalidComposition {
258                    reason: format!(
259                        "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
260                    ),
261                });
262            }
263            if ratio < COVERAGE_WARN_THRESHOLD {
264                eprintln!(
265                    "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
266                );
267            }
268
269            let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
270            for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
271                let field_data_idx =
272                    *coord_to_field_idx
273                        .get(coord)
274                        .ok_or_else(|| ObsError::InvalidObsSpec {
275                            reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
276                        })?;
277                let tensor_idx = region_plan.tensor_indices()[coord_idx];
278                gather_ops.push(GatherOp {
279                    field_data_idx,
280                    tensor_idx,
281                });
282            }
283
284            let element_count = region_plan.bounding_shape().total_elements();
285            let shape = match region_plan.bounding_shape() {
286                murk_space::BoundingShape::Rect(dims) => dims.clone(),
287            };
288            entry_shapes.push(shape);
289
290            let valid_mask = region_plan.take_valid_mask();
291            let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
292            entries.push(CompiledEntry {
293                field_id: entry.field_id,
294                transform: entry.transform.clone(),
295                dtype: entry.dtype,
296                output_offset,
297                mask_offset,
298                element_count,
299                gather_ops,
300                valid_mask,
301                valid_count,
302                valid_ratio: ratio,
303            });
304            total_valid += valid_count;
305            total_elements += element_count;
306
307            output_offset += element_count;
308            mask_offset += element_count;
309        }
310
311        let plan = ObsPlan {
312            strategy: PlanStrategy::Simple(SimplePlanData {
313                entries,
314                total_valid,
315                total_elements,
316            }),
317            output_len: output_offset,
318            mask_len: mask_offset,
319            compiled_generation: None,
320        };
321
322        Ok(ObsPlanResult {
323            output_len: plan.output_len,
324            mask_len: plan.mask_len,
325            entry_shapes,
326            plan,
327        })
328    }
329
330    /// Compile a Standard plan (has agent-relative entries).
331    ///
332    /// Fixed entries are compiled with pre-computed gather (same for all agents).
333    /// Agent entries are compiled as templates (resolved per-agent at execute time).
334    fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
335        let canonical = space.canonical_ordering();
336        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
337            .into_iter()
338            .enumerate()
339            .map(|(idx, coord)| (coord, idx))
340            .collect();
341
342        let geometry = GridGeometry::from_space(space);
343        let ndim = space.ndim();
344
345        let mut fixed_entries = Vec::new();
346        let mut agent_entries = Vec::new();
347        let mut output_offset = 0usize;
348        let mut mask_offset = 0usize;
349        let mut entry_shapes = Vec::new();
350
351        for (i, entry) in spec.entries.iter().enumerate() {
352            match &entry.region {
353                ObsRegion::Fixed(region_spec) => {
354                    if entry.pool.is_some() {
355                        return Err(ObsError::InvalidObsSpec {
356                            reason: format!("entry {i}: pooling on Fixed regions not supported"),
357                        });
358                    }
359
360                    let mut region_plan = space.compile_region(region_spec).map_err(|e| {
361                        ObsError::InvalidObsSpec {
362                            reason: format!("entry {i}: region compile failed: {e}"),
363                        }
364                    })?;
365
366                    let ratio = region_plan.valid_ratio();
367                    if ratio < COVERAGE_ERROR_THRESHOLD {
368                        return Err(ObsError::InvalidComposition {
369                            reason: format!(
370                                "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
371                            ),
372                        });
373                    }
374
375                    let mut gather_ops = Vec::with_capacity(region_plan.coords().len());
376                    for (coord_idx, coord) in region_plan.coords().iter().enumerate() {
377                        let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
378                            ObsError::InvalidObsSpec {
379                                reason: format!(
380                                    "entry {i}: coord {coord:?} not in canonical ordering"
381                                ),
382                            }
383                        })?;
384                        let tensor_idx = region_plan.tensor_indices()[coord_idx];
385                        gather_ops.push(GatherOp {
386                            field_data_idx,
387                            tensor_idx,
388                        });
389                    }
390
391                    let element_count = region_plan.bounding_shape().total_elements();
392                    let shape = match region_plan.bounding_shape() {
393                        murk_space::BoundingShape::Rect(dims) => dims.clone(),
394                    };
395                    entry_shapes.push(shape);
396
397                    let valid_mask = region_plan.take_valid_mask();
398                    let valid_count = valid_mask.iter().filter(|&&v| v == 1).count();
399                    fixed_entries.push(CompiledEntry {
400                        field_id: entry.field_id,
401                        transform: entry.transform.clone(),
402                        dtype: entry.dtype,
403                        output_offset,
404                        mask_offset,
405                        element_count,
406                        gather_ops,
407                        valid_mask,
408                        valid_count,
409                        valid_ratio: ratio,
410                    });
411
412                    output_offset += element_count;
413                    mask_offset += element_count;
414                }
415
416                ObsRegion::AgentDisk { radius } => {
417                    let half_ext: smallvec::SmallVec<[u32; 4]> =
418                        (0..ndim).map(|_| *radius).collect();
419                    let (ae, shape) = Self::compile_agent_entry(
420                        i,
421                        entry,
422                        &half_ext,
423                        *radius,
424                        &geometry,
425                        Some(*radius),
426                        output_offset,
427                        mask_offset,
428                    )?;
429                    entry_shapes.push(shape);
430                    output_offset += ae.element_count;
431                    mask_offset += ae.element_count;
432                    agent_entries.push(ae);
433                }
434
435                ObsRegion::AgentRect { half_extent } => {
436                    if half_extent.len() != ndim {
437                        return Err(ObsError::InvalidObsSpec {
438                            reason: format!(
439                                "entry {i}: AgentRect half_extent has {} dims, but space requires {ndim}",
440                                half_extent.len()
441                            ),
442                        });
443                    }
444                    let radius = *half_extent.iter().max().unwrap_or(&0);
445                    let (ae, shape) = Self::compile_agent_entry(
446                        i,
447                        entry,
448                        half_extent,
449                        radius,
450                        &geometry,
451                        None,
452                        output_offset,
453                        mask_offset,
454                    )?;
455                    entry_shapes.push(shape);
456                    output_offset += ae.element_count;
457                    mask_offset += ae.element_count;
458                    agent_entries.push(ae);
459                }
460            }
461        }
462
463        let plan = ObsPlan {
464            strategy: PlanStrategy::Standard(StandardPlanData {
465                fixed_entries,
466                agent_entries,
467                geometry,
468            }),
469            output_len: output_offset,
470            mask_len: mask_offset,
471            compiled_generation: None,
472        };
473
474        Ok(ObsPlanResult {
475            output_len: plan.output_len,
476            mask_len: plan.mask_len,
477            entry_shapes,
478            plan,
479        })
480    }
481
482    /// Compile a single agent-relative entry into a template.
483    ///
484    /// `disk_radius`: if `Some(r)`, template ops outside graph-distance `r`
485    /// are marked `in_disk = false` (for `AgentDisk`). `None` for `AgentRect`.
486    #[allow(clippy::too_many_arguments)]
487    fn compile_agent_entry(
488        entry_idx: usize,
489        entry: &crate::spec::ObsEntry,
490        half_extent: &[u32],
491        radius: u32,
492        geometry: &Option<GridGeometry>,
493        disk_radius: Option<u32>,
494        output_offset: usize,
495        mask_offset: usize,
496    ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
497        let pre_pool_shape: Vec<usize> =
498            half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
499        let pre_pool_element_count: usize = pre_pool_shape.iter().product();
500
501        let template_ops = generate_template_ops(half_extent, geometry, disk_radius)?;
502        let active_ops = template_ops
503            .iter()
504            .filter(|op| op.in_disk)
505            .cloned()
506            .collect();
507
508        let (element_count, output_shape) = if let Some(pool) = &entry.pool {
509            if pre_pool_shape.len() != 2 {
510                return Err(ObsError::InvalidObsSpec {
511                    reason: format!(
512                        "entry {entry_idx}: pooling requires 2D region, got {}D",
513                        pre_pool_shape.len()
514                    ),
515                });
516            }
517            let h = pre_pool_shape[0];
518            let w = pre_pool_shape[1];
519            let ks = pool.kernel_size;
520            let stride = pool.stride;
521            if ks == 0 || stride == 0 {
522                return Err(ObsError::InvalidObsSpec {
523                    reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
524                });
525            }
526            let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
527            let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
528            if out_h == 0 || out_w == 0 {
529                return Err(ObsError::InvalidObsSpec {
530                    reason: format!(
531                        "entry {entry_idx}: pool produces empty output \
532                         (region [{h},{w}], kernel_size {ks}, stride {stride})"
533                    ),
534                });
535            }
536            (out_h * out_w, vec![out_h, out_w])
537        } else {
538            (pre_pool_element_count, pre_pool_shape.clone())
539        };
540
541        Ok((
542            AgentCompiledEntry {
543                field_id: entry.field_id,
544                pool: entry.pool.clone(),
545                transform: entry.transform.clone(),
546                dtype: entry.dtype,
547                output_offset,
548                mask_offset,
549                element_count,
550                pre_pool_element_count,
551                pre_pool_shape,
552                active_ops,
553                radius,
554            },
555            output_shape,
556        ))
557    }
558
559    /// Compile with generation binding for PLAN_INVALIDATED detection.
560    ///
561    /// Same as [`compile`](Self::compile) but records the snapshot's
562    /// `world_generation_id` for later validation in [`ObsPlan::execute`].
563    pub fn compile_bound(
564        spec: &ObsSpec,
565        space: &dyn Space,
566        generation: WorldGenerationId,
567    ) -> Result<ObsPlanResult, ObsError> {
568        let mut result = Self::compile(spec, space)?;
569        result.plan.compiled_generation = Some(generation);
570        Ok(result)
571    }
572
573    /// Total number of f32 elements in the output tensor.
574    pub fn output_len(&self) -> usize {
575        self.output_len
576    }
577
578    /// Total number of bytes in the validity mask.
579    pub fn mask_len(&self) -> usize {
580        self.mask_len
581    }
582
583    /// The generation this plan was compiled against, if bound.
584    pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
585        self.compiled_generation
586    }
587
588    /// Execute the observation plan against a snapshot.
589    ///
590    /// Fills `output` with gathered and transformed field values, and
591    /// `mask` with validity flags (1 = valid, 0 = padding). Both
592    /// buffers must be pre-allocated to [`output_len`](Self::output_len)
593    /// and [`mask_len`](Self::mask_len) respectively.
594    ///
595    /// `engine_tick` is the current engine tick for computing
596    /// [`ObsMetadata::age_ticks`]. Pass `None` in Lockstep mode
597    /// (age is always 0). In RealtimeAsync mode, pass the current
598    /// engine tick so age reflects snapshot staleness.
599    ///
600    /// Returns [`ObsMetadata`] on success.
601    ///
602    /// # Errors
603    ///
604    /// - [`ObsError::PlanInvalidated`] if bound and generation mismatches.
605    /// - [`ObsError::ExecutionFailed`] if a field is missing from the snapshot.
606    pub fn execute(
607        &self,
608        snapshot: &dyn SnapshotAccess,
609        engine_tick: Option<TickId>,
610        output: &mut [f32],
611        mask: &mut [u8],
612    ) -> Result<ObsMetadata, ObsError> {
613        let simple = match &self.strategy {
614            PlanStrategy::Simple(data) => data,
615            PlanStrategy::Standard(_) => {
616                return Err(ObsError::ExecutionFailed {
617                    reason: "Standard plan requires execute_agents(), not execute()".into(),
618                });
619            }
620        };
621
622        if output.len() < self.output_len {
623            return Err(ObsError::ExecutionFailed {
624                reason: format!(
625                    "output buffer too small: {} < {}",
626                    output.len(),
627                    self.output_len
628                ),
629            });
630        }
631        if mask.len() < self.mask_len {
632            return Err(ObsError::ExecutionFailed {
633                reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
634            });
635        }
636
637        // Generation check (PLAN_INVALIDATED).
638        if let Some(compiled_gen) = self.compiled_generation {
639            let snapshot_gen = snapshot.world_generation_id();
640            if compiled_gen != snapshot_gen {
641                return Err(ObsError::PlanInvalidated {
642                    reason: format!(
643                        "plan compiled for generation {}, snapshot is generation {}",
644                        compiled_gen.0, snapshot_gen.0
645                    ),
646                });
647            }
648        }
649
650        Self::execute_simple_entries(&simple.entries, snapshot, output, mask)?;
651
652        let coverage = if simple.total_elements == 0 {
653            0.0
654        } else {
655            simple.total_valid as f64 / simple.total_elements as f64
656        };
657
658        let age_ticks = match engine_tick {
659            Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
660            None => 0,
661        };
662
663        Ok(ObsMetadata {
664            tick_id: snapshot.tick_id(),
665            age_ticks,
666            coverage,
667            world_generation_id: snapshot.world_generation_id(),
668            parameter_version: snapshot.parameter_version(),
669        })
670    }
671
672    /// Execute the plan for a batch of `N` identical environments.
673    ///
674    /// Each snapshot in the batch fills `output_len()` elements in the
675    /// output buffer, starting at `batch_idx * output_len()`. Same for
676    /// masks. This is the primary interface for vectorized RL training.
677    ///
678    /// **Partial writes on error:** If generation validation fails on
679    /// snapshot `N > 0`, output slots `0..N-1` are already written.
680    /// Callers should treat the entire batch as invalid on error.
681    ///
682    /// Returns one [`ObsMetadata`] per snapshot.
683    pub fn execute_batch(
684        &self,
685        snapshots: &[&dyn SnapshotAccess],
686        engine_tick: Option<TickId>,
687        output: &mut [f32],
688        mask: &mut [u8],
689    ) -> Result<Vec<ObsMetadata>, ObsError> {
690        let simple = match &self.strategy {
691            PlanStrategy::Simple(data) => data,
692            PlanStrategy::Standard(_) => {
693                return Err(ObsError::ExecutionFailed {
694                    reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
695                });
696            }
697        };
698
699        let batch_size = snapshots.len();
700        let expected_out = batch_size * self.output_len;
701        let expected_mask = batch_size * self.mask_len;
702
703        if output.len() < expected_out {
704            return Err(ObsError::ExecutionFailed {
705                reason: format!(
706                    "batch output buffer too small: {} < {}",
707                    output.len(),
708                    expected_out
709                ),
710            });
711        }
712        if mask.len() < expected_mask {
713            return Err(ObsError::ExecutionFailed {
714                reason: format!(
715                    "batch mask buffer too small: {} < {}",
716                    mask.len(),
717                    expected_mask
718                ),
719            });
720        }
721
722        let coverage = if simple.total_elements == 0 {
723            0.0
724        } else {
725            simple.total_valid as f64 / simple.total_elements as f64
726        };
727
728        let mut metadata = Vec::with_capacity(batch_size);
729        for (i, snap) in snapshots.iter().enumerate() {
730            if let Some(compiled_gen) = self.compiled_generation {
731                let snapshot_gen = snap.world_generation_id();
732                if compiled_gen != snapshot_gen {
733                    return Err(ObsError::PlanInvalidated {
734                        reason: format!(
735                            "plan compiled for generation {}, snapshot is generation {}",
736                            compiled_gen.0, snapshot_gen.0
737                        ),
738                    });
739                }
740            }
741
742            let out_start = i * self.output_len;
743            let mask_start = i * self.mask_len;
744            let out_slice = &mut output[out_start..out_start + self.output_len];
745            let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
746            Self::execute_simple_entries(&simple.entries, *snap, out_slice, mask_slice)?;
747
748            let age_ticks = match engine_tick {
749                Some(tick) => tick.0.saturating_sub(snap.tick_id().0),
750                None => 0,
751            };
752            metadata.push(ObsMetadata {
753                tick_id: snap.tick_id(),
754                age_ticks,
755                coverage,
756                world_generation_id: snap.world_generation_id(),
757                parameter_version: snap.parameter_version(),
758            });
759        }
760        Ok(metadata)
761    }
762
763    /// Execute the Standard plan for `N` agents in one environment.
764    ///
765    /// Each agent gets `output_len()` elements starting at
766    /// `agent_idx * output_len()`. Fixed entries produce the same
767    /// output for all agents; agent-relative entries are resolved
768    /// per-agent using interior/boundary dispatch.
769    ///
770    /// Interior agents (~49% for 20×20 grid, radius 3) use a branchless
771    /// fast path with stride arithmetic. Boundary agents fall back to
772    /// per-cell bounds checking.
773    pub fn execute_agents(
774        &self,
775        snapshot: &dyn SnapshotAccess,
776        space: &dyn Space,
777        agent_centers: &[Coord],
778        engine_tick: Option<TickId>,
779        output: &mut [f32],
780        mask: &mut [u8],
781    ) -> Result<Vec<ObsMetadata>, ObsError> {
782        let standard = match &self.strategy {
783            PlanStrategy::Standard(data) => data,
784            PlanStrategy::Simple(_) => {
785                return Err(ObsError::ExecutionFailed {
786                    reason: "execute_agents requires a Standard plan \
787                             (spec must contain agent-relative entries)"
788                        .into(),
789                });
790            }
791        };
792
793        let n_agents = agent_centers.len();
794        let expected_out = n_agents * self.output_len;
795        let expected_mask = n_agents * self.mask_len;
796
797        if output.len() < expected_out {
798            return Err(ObsError::ExecutionFailed {
799                reason: format!(
800                    "output buffer too small: {} < {}",
801                    output.len(),
802                    expected_out
803                ),
804            });
805        }
806        if mask.len() < expected_mask {
807            return Err(ObsError::ExecutionFailed {
808                reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
809            });
810        }
811
812        // Validate agent center dimensionality.
813        let expected_ndim = space.ndim();
814        for (i, center) in agent_centers.iter().enumerate() {
815            if center.len() != expected_ndim {
816                return Err(ObsError::ExecutionFailed {
817                    reason: format!(
818                        "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
819                        center.len()
820                    ),
821                });
822            }
823        }
824
825        // Generation check.
826        if let Some(compiled_gen) = self.compiled_generation {
827            let snapshot_gen = snapshot.world_generation_id();
828            if compiled_gen != snapshot_gen {
829                return Err(ObsError::PlanInvalidated {
830                    reason: format!(
831                        "plan compiled for generation {}, snapshot is generation {}",
832                        compiled_gen.0, snapshot_gen.0
833                    ),
834                });
835            }
836        }
837
838        // Pre-read all field data (shared borrows, valid for duration).
839        // Keep vectors aligned with compiled entry vectors to avoid
840        // map lookups on the hot path.
841        let mut fixed_field_data = Vec::with_capacity(standard.fixed_entries.len());
842        for entry in &standard.fixed_entries {
843            let data =
844                snapshot
845                    .read_field(entry.field_id)
846                    .ok_or_else(|| ObsError::ExecutionFailed {
847                        reason: format!("field {:?} not in snapshot", entry.field_id),
848                    })?;
849            fixed_field_data.push(data);
850        }
851        let mut agent_field_data = Vec::with_capacity(standard.agent_entries.len());
852        for entry in &standard.agent_entries {
853            let data =
854                snapshot
855                    .read_field(entry.field_id)
856                    .ok_or_else(|| ObsError::ExecutionFailed {
857                        reason: format!("field {:?} not in snapshot", entry.field_id),
858                    })?;
859            agent_field_data.push(data);
860        }
861
862        // ── Compute fixed entries ONCE (identical for all agents) ──
863        // Allocate scratch buffers for the fixed-entry output and mask, gather
864        // into them, then memcpy per agent. This avoids redundant N*M gathers.
865        let has_fixed = !standard.fixed_entries.is_empty();
866        let mut fixed_out_scratch = if has_fixed {
867            vec![0.0f32; self.output_len]
868        } else {
869            Vec::new()
870        };
871        let mut fixed_mask_scratch = if has_fixed {
872            vec![0u8; self.mask_len]
873        } else {
874            Vec::new()
875        };
876        let mut fixed_valid = 0usize;
877        let mut fixed_elements = 0usize;
878
879        for (entry, field_data) in standard
880            .fixed_entries
881            .iter()
882            .zip(fixed_field_data.iter().copied())
883        {
884            let out_slice = &mut fixed_out_scratch
885                [entry.output_offset..entry.output_offset + entry.element_count];
886            let mask_slice =
887                &mut fixed_mask_scratch[entry.mask_offset..entry.mask_offset + entry.element_count];
888
889            mask_slice.copy_from_slice(&entry.valid_mask);
890            for op in &entry.gather_ops {
891                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
892                    ObsError::ExecutionFailed {
893                        reason: format!(
894                            "field {:?} has {} elements but gather requires index {}",
895                            entry.field_id,
896                            field_data.len(),
897                            op.field_data_idx,
898                        ),
899                    }
900                })?;
901                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
902            }
903
904            fixed_valid += entry.valid_count;
905            fixed_elements += entry.element_count;
906        }
907
908        // ── Pre-allocate pooling scratch buffers ──────────────────
909        // Find the maximum pre_pool_element_count across all pooled entries
910        // so we can allocate once and reuse across agents (sequential processing).
911        let max_pool_scratch = standard
912            .agent_entries
913            .iter()
914            .filter(|e| e.pool.is_some())
915            .map(|e| e.pre_pool_element_count)
916            .max()
917            .unwrap_or(0);
918        let max_pool_output = standard
919            .agent_entries
920            .iter()
921            .filter(|e| e.pool.is_some())
922            .map(|e| e.element_count)
923            .max()
924            .unwrap_or(0);
925        let mut pool_scratch = vec![0.0f32; max_pool_scratch];
926        let mut pool_scratch_mask = vec![0u8; max_pool_scratch];
927        let mut pooled_scratch = vec![0.0f32; max_pool_output];
928        let mut pooled_scratch_mask = vec![0u8; max_pool_output];
929
930        let mut metadata = Vec::with_capacity(n_agents);
931
932        for (agent_i, center) in agent_centers.iter().enumerate() {
933            let out_start = agent_i * self.output_len;
934            let mask_start = agent_i * self.mask_len;
935            let agent_output = &mut output[out_start..out_start + self.output_len];
936            let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
937
938            // Initialize per-agent slices, then stamp fixed entries from
939            // pre-computed scratch (no re-gather).
940            agent_output.fill(0.0);
941            agent_mask.fill(0);
942            if has_fixed {
943                for entry in &standard.fixed_entries {
944                    let out_range = entry.output_offset..entry.output_offset + entry.element_count;
945                    let mask_range = entry.mask_offset..entry.mask_offset + entry.element_count;
946                    agent_output[out_range.clone()].copy_from_slice(&fixed_out_scratch[out_range]);
947                    agent_mask[mask_range.clone()].copy_from_slice(&fixed_mask_scratch[mask_range]);
948                }
949            }
950
951            let mut total_valid = fixed_valid;
952            let mut total_elements = fixed_elements;
953
954            // ── Agent-relative entries ───────────────────────────
955            for (entry, field_data) in standard
956                .agent_entries
957                .iter()
958                .zip(agent_field_data.iter().copied())
959            {
960                // Fast path: stride arithmetic works only for non-wrapping
961                // grids where all cells in the bounding box are in-bounds.
962                // Torus (all_wrap) requires modular arithmetic → slow path.
963                let use_fast_path = standard
964                    .geometry
965                    .as_ref()
966                    .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
967                    .unwrap_or(false);
968
969                // Zero the pooling scratch region for this entry before reuse.
970                if entry.pool.is_some() {
971                    pool_scratch[..entry.pre_pool_element_count].fill(0.0);
972                    pool_scratch_mask[..entry.pre_pool_element_count].fill(0);
973                }
974
975                let valid = execute_agent_entry(
976                    entry,
977                    center,
978                    field_data,
979                    &standard.geometry,
980                    space,
981                    use_fast_path,
982                    agent_output,
983                    agent_mask,
984                    &mut pool_scratch,
985                    &mut pool_scratch_mask,
986                    &mut pooled_scratch,
987                    &mut pooled_scratch_mask,
988                )?;
989
990                total_valid += valid;
991                total_elements += entry.element_count;
992            }
993
994            let coverage = if total_elements == 0 {
995                0.0
996            } else {
997                total_valid as f64 / total_elements as f64
998            };
999
1000            let age_ticks = match engine_tick {
1001                Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
1002                None => 0,
1003            };
1004
1005            metadata.push(ObsMetadata {
1006                tick_id: snapshot.tick_id(),
1007                age_ticks,
1008                coverage,
1009                world_generation_id: snapshot.world_generation_id(),
1010                parameter_version: snapshot.parameter_version(),
1011            });
1012        }
1013
1014        Ok(metadata)
1015    }
1016
1017    /// Whether this plan requires `execute_agents` (Standard) or `execute` (Simple).
1018    pub fn is_standard(&self) -> bool {
1019        matches!(self.strategy, PlanStrategy::Standard(_))
1020    }
1021
1022    /// Execute pre-compiled Simple plan entries into caller-provided buffers.
1023    fn execute_simple_entries(
1024        entries: &[CompiledEntry],
1025        snapshot: &dyn SnapshotAccess,
1026        output: &mut [f32],
1027        mask: &mut [u8],
1028    ) -> Result<(), ObsError> {
1029        for entry in entries {
1030            let field_data =
1031                snapshot
1032                    .read_field(entry.field_id)
1033                    .ok_or_else(|| ObsError::ExecutionFailed {
1034                        reason: format!("field {:?} not in snapshot", entry.field_id),
1035                    })?;
1036
1037            let out_slice =
1038                &mut output[entry.output_offset..entry.output_offset + entry.element_count];
1039            let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1040
1041            // Initialize to zero/padding.
1042            out_slice.fill(0.0);
1043            mask_slice.copy_from_slice(&entry.valid_mask);
1044
1045            // Branch-free gather: pre-computed (field_data_idx, tensor_idx) pairs.
1046            for op in &entry.gather_ops {
1047                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
1048                    ObsError::ExecutionFailed {
1049                        reason: format!(
1050                            "field {:?} has {} elements but gather requires index {}",
1051                            entry.field_id,
1052                            field_data.len(),
1053                            op.field_data_idx,
1054                        ),
1055                    }
1056                })?;
1057                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
1058            }
1059        }
1060        Ok(())
1061    }
1062}
1063
1064/// Execute a single agent-relative entry for one agent.
1065///
1066/// For pooled entries, `pool_scratch` and `pool_scratch_mask` must be
1067/// provided with sufficient capacity (zeroed by the caller). For
1068/// non-pooled entries these are ignored.
1069///
1070/// Returns the number of valid cells written.
1071#[allow(clippy::too_many_arguments)]
1072fn execute_agent_entry(
1073    entry: &AgentCompiledEntry,
1074    center: &Coord,
1075    field_data: &[f32],
1076    geometry: &Option<GridGeometry>,
1077    space: &dyn Space,
1078    use_fast_path: bool,
1079    agent_output: &mut [f32],
1080    agent_mask: &mut [u8],
1081    pool_scratch: &mut [f32],
1082    pool_scratch_mask: &mut [u8],
1083    pooled_scratch: &mut [f32],
1084    pooled_scratch_mask: &mut [u8],
1085) -> Result<usize, ObsError> {
1086    if entry.pool.is_some() {
1087        execute_agent_entry_pooled(
1088            entry,
1089            center,
1090            field_data,
1091            geometry,
1092            space,
1093            use_fast_path,
1094            agent_output,
1095            agent_mask,
1096            &mut pool_scratch[..entry.pre_pool_element_count],
1097            &mut pool_scratch_mask[..entry.pre_pool_element_count],
1098            &mut pooled_scratch[..entry.element_count],
1099            &mut pooled_scratch_mask[..entry.element_count],
1100        )
1101    } else {
1102        Ok(execute_agent_entry_direct(
1103            entry,
1104            center,
1105            field_data,
1106            geometry,
1107            space,
1108            use_fast_path,
1109            agent_output,
1110            agent_mask,
1111        ))
1112    }
1113}
1114
1115/// Direct gather (no pooling): gather + transform → output.
1116#[allow(clippy::too_many_arguments)]
1117fn execute_agent_entry_direct(
1118    entry: &AgentCompiledEntry,
1119    center: &Coord,
1120    field_data: &[f32],
1121    geometry: &Option<GridGeometry>,
1122    space: &dyn Space,
1123    use_fast_path: bool,
1124    agent_output: &mut [f32],
1125    agent_mask: &mut [u8],
1126) -> usize {
1127    let out_slice =
1128        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1129    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1130
1131    if use_fast_path {
1132        // FAST PATH: all cells in-bounds, branchless stride arithmetic.
1133        let geo = geometry.as_ref().unwrap();
1134        let base_rank = geo.canonical_rank(center) as isize;
1135        let mut valid = 0;
1136        for op in &entry.active_ops {
1137            let field_idx = (base_rank + op.stride_offset) as usize;
1138            if let Some(&val) = field_data.get(field_idx) {
1139                out_slice[op.tensor_idx] = apply_transform(val, &entry.transform);
1140                mask_slice[op.tensor_idx] = 1;
1141                valid += 1;
1142            }
1143        }
1144        valid
1145    } else {
1146        // SLOW PATH: bounds-check each offset (or modular wrap for torus).
1147        let mut valid = 0;
1148        for op in &entry.active_ops {
1149            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1150            if let Some(idx) = field_idx {
1151                if idx < field_data.len() {
1152                    out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
1153                    mask_slice[op.tensor_idx] = 1;
1154                    valid += 1;
1155                }
1156            }
1157        }
1158        valid
1159    }
1160}
1161
1162/// Pooled gather: gather → scratch → pool → transform → output.
1163///
1164/// `scratch` and `scratch_mask` are caller-provided buffers that must be
1165/// at least `entry.pre_pool_element_count` long. They are zeroed by the
1166/// caller before each invocation. This avoids per-agent heap allocation
1167/// when processing many agents sequentially (bug #83).
1168#[allow(clippy::too_many_arguments)]
1169fn execute_agent_entry_pooled(
1170    entry: &AgentCompiledEntry,
1171    center: &Coord,
1172    field_data: &[f32],
1173    geometry: &Option<GridGeometry>,
1174    space: &dyn Space,
1175    use_fast_path: bool,
1176    agent_output: &mut [f32],
1177    agent_mask: &mut [u8],
1178    scratch: &mut [f32],
1179    scratch_mask: &mut [u8],
1180    pooled: &mut [f32],
1181    pooled_mask: &mut [u8],
1182) -> Result<usize, ObsError> {
1183    if use_fast_path {
1184        let geo = geometry.as_ref().unwrap();
1185        let base_rank = geo.canonical_rank(center) as isize;
1186        for op in &entry.active_ops {
1187            let field_idx = (base_rank + op.stride_offset) as usize;
1188            if let Some(&val) = field_data.get(field_idx) {
1189                scratch[op.tensor_idx] = val;
1190                scratch_mask[op.tensor_idx] = 1;
1191            }
1192        }
1193    } else {
1194        for op in &entry.active_ops {
1195            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1196            if let Some(idx) = field_idx {
1197                if idx < field_data.len() {
1198                    scratch[op.tensor_idx] = field_data[idx];
1199                    scratch_mask[op.tensor_idx] = 1;
1200                }
1201            }
1202        }
1203    }
1204
1205    let pool_config = entry.pool.as_ref().unwrap();
1206    let (out_h, out_w) = pool_2d_into(
1207        scratch,
1208        scratch_mask,
1209        &entry.pre_pool_shape,
1210        pool_config,
1211        pooled,
1212        pooled_mask,
1213    )?;
1214
1215    let out_slice =
1216        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1217    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1218
1219    let n = (out_h * out_w).min(entry.element_count);
1220    for i in 0..n {
1221        out_slice[i] = apply_transform(pooled[i], &entry.transform);
1222    }
1223    mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1224
1225    Ok(pooled_mask[..n].iter().filter(|&&v| v == 1).count())
1226}
1227
1228/// Generate template operations for a rectangular bounding box.
1229///
1230/// `half_extent[d]` is the half-size per dimension. The bounding box is
1231/// `(2*he[0]+1) × (2*he[1]+1) × ...` in row-major order.
1232///
1233/// If `strides` is provided (from `GridGeometry`), each op gets a precomputed
1234/// `stride_offset` for the interior fast path.
1235///
1236/// If `disk_radius` is `Some(r)`, cells with graph distance > `r` are marked
1237/// `in_disk = false`. The `geometry` is required to compute graph distance.
1238/// When `geometry` is `None`, all cells are treated as in-disk (conservative).
1239fn generate_template_ops(
1240    half_extent: &[u32],
1241    geometry: &Option<GridGeometry>,
1242    disk_radius: Option<u32>,
1243) -> Result<Vec<TemplateOp>, ObsError> {
1244    let ndim = half_extent.len();
1245    let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1246    let total: usize = shape.iter().product();
1247
1248    let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1249
1250    let mut ops = Vec::with_capacity(total);
1251
1252    for tensor_idx in 0..total {
1253        // Decompose tensor_idx into n-d relative coords (row-major).
1254        let mut relative = Coord::new();
1255        let mut remaining = tensor_idx;
1256        // Build in reverse order, then reverse.
1257        for d in (0..ndim).rev() {
1258            let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1259            relative.push(coord);
1260            remaining /= shape[d];
1261        }
1262        relative.reverse();
1263
1264        let stride_offset = strides
1265            .map(|s| {
1266                relative
1267                    .iter()
1268                    .zip(s.iter())
1269                    .map(|(&r, &s)| r as isize * s as isize)
1270                    .sum::<isize>()
1271            })
1272            .unwrap_or(0);
1273
1274        let in_disk = match disk_radius {
1275            Some(r) => match geometry {
1276                Some(geo) => geo.graph_distance(&relative)? <= r,
1277                None => true, // no geometry → conservative (include all)
1278            },
1279            None => true, // AgentRect → all cells valid
1280        };
1281
1282        ops.push(TemplateOp {
1283            relative,
1284            tensor_idx,
1285            stride_offset,
1286            in_disk,
1287        });
1288    }
1289
1290    Ok(ops)
1291}
1292
1293/// Resolve the field data index for an absolute coordinate.
1294///
1295/// Handles three cases:
1296/// 1. Torus (all_wrap): modular wrap, always in-bounds.
1297/// 2. Grid with geometry: bounds-check then stride arithmetic.
1298/// 3. No geometry: fall back to `space.canonical_rank()`.
1299fn resolve_field_index(
1300    center: &Coord,
1301    relative: &Coord,
1302    geometry: &Option<GridGeometry>,
1303    space: &dyn Space,
1304) -> Option<usize> {
1305    if let Some(geo) = geometry {
1306        if geo.all_wrap {
1307            // Torus: wrap coordinates with modular arithmetic.
1308            let wrapped: Coord = center
1309                .iter()
1310                .zip(relative.iter())
1311                .zip(geo.coord_dims.iter())
1312                .map(|((&c, &r), &d)| {
1313                    let d = d as i32;
1314                    ((c + r) % d + d) % d
1315                })
1316                .collect();
1317            Some(geo.canonical_rank(&wrapped))
1318        } else {
1319            let abs_coord: Coord = center
1320                .iter()
1321                .zip(relative.iter())
1322                .map(|(&c, &r)| c + r)
1323                .collect();
1324            let abs_slice: &[i32] = &abs_coord;
1325            if geo.in_bounds(abs_slice) {
1326                Some(geo.canonical_rank(abs_slice))
1327            } else {
1328                None
1329            }
1330        }
1331    } else {
1332        let abs_coord: Coord = center
1333            .iter()
1334            .zip(relative.iter())
1335            .map(|(&c, &r)| c + r)
1336            .collect();
1337        space.canonical_rank(&abs_coord)
1338    }
1339}
1340
1341/// Apply a transform to a raw field value.
1342fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1343    match transform {
1344        ObsTransform::Identity => raw,
1345        ObsTransform::Normalize { min, max } => {
1346            let range = max - min;
1347            if range == 0.0 {
1348                0.0
1349            } else {
1350                let normalized = (raw as f64 - min) / range;
1351                normalized.clamp(0.0, 1.0) as f32
1352            }
1353        }
1354    }
1355}
1356
1357#[cfg(test)]
1358mod tests {
1359    use super::*;
1360    use crate::spec::{
1361        ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1362    };
1363    use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1364    use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1365    use murk_test_utils::MockSnapshot;
1366
1367    fn square4_space() -> Square4 {
1368        Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1369    }
1370
1371    fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1372        let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1373        snap.set_field(field, data);
1374        snap
1375    }
1376
1377    // ── Compilation tests ────────────────────────────────────
1378
1379    #[test]
1380    fn compile_empty_spec_errors() {
1381        let space = square4_space();
1382        let spec = ObsSpec { entries: vec![] };
1383        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1384        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1385    }
1386
1387    #[test]
1388    fn compile_all_region_square4() {
1389        let space = square4_space();
1390        let spec = ObsSpec {
1391            entries: vec![ObsEntry {
1392                field_id: FieldId(0),
1393                region: ObsRegion::Fixed(RegionSpec::All),
1394                pool: None,
1395                transform: ObsTransform::Identity,
1396                dtype: ObsDtype::F32,
1397            }],
1398        };
1399        let result = ObsPlan::compile(&spec, &space).unwrap();
1400        assert_eq!(result.output_len, 9); // 3x3
1401        assert_eq!(result.mask_len, 9);
1402        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1403    }
1404
1405    #[test]
1406    fn compile_rect_region() {
1407        let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1408        let spec = ObsSpec {
1409            entries: vec![ObsEntry {
1410                field_id: FieldId(0),
1411                region: ObsRegion::Fixed(RegionSpec::Rect {
1412                    min: smallvec::smallvec![1, 1],
1413                    max: smallvec::smallvec![2, 3],
1414                }),
1415                pool: None,
1416                transform: ObsTransform::Identity,
1417                dtype: ObsDtype::F32,
1418            }],
1419        };
1420        let result = ObsPlan::compile(&spec, &space).unwrap();
1421        // 2 rows x 3 cols = 6 cells
1422        assert_eq!(result.output_len, 6);
1423        assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1424    }
1425
1426    #[test]
1427    fn compile_two_entries_offsets() {
1428        let space = square4_space();
1429        let spec = ObsSpec {
1430            entries: vec![
1431                ObsEntry {
1432                    field_id: FieldId(0),
1433                    region: ObsRegion::Fixed(RegionSpec::All),
1434                    pool: None,
1435                    transform: ObsTransform::Identity,
1436                    dtype: ObsDtype::F32,
1437                },
1438                ObsEntry {
1439                    field_id: FieldId(1),
1440                    region: ObsRegion::Fixed(RegionSpec::All),
1441                    pool: None,
1442                    transform: ObsTransform::Identity,
1443                    dtype: ObsDtype::F32,
1444                },
1445            ],
1446        };
1447        let result = ObsPlan::compile(&spec, &space).unwrap();
1448        assert_eq!(result.output_len, 18); // 9 + 9
1449        assert_eq!(result.mask_len, 18);
1450    }
1451
1452    #[test]
1453    fn compile_invalid_region_errors() {
1454        let space = square4_space();
1455        let spec = ObsSpec {
1456            entries: vec![ObsEntry {
1457                field_id: FieldId(0),
1458                region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1459                pool: None,
1460                transform: ObsTransform::Identity,
1461                dtype: ObsDtype::F32,
1462            }],
1463        };
1464        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1465        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1466    }
1467
1468    #[test]
1469    fn compile_agent_rect_wrong_ndim_errors() {
1470        let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap(); // 2D
1471        let spec = ObsSpec {
1472            entries: vec![ObsEntry {
1473                field_id: FieldId(0),
1474                region: ObsRegion::AgentRect {
1475                    half_extent: smallvec::smallvec![1], // 1D on 2D space
1476                },
1477                pool: None,
1478                transform: ObsTransform::Identity,
1479                dtype: ObsDtype::F32,
1480            }],
1481        };
1482        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1483        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1484    }
1485
1486    #[test]
1487    fn compile_agent_rect_correct_ndim_ok() {
1488        let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap(); // 2D
1489        let spec = ObsSpec {
1490            entries: vec![ObsEntry {
1491                field_id: FieldId(0),
1492                region: ObsRegion::AgentRect {
1493                    half_extent: smallvec::smallvec![1, 2], // 2D on 2D space
1494                },
1495                pool: None,
1496                transform: ObsTransform::Identity,
1497                dtype: ObsDtype::F32,
1498            }],
1499        };
1500        assert!(ObsPlan::compile(&spec, &space).is_ok());
1501    }
1502
1503    // ── Execution tests ──────────────────────────────────────
1504
1505    #[test]
1506    fn execute_identity_all_region() {
1507        let space = square4_space();
1508        // Field data in canonical (row-major) order for 3x3:
1509        // (0,0)=1, (0,1)=2, (0,2)=3, (1,0)=4, ..., (2,2)=9
1510        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1511        let snap = snapshot_with_field(FieldId(0), data);
1512
1513        let spec = ObsSpec {
1514            entries: vec![ObsEntry {
1515                field_id: FieldId(0),
1516                region: ObsRegion::Fixed(RegionSpec::All),
1517                pool: None,
1518                transform: ObsTransform::Identity,
1519                dtype: ObsDtype::F32,
1520            }],
1521        };
1522        let result = ObsPlan::compile(&spec, &space).unwrap();
1523
1524        let mut output = vec![0.0f32; result.output_len];
1525        let mut mask = vec![0u8; result.mask_len];
1526        let meta = result
1527            .plan
1528            .execute(&snap, None, &mut output, &mut mask)
1529            .unwrap();
1530
1531        // Output should match field data in canonical order.
1532        let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1533        assert_eq!(output, expected);
1534        assert_eq!(mask, vec![1u8; 9]);
1535        assert_eq!(meta.tick_id, TickId(5));
1536        assert_eq!(meta.coverage, 1.0);
1537        assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1538        assert_eq!(meta.parameter_version, ParameterVersion(0));
1539        assert_eq!(meta.age_ticks, 0);
1540    }
1541
1542    #[test]
1543    fn execute_normalize_transform() {
1544        let space = square4_space();
1545        // Values 0..8 mapped to [0,1] with min=0, max=8.
1546        let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1547        let snap = snapshot_with_field(FieldId(0), data);
1548
1549        let spec = ObsSpec {
1550            entries: vec![ObsEntry {
1551                field_id: FieldId(0),
1552                region: ObsRegion::Fixed(RegionSpec::All),
1553                pool: None,
1554                transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1555                dtype: ObsDtype::F32,
1556            }],
1557        };
1558        let result = ObsPlan::compile(&spec, &space).unwrap();
1559
1560        let mut output = vec![0.0f32; result.output_len];
1561        let mut mask = vec![0u8; result.mask_len];
1562        result
1563            .plan
1564            .execute(&snap, None, &mut output, &mut mask)
1565            .unwrap();
1566
1567        // Each value x should be x/8.
1568        for (i, &v) in output.iter().enumerate() {
1569            let expected = i as f32 / 8.0;
1570            assert!(
1571                (v - expected).abs() < 1e-6,
1572                "output[{i}] = {v}, expected {expected}"
1573            );
1574        }
1575    }
1576
1577    #[test]
1578    fn execute_normalize_clamps_out_of_range() {
1579        let space = square4_space();
1580        // Values -5, 0, 5, 10, 15 etc.
1581        let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1582        let snap = snapshot_with_field(FieldId(0), data);
1583
1584        let spec = ObsSpec {
1585            entries: vec![ObsEntry {
1586                field_id: FieldId(0),
1587                region: ObsRegion::Fixed(RegionSpec::All),
1588                pool: None,
1589                transform: ObsTransform::Normalize {
1590                    min: 0.0,
1591                    max: 10.0,
1592                },
1593                dtype: ObsDtype::F32,
1594            }],
1595        };
1596        let result = ObsPlan::compile(&spec, &space).unwrap();
1597
1598        let mut output = vec![0.0f32; result.output_len];
1599        let mut mask = vec![0u8; result.mask_len];
1600        result
1601            .plan
1602            .execute(&snap, None, &mut output, &mut mask)
1603            .unwrap();
1604
1605        for &v in &output {
1606            assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1607        }
1608    }
1609
1610    #[test]
1611    fn execute_normalize_zero_range() {
1612        let space = square4_space();
1613        let data = vec![5.0f32; 9];
1614        let snap = snapshot_with_field(FieldId(0), data);
1615
1616        let spec = ObsSpec {
1617            entries: vec![ObsEntry {
1618                field_id: FieldId(0),
1619                region: ObsRegion::Fixed(RegionSpec::All),
1620                pool: None,
1621                transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1622                dtype: ObsDtype::F32,
1623            }],
1624        };
1625        let result = ObsPlan::compile(&spec, &space).unwrap();
1626
1627        let mut output = vec![-1.0f32; result.output_len];
1628        let mut mask = vec![0u8; result.mask_len];
1629        result
1630            .plan
1631            .execute(&snap, None, &mut output, &mut mask)
1632            .unwrap();
1633
1634        // Zero range → all outputs 0.0.
1635        assert!(output.iter().all(|&v| v == 0.0));
1636    }
1637
1638    #[test]
1639    fn execute_rect_subregion_correct_values() {
1640        let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1641        // 4x4 field: value = row * 4 + col + 1.
1642        let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1643        let snap = snapshot_with_field(FieldId(0), data);
1644
1645        let spec = ObsSpec {
1646            entries: vec![ObsEntry {
1647                field_id: FieldId(0),
1648                region: ObsRegion::Fixed(RegionSpec::Rect {
1649                    min: smallvec::smallvec![1, 1],
1650                    max: smallvec::smallvec![2, 2],
1651                }),
1652                pool: None,
1653                transform: ObsTransform::Identity,
1654                dtype: ObsDtype::F32,
1655            }],
1656        };
1657        let result = ObsPlan::compile(&spec, &space).unwrap();
1658        assert_eq!(result.output_len, 4); // 2x2
1659
1660        let mut output = vec![0.0f32; result.output_len];
1661        let mut mask = vec![0u8; result.mask_len];
1662        result
1663            .plan
1664            .execute(&snap, None, &mut output, &mut mask)
1665            .unwrap();
1666
1667        // Rect covers (1,1)=6, (1,2)=7, (2,1)=10, (2,2)=11
1668        assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1669        assert_eq!(mask, vec![1, 1, 1, 1]);
1670    }
1671
1672    #[test]
1673    fn execute_two_fields() {
1674        let space = square4_space();
1675        let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1676        let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1677        let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1678        snap.set_field(FieldId(0), data_a);
1679        snap.set_field(FieldId(1), data_b);
1680
1681        let spec = ObsSpec {
1682            entries: vec![
1683                ObsEntry {
1684                    field_id: FieldId(0),
1685                    region: ObsRegion::Fixed(RegionSpec::All),
1686                    pool: None,
1687                    transform: ObsTransform::Identity,
1688                    dtype: ObsDtype::F32,
1689                },
1690                ObsEntry {
1691                    field_id: FieldId(1),
1692                    region: ObsRegion::Fixed(RegionSpec::All),
1693                    pool: None,
1694                    transform: ObsTransform::Identity,
1695                    dtype: ObsDtype::F32,
1696                },
1697            ],
1698        };
1699        let result = ObsPlan::compile(&spec, &space).unwrap();
1700        assert_eq!(result.output_len, 18);
1701
1702        let mut output = vec![0.0f32; result.output_len];
1703        let mut mask = vec![0u8; result.mask_len];
1704        result
1705            .plan
1706            .execute(&snap, None, &mut output, &mut mask)
1707            .unwrap();
1708
1709        // First 9: field 0, next 9: field 1.
1710        let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1711        let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1712        assert_eq!(&output[..9], &expected_a);
1713        assert_eq!(&output[9..], &expected_b);
1714    }
1715
1716    #[test]
1717    fn execute_missing_field_errors() {
1718        let space = square4_space();
1719        let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1720
1721        let spec = ObsSpec {
1722            entries: vec![ObsEntry {
1723                field_id: FieldId(0),
1724                region: ObsRegion::Fixed(RegionSpec::All),
1725                pool: None,
1726                transform: ObsTransform::Identity,
1727                dtype: ObsDtype::F32,
1728            }],
1729        };
1730        let result = ObsPlan::compile(&spec, &space).unwrap();
1731
1732        let mut output = vec![0.0f32; result.output_len];
1733        let mut mask = vec![0u8; result.mask_len];
1734        let err = result
1735            .plan
1736            .execute(&snap, None, &mut output, &mut mask)
1737            .unwrap_err();
1738        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1739    }
1740
1741    #[test]
1742    fn execute_buffer_too_small_errors() {
1743        let space = square4_space();
1744        let data: Vec<f32> = vec![0.0; 9];
1745        let snap = snapshot_with_field(FieldId(0), data);
1746
1747        let spec = ObsSpec {
1748            entries: vec![ObsEntry {
1749                field_id: FieldId(0),
1750                region: ObsRegion::Fixed(RegionSpec::All),
1751                pool: None,
1752                transform: ObsTransform::Identity,
1753                dtype: ObsDtype::F32,
1754            }],
1755        };
1756        let result = ObsPlan::compile(&spec, &space).unwrap();
1757
1758        let mut output = vec![0.0f32; 4]; // too small
1759        let mut mask = vec![0u8; result.mask_len];
1760        let err = result
1761            .plan
1762            .execute(&snap, None, &mut output, &mut mask)
1763            .unwrap_err();
1764        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1765    }
1766
1767    // ── Validity / coverage tests ────────────────────────────
1768
1769    #[test]
1770    fn valid_ratio_one_for_square_all() {
1771        let space = square4_space();
1772        let data: Vec<f32> = vec![1.0; 9];
1773        let snap = snapshot_with_field(FieldId(0), data);
1774
1775        let spec = ObsSpec {
1776            entries: vec![ObsEntry {
1777                field_id: FieldId(0),
1778                region: ObsRegion::Fixed(RegionSpec::All),
1779                pool: None,
1780                transform: ObsTransform::Identity,
1781                dtype: ObsDtype::F32,
1782            }],
1783        };
1784        let result = ObsPlan::compile(&spec, &space).unwrap();
1785
1786        let mut output = vec![0.0f32; result.output_len];
1787        let mut mask = vec![0u8; result.mask_len];
1788        let meta = result
1789            .plan
1790            .execute(&snap, None, &mut output, &mut mask)
1791            .unwrap();
1792
1793        assert_eq!(meta.coverage, 1.0);
1794    }
1795
1796    // ── Generation binding tests ─────────────────────────────
1797
1798    #[test]
1799    fn plan_invalidated_on_generation_mismatch() {
1800        let space = square4_space();
1801        let data: Vec<f32> = vec![1.0; 9];
1802        let snap = snapshot_with_field(FieldId(0), data);
1803
1804        let spec = ObsSpec {
1805            entries: vec![ObsEntry {
1806                field_id: FieldId(0),
1807                region: ObsRegion::Fixed(RegionSpec::All),
1808                pool: None,
1809                transform: ObsTransform::Identity,
1810                dtype: ObsDtype::F32,
1811            }],
1812        };
1813        // Compile bound to generation 99, but snapshot is generation 1.
1814        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1815
1816        let mut output = vec![0.0f32; result.output_len];
1817        let mut mask = vec![0u8; result.mask_len];
1818        let err = result
1819            .plan
1820            .execute(&snap, None, &mut output, &mut mask)
1821            .unwrap_err();
1822        assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1823    }
1824
1825    #[test]
1826    fn generation_match_succeeds() {
1827        let space = square4_space();
1828        let data: Vec<f32> = vec![1.0; 9];
1829        let snap = snapshot_with_field(FieldId(0), data);
1830
1831        let spec = ObsSpec {
1832            entries: vec![ObsEntry {
1833                field_id: FieldId(0),
1834                region: ObsRegion::Fixed(RegionSpec::All),
1835                pool: None,
1836                transform: ObsTransform::Identity,
1837                dtype: ObsDtype::F32,
1838            }],
1839        };
1840        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1841
1842        let mut output = vec![0.0f32; result.output_len];
1843        let mut mask = vec![0u8; result.mask_len];
1844        result
1845            .plan
1846            .execute(&snap, None, &mut output, &mut mask)
1847            .unwrap();
1848    }
1849
1850    #[test]
1851    fn unbound_plan_ignores_generation() {
1852        let space = square4_space();
1853        let data: Vec<f32> = vec![1.0; 9];
1854        let snap = snapshot_with_field(FieldId(0), data);
1855
1856        let spec = ObsSpec {
1857            entries: vec![ObsEntry {
1858                field_id: FieldId(0),
1859                region: ObsRegion::Fixed(RegionSpec::All),
1860                pool: None,
1861                transform: ObsTransform::Identity,
1862                dtype: ObsDtype::F32,
1863            }],
1864        };
1865        // Unbound plan — no generation check.
1866        let result = ObsPlan::compile(&spec, &space).unwrap();
1867
1868        let mut output = vec![0.0f32; result.output_len];
1869        let mut mask = vec![0u8; result.mask_len];
1870        result
1871            .plan
1872            .execute(&snap, None, &mut output, &mut mask)
1873            .unwrap();
1874    }
1875
1876    // ── Metadata tests ───────────────────────────────────────
1877
1878    #[test]
1879    fn metadata_fields_populated() {
1880        let space = square4_space();
1881        let data: Vec<f32> = vec![1.0; 9];
1882        let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1883        snap.set_field(FieldId(0), data);
1884
1885        let spec = ObsSpec {
1886            entries: vec![ObsEntry {
1887                field_id: FieldId(0),
1888                region: ObsRegion::Fixed(RegionSpec::All),
1889                pool: None,
1890                transform: ObsTransform::Identity,
1891                dtype: ObsDtype::F32,
1892            }],
1893        };
1894        let result = ObsPlan::compile(&spec, &space).unwrap();
1895
1896        let mut output = vec![0.0f32; result.output_len];
1897        let mut mask = vec![0u8; result.mask_len];
1898        let meta = result
1899            .plan
1900            .execute(&snap, None, &mut output, &mut mask)
1901            .unwrap();
1902
1903        assert_eq!(meta.tick_id, TickId(42));
1904        assert_eq!(meta.age_ticks, 0);
1905        assert_eq!(meta.coverage, 1.0);
1906        assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1907        assert_eq!(meta.parameter_version, ParameterVersion(3));
1908    }
1909
1910    // ── Batch execution tests ────────────────────────────────
1911
1912    #[test]
1913    fn execute_batch_n1_matches_execute() {
1914        let space = square4_space();
1915        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1916        let snap = snapshot_with_field(FieldId(0), data.clone());
1917
1918        let spec = ObsSpec {
1919            entries: vec![ObsEntry {
1920                field_id: FieldId(0),
1921                region: ObsRegion::Fixed(RegionSpec::All),
1922                pool: None,
1923                transform: ObsTransform::Identity,
1924                dtype: ObsDtype::F32,
1925            }],
1926        };
1927        let result = ObsPlan::compile(&spec, &space).unwrap();
1928
1929        // Single execute.
1930        let mut out_single = vec![0.0f32; result.output_len];
1931        let mut mask_single = vec![0u8; result.mask_len];
1932        let meta_single = result
1933            .plan
1934            .execute(&snap, None, &mut out_single, &mut mask_single)
1935            .unwrap();
1936
1937        // Batch N=1.
1938        let mut out_batch = vec![0.0f32; result.output_len];
1939        let mut mask_batch = vec![0u8; result.mask_len];
1940        let snap_ref: &dyn SnapshotAccess = &snap;
1941        let meta_batch = result
1942            .plan
1943            .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1944            .unwrap();
1945
1946        assert_eq!(out_single, out_batch);
1947        assert_eq!(mask_single, mask_batch);
1948        assert_eq!(meta_single, meta_batch[0]);
1949    }
1950
1951    #[test]
1952    fn execute_batch_multiple_snapshots() {
1953        let space = square4_space();
1954        let spec = ObsSpec {
1955            entries: vec![ObsEntry {
1956                field_id: FieldId(0),
1957                region: ObsRegion::Fixed(RegionSpec::All),
1958                pool: None,
1959                transform: ObsTransform::Identity,
1960                dtype: ObsDtype::F32,
1961            }],
1962        };
1963        let result = ObsPlan::compile(&spec, &space).unwrap();
1964
1965        let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1966        let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1967
1968        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1969        let mut output = vec![0.0f32; result.output_len * 2];
1970        let mut mask = vec![0u8; result.mask_len * 2];
1971        let metas = result
1972            .plan
1973            .execute_batch(&snaps, None, &mut output, &mut mask)
1974            .unwrap();
1975
1976        assert_eq!(metas.len(), 2);
1977        assert!(output[..9].iter().all(|&v| v == 1.0));
1978        assert!(output[9..].iter().all(|&v| v == 2.0));
1979    }
1980
1981    #[test]
1982    fn execute_batch_buffer_too_small() {
1983        let space = square4_space();
1984        let spec = ObsSpec {
1985            entries: vec![ObsEntry {
1986                field_id: FieldId(0),
1987                region: ObsRegion::Fixed(RegionSpec::All),
1988                pool: None,
1989                transform: ObsTransform::Identity,
1990                dtype: ObsDtype::F32,
1991            }],
1992        };
1993        let result = ObsPlan::compile(&spec, &space).unwrap();
1994
1995        let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1996        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1997        let mut output = vec![0.0f32; 9]; // need 18
1998        let mut mask = vec![0u8; 18];
1999        let err = result
2000            .plan
2001            .execute_batch(&snaps, None, &mut output, &mut mask)
2002            .unwrap_err();
2003        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2004    }
2005
2006    #[test]
2007    fn execute_batch_rejects_mismatched_generation() {
2008        let space = square4_space();
2009        let spec = ObsSpec {
2010            entries: vec![ObsEntry {
2011                field_id: FieldId(0),
2012                region: ObsRegion::Fixed(RegionSpec::All),
2013                pool: None,
2014                transform: ObsTransform::Identity,
2015                dtype: ObsDtype::F32,
2016            }],
2017        };
2018        // Compile bound to generation 5.
2019        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(5)).unwrap();
2020
2021        // Snapshot has generation 1 (default from snapshot_with_field).
2022        let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
2023        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap];
2024        let mut output = vec![0.0f32; result.output_len];
2025        let mut mask = vec![0u8; result.mask_len];
2026        let err = result
2027            .plan
2028            .execute_batch(&snaps, None, &mut output, &mut mask)
2029            .unwrap_err();
2030        assert!(matches!(err, ObsError::PlanInvalidated { .. }));
2031    }
2032
2033    // ── Field length mismatch tests ──────────────────────────
2034
2035    #[test]
2036    fn short_field_buffer_returns_error_not_panic() {
2037        let space = square4_space(); // 3x3 = 9 cells
2038        let spec = ObsSpec {
2039            entries: vec![ObsEntry {
2040                field_id: FieldId(0),
2041                region: ObsRegion::Fixed(RegionSpec::All),
2042                pool: None,
2043                transform: ObsTransform::Identity,
2044                dtype: ObsDtype::F32,
2045            }],
2046        };
2047        let result = ObsPlan::compile(&spec, &space).unwrap();
2048
2049        // Snapshot field has only 4 elements, but plan expects 9.
2050        let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
2051        let mut output = vec![0.0f32; result.output_len];
2052        let mut mask = vec![0u8; result.mask_len];
2053        let err = result
2054            .plan
2055            .execute(&snap, None, &mut output, &mut mask)
2056            .unwrap_err();
2057        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2058    }
2059
2060    // ── Standard plan (agent-centered) tests ─────────────────
2061
2062    #[test]
2063    fn standard_plan_detected_from_agent_region() {
2064        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2065        let spec = ObsSpec {
2066            entries: vec![ObsEntry {
2067                field_id: FieldId(0),
2068                region: ObsRegion::AgentRect {
2069                    half_extent: smallvec::smallvec![2, 2],
2070                },
2071                pool: None,
2072                transform: ObsTransform::Identity,
2073                dtype: ObsDtype::F32,
2074            }],
2075        };
2076        let result = ObsPlan::compile(&spec, &space).unwrap();
2077        assert!(result.plan.is_standard());
2078        // 5x5 = 25 elements
2079        assert_eq!(result.output_len, 25);
2080        assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
2081    }
2082
2083    #[test]
2084    fn execute_on_standard_plan_errors() {
2085        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2086        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2087        let snap = snapshot_with_field(FieldId(0), data);
2088
2089        let spec = ObsSpec {
2090            entries: vec![ObsEntry {
2091                field_id: FieldId(0),
2092                region: ObsRegion::AgentDisk { radius: 2 },
2093                pool: None,
2094                transform: ObsTransform::Identity,
2095                dtype: ObsDtype::F32,
2096            }],
2097        };
2098        let result = ObsPlan::compile(&spec, &space).unwrap();
2099
2100        let mut output = vec![0.0f32; result.output_len];
2101        let mut mask = vec![0u8; result.mask_len];
2102        let err = result
2103            .plan
2104            .execute(&snap, None, &mut output, &mut mask)
2105            .unwrap_err();
2106        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
2107    }
2108
2109    #[test]
2110    fn interior_boundary_equivalence() {
2111        // An INTERIOR agent using Standard plan should produce identical
2112        // output to a Simple plan with an explicit Rect at the same position.
2113        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2114        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2115        let snap = snapshot_with_field(FieldId(0), data);
2116
2117        let radius = 3u32;
2118        let center: Coord = smallvec::smallvec![10, 10]; // interior
2119
2120        // Standard plan: AgentRect centered on agent.
2121        let standard_spec = ObsSpec {
2122            entries: vec![ObsEntry {
2123                field_id: FieldId(0),
2124                region: ObsRegion::AgentRect {
2125                    half_extent: smallvec::smallvec![radius, radius],
2126                },
2127                pool: None,
2128                transform: ObsTransform::Identity,
2129                dtype: ObsDtype::F32,
2130            }],
2131        };
2132        let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
2133        let mut std_output = vec![0.0f32; std_result.output_len];
2134        let mut std_mask = vec![0u8; std_result.mask_len];
2135        std_result
2136            .plan
2137            .execute_agents(
2138                &snap,
2139                &space,
2140                std::slice::from_ref(&center),
2141                None,
2142                &mut std_output,
2143                &mut std_mask,
2144            )
2145            .unwrap();
2146
2147        // Simple plan: explicit Rect covering the same area.
2148        let r = radius as i32;
2149        let simple_spec = ObsSpec {
2150            entries: vec![ObsEntry {
2151                field_id: FieldId(0),
2152                region: ObsRegion::Fixed(RegionSpec::Rect {
2153                    min: smallvec::smallvec![10 - r, 10 - r],
2154                    max: smallvec::smallvec![10 + r, 10 + r],
2155                }),
2156                pool: None,
2157                transform: ObsTransform::Identity,
2158                dtype: ObsDtype::F32,
2159            }],
2160        };
2161        let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
2162        let mut simple_output = vec![0.0f32; simple_result.output_len];
2163        let mut simple_mask = vec![0u8; simple_result.mask_len];
2164        simple_result
2165            .plan
2166            .execute(&snap, None, &mut simple_output, &mut simple_mask)
2167            .unwrap();
2168
2169        // Same shape, same values.
2170        assert_eq!(std_result.output_len, simple_result.output_len);
2171        assert_eq!(std_output, simple_output);
2172        assert_eq!(std_mask, simple_mask);
2173    }
2174
2175    #[test]
2176    fn boundary_agent_gets_padding() {
2177        // Agent at corner (0,0) with radius 2: many cells out-of-bounds.
2178        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2179        let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
2180        let snap = snapshot_with_field(FieldId(0), data);
2181
2182        let spec = ObsSpec {
2183            entries: vec![ObsEntry {
2184                field_id: FieldId(0),
2185                region: ObsRegion::AgentRect {
2186                    half_extent: smallvec::smallvec![2, 2],
2187                },
2188                pool: None,
2189                transform: ObsTransform::Identity,
2190                dtype: ObsDtype::F32,
2191            }],
2192        };
2193        let result = ObsPlan::compile(&spec, &space).unwrap();
2194        let center: Coord = smallvec::smallvec![0, 0];
2195        let mut output = vec![0.0f32; result.output_len];
2196        let mut mask = vec![0u8; result.mask_len];
2197        let metas = result
2198            .plan
2199            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2200            .unwrap();
2201
2202        // 5x5 = 25 cells total. Agent at (0,0) with radius 2:
2203        // Only cells with row in [0,2] and col in [0,2] are valid (3x3 = 9).
2204        let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
2205        assert_eq!(valid_count, 9);
2206
2207        // Coverage should be 9/25
2208        assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
2209
2210        // Check that the top-left corner of the bounding box (relative [-2,-2])
2211        // is padding (mask=0, value=0).
2212        assert_eq!(mask[0], 0); // relative (-2,-2) → absolute (-2,-2) → out of bounds
2213        assert_eq!(output[0], 0.0);
2214
2215        // The cell at relative (0,0) is at tensor_idx = 2*5+2 = 12
2216        // Absolute (0,0) → field value = 1.0
2217        assert_eq!(mask[12], 1);
2218        assert_eq!(output[12], 1.0);
2219    }
2220
2221    #[test]
2222    fn hex_foveation_interior() {
2223        // Test agent-centered observation on Hex2D grid.
2224        let space = Hex2D::new(20, 20).unwrap(); // 20 rows, 20 cols
2225        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2226        let snap = snapshot_with_field(FieldId(0), data);
2227
2228        let spec = ObsSpec {
2229            entries: vec![ObsEntry {
2230                field_id: FieldId(0),
2231                region: ObsRegion::AgentDisk { radius: 2 },
2232                pool: None,
2233                transform: ObsTransform::Identity,
2234                dtype: ObsDtype::F32,
2235            }],
2236        };
2237        let result = ObsPlan::compile(&spec, &space).unwrap();
2238        assert_eq!(result.output_len, 25); // 5x5 bounding box (tensor shape)
2239
2240        // Interior agent: q=10, r=10
2241        let center: Coord = smallvec::smallvec![10, 10];
2242        let mut output = vec![0.0f32; result.output_len];
2243        let mut mask = vec![0u8; result.mask_len];
2244        result
2245            .plan
2246            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2247            .unwrap();
2248
2249        // Hex disk of radius 2: 19 of 25 cells are within hex distance.
2250        // The 6 corners of the 5x5 bounding box exceed hex distance 2.
2251        // Hex distance = max(|dq|, |dr|, |dq+dr|) for axial coordinates.
2252        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2253        assert_eq!(valid_count, 19);
2254
2255        // Corners that should be masked out (distance > 2):
2256        // tensor_idx 0: dq=-2,dr=-2 → max(2,2,4)=4
2257        // tensor_idx 1: dq=-2,dr=-1 → max(2,1,3)=3
2258        // tensor_idx 5: dq=-1,dr=-2 → max(1,2,3)=3
2259        // tensor_idx 19: dq=+1,dr=+2 → max(1,2,3)=3
2260        // tensor_idx 23: dq=+2,dr=+1 → max(2,1,3)=3
2261        // tensor_idx 24: dq=+2,dr=+2 → max(2,2,4)=4
2262        for &idx in &[0, 1, 5, 19, 23, 24] {
2263            assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2264            assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2265        }
2266
2267        // Center cell is at tensor_idx = 2*5+2 = 12 (relative [0,0]).
2268        // Hex2D canonical_rank([q,r]) = r*cols + q = 10*20 + 10 = 210
2269        assert_eq!(output[12], 210.0);
2270
2271        // Cell at relative [1, 0] (dq=+1, dr=0) → absolute [11, 10]
2272        // rank = 10*20 + 11 = 211
2273        // In row-major bounding box: dim0_idx=3, dim1_idx=2 → tensor_idx = 3*5+2 = 17
2274        assert_eq!(output[17], 211.0);
2275    }
2276
2277    #[test]
2278    fn wrap_space_all_interior() {
2279        // Wrapped (torus) space: all agents are interior.
2280        let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2281        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2282        let snap = snapshot_with_field(FieldId(0), data);
2283
2284        let spec = ObsSpec {
2285            entries: vec![ObsEntry {
2286                field_id: FieldId(0),
2287                region: ObsRegion::AgentRect {
2288                    half_extent: smallvec::smallvec![2, 2],
2289                },
2290                pool: None,
2291                transform: ObsTransform::Identity,
2292                dtype: ObsDtype::F32,
2293            }],
2294        };
2295        let result = ObsPlan::compile(&spec, &space).unwrap();
2296
2297        // Agent at corner (0,0) — still interior on torus.
2298        let center: Coord = smallvec::smallvec![0, 0];
2299        let mut output = vec![0.0f32; result.output_len];
2300        let mut mask = vec![0u8; result.mask_len];
2301        result
2302            .plan
2303            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2304            .unwrap();
2305
2306        // All 25 cells valid (torus wraps).
2307        assert!(mask.iter().all(|&v| v == 1));
2308        assert_eq!(output[12], 0.0); // center (0,0) → rank 0
2309    }
2310
2311    #[test]
2312    fn execute_agents_multiple() {
2313        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2314        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2315        let snap = snapshot_with_field(FieldId(0), data);
2316
2317        let spec = ObsSpec {
2318            entries: vec![ObsEntry {
2319                field_id: FieldId(0),
2320                region: ObsRegion::AgentRect {
2321                    half_extent: smallvec::smallvec![1, 1],
2322                },
2323                pool: None,
2324                transform: ObsTransform::Identity,
2325                dtype: ObsDtype::F32,
2326            }],
2327        };
2328        let result = ObsPlan::compile(&spec, &space).unwrap();
2329        assert_eq!(result.output_len, 9); // 3x3
2330
2331        // Two agents: one interior, one at edge.
2332        let centers = vec![
2333            smallvec::smallvec![5, 5], // interior
2334            smallvec::smallvec![0, 5], // top edge
2335        ];
2336        let n = centers.len();
2337        let mut output = vec![0.0f32; result.output_len * n];
2338        let mut mask = vec![0u8; result.mask_len * n];
2339        let metas = result
2340            .plan
2341            .execute_agents(&snap, &space, &centers, None, &mut output, &mut mask)
2342            .unwrap();
2343
2344        assert_eq!(metas.len(), 2);
2345
2346        // Agent 0 (interior): all 9 cells valid, center = (5,5) → rank 55
2347        assert!(mask[..9].iter().all(|&v| v == 1));
2348        assert_eq!(output[4], 55.0); // center at tensor_idx = 1*3+1 = 4
2349
2350        // Agent 1 (top edge): row -1 is out of bounds → 3 cells masked
2351        let agent1_mask = &mask[9..18];
2352        let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2353        assert_eq!(valid_count, 6); // 2 rows in-bounds × 3 cols
2354    }
2355
2356    #[test]
2357    fn execute_agents_with_normalize() {
2358        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2359        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2360        let snap = snapshot_with_field(FieldId(0), data);
2361
2362        let spec = ObsSpec {
2363            entries: vec![ObsEntry {
2364                field_id: FieldId(0),
2365                region: ObsRegion::AgentRect {
2366                    half_extent: smallvec::smallvec![1, 1],
2367                },
2368                pool: None,
2369                transform: ObsTransform::Normalize {
2370                    min: 0.0,
2371                    max: 99.0,
2372                },
2373                dtype: ObsDtype::F32,
2374            }],
2375        };
2376        let result = ObsPlan::compile(&spec, &space).unwrap();
2377
2378        let center: Coord = smallvec::smallvec![5, 5];
2379        let mut output = vec![0.0f32; result.output_len];
2380        let mut mask = vec![0u8; result.mask_len];
2381        result
2382            .plan
2383            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2384            .unwrap();
2385
2386        // Center (5,5) rank=55, normalized = 55/99 ≈ 0.5556
2387        let expected = 55.0 / 99.0;
2388        assert!((output[4] - expected as f32).abs() < 1e-5);
2389    }
2390
2391    #[test]
2392    fn execute_agents_with_pooling() {
2393        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2394        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2395        let snap = snapshot_with_field(FieldId(0), data);
2396
2397        // AgentRect with half_extent=3 → 7x7 bounding box.
2398        // Mean pool 2x2 stride 2 → floor((7-2)/2)+1 = 3 per dim → 3x3 = 9 output.
2399        let spec = ObsSpec {
2400            entries: vec![ObsEntry {
2401                field_id: FieldId(0),
2402                region: ObsRegion::AgentRect {
2403                    half_extent: smallvec::smallvec![3, 3],
2404                },
2405                pool: Some(PoolConfig {
2406                    kernel: PoolKernel::Mean,
2407                    kernel_size: 2,
2408                    stride: 2,
2409                }),
2410                transform: ObsTransform::Identity,
2411                dtype: ObsDtype::F32,
2412            }],
2413        };
2414        let result = ObsPlan::compile(&spec, &space).unwrap();
2415        assert_eq!(result.output_len, 9); // 3x3
2416        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2417
2418        // Interior agent at (10, 10): all cells valid.
2419        let center: Coord = smallvec::smallvec![10, 10];
2420        let mut output = vec![0.0f32; result.output_len];
2421        let mut mask = vec![0u8; result.mask_len];
2422        result
2423            .plan
2424            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2425            .unwrap();
2426
2427        // All pooled cells should be valid.
2428        assert!(mask.iter().all(|&v| v == 1));
2429
2430        // Verify first pooled cell: mean of top-left 2x2 of the 7x7 gather.
2431        // Gather bounding box starts at (10-3, 10-3) = (7, 7).
2432        // Top-left 2x2: (7,7)=147, (7,8)=148, (8,7)=167, (8,8)=168
2433        // Mean = (147+148+167+168)/4 = 157.5
2434        assert!((output[0] - 157.5).abs() < 1e-4);
2435    }
2436
2437    #[test]
2438    fn mixed_fixed_and_agent_entries() {
2439        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2440        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2441        let snap = snapshot_with_field(FieldId(0), data);
2442
2443        let spec = ObsSpec {
2444            entries: vec![
2445                // Fixed entry: full grid (100 elements).
2446                ObsEntry {
2447                    field_id: FieldId(0),
2448                    region: ObsRegion::Fixed(RegionSpec::All),
2449                    pool: None,
2450                    transform: ObsTransform::Identity,
2451                    dtype: ObsDtype::F32,
2452                },
2453                // Agent entry: 3x3 rect around agent.
2454                ObsEntry {
2455                    field_id: FieldId(0),
2456                    region: ObsRegion::AgentRect {
2457                        half_extent: smallvec::smallvec![1, 1],
2458                    },
2459                    pool: None,
2460                    transform: ObsTransform::Identity,
2461                    dtype: ObsDtype::F32,
2462                },
2463            ],
2464        };
2465        let result = ObsPlan::compile(&spec, &space).unwrap();
2466        assert!(result.plan.is_standard());
2467        assert_eq!(result.output_len, 109); // 100 + 9
2468
2469        let center: Coord = smallvec::smallvec![5, 5];
2470        let mut output = vec![0.0f32; result.output_len];
2471        let mut mask = vec![0u8; result.mask_len];
2472        result
2473            .plan
2474            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2475            .unwrap();
2476
2477        // Fixed entry: first 100 elements match field data.
2478        let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2479        assert_eq!(&output[..100], &expected[..]);
2480        assert!(mask[..100].iter().all(|&v| v == 1));
2481
2482        // Agent entry: 3x3 centered on (5,5). Center at tensor_idx = 1*3+1 = 4.
2483        // rank(5,5) = 55
2484        assert_eq!(output[100 + 4], 55.0);
2485    }
2486
2487    #[test]
2488    fn wrong_dimensionality_returns_error() {
2489        // 2D space but 1D agent center → should error, not panic.
2490        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2491        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2492        let snap = snapshot_with_field(FieldId(0), data);
2493
2494        let spec = ObsSpec {
2495            entries: vec![ObsEntry {
2496                field_id: FieldId(0),
2497                region: ObsRegion::AgentDisk { radius: 1 },
2498                pool: None,
2499                transform: ObsTransform::Identity,
2500                dtype: ObsDtype::F32,
2501            }],
2502        };
2503        let result = ObsPlan::compile(&spec, &space).unwrap();
2504
2505        let bad_center: Coord = smallvec::smallvec![5]; // 1D, not 2D
2506        let mut output = vec![0.0f32; result.output_len];
2507        let mut mask = vec![0u8; result.mask_len];
2508        let err =
2509            result
2510                .plan
2511                .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2512        assert!(err.is_err());
2513        let msg = format!("{}", err.unwrap_err());
2514        assert!(
2515            msg.contains("dimensions"),
2516            "error should mention dimensions: {msg}"
2517        );
2518    }
2519
2520    #[test]
2521    fn agent_disk_square4_filters_corners() {
2522        // On a 4-connected grid, AgentDisk radius=2 should use Manhattan distance.
2523        // Bounding box is 5x5 = 25, but Manhattan disk has 13 cells (diamond shape).
2524        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2525        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2526        let snap = snapshot_with_field(FieldId(0), data);
2527
2528        let spec = ObsSpec {
2529            entries: vec![ObsEntry {
2530                field_id: FieldId(0),
2531                region: ObsRegion::AgentDisk { radius: 2 },
2532                pool: None,
2533                transform: ObsTransform::Identity,
2534                dtype: ObsDtype::F32,
2535            }],
2536        };
2537        let result = ObsPlan::compile(&spec, &space).unwrap();
2538        assert_eq!(result.output_len, 25); // tensor shape is still 5x5
2539
2540        // Interior agent at (10, 10).
2541        let center: Coord = smallvec::smallvec![10, 10];
2542        let mut output = vec![0.0f32; 25];
2543        let mut mask = vec![0u8; 25];
2544        result
2545            .plan
2546            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2547            .unwrap();
2548
2549        // Manhattan distance disk of radius 2 on a 5x5 bounding box:
2550        //   . . X . .    (row -2: only center col)
2551        //   . X X X .    (row -1: 3 cells)
2552        //   X X X X X    (row  0: 5 cells)
2553        //   . X X X .    (row +1: 3 cells)
2554        //   . . X . .    (row +2: only center col)
2555        // Total: 1 + 3 + 5 + 3 + 1 = 13 cells
2556        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2557        assert_eq!(
2558            valid_count, 13,
2559            "Manhattan disk radius=2 should have 13 cells"
2560        );
2561
2562        // Corners should be masked out: (dr,dc) where |dr|+|dc| > 2
2563        // tensor_idx 0: dr=-2,dc=-2 → dist=4 → OUT
2564        // tensor_idx 4: dr=-2,dc=+2 → dist=4 → OUT
2565        // tensor_idx 20: dr=+2,dc=-2 → dist=4 → OUT
2566        // tensor_idx 24: dr=+2,dc=+2 → dist=4 → OUT
2567        for &idx in &[0, 4, 20, 24] {
2568            assert_eq!(
2569                mask[idx], 0,
2570                "corner tensor_idx {idx} should be outside disk"
2571            );
2572        }
2573
2574        // Center cell: tensor_idx = 2*5+2 = 12, absolute = row 10 * 20 + col 10 = 210
2575        assert_eq!(output[12], 210.0);
2576        assert_eq!(mask[12], 1);
2577    }
2578
2579    #[test]
2580    fn agent_rect_no_disk_filtering() {
2581        // AgentRect should NOT filter any cells — full rectangle is valid.
2582        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2583        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2584        let snap = snapshot_with_field(FieldId(0), data);
2585
2586        let spec = ObsSpec {
2587            entries: vec![ObsEntry {
2588                field_id: FieldId(0),
2589                region: ObsRegion::AgentRect {
2590                    half_extent: smallvec::smallvec![2, 2],
2591                },
2592                pool: None,
2593                transform: ObsTransform::Identity,
2594                dtype: ObsDtype::F32,
2595            }],
2596        };
2597        let result = ObsPlan::compile(&spec, &space).unwrap();
2598
2599        let center: Coord = smallvec::smallvec![10, 10];
2600        let mut output = vec![0.0f32; 25];
2601        let mut mask = vec![0u8; 25];
2602        result
2603            .plan
2604            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2605            .unwrap();
2606
2607        // All 25 cells should be valid for AgentRect (no disk filtering).
2608        assert!(mask.iter().all(|&v| v == 1));
2609    }
2610
2611    #[test]
2612    fn agent_disk_square8_chebyshev() {
2613        // On an 8-connected grid, AgentDisk radius=1 uses Chebyshev distance.
2614        // Bounding box is 3x3 = 9, Chebyshev disk radius=1 = full 3x3 → 9 cells.
2615        let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2616        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2617        let snap = snapshot_with_field(FieldId(0), data);
2618
2619        let spec = ObsSpec {
2620            entries: vec![ObsEntry {
2621                field_id: FieldId(0),
2622                region: ObsRegion::AgentDisk { radius: 1 },
2623                pool: None,
2624                transform: ObsTransform::Identity,
2625                dtype: ObsDtype::F32,
2626            }],
2627        };
2628        let result = ObsPlan::compile(&spec, &space).unwrap();
2629        assert_eq!(result.output_len, 9);
2630
2631        let center: Coord = smallvec::smallvec![5, 5];
2632        let mut output = vec![0.0f32; 9];
2633        let mut mask = vec![0u8; 9];
2634        result
2635            .plan
2636            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2637            .unwrap();
2638
2639        // Chebyshev distance <= 1 covers full 3x3 = 9 cells (all corners included).
2640        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2641        assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2642    }
2643
2644    #[test]
2645    fn compile_rejects_inverted_normalize_range() {
2646        let space = square4_space();
2647        let spec = ObsSpec {
2648            entries: vec![ObsEntry {
2649                field_id: FieldId(0),
2650                region: ObsRegion::Fixed(RegionSpec::All),
2651                pool: None,
2652                transform: ObsTransform::Normalize {
2653                    min: 10.0,
2654                    max: 5.0,
2655                },
2656                dtype: ObsDtype::F32,
2657            }],
2658        };
2659        let err = ObsPlan::compile(&spec, &space).unwrap_err();
2660        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
2661    }
2662
2663    #[test]
2664    fn compile_rejects_nan_normalize() {
2665        let space = square4_space();
2666        let spec = ObsSpec {
2667            entries: vec![ObsEntry {
2668                field_id: FieldId(0),
2669                region: ObsRegion::Fixed(RegionSpec::All),
2670                pool: None,
2671                transform: ObsTransform::Normalize {
2672                    min: f64::NAN,
2673                    max: 1.0,
2674                },
2675                dtype: ObsDtype::F32,
2676            }],
2677        };
2678        assert!(ObsPlan::compile(&spec, &space).is_err());
2679    }
2680
2681    #[test]
2682    fn execute_batch_partial_write_on_generation_mismatch() {
2683        let space = square4_space();
2684        let spec = ObsSpec {
2685            entries: vec![ObsEntry {
2686                field_id: FieldId(0),
2687                region: ObsRegion::Fixed(RegionSpec::All),
2688                pool: None,
2689                transform: ObsTransform::Identity,
2690                dtype: ObsDtype::F32,
2691            }],
2692        };
2693        // Compile bound to generation 1.
2694        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
2695
2696        // First snapshot: generation 1 (matches), data = 7.0
2697        let snap_ok = snapshot_with_field(FieldId(0), vec![7.0; 9]);
2698        // Second snapshot: generation 99 (mismatch)
2699        let mut snap_bad = MockSnapshot::new(TickId(5), WorldGenerationId(99), ParameterVersion(0));
2700        snap_bad.set_field(FieldId(0), vec![0.0; 9]);
2701
2702        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_ok, &snap_bad];
2703        let mut output = vec![0.0f32; result.output_len * 2];
2704        let mut mask = vec![0u8; result.mask_len * 2];
2705
2706        // Should fail on second snapshot.
2707        let err = result
2708            .plan
2709            .execute_batch(&snaps, None, &mut output, &mut mask)
2710            .unwrap_err();
2711        assert!(matches!(err, ObsError::PlanInvalidated { .. }));
2712
2713        // First snapshot's data was already written (partial write).
2714        assert!(
2715            output[..9].iter().all(|&v| v == 7.0),
2716            "first snapshot should be written despite batch error"
2717        );
2718    }
2719}