Skip to main content

murk_obs/
plan.rs

1//! Observation plan compilation and execution.
2//!
3//! [`ObsPlan`] is compiled from an [`ObsSpec`] + [`Space`], producing
4//! a reusable gather plan that can be executed against any
5//! [`SnapshotAccess`] implementor. The "Simple plan class" uses a
6//! branch-free flat gather: for each entry, iterate pre-computed
7//! `(field_data_index, tensor_index)` pairs, read the field value,
8//! optionally transform it, and write to the caller-allocated buffer.
9
10use indexmap::IndexMap;
11
12use murk_core::error::ObsError;
13use murk_core::{Coord, FieldId, SnapshotAccess, TickId, WorldGenerationId};
14use murk_space::Space;
15
16use crate::geometry::GridGeometry;
17use crate::metadata::ObsMetadata;
18use crate::pool::pool_2d;
19use crate::spec::{ObsDtype, ObsRegion, ObsSpec, ObsTransform, PoolConfig};
20
21/// Coverage threshold: warn if valid_ratio < this.
22const COVERAGE_WARN_THRESHOLD: f64 = 0.5;
23
24/// Coverage threshold: error if valid_ratio < this.
25const COVERAGE_ERROR_THRESHOLD: f64 = 0.35;
26
27/// Result of compiling an [`ObsSpec`].
28#[derive(Debug)]
29pub struct ObsPlanResult {
30    /// The compiled plan, ready for execution.
31    pub plan: ObsPlan,
32    /// Total number of f32 elements in the output tensor.
33    pub output_len: usize,
34    /// Shape per entry (each entry's region bounding shape dimensions).
35    pub entry_shapes: Vec<Vec<usize>>,
36    /// Length of the validity mask in bytes.
37    pub mask_len: usize,
38}
39
40/// Compiled observation plan: either Simple or Standard class.
41///
42/// **Simple** (all `Fixed` regions): pre-computed gather indices, branch-free
43/// loop, zero spatial computation at runtime. Use [`execute`](Self::execute).
44///
45/// **Standard** (any agent-relative region): template-based gather with
46/// interior/boundary dispatch. Use [`execute_agents`](Self::execute_agents).
47#[derive(Debug)]
48pub struct ObsPlan {
49    strategy: PlanStrategy,
50    /// Total output elements across all entries (per agent for Standard).
51    output_len: usize,
52    /// Total mask bytes across all entries (per agent for Standard).
53    mask_len: usize,
54    /// Generation at compile time (for PLAN_INVALIDATED detection).
55    compiled_generation: Option<WorldGenerationId>,
56}
57
58/// Pre-computed gather instruction for a single cell.
59///
60/// At execution time, we read `field_data[field_data_idx]` and write
61/// the (transformed) value to `output[tensor_idx]`.
62#[derive(Debug, Clone)]
63struct GatherOp {
64    /// Index into the flat field data array (canonical ordering).
65    field_data_idx: usize,
66    /// Index into the output slice for this entry.
67    tensor_idx: usize,
68}
69
70/// A single compiled entry ready for gather execution.
71#[derive(Debug)]
72struct CompiledEntry {
73    field_id: FieldId,
74    transform: ObsTransform,
75    #[allow(dead_code)]
76    dtype: ObsDtype,
77    /// Offset into the output buffer where this entry starts.
78    output_offset: usize,
79    /// Offset into the validity mask where this entry starts.
80    mask_offset: usize,
81    /// Number of elements this entry contributes to the output.
82    element_count: usize,
83    /// Pre-computed gather operations (one per valid cell in the region).
84    gather_ops: Vec<GatherOp>,
85    /// Pre-computed validity mask for this entry's bounding box.
86    valid_mask: Vec<u8>,
87    /// Valid ratio for this entry's region.
88    #[allow(dead_code)]
89    valid_ratio: f64,
90}
91
92/// Relative offset from agent center for template-based gather.
93///
94/// At compile time, the bounding box of an agent-centered region is
95/// decomposed into `TemplateOp`s. At execute time, the agent center
96/// is resolved and each op is applied: `field_data[base_rank + stride_offset]`.
97#[derive(Debug, Clone)]
98struct TemplateOp {
99    /// Offset from center per coordinate axis.
100    relative: Coord,
101    /// Position in the gather bounding-box tensor (row-major).
102    tensor_idx: usize,
103    /// Precomputed `sum(relative[i] * strides[i])` for interior fast path.
104    /// Zero if no `GridGeometry` is available (fallback path only).
105    stride_offset: isize,
106    /// Whether this cell is within the disk region (always true for AgentRect).
107    /// For AgentDisk, cells outside the graph-distance radius are excluded.
108    in_disk: bool,
109}
110
111/// Compiled agent-relative entry for the Standard plan class.
112///
113/// Stores template data that is instantiated per-agent at execute time.
114/// The bounding box shape comes from the region (e.g., `[2r+1, 2r+1]` for
115/// `AgentDisk`/`AgentRect`), and may be reduced by pooling.
116#[derive(Debug)]
117struct AgentCompiledEntry {
118    field_id: FieldId,
119    pool: Option<PoolConfig>,
120    transform: ObsTransform,
121    #[allow(dead_code)]
122    dtype: ObsDtype,
123    /// Offset into the per-agent output buffer.
124    output_offset: usize,
125    /// Offset into the per-agent mask buffer.
126    mask_offset: usize,
127    /// Post-pool output elements (written to output).
128    element_count: usize,
129    /// Pre-pool bounding-box elements (gather buffer size).
130    pre_pool_element_count: usize,
131    /// Shape of the pre-pool bounding box (e.g., `[7, 7]`).
132    pre_pool_shape: Vec<usize>,
133    /// Template operations (one per cell in bounding box).
134    template_ops: Vec<TemplateOp>,
135    /// Radius for `is_interior` check.
136    radius: u32,
137}
138
139/// Data for the Standard plan class (agent-centered foveation + pooling).
140#[derive(Debug)]
141struct StandardPlanData {
142    /// Entries with `ObsRegion::Fixed` (same output for all agents).
143    fixed_entries: Vec<CompiledEntry>,
144    /// Entries with agent-relative regions (resolved per-agent).
145    agent_entries: Vec<AgentCompiledEntry>,
146    /// Grid geometry for interior/boundary dispatch (`None` → all slow path).
147    geometry: Option<GridGeometry>,
148}
149
150/// Internal plan strategy: Simple (all-fixed) or Standard (agent-centered).
151#[derive(Debug)]
152enum PlanStrategy {
153    /// All entries are `ObsRegion::Fixed`: pre-computed gather indices.
154    Simple(Vec<CompiledEntry>),
155    /// At least one entry is agent-relative: template-based gather.
156    Standard(StandardPlanData),
157}
158
159impl ObsPlan {
160    /// Compile an [`ObsSpec`] against a [`Space`].
161    ///
162    /// Detects whether the spec contains agent-relative regions and
163    /// dispatches to the appropriate plan class:
164    /// - All `Fixed` → **Simple** (pre-computed gather)
165    /// - Any `AgentDisk`/`AgentRect` → **Standard** (template-based)
166    pub fn compile(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
167        if spec.entries.is_empty() {
168            return Err(ObsError::InvalidObsSpec {
169                reason: "ObsSpec has no entries".into(),
170            });
171        }
172
173        let has_agent = spec.entries.iter().any(|e| {
174            matches!(
175                e.region,
176                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. }
177            )
178        });
179
180        if has_agent {
181            Self::compile_standard(spec, space)
182        } else {
183            Self::compile_simple(spec, space)
184        }
185    }
186
187    /// Compile a Simple plan (all `Fixed` regions, no agent-relative entries).
188    fn compile_simple(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
189        let canonical = space.canonical_ordering();
190        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
191            .into_iter()
192            .enumerate()
193            .map(|(idx, coord)| (coord, idx))
194            .collect();
195
196        let mut entries = Vec::with_capacity(spec.entries.len());
197        let mut output_offset = 0usize;
198        let mut mask_offset = 0usize;
199        let mut entry_shapes = Vec::with_capacity(spec.entries.len());
200
201        for (i, entry) in spec.entries.iter().enumerate() {
202            let fixed_region = match &entry.region {
203                ObsRegion::Fixed(spec) => spec,
204                ObsRegion::AgentDisk { .. } | ObsRegion::AgentRect { .. } => {
205                    return Err(ObsError::InvalidObsSpec {
206                        reason: format!("entry {i}: agent-relative region in Simple plan"),
207                    });
208                }
209            };
210            if entry.pool.is_some() {
211                return Err(ObsError::InvalidObsSpec {
212                    reason: format!(
213                        "entry {i}: pooling requires a Standard plan (use agent-relative region)"
214                    ),
215                });
216            }
217
218            let region_plan =
219                space
220                    .compile_region(fixed_region)
221                    .map_err(|e| ObsError::InvalidObsSpec {
222                        reason: format!("entry {i}: region compile failed: {e}"),
223                    })?;
224
225            let ratio = region_plan.valid_ratio();
226            if ratio < COVERAGE_ERROR_THRESHOLD {
227                return Err(ObsError::InvalidComposition {
228                    reason: format!(
229                        "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
230                    ),
231                });
232            }
233            if ratio < COVERAGE_WARN_THRESHOLD {
234                eprintln!(
235                    "murk-obs: warning: entry {i} valid_ratio {ratio:.3} < {COVERAGE_WARN_THRESHOLD}"
236                );
237            }
238
239            let mut gather_ops = Vec::with_capacity(region_plan.coords.len());
240            for (coord_idx, coord) in region_plan.coords.iter().enumerate() {
241                let field_data_idx =
242                    *coord_to_field_idx
243                        .get(coord)
244                        .ok_or_else(|| ObsError::InvalidObsSpec {
245                            reason: format!("entry {i}: coord {coord:?} not in canonical ordering"),
246                        })?;
247                let tensor_idx = region_plan.tensor_indices[coord_idx];
248                gather_ops.push(GatherOp {
249                    field_data_idx,
250                    tensor_idx,
251                });
252            }
253
254            let element_count = region_plan.bounding_shape.total_elements();
255            let shape = match &region_plan.bounding_shape {
256                murk_space::BoundingShape::Rect(dims) => dims.clone(),
257            };
258            entry_shapes.push(shape);
259
260            entries.push(CompiledEntry {
261                field_id: entry.field_id,
262                transform: entry.transform.clone(),
263                dtype: entry.dtype,
264                output_offset,
265                mask_offset,
266                element_count,
267                gather_ops,
268                valid_mask: region_plan.valid_mask,
269                valid_ratio: ratio,
270            });
271
272            output_offset += element_count;
273            mask_offset += element_count;
274        }
275
276        let plan = ObsPlan {
277            strategy: PlanStrategy::Simple(entries),
278            output_len: output_offset,
279            mask_len: mask_offset,
280            compiled_generation: None,
281        };
282
283        Ok(ObsPlanResult {
284            output_len: plan.output_len,
285            mask_len: plan.mask_len,
286            entry_shapes,
287            plan,
288        })
289    }
290
291    /// Compile a Standard plan (has agent-relative entries).
292    ///
293    /// Fixed entries are compiled with pre-computed gather (same for all agents).
294    /// Agent entries are compiled as templates (resolved per-agent at execute time).
295    fn compile_standard(spec: &ObsSpec, space: &dyn Space) -> Result<ObsPlanResult, ObsError> {
296        let canonical = space.canonical_ordering();
297        let coord_to_field_idx: IndexMap<Coord, usize> = canonical
298            .into_iter()
299            .enumerate()
300            .map(|(idx, coord)| (coord, idx))
301            .collect();
302
303        let geometry = GridGeometry::from_space(space);
304        let ndim = space.ndim();
305
306        let mut fixed_entries = Vec::new();
307        let mut agent_entries = Vec::new();
308        let mut output_offset = 0usize;
309        let mut mask_offset = 0usize;
310        let mut entry_shapes = Vec::new();
311
312        for (i, entry) in spec.entries.iter().enumerate() {
313            match &entry.region {
314                ObsRegion::Fixed(region_spec) => {
315                    if entry.pool.is_some() {
316                        return Err(ObsError::InvalidObsSpec {
317                            reason: format!("entry {i}: pooling on Fixed regions not supported"),
318                        });
319                    }
320
321                    let region_plan = space.compile_region(region_spec).map_err(|e| {
322                        ObsError::InvalidObsSpec {
323                            reason: format!("entry {i}: region compile failed: {e}"),
324                        }
325                    })?;
326
327                    let ratio = region_plan.valid_ratio();
328                    if ratio < COVERAGE_ERROR_THRESHOLD {
329                        return Err(ObsError::InvalidComposition {
330                            reason: format!(
331                                "entry {i}: valid_ratio {ratio:.3} < {COVERAGE_ERROR_THRESHOLD}"
332                            ),
333                        });
334                    }
335
336                    let mut gather_ops = Vec::with_capacity(region_plan.coords.len());
337                    for (coord_idx, coord) in region_plan.coords.iter().enumerate() {
338                        let field_data_idx = *coord_to_field_idx.get(coord).ok_or_else(|| {
339                            ObsError::InvalidObsSpec {
340                                reason: format!(
341                                    "entry {i}: coord {coord:?} not in canonical ordering"
342                                ),
343                            }
344                        })?;
345                        let tensor_idx = region_plan.tensor_indices[coord_idx];
346                        gather_ops.push(GatherOp {
347                            field_data_idx,
348                            tensor_idx,
349                        });
350                    }
351
352                    let element_count = region_plan.bounding_shape.total_elements();
353                    let shape = match &region_plan.bounding_shape {
354                        murk_space::BoundingShape::Rect(dims) => dims.clone(),
355                    };
356                    entry_shapes.push(shape);
357
358                    fixed_entries.push(CompiledEntry {
359                        field_id: entry.field_id,
360                        transform: entry.transform.clone(),
361                        dtype: entry.dtype,
362                        output_offset,
363                        mask_offset,
364                        element_count,
365                        gather_ops,
366                        valid_mask: region_plan.valid_mask,
367                        valid_ratio: ratio,
368                    });
369
370                    output_offset += element_count;
371                    mask_offset += element_count;
372                }
373
374                ObsRegion::AgentDisk { radius } => {
375                    let half_ext: smallvec::SmallVec<[u32; 4]> =
376                        (0..ndim).map(|_| *radius).collect();
377                    let (ae, shape) = Self::compile_agent_entry(
378                        i,
379                        entry,
380                        &half_ext,
381                        *radius,
382                        &geometry,
383                        Some(*radius),
384                        output_offset,
385                        mask_offset,
386                    )?;
387                    entry_shapes.push(shape);
388                    output_offset += ae.element_count;
389                    mask_offset += ae.element_count;
390                    agent_entries.push(ae);
391                }
392
393                ObsRegion::AgentRect { half_extent } => {
394                    let radius = *half_extent.iter().max().unwrap_or(&0);
395                    let (ae, shape) = Self::compile_agent_entry(
396                        i,
397                        entry,
398                        half_extent,
399                        radius,
400                        &geometry,
401                        None,
402                        output_offset,
403                        mask_offset,
404                    )?;
405                    entry_shapes.push(shape);
406                    output_offset += ae.element_count;
407                    mask_offset += ae.element_count;
408                    agent_entries.push(ae);
409                }
410            }
411        }
412
413        let plan = ObsPlan {
414            strategy: PlanStrategy::Standard(StandardPlanData {
415                fixed_entries,
416                agent_entries,
417                geometry,
418            }),
419            output_len: output_offset,
420            mask_len: mask_offset,
421            compiled_generation: None,
422        };
423
424        Ok(ObsPlanResult {
425            output_len: plan.output_len,
426            mask_len: plan.mask_len,
427            entry_shapes,
428            plan,
429        })
430    }
431
432    /// Compile a single agent-relative entry into a template.
433    ///
434    /// `disk_radius`: if `Some(r)`, template ops outside graph-distance `r`
435    /// are marked `in_disk = false` (for `AgentDisk`). `None` for `AgentRect`.
436    #[allow(clippy::too_many_arguments)]
437    fn compile_agent_entry(
438        entry_idx: usize,
439        entry: &crate::spec::ObsEntry,
440        half_extent: &[u32],
441        radius: u32,
442        geometry: &Option<GridGeometry>,
443        disk_radius: Option<u32>,
444        output_offset: usize,
445        mask_offset: usize,
446    ) -> Result<(AgentCompiledEntry, Vec<usize>), ObsError> {
447        let pre_pool_shape: Vec<usize> =
448            half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
449        let pre_pool_element_count: usize = pre_pool_shape.iter().product();
450
451        let template_ops = generate_template_ops(half_extent, geometry, disk_radius);
452
453        let (element_count, output_shape) = if let Some(pool) = &entry.pool {
454            if pre_pool_shape.len() != 2 {
455                return Err(ObsError::InvalidObsSpec {
456                    reason: format!(
457                        "entry {entry_idx}: pooling requires 2D region, got {}D",
458                        pre_pool_shape.len()
459                    ),
460                });
461            }
462            let h = pre_pool_shape[0];
463            let w = pre_pool_shape[1];
464            let ks = pool.kernel_size;
465            let stride = pool.stride;
466            if ks == 0 || stride == 0 {
467                return Err(ObsError::InvalidObsSpec {
468                    reason: format!("entry {entry_idx}: pool kernel_size and stride must be > 0"),
469                });
470            }
471            let out_h = if h >= ks { (h - ks) / stride + 1 } else { 0 };
472            let out_w = if w >= ks { (w - ks) / stride + 1 } else { 0 };
473            if out_h == 0 || out_w == 0 {
474                return Err(ObsError::InvalidObsSpec {
475                    reason: format!(
476                        "entry {entry_idx}: pool produces empty output \
477                         (region [{h},{w}], kernel_size {ks}, stride {stride})"
478                    ),
479                });
480            }
481            (out_h * out_w, vec![out_h, out_w])
482        } else {
483            (pre_pool_element_count, pre_pool_shape.clone())
484        };
485
486        Ok((
487            AgentCompiledEntry {
488                field_id: entry.field_id,
489                pool: entry.pool.clone(),
490                transform: entry.transform.clone(),
491                dtype: entry.dtype,
492                output_offset,
493                mask_offset,
494                element_count,
495                pre_pool_element_count,
496                pre_pool_shape,
497                template_ops,
498                radius,
499            },
500            output_shape,
501        ))
502    }
503
504    /// Compile with generation binding for PLAN_INVALIDATED detection.
505    ///
506    /// Same as [`compile`](Self::compile) but records the snapshot's
507    /// `world_generation_id` for later validation in [`ObsPlan::execute`].
508    pub fn compile_bound(
509        spec: &ObsSpec,
510        space: &dyn Space,
511        generation: WorldGenerationId,
512    ) -> Result<ObsPlanResult, ObsError> {
513        let mut result = Self::compile(spec, space)?;
514        result.plan.compiled_generation = Some(generation);
515        Ok(result)
516    }
517
518    /// Total number of f32 elements in the output tensor.
519    pub fn output_len(&self) -> usize {
520        self.output_len
521    }
522
523    /// Total number of bytes in the validity mask.
524    pub fn mask_len(&self) -> usize {
525        self.mask_len
526    }
527
528    /// The generation this plan was compiled against, if bound.
529    pub fn compiled_generation(&self) -> Option<WorldGenerationId> {
530        self.compiled_generation
531    }
532
533    /// Execute the observation plan against a snapshot.
534    ///
535    /// Fills `output` with gathered and transformed field values, and
536    /// `mask` with validity flags (1 = valid, 0 = padding). Both
537    /// buffers must be pre-allocated to [`output_len`](Self::output_len)
538    /// and [`mask_len`](Self::mask_len) respectively.
539    ///
540    /// `engine_tick` is the current engine tick for computing
541    /// [`ObsMetadata::age_ticks`]. Pass `None` in Lockstep mode
542    /// (age is always 0). In RealtimeAsync mode, pass the current
543    /// engine tick so age reflects snapshot staleness.
544    ///
545    /// Returns [`ObsMetadata`] on success.
546    ///
547    /// # Errors
548    ///
549    /// - [`ObsError::PlanInvalidated`] if bound and generation mismatches.
550    /// - [`ObsError::ExecutionFailed`] if a field is missing from the snapshot.
551    pub fn execute(
552        &self,
553        snapshot: &dyn SnapshotAccess,
554        engine_tick: Option<TickId>,
555        output: &mut [f32],
556        mask: &mut [u8],
557    ) -> Result<ObsMetadata, ObsError> {
558        let entries = match &self.strategy {
559            PlanStrategy::Simple(entries) => entries,
560            PlanStrategy::Standard(_) => {
561                return Err(ObsError::ExecutionFailed {
562                    reason: "Standard plan requires execute_agents(), not execute()".into(),
563                });
564            }
565        };
566
567        if output.len() < self.output_len {
568            return Err(ObsError::ExecutionFailed {
569                reason: format!(
570                    "output buffer too small: {} < {}",
571                    output.len(),
572                    self.output_len
573                ),
574            });
575        }
576        if mask.len() < self.mask_len {
577            return Err(ObsError::ExecutionFailed {
578                reason: format!("mask buffer too small: {} < {}", mask.len(), self.mask_len),
579            });
580        }
581
582        // Generation check (PLAN_INVALIDATED).
583        if let Some(compiled_gen) = self.compiled_generation {
584            let snapshot_gen = snapshot.world_generation_id();
585            if compiled_gen != snapshot_gen {
586                return Err(ObsError::PlanInvalidated {
587                    reason: format!(
588                        "plan compiled for generation {}, snapshot is generation {}",
589                        compiled_gen.0, snapshot_gen.0
590                    ),
591                });
592            }
593        }
594
595        let mut total_valid = 0usize;
596        let mut total_elements = 0usize;
597
598        for entry in entries {
599            let field_data =
600                snapshot
601                    .read_field(entry.field_id)
602                    .ok_or_else(|| ObsError::ExecutionFailed {
603                        reason: format!("field {:?} not in snapshot", entry.field_id),
604                    })?;
605
606            let out_slice =
607                &mut output[entry.output_offset..entry.output_offset + entry.element_count];
608            let mask_slice = &mut mask[entry.mask_offset..entry.mask_offset + entry.element_count];
609
610            // Initialize to zero/padding.
611            out_slice.fill(0.0);
612            mask_slice.copy_from_slice(&entry.valid_mask);
613
614            // Branch-free gather: pre-computed (field_data_idx, tensor_idx) pairs.
615            for op in &entry.gather_ops {
616                let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
617                    ObsError::ExecutionFailed {
618                        reason: format!(
619                            "field {:?} has {} elements but gather requires index {}",
620                            entry.field_id,
621                            field_data.len(),
622                            op.field_data_idx,
623                        ),
624                    }
625                })?;
626                out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
627            }
628
629            total_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
630            total_elements += entry.element_count;
631        }
632
633        let coverage = if total_elements == 0 {
634            0.0
635        } else {
636            total_valid as f64 / total_elements as f64
637        };
638
639        let age_ticks = match engine_tick {
640            Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
641            None => 0,
642        };
643
644        Ok(ObsMetadata {
645            tick_id: snapshot.tick_id(),
646            age_ticks,
647            coverage,
648            world_generation_id: snapshot.world_generation_id(),
649            parameter_version: snapshot.parameter_version(),
650        })
651    }
652
653    /// Execute the plan for a batch of `N` identical environments.
654    ///
655    /// Each snapshot in the batch fills `output_len()` elements in the
656    /// output buffer, starting at `batch_idx * output_len()`. Same for
657    /// masks. This is the primary interface for vectorized RL training.
658    ///
659    /// Returns one [`ObsMetadata`] per snapshot.
660    pub fn execute_batch(
661        &self,
662        snapshots: &[&dyn SnapshotAccess],
663        engine_tick: Option<TickId>,
664        output: &mut [f32],
665        mask: &mut [u8],
666    ) -> Result<Vec<ObsMetadata>, ObsError> {
667        // execute_batch only works with Simple plans.
668        if matches!(self.strategy, PlanStrategy::Standard(_)) {
669            return Err(ObsError::ExecutionFailed {
670                reason: "Standard plan requires execute_agents(), not execute_batch()".into(),
671            });
672        }
673
674        let batch_size = snapshots.len();
675        let expected_out = batch_size * self.output_len;
676        let expected_mask = batch_size * self.mask_len;
677
678        if output.len() < expected_out {
679            return Err(ObsError::ExecutionFailed {
680                reason: format!(
681                    "batch output buffer too small: {} < {}",
682                    output.len(),
683                    expected_out
684                ),
685            });
686        }
687        if mask.len() < expected_mask {
688            return Err(ObsError::ExecutionFailed {
689                reason: format!(
690                    "batch mask buffer too small: {} < {}",
691                    mask.len(),
692                    expected_mask
693                ),
694            });
695        }
696
697        let mut metadata = Vec::with_capacity(batch_size);
698        for (i, snap) in snapshots.iter().enumerate() {
699            let out_start = i * self.output_len;
700            let mask_start = i * self.mask_len;
701            let out_slice = &mut output[out_start..out_start + self.output_len];
702            let mask_slice = &mut mask[mask_start..mask_start + self.mask_len];
703            let meta = self.execute(*snap, engine_tick, out_slice, mask_slice)?;
704            metadata.push(meta);
705        }
706        Ok(metadata)
707    }
708
709    /// Execute the Standard plan for `N` agents in one environment.
710    ///
711    /// Each agent gets `output_len()` elements starting at
712    /// `agent_idx * output_len()`. Fixed entries produce the same
713    /// output for all agents; agent-relative entries are resolved
714    /// per-agent using interior/boundary dispatch.
715    ///
716    /// Interior agents (~49% for 20×20 grid, radius 3) use a branchless
717    /// fast path with stride arithmetic. Boundary agents fall back to
718    /// per-cell bounds checking.
719    pub fn execute_agents(
720        &self,
721        snapshot: &dyn SnapshotAccess,
722        space: &dyn Space,
723        agent_centers: &[Coord],
724        engine_tick: Option<TickId>,
725        output: &mut [f32],
726        mask: &mut [u8],
727    ) -> Result<Vec<ObsMetadata>, ObsError> {
728        let standard = match &self.strategy {
729            PlanStrategy::Standard(data) => data,
730            PlanStrategy::Simple(_) => {
731                return Err(ObsError::ExecutionFailed {
732                    reason: "execute_agents requires a Standard plan \
733                             (spec must contain agent-relative entries)"
734                        .into(),
735                });
736            }
737        };
738
739        let n_agents = agent_centers.len();
740        let expected_out = n_agents * self.output_len;
741        let expected_mask = n_agents * self.mask_len;
742
743        if output.len() < expected_out {
744            return Err(ObsError::ExecutionFailed {
745                reason: format!(
746                    "output buffer too small: {} < {}",
747                    output.len(),
748                    expected_out
749                ),
750            });
751        }
752        if mask.len() < expected_mask {
753            return Err(ObsError::ExecutionFailed {
754                reason: format!("mask buffer too small: {} < {}", mask.len(), expected_mask),
755            });
756        }
757
758        // Validate agent center dimensionality.
759        let expected_ndim = space.ndim();
760        for (i, center) in agent_centers.iter().enumerate() {
761            if center.len() != expected_ndim {
762                return Err(ObsError::ExecutionFailed {
763                    reason: format!(
764                        "agent_centers[{i}] has {} dimensions, but space requires {expected_ndim}",
765                        center.len()
766                    ),
767                });
768            }
769        }
770
771        // Generation check.
772        if let Some(compiled_gen) = self.compiled_generation {
773            let snapshot_gen = snapshot.world_generation_id();
774            if compiled_gen != snapshot_gen {
775                return Err(ObsError::PlanInvalidated {
776                    reason: format!(
777                        "plan compiled for generation {}, snapshot is generation {}",
778                        compiled_gen.0, snapshot_gen.0
779                    ),
780                });
781            }
782        }
783
784        // Pre-read all field data (shared borrows, valid for duration).
785        let mut field_data_map: IndexMap<FieldId, &[f32]> = IndexMap::new();
786        for entry in &standard.fixed_entries {
787            if !field_data_map.contains_key(&entry.field_id) {
788                let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
789                    ObsError::ExecutionFailed {
790                        reason: format!("field {:?} not in snapshot", entry.field_id),
791                    }
792                })?;
793                field_data_map.insert(entry.field_id, data);
794            }
795        }
796        for entry in &standard.agent_entries {
797            if !field_data_map.contains_key(&entry.field_id) {
798                let data = snapshot.read_field(entry.field_id).ok_or_else(|| {
799                    ObsError::ExecutionFailed {
800                        reason: format!("field {:?} not in snapshot", entry.field_id),
801                    }
802                })?;
803                field_data_map.insert(entry.field_id, data);
804            }
805        }
806
807        let mut metadata = Vec::with_capacity(n_agents);
808
809        for (agent_i, center) in agent_centers.iter().enumerate() {
810            let out_start = agent_i * self.output_len;
811            let mask_start = agent_i * self.mask_len;
812            let agent_output = &mut output[out_start..out_start + self.output_len];
813            let agent_mask = &mut mask[mask_start..mask_start + self.mask_len];
814
815            agent_output.fill(0.0);
816            agent_mask.fill(0);
817
818            let mut total_valid = 0usize;
819            let mut total_elements = 0usize;
820
821            // ── Fixed entries (same for all agents) ──────────────
822            for entry in &standard.fixed_entries {
823                let field_data = field_data_map[&entry.field_id];
824                let out_slice = &mut agent_output
825                    [entry.output_offset..entry.output_offset + entry.element_count];
826                let mask_slice =
827                    &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
828
829                mask_slice.copy_from_slice(&entry.valid_mask);
830                for op in &entry.gather_ops {
831                    let raw = *field_data.get(op.field_data_idx).ok_or_else(|| {
832                        ObsError::ExecutionFailed {
833                            reason: format!(
834                                "field {:?} has {} elements but gather requires index {}",
835                                entry.field_id,
836                                field_data.len(),
837                                op.field_data_idx,
838                            ),
839                        }
840                    })?;
841                    out_slice[op.tensor_idx] = apply_transform(raw, &entry.transform);
842                }
843
844                total_valid += entry.valid_mask.iter().filter(|&&v| v == 1).count();
845                total_elements += entry.element_count;
846            }
847
848            // ── Agent-relative entries ───────────────────────────
849            for entry in &standard.agent_entries {
850                let field_data = field_data_map[&entry.field_id];
851
852                // Fast path: stride arithmetic works only for non-wrapping
853                // grids where all cells in the bounding box are in-bounds.
854                // Torus (all_wrap) requires modular arithmetic → slow path.
855                let use_fast_path = standard
856                    .geometry
857                    .as_ref()
858                    .map(|geo| !geo.all_wrap && geo.is_interior(center, entry.radius))
859                    .unwrap_or(false);
860
861                let valid = execute_agent_entry(
862                    entry,
863                    center,
864                    field_data,
865                    &standard.geometry,
866                    space,
867                    use_fast_path,
868                    agent_output,
869                    agent_mask,
870                );
871
872                total_valid += valid;
873                total_elements += entry.element_count;
874            }
875
876            let coverage = if total_elements == 0 {
877                0.0
878            } else {
879                total_valid as f64 / total_elements as f64
880            };
881
882            let age_ticks = match engine_tick {
883                Some(tick) => tick.0.saturating_sub(snapshot.tick_id().0),
884                None => 0,
885            };
886
887            metadata.push(ObsMetadata {
888                tick_id: snapshot.tick_id(),
889                age_ticks,
890                coverage,
891                world_generation_id: snapshot.world_generation_id(),
892                parameter_version: snapshot.parameter_version(),
893            });
894        }
895
896        Ok(metadata)
897    }
898
899    /// Whether this plan requires `execute_agents` (Standard) or `execute` (Simple).
900    pub fn is_standard(&self) -> bool {
901        matches!(self.strategy, PlanStrategy::Standard(_))
902    }
903}
904
905/// Execute a single agent-relative entry for one agent.
906///
907/// Returns the number of valid cells written.
908#[allow(clippy::too_many_arguments)]
909fn execute_agent_entry(
910    entry: &AgentCompiledEntry,
911    center: &Coord,
912    field_data: &[f32],
913    geometry: &Option<GridGeometry>,
914    space: &dyn Space,
915    use_fast_path: bool,
916    agent_output: &mut [f32],
917    agent_mask: &mut [u8],
918) -> usize {
919    if entry.pool.is_some() {
920        execute_agent_entry_pooled(
921            entry,
922            center,
923            field_data,
924            geometry,
925            space,
926            use_fast_path,
927            agent_output,
928            agent_mask,
929        )
930    } else {
931        execute_agent_entry_direct(
932            entry,
933            center,
934            field_data,
935            geometry,
936            space,
937            use_fast_path,
938            agent_output,
939            agent_mask,
940        )
941    }
942}
943
944/// Direct gather (no pooling): gather + transform → output.
945#[allow(clippy::too_many_arguments)]
946fn execute_agent_entry_direct(
947    entry: &AgentCompiledEntry,
948    center: &Coord,
949    field_data: &[f32],
950    geometry: &Option<GridGeometry>,
951    space: &dyn Space,
952    use_fast_path: bool,
953    agent_output: &mut [f32],
954    agent_mask: &mut [u8],
955) -> usize {
956    let out_slice =
957        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
958    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
959
960    if use_fast_path {
961        // FAST PATH: all cells in-bounds, branchless stride arithmetic.
962        let geo = geometry.as_ref().unwrap();
963        let base_rank = geo.canonical_rank(center) as isize;
964        let mut valid = 0;
965        for op in &entry.template_ops {
966            if !op.in_disk {
967                continue;
968            }
969            let field_idx = (base_rank + op.stride_offset) as usize;
970            out_slice[op.tensor_idx] = apply_transform(field_data[field_idx], &entry.transform);
971            mask_slice[op.tensor_idx] = 1;
972            valid += 1;
973        }
974        valid
975    } else {
976        // SLOW PATH: bounds-check each offset (or modular wrap for torus).
977        let mut valid = 0;
978        for op in &entry.template_ops {
979            if !op.in_disk {
980                continue;
981            }
982            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
983            if let Some(idx) = field_idx {
984                if idx < field_data.len() {
985                    out_slice[op.tensor_idx] = apply_transform(field_data[idx], &entry.transform);
986                    mask_slice[op.tensor_idx] = 1;
987                    valid += 1;
988                }
989            }
990        }
991        valid
992    }
993}
994
995/// Pooled gather: gather → scratch → pool → transform → output.
996#[allow(clippy::too_many_arguments)]
997fn execute_agent_entry_pooled(
998    entry: &AgentCompiledEntry,
999    center: &Coord,
1000    field_data: &[f32],
1001    geometry: &Option<GridGeometry>,
1002    space: &dyn Space,
1003    use_fast_path: bool,
1004    agent_output: &mut [f32],
1005    agent_mask: &mut [u8],
1006) -> usize {
1007    let mut scratch = vec![0.0f32; entry.pre_pool_element_count];
1008    let mut scratch_mask = vec![0u8; entry.pre_pool_element_count];
1009
1010    if use_fast_path {
1011        let geo = geometry.as_ref().unwrap();
1012        let base_rank = geo.canonical_rank(center) as isize;
1013        for op in &entry.template_ops {
1014            if !op.in_disk {
1015                continue;
1016            }
1017            let field_idx = (base_rank + op.stride_offset) as usize;
1018            scratch[op.tensor_idx] = field_data[field_idx];
1019            scratch_mask[op.tensor_idx] = 1;
1020        }
1021    } else {
1022        for op in &entry.template_ops {
1023            if !op.in_disk {
1024                continue;
1025            }
1026            let field_idx = resolve_field_index(center, &op.relative, geometry, space);
1027            if let Some(idx) = field_idx {
1028                if idx < field_data.len() {
1029                    scratch[op.tensor_idx] = field_data[idx];
1030                    scratch_mask[op.tensor_idx] = 1;
1031                }
1032            }
1033        }
1034    }
1035
1036    let pool_config = entry.pool.as_ref().unwrap();
1037    let (pooled, pooled_mask, _) =
1038        pool_2d(&scratch, &scratch_mask, &entry.pre_pool_shape, pool_config);
1039
1040    let out_slice =
1041        &mut agent_output[entry.output_offset..entry.output_offset + entry.element_count];
1042    let mask_slice = &mut agent_mask[entry.mask_offset..entry.mask_offset + entry.element_count];
1043
1044    let n = pooled.len().min(entry.element_count);
1045    for i in 0..n {
1046        out_slice[i] = apply_transform(pooled[i], &entry.transform);
1047    }
1048    mask_slice[..n].copy_from_slice(&pooled_mask[..n]);
1049
1050    pooled_mask[..n].iter().filter(|&&v| v == 1).count()
1051}
1052
1053/// Generate template operations for a rectangular bounding box.
1054///
1055/// `half_extent[d]` is the half-size per dimension. The bounding box is
1056/// `(2*he[0]+1) × (2*he[1]+1) × ...` in row-major order.
1057///
1058/// If `strides` is provided (from `GridGeometry`), each op gets a precomputed
1059/// `stride_offset` for the interior fast path.
1060///
1061/// If `disk_radius` is `Some(r)`, cells with graph distance > `r` are marked
1062/// `in_disk = false`. The `geometry` is required to compute graph distance.
1063/// When `geometry` is `None`, all cells are treated as in-disk (conservative).
1064fn generate_template_ops(
1065    half_extent: &[u32],
1066    geometry: &Option<GridGeometry>,
1067    disk_radius: Option<u32>,
1068) -> Vec<TemplateOp> {
1069    let ndim = half_extent.len();
1070    let shape: Vec<usize> = half_extent.iter().map(|&he| 2 * he as usize + 1).collect();
1071    let total: usize = shape.iter().product();
1072
1073    let strides = geometry.as_ref().map(|g| g.coord_strides.as_slice());
1074
1075    let mut ops = Vec::with_capacity(total);
1076
1077    for tensor_idx in 0..total {
1078        // Decompose tensor_idx into n-d relative coords (row-major).
1079        let mut relative = Coord::new();
1080        let mut remaining = tensor_idx;
1081        // Build in reverse order, then reverse.
1082        for d in (0..ndim).rev() {
1083            let coord = (remaining % shape[d]) as i32 - half_extent[d] as i32;
1084            relative.push(coord);
1085            remaining /= shape[d];
1086        }
1087        relative.reverse();
1088
1089        let stride_offset = strides
1090            .map(|s| {
1091                relative
1092                    .iter()
1093                    .zip(s.iter())
1094                    .map(|(&r, &s)| r as isize * s as isize)
1095                    .sum::<isize>()
1096            })
1097            .unwrap_or(0);
1098
1099        let in_disk = match disk_radius {
1100            Some(r) => match geometry {
1101                Some(geo) => geo.graph_distance(&relative) <= r,
1102                None => true, // no geometry → conservative (include all)
1103            },
1104            None => true, // AgentRect → all cells valid
1105        };
1106
1107        ops.push(TemplateOp {
1108            relative,
1109            tensor_idx,
1110            stride_offset,
1111            in_disk,
1112        });
1113    }
1114
1115    ops
1116}
1117
1118/// Resolve the field data index for an absolute coordinate.
1119///
1120/// Handles three cases:
1121/// 1. Torus (all_wrap): modular wrap, always in-bounds.
1122/// 2. Grid with geometry: bounds-check then stride arithmetic.
1123/// 3. No geometry: fall back to `space.canonical_rank()`.
1124fn resolve_field_index(
1125    center: &Coord,
1126    relative: &Coord,
1127    geometry: &Option<GridGeometry>,
1128    space: &dyn Space,
1129) -> Option<usize> {
1130    if let Some(geo) = geometry {
1131        if geo.all_wrap {
1132            // Torus: wrap coordinates with modular arithmetic.
1133            let wrapped: Coord = center
1134                .iter()
1135                .zip(relative.iter())
1136                .zip(geo.coord_dims.iter())
1137                .map(|((&c, &r), &d)| {
1138                    let d = d as i32;
1139                    ((c + r) % d + d) % d
1140                })
1141                .collect();
1142            Some(geo.canonical_rank(&wrapped))
1143        } else {
1144            let abs_coord: Coord = center
1145                .iter()
1146                .zip(relative.iter())
1147                .map(|(&c, &r)| c + r)
1148                .collect();
1149            let abs_slice: &[i32] = &abs_coord;
1150            if geo.in_bounds(abs_slice) {
1151                Some(geo.canonical_rank(abs_slice))
1152            } else {
1153                None
1154            }
1155        }
1156    } else {
1157        let abs_coord: Coord = center
1158            .iter()
1159            .zip(relative.iter())
1160            .map(|(&c, &r)| c + r)
1161            .collect();
1162        space.canonical_rank(&abs_coord)
1163    }
1164}
1165
1166/// Apply a transform to a raw field value.
1167fn apply_transform(raw: f32, transform: &ObsTransform) -> f32 {
1168    match transform {
1169        ObsTransform::Identity => raw,
1170        ObsTransform::Normalize { min, max } => {
1171            let range = max - min;
1172            if range == 0.0 {
1173                0.0
1174            } else {
1175                let normalized = (raw as f64 - min) / range;
1176                normalized.clamp(0.0, 1.0) as f32
1177            }
1178        }
1179    }
1180}
1181
1182#[cfg(test)]
1183mod tests {
1184    use super::*;
1185    use crate::spec::{
1186        ObsDtype, ObsEntry, ObsRegion, ObsSpec, ObsTransform, PoolConfig, PoolKernel,
1187    };
1188    use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
1189    use murk_space::{EdgeBehavior, Hex2D, RegionSpec, Square4, Square8};
1190    use murk_test_utils::MockSnapshot;
1191
1192    fn square4_space() -> Square4 {
1193        Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
1194    }
1195
1196    fn snapshot_with_field(field: FieldId, data: Vec<f32>) -> MockSnapshot {
1197        let mut snap = MockSnapshot::new(TickId(5), WorldGenerationId(1), ParameterVersion(0));
1198        snap.set_field(field, data);
1199        snap
1200    }
1201
1202    // ── Compilation tests ────────────────────────────────────
1203
1204    #[test]
1205    fn compile_empty_spec_errors() {
1206        let space = square4_space();
1207        let spec = ObsSpec { entries: vec![] };
1208        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1209        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1210    }
1211
1212    #[test]
1213    fn compile_all_region_square4() {
1214        let space = square4_space();
1215        let spec = ObsSpec {
1216            entries: vec![ObsEntry {
1217                field_id: FieldId(0),
1218                region: ObsRegion::Fixed(RegionSpec::All),
1219                pool: None,
1220                transform: ObsTransform::Identity,
1221                dtype: ObsDtype::F32,
1222            }],
1223        };
1224        let result = ObsPlan::compile(&spec, &space).unwrap();
1225        assert_eq!(result.output_len, 9); // 3x3
1226        assert_eq!(result.mask_len, 9);
1227        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
1228    }
1229
1230    #[test]
1231    fn compile_rect_region() {
1232        let space = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
1233        let spec = ObsSpec {
1234            entries: vec![ObsEntry {
1235                field_id: FieldId(0),
1236                region: ObsRegion::Fixed(RegionSpec::Rect {
1237                    min: smallvec::smallvec![1, 1],
1238                    max: smallvec::smallvec![2, 3],
1239                }),
1240                pool: None,
1241                transform: ObsTransform::Identity,
1242                dtype: ObsDtype::F32,
1243            }],
1244        };
1245        let result = ObsPlan::compile(&spec, &space).unwrap();
1246        // 2 rows x 3 cols = 6 cells
1247        assert_eq!(result.output_len, 6);
1248        assert_eq!(result.entry_shapes, vec![vec![2, 3]]);
1249    }
1250
1251    #[test]
1252    fn compile_two_entries_offsets() {
1253        let space = square4_space();
1254        let spec = ObsSpec {
1255            entries: vec![
1256                ObsEntry {
1257                    field_id: FieldId(0),
1258                    region: ObsRegion::Fixed(RegionSpec::All),
1259                    pool: None,
1260                    transform: ObsTransform::Identity,
1261                    dtype: ObsDtype::F32,
1262                },
1263                ObsEntry {
1264                    field_id: FieldId(1),
1265                    region: ObsRegion::Fixed(RegionSpec::All),
1266                    pool: None,
1267                    transform: ObsTransform::Identity,
1268                    dtype: ObsDtype::F32,
1269                },
1270            ],
1271        };
1272        let result = ObsPlan::compile(&spec, &space).unwrap();
1273        assert_eq!(result.output_len, 18); // 9 + 9
1274        assert_eq!(result.mask_len, 18);
1275    }
1276
1277    #[test]
1278    fn compile_invalid_region_errors() {
1279        let space = square4_space();
1280        let spec = ObsSpec {
1281            entries: vec![ObsEntry {
1282                field_id: FieldId(0),
1283                region: ObsRegion::Fixed(RegionSpec::Coords(vec![smallvec::smallvec![99, 99]])),
1284                pool: None,
1285                transform: ObsTransform::Identity,
1286                dtype: ObsDtype::F32,
1287            }],
1288        };
1289        let err = ObsPlan::compile(&spec, &space).unwrap_err();
1290        assert!(matches!(err, ObsError::InvalidObsSpec { .. }));
1291    }
1292
1293    // ── Execution tests ──────────────────────────────────────
1294
1295    #[test]
1296    fn execute_identity_all_region() {
1297        let space = square4_space();
1298        // Field data in canonical (row-major) order for 3x3:
1299        // (0,0)=1, (0,1)=2, (0,2)=3, (1,0)=4, ..., (2,2)=9
1300        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1301        let snap = snapshot_with_field(FieldId(0), data);
1302
1303        let spec = ObsSpec {
1304            entries: vec![ObsEntry {
1305                field_id: FieldId(0),
1306                region: ObsRegion::Fixed(RegionSpec::All),
1307                pool: None,
1308                transform: ObsTransform::Identity,
1309                dtype: ObsDtype::F32,
1310            }],
1311        };
1312        let result = ObsPlan::compile(&spec, &space).unwrap();
1313
1314        let mut output = vec![0.0f32; result.output_len];
1315        let mut mask = vec![0u8; result.mask_len];
1316        let meta = result
1317            .plan
1318            .execute(&snap, None, &mut output, &mut mask)
1319            .unwrap();
1320
1321        // Output should match field data in canonical order.
1322        let expected: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1323        assert_eq!(output, expected);
1324        assert_eq!(mask, vec![1u8; 9]);
1325        assert_eq!(meta.tick_id, TickId(5));
1326        assert_eq!(meta.coverage, 1.0);
1327        assert_eq!(meta.world_generation_id, WorldGenerationId(1));
1328        assert_eq!(meta.parameter_version, ParameterVersion(0));
1329        assert_eq!(meta.age_ticks, 0);
1330    }
1331
1332    #[test]
1333    fn execute_normalize_transform() {
1334        let space = square4_space();
1335        // Values 0..8 mapped to [0,1] with min=0, max=8.
1336        let data: Vec<f32> = (0..9).map(|x| x as f32).collect();
1337        let snap = snapshot_with_field(FieldId(0), data);
1338
1339        let spec = ObsSpec {
1340            entries: vec![ObsEntry {
1341                field_id: FieldId(0),
1342                region: ObsRegion::Fixed(RegionSpec::All),
1343                pool: None,
1344                transform: ObsTransform::Normalize { min: 0.0, max: 8.0 },
1345                dtype: ObsDtype::F32,
1346            }],
1347        };
1348        let result = ObsPlan::compile(&spec, &space).unwrap();
1349
1350        let mut output = vec![0.0f32; result.output_len];
1351        let mut mask = vec![0u8; result.mask_len];
1352        result
1353            .plan
1354            .execute(&snap, None, &mut output, &mut mask)
1355            .unwrap();
1356
1357        // Each value x should be x/8.
1358        for (i, &v) in output.iter().enumerate() {
1359            let expected = i as f32 / 8.0;
1360            assert!(
1361                (v - expected).abs() < 1e-6,
1362                "output[{i}] = {v}, expected {expected}"
1363            );
1364        }
1365    }
1366
1367    #[test]
1368    fn execute_normalize_clamps_out_of_range() {
1369        let space = square4_space();
1370        // Values -5, 0, 5, 10, 15 etc.
1371        let data: Vec<f32> = (-4..5).map(|x| x as f32 * 5.0).collect();
1372        let snap = snapshot_with_field(FieldId(0), data);
1373
1374        let spec = ObsSpec {
1375            entries: vec![ObsEntry {
1376                field_id: FieldId(0),
1377                region: ObsRegion::Fixed(RegionSpec::All),
1378                pool: None,
1379                transform: ObsTransform::Normalize {
1380                    min: 0.0,
1381                    max: 10.0,
1382                },
1383                dtype: ObsDtype::F32,
1384            }],
1385        };
1386        let result = ObsPlan::compile(&spec, &space).unwrap();
1387
1388        let mut output = vec![0.0f32; result.output_len];
1389        let mut mask = vec![0u8; result.mask_len];
1390        result
1391            .plan
1392            .execute(&snap, None, &mut output, &mut mask)
1393            .unwrap();
1394
1395        for &v in &output {
1396            assert!((0.0..=1.0).contains(&v), "value {v} out of [0,1] range");
1397        }
1398    }
1399
1400    #[test]
1401    fn execute_normalize_zero_range() {
1402        let space = square4_space();
1403        let data = vec![5.0f32; 9];
1404        let snap = snapshot_with_field(FieldId(0), data);
1405
1406        let spec = ObsSpec {
1407            entries: vec![ObsEntry {
1408                field_id: FieldId(0),
1409                region: ObsRegion::Fixed(RegionSpec::All),
1410                pool: None,
1411                transform: ObsTransform::Normalize { min: 5.0, max: 5.0 },
1412                dtype: ObsDtype::F32,
1413            }],
1414        };
1415        let result = ObsPlan::compile(&spec, &space).unwrap();
1416
1417        let mut output = vec![-1.0f32; result.output_len];
1418        let mut mask = vec![0u8; result.mask_len];
1419        result
1420            .plan
1421            .execute(&snap, None, &mut output, &mut mask)
1422            .unwrap();
1423
1424        // Zero range → all outputs 0.0.
1425        assert!(output.iter().all(|&v| v == 0.0));
1426    }
1427
1428    #[test]
1429    fn execute_rect_subregion_correct_values() {
1430        let space = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
1431        // 4x4 field: value = row * 4 + col + 1.
1432        let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
1433        let snap = snapshot_with_field(FieldId(0), data);
1434
1435        let spec = ObsSpec {
1436            entries: vec![ObsEntry {
1437                field_id: FieldId(0),
1438                region: ObsRegion::Fixed(RegionSpec::Rect {
1439                    min: smallvec::smallvec![1, 1],
1440                    max: smallvec::smallvec![2, 2],
1441                }),
1442                pool: None,
1443                transform: ObsTransform::Identity,
1444                dtype: ObsDtype::F32,
1445            }],
1446        };
1447        let result = ObsPlan::compile(&spec, &space).unwrap();
1448        assert_eq!(result.output_len, 4); // 2x2
1449
1450        let mut output = vec![0.0f32; result.output_len];
1451        let mut mask = vec![0u8; result.mask_len];
1452        result
1453            .plan
1454            .execute(&snap, None, &mut output, &mut mask)
1455            .unwrap();
1456
1457        // Rect covers (1,1)=6, (1,2)=7, (2,1)=10, (2,2)=11
1458        assert_eq!(output, vec![6.0, 7.0, 10.0, 11.0]);
1459        assert_eq!(mask, vec![1, 1, 1, 1]);
1460    }
1461
1462    #[test]
1463    fn execute_two_fields() {
1464        let space = square4_space();
1465        let data_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1466        let data_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1467        let mut snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1468        snap.set_field(FieldId(0), data_a);
1469        snap.set_field(FieldId(1), data_b);
1470
1471        let spec = ObsSpec {
1472            entries: vec![
1473                ObsEntry {
1474                    field_id: FieldId(0),
1475                    region: ObsRegion::Fixed(RegionSpec::All),
1476                    pool: None,
1477                    transform: ObsTransform::Identity,
1478                    dtype: ObsDtype::F32,
1479                },
1480                ObsEntry {
1481                    field_id: FieldId(1),
1482                    region: ObsRegion::Fixed(RegionSpec::All),
1483                    pool: None,
1484                    transform: ObsTransform::Identity,
1485                    dtype: ObsDtype::F32,
1486                },
1487            ],
1488        };
1489        let result = ObsPlan::compile(&spec, &space).unwrap();
1490        assert_eq!(result.output_len, 18);
1491
1492        let mut output = vec![0.0f32; result.output_len];
1493        let mut mask = vec![0u8; result.mask_len];
1494        result
1495            .plan
1496            .execute(&snap, None, &mut output, &mut mask)
1497            .unwrap();
1498
1499        // First 9: field 0, next 9: field 1.
1500        let expected_a: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1501        let expected_b: Vec<f32> = (10..=18).map(|x| x as f32).collect();
1502        assert_eq!(&output[..9], &expected_a);
1503        assert_eq!(&output[9..], &expected_b);
1504    }
1505
1506    #[test]
1507    fn execute_missing_field_errors() {
1508        let space = square4_space();
1509        let snap = MockSnapshot::new(TickId(1), WorldGenerationId(1), ParameterVersion(0));
1510
1511        let spec = ObsSpec {
1512            entries: vec![ObsEntry {
1513                field_id: FieldId(0),
1514                region: ObsRegion::Fixed(RegionSpec::All),
1515                pool: None,
1516                transform: ObsTransform::Identity,
1517                dtype: ObsDtype::F32,
1518            }],
1519        };
1520        let result = ObsPlan::compile(&spec, &space).unwrap();
1521
1522        let mut output = vec![0.0f32; result.output_len];
1523        let mut mask = vec![0u8; result.mask_len];
1524        let err = result
1525            .plan
1526            .execute(&snap, None, &mut output, &mut mask)
1527            .unwrap_err();
1528        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1529    }
1530
1531    #[test]
1532    fn execute_buffer_too_small_errors() {
1533        let space = square4_space();
1534        let data: Vec<f32> = vec![0.0; 9];
1535        let snap = snapshot_with_field(FieldId(0), data);
1536
1537        let spec = ObsSpec {
1538            entries: vec![ObsEntry {
1539                field_id: FieldId(0),
1540                region: ObsRegion::Fixed(RegionSpec::All),
1541                pool: None,
1542                transform: ObsTransform::Identity,
1543                dtype: ObsDtype::F32,
1544            }],
1545        };
1546        let result = ObsPlan::compile(&spec, &space).unwrap();
1547
1548        let mut output = vec![0.0f32; 4]; // too small
1549        let mut mask = vec![0u8; result.mask_len];
1550        let err = result
1551            .plan
1552            .execute(&snap, None, &mut output, &mut mask)
1553            .unwrap_err();
1554        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1555    }
1556
1557    // ── Validity / coverage tests ────────────────────────────
1558
1559    #[test]
1560    fn valid_ratio_one_for_square_all() {
1561        let space = square4_space();
1562        let data: Vec<f32> = vec![1.0; 9];
1563        let snap = snapshot_with_field(FieldId(0), data);
1564
1565        let spec = ObsSpec {
1566            entries: vec![ObsEntry {
1567                field_id: FieldId(0),
1568                region: ObsRegion::Fixed(RegionSpec::All),
1569                pool: None,
1570                transform: ObsTransform::Identity,
1571                dtype: ObsDtype::F32,
1572            }],
1573        };
1574        let result = ObsPlan::compile(&spec, &space).unwrap();
1575
1576        let mut output = vec![0.0f32; result.output_len];
1577        let mut mask = vec![0u8; result.mask_len];
1578        let meta = result
1579            .plan
1580            .execute(&snap, None, &mut output, &mut mask)
1581            .unwrap();
1582
1583        assert_eq!(meta.coverage, 1.0);
1584    }
1585
1586    // ── Generation binding tests ─────────────────────────────
1587
1588    #[test]
1589    fn plan_invalidated_on_generation_mismatch() {
1590        let space = square4_space();
1591        let data: Vec<f32> = vec![1.0; 9];
1592        let snap = snapshot_with_field(FieldId(0), data);
1593
1594        let spec = ObsSpec {
1595            entries: vec![ObsEntry {
1596                field_id: FieldId(0),
1597                region: ObsRegion::Fixed(RegionSpec::All),
1598                pool: None,
1599                transform: ObsTransform::Identity,
1600                dtype: ObsDtype::F32,
1601            }],
1602        };
1603        // Compile bound to generation 99, but snapshot is generation 1.
1604        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(99)).unwrap();
1605
1606        let mut output = vec![0.0f32; result.output_len];
1607        let mut mask = vec![0u8; result.mask_len];
1608        let err = result
1609            .plan
1610            .execute(&snap, None, &mut output, &mut mask)
1611            .unwrap_err();
1612        assert!(matches!(err, ObsError::PlanInvalidated { .. }));
1613    }
1614
1615    #[test]
1616    fn generation_match_succeeds() {
1617        let space = square4_space();
1618        let data: Vec<f32> = vec![1.0; 9];
1619        let snap = snapshot_with_field(FieldId(0), data);
1620
1621        let spec = ObsSpec {
1622            entries: vec![ObsEntry {
1623                field_id: FieldId(0),
1624                region: ObsRegion::Fixed(RegionSpec::All),
1625                pool: None,
1626                transform: ObsTransform::Identity,
1627                dtype: ObsDtype::F32,
1628            }],
1629        };
1630        let result = ObsPlan::compile_bound(&spec, &space, WorldGenerationId(1)).unwrap();
1631
1632        let mut output = vec![0.0f32; result.output_len];
1633        let mut mask = vec![0u8; result.mask_len];
1634        result
1635            .plan
1636            .execute(&snap, None, &mut output, &mut mask)
1637            .unwrap();
1638    }
1639
1640    #[test]
1641    fn unbound_plan_ignores_generation() {
1642        let space = square4_space();
1643        let data: Vec<f32> = vec![1.0; 9];
1644        let snap = snapshot_with_field(FieldId(0), data);
1645
1646        let spec = ObsSpec {
1647            entries: vec![ObsEntry {
1648                field_id: FieldId(0),
1649                region: ObsRegion::Fixed(RegionSpec::All),
1650                pool: None,
1651                transform: ObsTransform::Identity,
1652                dtype: ObsDtype::F32,
1653            }],
1654        };
1655        // Unbound plan — no generation check.
1656        let result = ObsPlan::compile(&spec, &space).unwrap();
1657
1658        let mut output = vec![0.0f32; result.output_len];
1659        let mut mask = vec![0u8; result.mask_len];
1660        result
1661            .plan
1662            .execute(&snap, None, &mut output, &mut mask)
1663            .unwrap();
1664    }
1665
1666    // ── Metadata tests ───────────────────────────────────────
1667
1668    #[test]
1669    fn metadata_fields_populated() {
1670        let space = square4_space();
1671        let data: Vec<f32> = vec![1.0; 9];
1672        let mut snap = MockSnapshot::new(TickId(42), WorldGenerationId(7), ParameterVersion(3));
1673        snap.set_field(FieldId(0), data);
1674
1675        let spec = ObsSpec {
1676            entries: vec![ObsEntry {
1677                field_id: FieldId(0),
1678                region: ObsRegion::Fixed(RegionSpec::All),
1679                pool: None,
1680                transform: ObsTransform::Identity,
1681                dtype: ObsDtype::F32,
1682            }],
1683        };
1684        let result = ObsPlan::compile(&spec, &space).unwrap();
1685
1686        let mut output = vec![0.0f32; result.output_len];
1687        let mut mask = vec![0u8; result.mask_len];
1688        let meta = result
1689            .plan
1690            .execute(&snap, None, &mut output, &mut mask)
1691            .unwrap();
1692
1693        assert_eq!(meta.tick_id, TickId(42));
1694        assert_eq!(meta.age_ticks, 0);
1695        assert_eq!(meta.coverage, 1.0);
1696        assert_eq!(meta.world_generation_id, WorldGenerationId(7));
1697        assert_eq!(meta.parameter_version, ParameterVersion(3));
1698    }
1699
1700    // ── Batch execution tests ────────────────────────────────
1701
1702    #[test]
1703    fn execute_batch_n1_matches_execute() {
1704        let space = square4_space();
1705        let data: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1706        let snap = snapshot_with_field(FieldId(0), data.clone());
1707
1708        let spec = ObsSpec {
1709            entries: vec![ObsEntry {
1710                field_id: FieldId(0),
1711                region: ObsRegion::Fixed(RegionSpec::All),
1712                pool: None,
1713                transform: ObsTransform::Identity,
1714                dtype: ObsDtype::F32,
1715            }],
1716        };
1717        let result = ObsPlan::compile(&spec, &space).unwrap();
1718
1719        // Single execute.
1720        let mut out_single = vec![0.0f32; result.output_len];
1721        let mut mask_single = vec![0u8; result.mask_len];
1722        let meta_single = result
1723            .plan
1724            .execute(&snap, None, &mut out_single, &mut mask_single)
1725            .unwrap();
1726
1727        // Batch N=1.
1728        let mut out_batch = vec![0.0f32; result.output_len];
1729        let mut mask_batch = vec![0u8; result.mask_len];
1730        let snap_ref: &dyn SnapshotAccess = &snap;
1731        let meta_batch = result
1732            .plan
1733            .execute_batch(&[snap_ref], None, &mut out_batch, &mut mask_batch)
1734            .unwrap();
1735
1736        assert_eq!(out_single, out_batch);
1737        assert_eq!(mask_single, mask_batch);
1738        assert_eq!(meta_single, meta_batch[0]);
1739    }
1740
1741    #[test]
1742    fn execute_batch_multiple_snapshots() {
1743        let space = square4_space();
1744        let spec = ObsSpec {
1745            entries: vec![ObsEntry {
1746                field_id: FieldId(0),
1747                region: ObsRegion::Fixed(RegionSpec::All),
1748                pool: None,
1749                transform: ObsTransform::Identity,
1750                dtype: ObsDtype::F32,
1751            }],
1752        };
1753        let result = ObsPlan::compile(&spec, &space).unwrap();
1754
1755        let snap_a = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1756        let snap_b = snapshot_with_field(FieldId(0), vec![2.0; 9]);
1757
1758        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap_a, &snap_b];
1759        let mut output = vec![0.0f32; result.output_len * 2];
1760        let mut mask = vec![0u8; result.mask_len * 2];
1761        let metas = result
1762            .plan
1763            .execute_batch(&snaps, None, &mut output, &mut mask)
1764            .unwrap();
1765
1766        assert_eq!(metas.len(), 2);
1767        assert!(output[..9].iter().all(|&v| v == 1.0));
1768        assert!(output[9..].iter().all(|&v| v == 2.0));
1769    }
1770
1771    #[test]
1772    fn execute_batch_buffer_too_small() {
1773        let space = square4_space();
1774        let spec = ObsSpec {
1775            entries: vec![ObsEntry {
1776                field_id: FieldId(0),
1777                region: ObsRegion::Fixed(RegionSpec::All),
1778                pool: None,
1779                transform: ObsTransform::Identity,
1780                dtype: ObsDtype::F32,
1781            }],
1782        };
1783        let result = ObsPlan::compile(&spec, &space).unwrap();
1784
1785        let snap = snapshot_with_field(FieldId(0), vec![1.0; 9]);
1786        let snaps: Vec<&dyn SnapshotAccess> = vec![&snap, &snap];
1787        let mut output = vec![0.0f32; 9]; // need 18
1788        let mut mask = vec![0u8; 18];
1789        let err = result
1790            .plan
1791            .execute_batch(&snaps, None, &mut output, &mut mask)
1792            .unwrap_err();
1793        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1794    }
1795
1796    // ── Field length mismatch tests ──────────────────────────
1797
1798    #[test]
1799    fn short_field_buffer_returns_error_not_panic() {
1800        let space = square4_space(); // 3x3 = 9 cells
1801        let spec = ObsSpec {
1802            entries: vec![ObsEntry {
1803                field_id: FieldId(0),
1804                region: ObsRegion::Fixed(RegionSpec::All),
1805                pool: None,
1806                transform: ObsTransform::Identity,
1807                dtype: ObsDtype::F32,
1808            }],
1809        };
1810        let result = ObsPlan::compile(&spec, &space).unwrap();
1811
1812        // Snapshot field has only 4 elements, but plan expects 9.
1813        let snap = snapshot_with_field(FieldId(0), vec![1.0; 4]);
1814        let mut output = vec![0.0f32; result.output_len];
1815        let mut mask = vec![0u8; result.mask_len];
1816        let err = result
1817            .plan
1818            .execute(&snap, None, &mut output, &mut mask)
1819            .unwrap_err();
1820        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1821    }
1822
1823    // ── Standard plan (agent-centered) tests ─────────────────
1824
1825    #[test]
1826    fn standard_plan_detected_from_agent_region() {
1827        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1828        let spec = ObsSpec {
1829            entries: vec![ObsEntry {
1830                field_id: FieldId(0),
1831                region: ObsRegion::AgentRect {
1832                    half_extent: smallvec::smallvec![2, 2],
1833                },
1834                pool: None,
1835                transform: ObsTransform::Identity,
1836                dtype: ObsDtype::F32,
1837            }],
1838        };
1839        let result = ObsPlan::compile(&spec, &space).unwrap();
1840        assert!(result.plan.is_standard());
1841        // 5x5 = 25 elements
1842        assert_eq!(result.output_len, 25);
1843        assert_eq!(result.entry_shapes, vec![vec![5, 5]]);
1844    }
1845
1846    #[test]
1847    fn execute_on_standard_plan_errors() {
1848        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1849        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
1850        let snap = snapshot_with_field(FieldId(0), data);
1851
1852        let spec = ObsSpec {
1853            entries: vec![ObsEntry {
1854                field_id: FieldId(0),
1855                region: ObsRegion::AgentDisk { radius: 2 },
1856                pool: None,
1857                transform: ObsTransform::Identity,
1858                dtype: ObsDtype::F32,
1859            }],
1860        };
1861        let result = ObsPlan::compile(&spec, &space).unwrap();
1862
1863        let mut output = vec![0.0f32; result.output_len];
1864        let mut mask = vec![0u8; result.mask_len];
1865        let err = result
1866            .plan
1867            .execute(&snap, None, &mut output, &mut mask)
1868            .unwrap_err();
1869        assert!(matches!(err, ObsError::ExecutionFailed { .. }));
1870    }
1871
1872    #[test]
1873    fn interior_boundary_equivalence() {
1874        // An INTERIOR agent using Standard plan should produce identical
1875        // output to a Simple plan with an explicit Rect at the same position.
1876        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
1877        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
1878        let snap = snapshot_with_field(FieldId(0), data);
1879
1880        let radius = 3u32;
1881        let center: Coord = smallvec::smallvec![10, 10]; // interior
1882
1883        // Standard plan: AgentRect centered on agent.
1884        let standard_spec = ObsSpec {
1885            entries: vec![ObsEntry {
1886                field_id: FieldId(0),
1887                region: ObsRegion::AgentRect {
1888                    half_extent: smallvec::smallvec![radius, radius],
1889                },
1890                pool: None,
1891                transform: ObsTransform::Identity,
1892                dtype: ObsDtype::F32,
1893            }],
1894        };
1895        let std_result = ObsPlan::compile(&standard_spec, &space).unwrap();
1896        let mut std_output = vec![0.0f32; std_result.output_len];
1897        let mut std_mask = vec![0u8; std_result.mask_len];
1898        std_result
1899            .plan
1900            .execute_agents(
1901                &snap,
1902                &space,
1903                std::slice::from_ref(&center),
1904                None,
1905                &mut std_output,
1906                &mut std_mask,
1907            )
1908            .unwrap();
1909
1910        // Simple plan: explicit Rect covering the same area.
1911        let r = radius as i32;
1912        let simple_spec = ObsSpec {
1913            entries: vec![ObsEntry {
1914                field_id: FieldId(0),
1915                region: ObsRegion::Fixed(RegionSpec::Rect {
1916                    min: smallvec::smallvec![10 - r, 10 - r],
1917                    max: smallvec::smallvec![10 + r, 10 + r],
1918                }),
1919                pool: None,
1920                transform: ObsTransform::Identity,
1921                dtype: ObsDtype::F32,
1922            }],
1923        };
1924        let simple_result = ObsPlan::compile(&simple_spec, &space).unwrap();
1925        let mut simple_output = vec![0.0f32; simple_result.output_len];
1926        let mut simple_mask = vec![0u8; simple_result.mask_len];
1927        simple_result
1928            .plan
1929            .execute(&snap, None, &mut simple_output, &mut simple_mask)
1930            .unwrap();
1931
1932        // Same shape, same values.
1933        assert_eq!(std_result.output_len, simple_result.output_len);
1934        assert_eq!(std_output, simple_output);
1935        assert_eq!(std_mask, simple_mask);
1936    }
1937
1938    #[test]
1939    fn boundary_agent_gets_padding() {
1940        // Agent at corner (0,0) with radius 2: many cells out-of-bounds.
1941        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
1942        let data: Vec<f32> = (0..100).map(|x| x as f32 + 1.0).collect();
1943        let snap = snapshot_with_field(FieldId(0), data);
1944
1945        let spec = ObsSpec {
1946            entries: vec![ObsEntry {
1947                field_id: FieldId(0),
1948                region: ObsRegion::AgentRect {
1949                    half_extent: smallvec::smallvec![2, 2],
1950                },
1951                pool: None,
1952                transform: ObsTransform::Identity,
1953                dtype: ObsDtype::F32,
1954            }],
1955        };
1956        let result = ObsPlan::compile(&spec, &space).unwrap();
1957        let center: Coord = smallvec::smallvec![0, 0];
1958        let mut output = vec![0.0f32; result.output_len];
1959        let mut mask = vec![0u8; result.mask_len];
1960        let metas = result
1961            .plan
1962            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
1963            .unwrap();
1964
1965        // 5x5 = 25 cells total. Agent at (0,0) with radius 2:
1966        // Only cells with row in [0,2] and col in [0,2] are valid (3x3 = 9).
1967        let valid_count: usize = mask.iter().filter(|&&v| v == 1).count();
1968        assert_eq!(valid_count, 9);
1969
1970        // Coverage should be 9/25
1971        assert!((metas[0].coverage - 9.0 / 25.0).abs() < 1e-6);
1972
1973        // Check that the top-left corner of the bounding box (relative [-2,-2])
1974        // is padding (mask=0, value=0).
1975        assert_eq!(mask[0], 0); // relative (-2,-2) → absolute (-2,-2) → out of bounds
1976        assert_eq!(output[0], 0.0);
1977
1978        // The cell at relative (0,0) is at tensor_idx = 2*5+2 = 12
1979        // Absolute (0,0) → field value = 1.0
1980        assert_eq!(mask[12], 1);
1981        assert_eq!(output[12], 1.0);
1982    }
1983
1984    #[test]
1985    fn hex_foveation_interior() {
1986        // Test agent-centered observation on Hex2D grid.
1987        let space = Hex2D::new(20, 20).unwrap(); // 20 rows, 20 cols
1988        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
1989        let snap = snapshot_with_field(FieldId(0), data);
1990
1991        let spec = ObsSpec {
1992            entries: vec![ObsEntry {
1993                field_id: FieldId(0),
1994                region: ObsRegion::AgentDisk { radius: 2 },
1995                pool: None,
1996                transform: ObsTransform::Identity,
1997                dtype: ObsDtype::F32,
1998            }],
1999        };
2000        let result = ObsPlan::compile(&spec, &space).unwrap();
2001        assert_eq!(result.output_len, 25); // 5x5 bounding box (tensor shape)
2002
2003        // Interior agent: q=10, r=10
2004        let center: Coord = smallvec::smallvec![10, 10];
2005        let mut output = vec![0.0f32; result.output_len];
2006        let mut mask = vec![0u8; result.mask_len];
2007        result
2008            .plan
2009            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2010            .unwrap();
2011
2012        // Hex disk of radius 2: 19 of 25 cells are within hex distance.
2013        // The 6 corners of the 5x5 bounding box exceed hex distance 2.
2014        // Hex distance = max(|dq|, |dr|, |dq+dr|) for axial coordinates.
2015        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2016        assert_eq!(valid_count, 19);
2017
2018        // Corners that should be masked out (distance > 2):
2019        // tensor_idx 0: dq=-2,dr=-2 → max(2,2,4)=4
2020        // tensor_idx 1: dq=-2,dr=-1 → max(2,1,3)=3
2021        // tensor_idx 5: dq=-1,dr=-2 → max(1,2,3)=3
2022        // tensor_idx 19: dq=+1,dr=+2 → max(1,2,3)=3
2023        // tensor_idx 23: dq=+2,dr=+1 → max(2,1,3)=3
2024        // tensor_idx 24: dq=+2,dr=+2 → max(2,2,4)=4
2025        for &idx in &[0, 1, 5, 19, 23, 24] {
2026            assert_eq!(mask[idx], 0, "tensor_idx {idx} should be outside hex disk");
2027            assert_eq!(output[idx], 0.0, "tensor_idx {idx} should be zero-padded");
2028        }
2029
2030        // Center cell is at tensor_idx = 2*5+2 = 12 (relative [0,0]).
2031        // Hex2D canonical_rank([q,r]) = r*cols + q = 10*20 + 10 = 210
2032        assert_eq!(output[12], 210.0);
2033
2034        // Cell at relative [1, 0] (dq=+1, dr=0) → absolute [11, 10]
2035        // rank = 10*20 + 11 = 211
2036        // In row-major bounding box: dim0_idx=3, dim1_idx=2 → tensor_idx = 3*5+2 = 17
2037        assert_eq!(output[17], 211.0);
2038    }
2039
2040    #[test]
2041    fn wrap_space_all_interior() {
2042        // Wrapped (torus) space: all agents are interior.
2043        let space = Square4::new(10, 10, EdgeBehavior::Wrap).unwrap();
2044        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2045        let snap = snapshot_with_field(FieldId(0), data);
2046
2047        let spec = ObsSpec {
2048            entries: vec![ObsEntry {
2049                field_id: FieldId(0),
2050                region: ObsRegion::AgentRect {
2051                    half_extent: smallvec::smallvec![2, 2],
2052                },
2053                pool: None,
2054                transform: ObsTransform::Identity,
2055                dtype: ObsDtype::F32,
2056            }],
2057        };
2058        let result = ObsPlan::compile(&spec, &space).unwrap();
2059
2060        // Agent at corner (0,0) — still interior on torus.
2061        let center: Coord = smallvec::smallvec![0, 0];
2062        let mut output = vec![0.0f32; result.output_len];
2063        let mut mask = vec![0u8; result.mask_len];
2064        result
2065            .plan
2066            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2067            .unwrap();
2068
2069        // All 25 cells valid (torus wraps).
2070        assert!(mask.iter().all(|&v| v == 1));
2071        assert_eq!(output[12], 0.0); // center (0,0) → rank 0
2072    }
2073
2074    #[test]
2075    fn execute_agents_multiple() {
2076        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2077        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2078        let snap = snapshot_with_field(FieldId(0), data);
2079
2080        let spec = ObsSpec {
2081            entries: vec![ObsEntry {
2082                field_id: FieldId(0),
2083                region: ObsRegion::AgentRect {
2084                    half_extent: smallvec::smallvec![1, 1],
2085                },
2086                pool: None,
2087                transform: ObsTransform::Identity,
2088                dtype: ObsDtype::F32,
2089            }],
2090        };
2091        let result = ObsPlan::compile(&spec, &space).unwrap();
2092        assert_eq!(result.output_len, 9); // 3x3
2093
2094        // Two agents: one interior, one at edge.
2095        let centers = vec![
2096            smallvec::smallvec![5, 5], // interior
2097            smallvec::smallvec![0, 5], // top edge
2098        ];
2099        let n = centers.len();
2100        let mut output = vec![0.0f32; result.output_len * n];
2101        let mut mask = vec![0u8; result.mask_len * n];
2102        let metas = result
2103            .plan
2104            .execute_agents(&snap, &space, &centers, None, &mut output, &mut mask)
2105            .unwrap();
2106
2107        assert_eq!(metas.len(), 2);
2108
2109        // Agent 0 (interior): all 9 cells valid, center = (5,5) → rank 55
2110        assert!(mask[..9].iter().all(|&v| v == 1));
2111        assert_eq!(output[4], 55.0); // center at tensor_idx = 1*3+1 = 4
2112
2113        // Agent 1 (top edge): row -1 is out of bounds → 3 cells masked
2114        let agent1_mask = &mask[9..18];
2115        let valid_count: usize = agent1_mask.iter().filter(|&&v| v == 1).count();
2116        assert_eq!(valid_count, 6); // 2 rows in-bounds × 3 cols
2117    }
2118
2119    #[test]
2120    fn execute_agents_with_normalize() {
2121        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2122        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2123        let snap = snapshot_with_field(FieldId(0), data);
2124
2125        let spec = ObsSpec {
2126            entries: vec![ObsEntry {
2127                field_id: FieldId(0),
2128                region: ObsRegion::AgentRect {
2129                    half_extent: smallvec::smallvec![1, 1],
2130                },
2131                pool: None,
2132                transform: ObsTransform::Normalize {
2133                    min: 0.0,
2134                    max: 99.0,
2135                },
2136                dtype: ObsDtype::F32,
2137            }],
2138        };
2139        let result = ObsPlan::compile(&spec, &space).unwrap();
2140
2141        let center: Coord = smallvec::smallvec![5, 5];
2142        let mut output = vec![0.0f32; result.output_len];
2143        let mut mask = vec![0u8; result.mask_len];
2144        result
2145            .plan
2146            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2147            .unwrap();
2148
2149        // Center (5,5) rank=55, normalized = 55/99 ≈ 0.5556
2150        let expected = 55.0 / 99.0;
2151        assert!((output[4] - expected as f32).abs() < 1e-5);
2152    }
2153
2154    #[test]
2155    fn execute_agents_with_pooling() {
2156        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2157        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2158        let snap = snapshot_with_field(FieldId(0), data);
2159
2160        // AgentRect with half_extent=3 → 7x7 bounding box.
2161        // Mean pool 2x2 stride 2 → floor((7-2)/2)+1 = 3 per dim → 3x3 = 9 output.
2162        let spec = ObsSpec {
2163            entries: vec![ObsEntry {
2164                field_id: FieldId(0),
2165                region: ObsRegion::AgentRect {
2166                    half_extent: smallvec::smallvec![3, 3],
2167                },
2168                pool: Some(PoolConfig {
2169                    kernel: PoolKernel::Mean,
2170                    kernel_size: 2,
2171                    stride: 2,
2172                }),
2173                transform: ObsTransform::Identity,
2174                dtype: ObsDtype::F32,
2175            }],
2176        };
2177        let result = ObsPlan::compile(&spec, &space).unwrap();
2178        assert_eq!(result.output_len, 9); // 3x3
2179        assert_eq!(result.entry_shapes, vec![vec![3, 3]]);
2180
2181        // Interior agent at (10, 10): all cells valid.
2182        let center: Coord = smallvec::smallvec![10, 10];
2183        let mut output = vec![0.0f32; result.output_len];
2184        let mut mask = vec![0u8; result.mask_len];
2185        result
2186            .plan
2187            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2188            .unwrap();
2189
2190        // All pooled cells should be valid.
2191        assert!(mask.iter().all(|&v| v == 1));
2192
2193        // Verify first pooled cell: mean of top-left 2x2 of the 7x7 gather.
2194        // Gather bounding box starts at (10-3, 10-3) = (7, 7).
2195        // Top-left 2x2: (7,7)=147, (7,8)=148, (8,7)=167, (8,8)=168
2196        // Mean = (147+148+167+168)/4 = 157.5
2197        assert!((output[0] - 157.5).abs() < 1e-4);
2198    }
2199
2200    #[test]
2201    fn mixed_fixed_and_agent_entries() {
2202        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2203        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2204        let snap = snapshot_with_field(FieldId(0), data);
2205
2206        let spec = ObsSpec {
2207            entries: vec![
2208                // Fixed entry: full grid (100 elements).
2209                ObsEntry {
2210                    field_id: FieldId(0),
2211                    region: ObsRegion::Fixed(RegionSpec::All),
2212                    pool: None,
2213                    transform: ObsTransform::Identity,
2214                    dtype: ObsDtype::F32,
2215                },
2216                // Agent entry: 3x3 rect around agent.
2217                ObsEntry {
2218                    field_id: FieldId(0),
2219                    region: ObsRegion::AgentRect {
2220                        half_extent: smallvec::smallvec![1, 1],
2221                    },
2222                    pool: None,
2223                    transform: ObsTransform::Identity,
2224                    dtype: ObsDtype::F32,
2225                },
2226            ],
2227        };
2228        let result = ObsPlan::compile(&spec, &space).unwrap();
2229        assert!(result.plan.is_standard());
2230        assert_eq!(result.output_len, 109); // 100 + 9
2231
2232        let center: Coord = smallvec::smallvec![5, 5];
2233        let mut output = vec![0.0f32; result.output_len];
2234        let mut mask = vec![0u8; result.mask_len];
2235        result
2236            .plan
2237            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2238            .unwrap();
2239
2240        // Fixed entry: first 100 elements match field data.
2241        let expected: Vec<f32> = (0..100).map(|x| x as f32).collect();
2242        assert_eq!(&output[..100], &expected[..]);
2243        assert!(mask[..100].iter().all(|&v| v == 1));
2244
2245        // Agent entry: 3x3 centered on (5,5). Center at tensor_idx = 1*3+1 = 4.
2246        // rank(5,5) = 55
2247        assert_eq!(output[100 + 4], 55.0);
2248    }
2249
2250    #[test]
2251    fn wrong_dimensionality_returns_error() {
2252        // 2D space but 1D agent center → should error, not panic.
2253        let space = Square4::new(10, 10, EdgeBehavior::Absorb).unwrap();
2254        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2255        let snap = snapshot_with_field(FieldId(0), data);
2256
2257        let spec = ObsSpec {
2258            entries: vec![ObsEntry {
2259                field_id: FieldId(0),
2260                region: ObsRegion::AgentDisk { radius: 1 },
2261                pool: None,
2262                transform: ObsTransform::Identity,
2263                dtype: ObsDtype::F32,
2264            }],
2265        };
2266        let result = ObsPlan::compile(&spec, &space).unwrap();
2267
2268        let bad_center: Coord = smallvec::smallvec![5]; // 1D, not 2D
2269        let mut output = vec![0.0f32; result.output_len];
2270        let mut mask = vec![0u8; result.mask_len];
2271        let err =
2272            result
2273                .plan
2274                .execute_agents(&snap, &space, &[bad_center], None, &mut output, &mut mask);
2275        assert!(err.is_err());
2276        let msg = format!("{}", err.unwrap_err());
2277        assert!(
2278            msg.contains("dimensions"),
2279            "error should mention dimensions: {msg}"
2280        );
2281    }
2282
2283    #[test]
2284    fn agent_disk_square4_filters_corners() {
2285        // On a 4-connected grid, AgentDisk radius=2 should use Manhattan distance.
2286        // Bounding box is 5x5 = 25, but Manhattan disk has 13 cells (diamond shape).
2287        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2288        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2289        let snap = snapshot_with_field(FieldId(0), data);
2290
2291        let spec = ObsSpec {
2292            entries: vec![ObsEntry {
2293                field_id: FieldId(0),
2294                region: ObsRegion::AgentDisk { radius: 2 },
2295                pool: None,
2296                transform: ObsTransform::Identity,
2297                dtype: ObsDtype::F32,
2298            }],
2299        };
2300        let result = ObsPlan::compile(&spec, &space).unwrap();
2301        assert_eq!(result.output_len, 25); // tensor shape is still 5x5
2302
2303        // Interior agent at (10, 10).
2304        let center: Coord = smallvec::smallvec![10, 10];
2305        let mut output = vec![0.0f32; 25];
2306        let mut mask = vec![0u8; 25];
2307        result
2308            .plan
2309            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2310            .unwrap();
2311
2312        // Manhattan distance disk of radius 2 on a 5x5 bounding box:
2313        //   . . X . .    (row -2: only center col)
2314        //   . X X X .    (row -1: 3 cells)
2315        //   X X X X X    (row  0: 5 cells)
2316        //   . X X X .    (row +1: 3 cells)
2317        //   . . X . .    (row +2: only center col)
2318        // Total: 1 + 3 + 5 + 3 + 1 = 13 cells
2319        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2320        assert_eq!(
2321            valid_count, 13,
2322            "Manhattan disk radius=2 should have 13 cells"
2323        );
2324
2325        // Corners should be masked out: (dr,dc) where |dr|+|dc| > 2
2326        // tensor_idx 0: dr=-2,dc=-2 → dist=4 → OUT
2327        // tensor_idx 4: dr=-2,dc=+2 → dist=4 → OUT
2328        // tensor_idx 20: dr=+2,dc=-2 → dist=4 → OUT
2329        // tensor_idx 24: dr=+2,dc=+2 → dist=4 → OUT
2330        for &idx in &[0, 4, 20, 24] {
2331            assert_eq!(
2332                mask[idx], 0,
2333                "corner tensor_idx {idx} should be outside disk"
2334            );
2335        }
2336
2337        // Center cell: tensor_idx = 2*5+2 = 12, absolute = row 10 * 20 + col 10 = 210
2338        assert_eq!(output[12], 210.0);
2339        assert_eq!(mask[12], 1);
2340    }
2341
2342    #[test]
2343    fn agent_rect_no_disk_filtering() {
2344        // AgentRect should NOT filter any cells — full rectangle is valid.
2345        let space = Square4::new(20, 20, EdgeBehavior::Absorb).unwrap();
2346        let data: Vec<f32> = (0..400).map(|x| x as f32).collect();
2347        let snap = snapshot_with_field(FieldId(0), data);
2348
2349        let spec = ObsSpec {
2350            entries: vec![ObsEntry {
2351                field_id: FieldId(0),
2352                region: ObsRegion::AgentRect {
2353                    half_extent: smallvec::smallvec![2, 2],
2354                },
2355                pool: None,
2356                transform: ObsTransform::Identity,
2357                dtype: ObsDtype::F32,
2358            }],
2359        };
2360        let result = ObsPlan::compile(&spec, &space).unwrap();
2361
2362        let center: Coord = smallvec::smallvec![10, 10];
2363        let mut output = vec![0.0f32; 25];
2364        let mut mask = vec![0u8; 25];
2365        result
2366            .plan
2367            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2368            .unwrap();
2369
2370        // All 25 cells should be valid for AgentRect (no disk filtering).
2371        assert!(mask.iter().all(|&v| v == 1));
2372    }
2373
2374    #[test]
2375    fn agent_disk_square8_chebyshev() {
2376        // On an 8-connected grid, AgentDisk radius=1 uses Chebyshev distance.
2377        // Bounding box is 3x3 = 9, Chebyshev disk radius=1 = full 3x3 → 9 cells.
2378        let space = Square8::new(10, 10, EdgeBehavior::Absorb).unwrap();
2379        let data: Vec<f32> = (0..100).map(|x| x as f32).collect();
2380        let snap = snapshot_with_field(FieldId(0), data);
2381
2382        let spec = ObsSpec {
2383            entries: vec![ObsEntry {
2384                field_id: FieldId(0),
2385                region: ObsRegion::AgentDisk { radius: 1 },
2386                pool: None,
2387                transform: ObsTransform::Identity,
2388                dtype: ObsDtype::F32,
2389            }],
2390        };
2391        let result = ObsPlan::compile(&spec, &space).unwrap();
2392        assert_eq!(result.output_len, 9);
2393
2394        let center: Coord = smallvec::smallvec![5, 5];
2395        let mut output = vec![0.0f32; 9];
2396        let mut mask = vec![0u8; 9];
2397        result
2398            .plan
2399            .execute_agents(&snap, &space, &[center], None, &mut output, &mut mask)
2400            .unwrap();
2401
2402        // Chebyshev distance <= 1 covers full 3x3 = 9 cells (all corners included).
2403        let valid_count = mask.iter().filter(|&&v| v == 1).count();
2404        assert_eq!(valid_count, 9, "Chebyshev disk radius=1 = full 3x3");
2405    }
2406}