Skip to main content

murk_engine/
batched.rs

1//! Batched simulation engine for vectorized RL training.
2//!
3//! [`BatchedEngine`] owns N [`LockstepWorld`]s and steps them all in a
4//! single call, eliminating per-world FFI overhead. Observation extraction
5//! uses [`ObsPlan::execute_batch()`] to fill a contiguous output buffer
6//! across all worlds.
7//!
8//! # Design
9//!
10//! The hot path is `step_and_observe`: step all worlds sequentially, then
11//! extract observations in batch. The GIL is released once at the Python
12//! layer, covering the entire operation. This reduces 2N GIL cycles (the
13//! current `MurkVecEnv` approach) to exactly 1.
14//!
15//! Parallelism (rayon) is deferred to v2. The GIL elimination alone is
16//! the dominant win; adding `par_iter_mut` later is a 3-line change.
17
18use murk_core::command::Command;
19use murk_core::error::ObsError;
20use murk_core::id::TickId;
21use murk_core::traits::SnapshotAccess;
22use murk_obs::metadata::ObsMetadata;
23use murk_obs::plan::ObsPlan;
24use murk_obs::spec::ObsSpec;
25
26use crate::config::{ConfigError, WorldConfig};
27use crate::lockstep::LockstepWorld;
28use crate::metrics::StepMetrics;
29use crate::tick::TickError;
30
31// ── Error type ──────────────────────────────────────────────────
32
33/// Error from a batched operation, annotated with the failing world index.
34#[derive(Debug, PartialEq)]
35pub enum BatchError {
36    /// A world's `step_sync()` failed.
37    Step {
38        /// Index of the world that failed (0-based).
39        world_index: usize,
40        /// The underlying tick error.
41        error: TickError,
42    },
43    /// Observation extraction failed.
44    Observe(ObsError),
45    /// Configuration error during construction or reset.
46    Config(ConfigError),
47    /// World index out of bounds.
48    InvalidIndex {
49        /// The requested index.
50        world_index: usize,
51        /// Total number of worlds.
52        num_worlds: usize,
53    },
54    /// No observation plan was compiled (called observe without obs_spec).
55    NoObsPlan,
56    /// Batch-level argument validation failed.
57    InvalidArgument {
58        /// Human-readable description of what's wrong.
59        reason: String,
60    },
61}
62
63impl std::fmt::Display for BatchError {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            BatchError::Step { world_index, error } => {
67                write!(f, "world {world_index}: step failed: {error:?}")
68            }
69            BatchError::Observe(e) => write!(f, "observe failed: {e:?}"),
70            BatchError::Config(e) => write!(f, "config error: {e:?}"),
71            BatchError::InvalidIndex {
72                world_index,
73                num_worlds,
74            } => write!(
75                f,
76                "world index {world_index} out of range (num_worlds={num_worlds})"
77            ),
78            BatchError::NoObsPlan => write!(f, "no observation plan compiled"),
79            BatchError::InvalidArgument { reason } => {
80                write!(f, "invalid argument: {reason}")
81            }
82        }
83    }
84}
85
86impl std::error::Error for BatchError {}
87
88// ── Result type ─────────────────────────────────────────────────
89
90/// Result of stepping a batch of worlds.
91pub struct BatchResult {
92    /// Per-world tick IDs after stepping.
93    pub tick_ids: Vec<TickId>,
94    /// Per-world step metrics.
95    pub metrics: Vec<StepMetrics>,
96}
97
98// ── BatchedEngine ───────────────────────────────────────────────
99
100/// Batched simulation engine owning N lockstep worlds.
101///
102/// Created from N [`WorldConfig`]s with an optional [`ObsSpec`].
103/// All worlds must share the same space topology (validated at
104/// construction).
105///
106/// The primary interface is [`step_and_observe()`](Self::step_and_observe):
107/// step all worlds, then extract observations into a contiguous buffer
108/// using [`ObsPlan::execute_batch()`].
109pub struct BatchedEngine {
110    worlds: Vec<LockstepWorld>,
111    obs_plan: Option<ObsPlan>,
112    obs_output_len: usize,
113    obs_mask_len: usize,
114}
115
116impl BatchedEngine {
117    /// Create a batched engine from N world configs.
118    ///
119    /// If `obs_spec` is provided, compiles an [`ObsPlan`] from the first
120    /// world's space. All worlds must have the same `cell_count`
121    /// (defensive check).
122    ///
123    /// # Errors
124    ///
125    /// Returns [`BatchError::Config`] if any world fails to construct,
126    /// or [`BatchError::Observe`] if the obs plan fails to compile.
127    pub fn new(configs: Vec<WorldConfig>, obs_spec: Option<&ObsSpec>) -> Result<Self, BatchError> {
128        if configs.is_empty() {
129            return Err(BatchError::InvalidArgument {
130                reason: "BatchedEngine requires at least one world config".into(),
131            });
132        }
133
134        let mut worlds = Vec::with_capacity(configs.len());
135        for config in configs {
136            let world = LockstepWorld::new(config).map_err(BatchError::Config)?;
137            worlds.push(world);
138        }
139
140        // Validate all worlds share the same space topology.
141        // topology_eq checks TypeId, dimensions, and behavioral parameters
142        // (e.g. EdgeBehavior) so that spaces like Line1D(10, Absorb) and
143        // Line1D(10, Wrap) are correctly rejected.
144        let ref_space = worlds[0].space();
145        for (i, world) in worlds.iter().enumerate().skip(1) {
146            if !ref_space.topology_eq(world.space()) {
147                return Err(BatchError::InvalidArgument {
148                    reason: format!(
149                        "world 0 and world {i} have incompatible space topologies; \
150                         all worlds in a batch must use the same topology"
151                    ),
152                });
153            }
154        }
155
156        // Compile obs plan if spec provided.
157        let (obs_plan, obs_output_len, obs_mask_len) = match obs_spec {
158            Some(spec) => {
159                let result =
160                    ObsPlan::compile(spec, worlds[0].space()).map_err(BatchError::Observe)?;
161
162                // Validate all worlds have matching field schemas for observed fields.
163                // ObsPlan::compile only takes a Space (not a snapshot), so field
164                // existence isn't checked until execute(). Catching mismatches here
165                // prevents late observation failures after worlds have been stepped.
166                let ref_snap = worlds[0].snapshot();
167                for entry in &spec.entries {
168                    let fid = entry.field_id;
169                    let ref_len = ref_snap.read_field(fid).map(|d| d.len());
170                    for (i, world) in worlds.iter().enumerate().skip(1) {
171                        let snap = world.snapshot();
172                        let other_len = snap.read_field(fid).map(|d| d.len());
173                        if other_len != ref_len {
174                            return Err(BatchError::InvalidArgument {
175                                reason: format!(
176                                    "world {i} field {fid:?}: {} elements, \
177                                     world 0 has {} elements; \
178                                     all worlds must share the same field schema",
179                                    other_len
180                                        .map(|n| n.to_string())
181                                        .unwrap_or_else(|| "missing".into()),
182                                    ref_len
183                                        .map(|n| n.to_string())
184                                        .unwrap_or_else(|| "missing".into()),
185                                ),
186                            });
187                        }
188                    }
189                    // Absolute check: the observed field must actually exist
190                    // in the reference world. Without this, None == None passes
191                    // the cross-world comparison when the field is missing from
192                    // all worlds, deferring the error to observe time.
193                    if ref_len.is_none() {
194                        return Err(BatchError::InvalidArgument {
195                            reason: format!(
196                                "obs spec references {fid:?} which is missing from world 0; \
197                                 every observed field must exist in all worlds",
198                            ),
199                        });
200                    }
201                }
202
203                (Some(result.plan), result.output_len, result.mask_len)
204            }
205            None => (None, 0, 0),
206        };
207
208        Ok(BatchedEngine {
209            worlds,
210            obs_plan,
211            obs_output_len,
212            obs_mask_len,
213        })
214    }
215
216    /// Step all worlds and extract observations in one call.
217    ///
218    /// `commands` must have exactly `num_worlds()` entries.
219    /// `output` must have at least `num_worlds() * obs_output_len()` elements.
220    /// `mask` must have at least `num_worlds() * obs_mask_len()` bytes.
221    ///
222    /// Returns per-world tick IDs and metrics.
223    pub fn step_and_observe(
224        &mut self,
225        commands: &[Vec<Command>],
226        output: &mut [f32],
227        mask: &mut [u8],
228    ) -> Result<BatchResult, BatchError> {
229        // Pre-flight: validate observation preconditions before mutating
230        // world state. Without this, a late observe failure (no obs plan,
231        // buffer too small) would leave worlds stepped but observations
232        // unextracted — making the error non-atomic.
233        self.validate_observe_buffers(output, mask)?;
234
235        let result = self.step_all(commands)?;
236
237        // Observe phase: borrow worlds immutably for snapshot collection.
238        self.observe_all_inner(output, mask)?;
239
240        Ok(result)
241    }
242
243    /// Step all worlds without observation extraction.
244    pub fn step_all(&mut self, commands: &[Vec<Command>]) -> Result<BatchResult, BatchError> {
245        let n = self.worlds.len();
246        if commands.len() != n {
247            return Err(BatchError::InvalidArgument {
248                reason: format!("commands has {} entries, expected {n}", commands.len()),
249            });
250        }
251
252        let mut tick_ids = Vec::with_capacity(n);
253        let mut metrics = Vec::with_capacity(n);
254
255        for (idx, world) in self.worlds.iter_mut().enumerate() {
256            let result = world
257                .step_sync(commands[idx].clone())
258                .map_err(|e| BatchError::Step {
259                    world_index: idx,
260                    error: e,
261                })?;
262            tick_ids.push(result.snapshot.tick_id());
263            metrics.push(result.metrics);
264        }
265
266        Ok(BatchResult { tick_ids, metrics })
267    }
268
269    /// Extract observations from all worlds without stepping.
270    ///
271    /// Used after `reset_all()` to get initial observations.
272    pub fn observe_all(
273        &self,
274        output: &mut [f32],
275        mask: &mut [u8],
276    ) -> Result<Vec<ObsMetadata>, BatchError> {
277        self.observe_all_inner(output, mask)
278    }
279
280    /// Internal observation extraction shared by step_and_observe and observe_all.
281    fn observe_all_inner(
282        &self,
283        output: &mut [f32],
284        mask: &mut [u8],
285    ) -> Result<Vec<ObsMetadata>, BatchError> {
286        let plan = self.obs_plan.as_ref().ok_or(BatchError::NoObsPlan)?;
287
288        let snapshots: Vec<_> = self.worlds.iter().map(|w| w.snapshot()).collect();
289        let snap_refs: Vec<&dyn SnapshotAccess> =
290            snapshots.iter().map(|s| s as &dyn SnapshotAccess).collect();
291
292        plan.execute_batch(&snap_refs, None, output, mask)
293            .map_err(BatchError::Observe)
294    }
295
296    /// Validate that observation preconditions are met (plan exists, buffers
297    /// large enough) without performing any mutation. Called by
298    /// `step_and_observe` before `step_all` to guarantee atomicity.
299    fn validate_observe_buffers(&self, output: &[f32], mask: &[u8]) -> Result<(), BatchError> {
300        let plan = self.obs_plan.as_ref().ok_or(BatchError::NoObsPlan)?;
301        if plan.is_standard() {
302            return Err(BatchError::InvalidArgument {
303                reason: "obs spec uses agent-relative regions (AgentDisk/AgentRect), \
304                         which are unsupported in batched step_and_observe"
305                    .into(),
306            });
307        }
308        let n = self.worlds.len();
309        let expected_out = n * self.obs_output_len;
310        let expected_mask = n * self.obs_mask_len;
311        if output.len() < expected_out {
312            return Err(BatchError::InvalidArgument {
313                reason: format!("output buffer too small: {} < {expected_out}", output.len()),
314            });
315        }
316        if mask.len() < expected_mask {
317            return Err(BatchError::InvalidArgument {
318                reason: format!("mask buffer too small: {} < {expected_mask}", mask.len()),
319            });
320        }
321        Ok(())
322    }
323
324    /// Reset a single world by index.
325    pub fn reset_world(&mut self, idx: usize, seed: u64) -> Result<(), BatchError> {
326        let n = self.worlds.len();
327        let world = self.worlds.get_mut(idx).ok_or(BatchError::InvalidIndex {
328            world_index: idx,
329            num_worlds: n,
330        })?;
331        world.reset(seed).map_err(BatchError::Config)?;
332        Ok(())
333    }
334
335    /// Reset all worlds with per-world seeds.
336    pub fn reset_all(&mut self, seeds: &[u64]) -> Result<(), BatchError> {
337        let n = self.worlds.len();
338        if seeds.len() != n {
339            return Err(BatchError::InvalidArgument {
340                reason: format!("seeds has {} entries, expected {n}", seeds.len()),
341            });
342        }
343        for (idx, world) in self.worlds.iter_mut().enumerate() {
344            world.reset(seeds[idx]).map_err(BatchError::Config)?;
345        }
346        Ok(())
347    }
348
349    /// Number of worlds in the batch.
350    pub fn num_worlds(&self) -> usize {
351        self.worlds.len()
352    }
353
354    /// Per-world observation output length (f32 elements).
355    pub fn obs_output_len(&self) -> usize {
356        self.obs_output_len
357    }
358
359    /// Per-world observation mask length (bytes).
360    pub fn obs_mask_len(&self) -> usize {
361        self.obs_mask_len
362    }
363
364    /// Current tick ID of a specific world.
365    pub fn world_tick(&self, idx: usize) -> Option<TickId> {
366        self.worlds.get(idx).map(|w| w.current_tick())
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use murk_core::id::FieldId;
374    use murk_core::traits::FieldReader;
375    use murk_obs::spec::{ObsDtype, ObsEntry, ObsRegion, ObsTransform};
376    use murk_space::{EdgeBehavior, Line1D, RegionSpec, Square4};
377    use murk_test_utils::ConstPropagator;
378
379    use crate::config::BackoffConfig;
380
381    fn scalar_field(name: &str) -> murk_core::FieldDef {
382        murk_core::FieldDef {
383            name: name.to_string(),
384            field_type: murk_core::FieldType::Scalar,
385            mutability: murk_core::FieldMutability::PerTick,
386            units: None,
387            bounds: None,
388            boundary_behavior: murk_core::BoundaryBehavior::Clamp,
389        }
390    }
391
392    fn make_config(seed: u64, value: f32) -> WorldConfig {
393        WorldConfig {
394            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
395            fields: vec![scalar_field("energy")],
396            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), value))],
397            dt: 0.1,
398            seed,
399            ring_buffer_size: 8,
400            max_ingress_queue: 1024,
401            tick_rate_hz: None,
402            backoff: BackoffConfig::default(),
403        }
404    }
405
406    fn make_grid_config(seed: u64, value: f32) -> WorldConfig {
407        WorldConfig {
408            space: Box::new(Square4::new(4, 4, EdgeBehavior::Absorb).unwrap()),
409            fields: vec![scalar_field("energy")],
410            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), value))],
411            dt: 0.1,
412            seed,
413            ring_buffer_size: 8,
414            max_ingress_queue: 1024,
415            tick_rate_hz: None,
416            backoff: BackoffConfig::default(),
417        }
418    }
419
420    fn obs_spec_all_field0() -> ObsSpec {
421        ObsSpec {
422            entries: vec![ObsEntry {
423                field_id: FieldId(0),
424                region: ObsRegion::Fixed(RegionSpec::All),
425                pool: None,
426                transform: ObsTransform::Identity,
427                dtype: ObsDtype::F32,
428            }],
429        }
430    }
431
432    // ── Construction tests ────────────────────────────────────
433
434    #[test]
435    fn new_single_world() {
436        let configs = vec![make_config(42, 1.0)];
437        let engine = BatchedEngine::new(configs, None).unwrap();
438        assert_eq!(engine.num_worlds(), 1);
439        assert_eq!(engine.obs_output_len(), 0);
440        assert_eq!(engine.obs_mask_len(), 0);
441    }
442
443    #[test]
444    fn new_four_worlds() {
445        let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
446        let engine = BatchedEngine::new(configs, None).unwrap();
447        assert_eq!(engine.num_worlds(), 4);
448    }
449
450    #[test]
451    fn new_zero_worlds_is_error() {
452        let result = BatchedEngine::new(vec![], None);
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn new_with_obs_spec() {
458        let configs = vec![make_config(42, 1.0)];
459        let spec = obs_spec_all_field0();
460        let engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
461        assert_eq!(engine.obs_output_len(), 10); // Line1D(10) → 10 cells
462        assert_eq!(engine.obs_mask_len(), 10);
463    }
464
465    // ── Determinism test ──────────────────────────────────────
466
467    #[test]
468    fn batch_matches_independent_worlds() {
469        let spec = obs_spec_all_field0();
470
471        // Batched: 2 worlds
472        let configs = vec![make_config(42, 42.0), make_config(99, 42.0)];
473        let mut batched = BatchedEngine::new(configs, Some(&spec)).unwrap();
474        let n = batched.num_worlds();
475        let out_len = n * batched.obs_output_len();
476        let mask_len = n * batched.obs_mask_len();
477        let mut batch_output = vec![0.0f32; out_len];
478        let mut batch_mask = vec![0u8; mask_len];
479
480        let commands = vec![vec![], vec![]];
481        batched
482            .step_and_observe(&commands, &mut batch_output, &mut batch_mask)
483            .unwrap();
484
485        // Independent: 2 separate worlds
486        let mut w0 = LockstepWorld::new(make_config(42, 42.0)).unwrap();
487        let mut w1 = LockstepWorld::new(make_config(99, 42.0)).unwrap();
488        let r0 = w0.step_sync(vec![]).unwrap();
489        let r1 = w1.step_sync(vec![]).unwrap();
490
491        let d0 = r0.snapshot.read(FieldId(0)).unwrap();
492        let d1 = r1.snapshot.read(FieldId(0)).unwrap();
493
494        // Batch output should be [world0_obs | world1_obs]
495        assert_eq!(&batch_output[..10], d0);
496        assert_eq!(&batch_output[10..20], d1);
497    }
498
499    // ── Observation correctness ───────────────────────────────
500
501    #[test]
502    fn observation_filled_with_const_value() {
503        let spec = obs_spec_all_field0();
504        let configs = vec![
505            make_config(1, 42.0),
506            make_config(2, 42.0),
507            make_config(3, 42.0),
508        ];
509        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
510
511        let commands = vec![vec![], vec![], vec![]];
512        let n = engine.num_worlds();
513        let mut output = vec![0.0f32; n * engine.obs_output_len()];
514        let mut mask = vec![0u8; n * engine.obs_mask_len()];
515        engine
516            .step_and_observe(&commands, &mut output, &mut mask)
517            .unwrap();
518
519        // All cells should be 42.0 for all worlds.
520        assert!(output.iter().all(|&v| v == 42.0));
521        assert!(mask.iter().all(|&m| m == 1));
522    }
523
524    // ── Reset tests ───────────────────────────────────────────
525
526    #[test]
527    fn reset_single_world_preserves_others() {
528        let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
529        let mut engine = BatchedEngine::new(configs, None).unwrap();
530
531        // Step all once.
532        let commands = vec![vec![]; 4];
533        engine.step_all(&commands).unwrap();
534        assert_eq!(engine.world_tick(0), Some(TickId(1)));
535        assert_eq!(engine.world_tick(3), Some(TickId(1)));
536
537        // Reset only world 0.
538        engine.reset_world(0, 999).unwrap();
539        assert_eq!(engine.world_tick(0), Some(TickId(0)));
540        assert_eq!(engine.world_tick(1), Some(TickId(1)));
541        assert_eq!(engine.world_tick(2), Some(TickId(1)));
542        assert_eq!(engine.world_tick(3), Some(TickId(1)));
543    }
544
545    #[test]
546    fn reset_all_resets_to_tick_zero() {
547        let configs: Vec<_> = (0..3).map(|i| make_config(i, 1.0)).collect();
548        let mut engine = BatchedEngine::new(configs, None).unwrap();
549
550        // Step all twice.
551        let commands = vec![vec![]; 3];
552        engine.step_all(&commands).unwrap();
553        engine.step_all(&commands).unwrap();
554
555        engine.reset_all(&[10, 20, 30]).unwrap();
556        for i in 0..3 {
557            assert_eq!(engine.world_tick(i), Some(TickId(0)));
558        }
559    }
560
561    // ── Error isolation ───────────────────────────────────────
562
563    #[test]
564    fn invalid_world_index_returns_error() {
565        let configs = vec![make_config(0, 1.0)];
566        let mut engine = BatchedEngine::new(configs, None).unwrap();
567
568        let result = engine.reset_world(5, 0);
569        assert!(matches!(result, Err(BatchError::InvalidIndex { .. })));
570    }
571
572    #[test]
573    fn wrong_command_count_returns_error() {
574        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
575        let mut engine = BatchedEngine::new(configs, None).unwrap();
576
577        let result = engine.step_all(&[vec![]]); // 1 entry for 2 worlds
578        assert!(result.is_err());
579    }
580
581    #[test]
582    fn observe_without_plan_returns_error() {
583        let configs = vec![make_config(0, 1.0)];
584        let engine = BatchedEngine::new(configs, None).unwrap();
585
586        let mut output = vec![0.0f32; 10];
587        let mut mask = vec![0u8; 10];
588        let result = engine.observe_all(&mut output, &mut mask);
589        assert!(matches!(result, Err(BatchError::NoObsPlan)));
590    }
591
592    // ── Observe after reset ───────────────────────────────────
593
594    #[test]
595    fn observe_all_after_reset() {
596        let spec = obs_spec_all_field0();
597        let configs = vec![make_config(1, 42.0), make_config(2, 42.0)];
598        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
599
600        // Step once to populate data.
601        let commands = vec![vec![], vec![]];
602        let n = engine.num_worlds();
603        let mut output = vec![0.0f32; n * engine.obs_output_len()];
604        let mut mask = vec![0u8; n * engine.obs_mask_len()];
605        engine
606            .step_and_observe(&commands, &mut output, &mut mask)
607            .unwrap();
608
609        // Reset all and observe (initial state is zeroed).
610        engine.reset_all(&[10, 20]).unwrap();
611        let meta = engine.observe_all(&mut output, &mut mask).unwrap();
612        assert_eq!(meta.len(), 2);
613        assert_eq!(meta[0].tick_id, TickId(0));
614        assert_eq!(meta[1].tick_id, TickId(0));
615    }
616
617    // ── Topology validation ──────────────────────────────────
618
619    #[test]
620    fn mixed_space_types_rejected() {
621        use murk_space::Ring1D;
622
623        // Line1D(10) and Ring1D(10): same ndim, same cell_count, different type.
624        let line_config = WorldConfig {
625            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
626            fields: vec![scalar_field("energy")],
627            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
628            dt: 0.1,
629            seed: 1,
630            ring_buffer_size: 8,
631            max_ingress_queue: 1024,
632            tick_rate_hz: None,
633            backoff: BackoffConfig::default(),
634        };
635        let ring_config = WorldConfig {
636            space: Box::new(Ring1D::new(10).unwrap()),
637            fields: vec![scalar_field("energy")],
638            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
639            dt: 0.1,
640            seed: 2,
641            ring_buffer_size: 8,
642            max_ingress_queue: 1024,
643            tick_rate_hz: None,
644            backoff: BackoffConfig::default(),
645        };
646
647        let result = BatchedEngine::new(vec![line_config, ring_config], None);
648        match result {
649            Err(e) => {
650                let msg = format!("{e}");
651                assert!(msg.contains("incompatible space topologies"), "got: {msg}");
652            }
653            Ok(_) => panic!("expected error for mixed space types"),
654        }
655    }
656
657    #[test]
658    fn mixed_edge_behaviors_rejected() {
659        // Line1D(10, Absorb) and Line1D(10, Wrap): same TypeId, ndim, cell_count,
660        // but different edge behavior — must be rejected.
661        let absorb_config = WorldConfig {
662            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
663            fields: vec![scalar_field("energy")],
664            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
665            dt: 0.1,
666            seed: 1,
667            ring_buffer_size: 8,
668            max_ingress_queue: 1024,
669            tick_rate_hz: None,
670            backoff: BackoffConfig::default(),
671        };
672        let wrap_config = WorldConfig {
673            space: Box::new(Line1D::new(10, EdgeBehavior::Wrap).unwrap()),
674            fields: vec![scalar_field("energy")],
675            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
676            dt: 0.1,
677            seed: 2,
678            ring_buffer_size: 8,
679            max_ingress_queue: 1024,
680            tick_rate_hz: None,
681            backoff: BackoffConfig::default(),
682        };
683
684        let result = BatchedEngine::new(vec![absorb_config, wrap_config], None);
685        assert!(result.is_err(), "expected error for mixed edge behaviors");
686    }
687
688    // ── Atomic step_and_observe ──────────────────────────────
689
690    #[test]
691    fn step_and_observe_no_plan_does_not_step() {
692        // Without an obs plan, step_and_observe should fail *before*
693        // advancing any world state.
694        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
695        let mut engine = BatchedEngine::new(configs, None).unwrap();
696
697        let commands = vec![vec![], vec![]];
698        let mut output = vec![0.0f32; 20];
699        let mut mask = vec![0u8; 20];
700        let result = engine.step_and_observe(&commands, &mut output, &mut mask);
701        assert!(matches!(result, Err(BatchError::NoObsPlan)));
702
703        // Worlds must still be at tick 0 — no mutation occurred.
704        assert_eq!(engine.world_tick(0), Some(TickId(0)));
705        assert_eq!(engine.world_tick(1), Some(TickId(0)));
706    }
707
708    #[test]
709    fn step_and_observe_small_buffer_does_not_step() {
710        // Buffer too small should fail before advancing world state.
711        let spec = obs_spec_all_field0();
712        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
713        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
714
715        let commands = vec![vec![], vec![]];
716        let mut output = vec![0.0f32; 5]; // need 20, only 5
717        let mut mask = vec![0u8; 20];
718        let result = engine.step_and_observe(&commands, &mut output, &mut mask);
719        assert!(result.is_err());
720
721        // Worlds must still be at tick 0.
722        assert_eq!(engine.world_tick(0), Some(TickId(0)));
723        assert_eq!(engine.world_tick(1), Some(TickId(0)));
724    }
725
726    #[test]
727    fn step_and_observe_agent_relative_plan_does_not_step() {
728        let spec = ObsSpec {
729            entries: vec![ObsEntry {
730                field_id: FieldId(0),
731                region: ObsRegion::AgentRect {
732                    half_extent: smallvec::smallvec![1, 1],
733                },
734                pool: None,
735                transform: ObsTransform::Identity,
736                dtype: ObsDtype::F32,
737            }],
738        };
739        let configs = vec![make_grid_config(0, 1.0), make_grid_config(1, 1.0)];
740        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
741        let n = engine.num_worlds();
742        let mut output = vec![0.0f32; n * engine.obs_output_len()];
743        let mut mask = vec![0u8; n * engine.obs_mask_len()];
744
745        let result = engine.step_and_observe(&[vec![], vec![]], &mut output, &mut mask);
746        match result {
747            Err(BatchError::InvalidArgument { reason }) => {
748                assert!(
749                    reason.contains("AgentDisk/AgentRect"),
750                    "unexpected reason: {reason}"
751                );
752            }
753            _ => panic!("expected InvalidArgument for agent-relative plan"),
754        }
755
756        assert_eq!(engine.world_tick(0), Some(TickId(0)));
757        assert_eq!(engine.world_tick(1), Some(TickId(0)));
758    }
759
760    // ── Field schema validation ─────────────────────────────
761
762    #[test]
763    fn obs_spec_referencing_missing_field_rejected() {
764        // Both worlds lack FieldId(1), but obs spec references it.
765        // Construction must fail — not silently pass and blow up at observe time.
766        let spec = ObsSpec {
767            entries: vec![
768                ObsEntry {
769                    field_id: FieldId(0),
770                    region: ObsRegion::Fixed(RegionSpec::All),
771                    pool: None,
772                    transform: ObsTransform::Identity,
773                    dtype: ObsDtype::F32,
774                },
775                ObsEntry {
776                    field_id: FieldId(1), // missing in both worlds
777                    region: ObsRegion::Fixed(RegionSpec::All),
778                    pool: None,
779                    transform: ObsTransform::Identity,
780                    dtype: ObsDtype::F32,
781                },
782            ],
783        };
784
785        // Both worlds only have FieldId(0)
786        let configs = vec![make_config(1, 1.0), make_config(2, 1.0)];
787        let result = BatchedEngine::new(configs, Some(&spec));
788        match result {
789            Err(e) => {
790                let msg = format!("{e}");
791                assert!(
792                    msg.contains("missing"),
793                    "error should mention missing field, got: {msg}"
794                );
795            }
796            Ok(_) => {
797                panic!("expected error for obs spec referencing field missing from all worlds")
798            }
799        }
800    }
801
802    #[test]
803    fn obs_spec_referencing_missing_field_single_world_rejected() {
804        // Single world lacks FieldId(1), obs spec references it.
805        // The cross-world loop is skipped (only 1 world), so the
806        // ref_len check must still catch this.
807        let spec = ObsSpec {
808            entries: vec![ObsEntry {
809                field_id: FieldId(1), // missing
810                region: ObsRegion::Fixed(RegionSpec::All),
811                pool: None,
812                transform: ObsTransform::Identity,
813                dtype: ObsDtype::F32,
814            }],
815        };
816
817        let configs = vec![make_config(1, 1.0)]; // only has FieldId(0)
818        let result = BatchedEngine::new(configs, Some(&spec));
819        assert!(
820            result.is_err(),
821            "expected error for obs spec referencing field missing from single world"
822        );
823    }
824
825    #[test]
826    fn mismatched_field_schemas_rejected() {
827        // World 0 has 2 fields, world 1 has only 1. Obs spec references
828        // FieldId(1) which is missing in world 1. Construction must fail.
829        let spec = ObsSpec {
830            entries: vec![
831                ObsEntry {
832                    field_id: FieldId(0),
833                    region: ObsRegion::Fixed(RegionSpec::All),
834                    pool: None,
835                    transform: ObsTransform::Identity,
836                    dtype: ObsDtype::F32,
837                },
838                ObsEntry {
839                    field_id: FieldId(1),
840                    region: ObsRegion::Fixed(RegionSpec::All),
841                    pool: None,
842                    transform: ObsTransform::Identity,
843                    dtype: ObsDtype::F32,
844                },
845            ],
846        };
847
848        // World 0: has 2 fields (FieldId(0) and FieldId(1))
849        let config_two_fields = WorldConfig {
850            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
851            fields: vec![scalar_field("energy"), scalar_field("temp")],
852            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
853            dt: 0.1,
854            seed: 1,
855            ring_buffer_size: 8,
856            max_ingress_queue: 1024,
857            tick_rate_hz: None,
858            backoff: BackoffConfig::default(),
859        };
860
861        // World 1: has only 1 field (FieldId(0)), missing FieldId(1)
862        let config_one_field = WorldConfig {
863            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
864            fields: vec![scalar_field("energy")],
865            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
866            dt: 0.1,
867            seed: 2,
868            ring_buffer_size: 8,
869            max_ingress_queue: 1024,
870            tick_rate_hz: None,
871            backoff: BackoffConfig::default(),
872        };
873
874        let result = BatchedEngine::new(vec![config_two_fields, config_one_field], Some(&spec));
875        match result {
876            Err(e) => {
877                let msg = format!("{e}");
878                assert!(
879                    msg.contains("field") && msg.contains("missing"),
880                    "error should mention missing field, got: {msg}"
881                );
882            }
883            Ok(_) => panic!("expected error for mismatched field schemas"),
884        }
885    }
886}