Skip to main content

murk_engine/
batched.rs

1//! Batched simulation engine for vectorized RL training.
2//!
3//! [`BatchedEngine`] owns N [`LockstepWorld`]s and steps them all in a
4//! single call, eliminating per-world FFI overhead. Observation extraction
5//! uses [`ObsPlan::execute_batch()`] to fill a contiguous output buffer
6//! across all worlds.
7//!
8//! # Design
9//!
10//! The hot path is `step_and_observe`: step all worlds sequentially, then
11//! extract observations in batch. The GIL is released once at the Python
12//! layer, covering the entire operation. This reduces 2N GIL cycles (the
13//! current `MurkVecEnv` approach) to exactly 1.
14//!
15//! Parallelism (rayon) is deferred to v2. The GIL elimination alone is
16//! the dominant win; adding `par_iter_mut` later is a 3-line change.
17
18use murk_core::command::Command;
19use murk_core::error::ObsError;
20use murk_core::id::TickId;
21use murk_core::traits::SnapshotAccess;
22use murk_obs::metadata::ObsMetadata;
23use murk_obs::plan::ObsPlan;
24use murk_obs::spec::ObsSpec;
25
26use crate::config::{ConfigError, WorldConfig};
27use crate::lockstep::LockstepWorld;
28use crate::metrics::StepMetrics;
29use crate::tick::TickError;
30
31// ── Error type ──────────────────────────────────────────────────
32
33/// Error from a batched operation, annotated with the failing world index.
34#[derive(Debug, PartialEq)]
35pub enum BatchError {
36    /// A world's `step_sync()` failed.
37    Step {
38        /// Index of the world that failed (0-based).
39        world_index: usize,
40        /// The underlying tick error.
41        error: TickError,
42    },
43    /// Observation extraction failed.
44    Observe(ObsError),
45    /// Configuration error during construction or reset.
46    Config(ConfigError),
47    /// World index out of bounds.
48    InvalidIndex {
49        /// The requested index.
50        world_index: usize,
51        /// Total number of worlds.
52        num_worlds: usize,
53    },
54    /// No observation plan was compiled (called observe without obs_spec).
55    NoObsPlan,
56    /// Batch-level argument validation failed.
57    InvalidArgument {
58        /// Human-readable description of what's wrong.
59        reason: String,
60    },
61}
62
63impl std::fmt::Display for BatchError {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            BatchError::Step { world_index, error } => {
67                write!(f, "world {world_index}: step failed: {error:?}")
68            }
69            BatchError::Observe(e) => write!(f, "observe failed: {e:?}"),
70            BatchError::Config(e) => write!(f, "config error: {e:?}"),
71            BatchError::InvalidIndex {
72                world_index,
73                num_worlds,
74            } => write!(
75                f,
76                "world index {world_index} out of range (num_worlds={num_worlds})"
77            ),
78            BatchError::NoObsPlan => write!(f, "no observation plan compiled"),
79            BatchError::InvalidArgument { reason } => {
80                write!(f, "invalid argument: {reason}")
81            }
82        }
83    }
84}
85
86impl std::error::Error for BatchError {}
87
88// ── Result type ─────────────────────────────────────────────────
89
90/// Result of stepping a batch of worlds.
91pub struct BatchResult {
92    /// Per-world tick IDs after stepping.
93    pub tick_ids: Vec<TickId>,
94    /// Per-world step metrics.
95    pub metrics: Vec<StepMetrics>,
96}
97
98// ── BatchedEngine ───────────────────────────────────────────────
99
100/// Batched simulation engine owning N lockstep worlds.
101///
102/// Created from N [`WorldConfig`]s with an optional [`ObsSpec`].
103/// All worlds must share the same space topology (validated at
104/// construction).
105///
106/// The primary interface is [`step_and_observe()`](Self::step_and_observe):
107/// step all worlds, then extract observations into a contiguous buffer
108/// using [`ObsPlan::execute_batch()`].
109pub struct BatchedEngine {
110    worlds: Vec<LockstepWorld>,
111    obs_plan: Option<ObsPlan>,
112    obs_output_len: usize,
113    obs_mask_len: usize,
114}
115
116impl BatchedEngine {
117    /// Create a batched engine from N world configs.
118    ///
119    /// If `obs_spec` is provided, compiles an [`ObsPlan`] from the first
120    /// world's space. All worlds must have the same `cell_count`
121    /// (defensive check).
122    ///
123    /// # Errors
124    ///
125    /// Returns [`BatchError::Config`] if any world fails to construct,
126    /// or [`BatchError::Observe`] if the obs plan fails to compile.
127    pub fn new(configs: Vec<WorldConfig>, obs_spec: Option<&ObsSpec>) -> Result<Self, BatchError> {
128        if configs.is_empty() {
129            return Err(BatchError::InvalidArgument {
130                reason: "BatchedEngine requires at least one world config".into(),
131            });
132        }
133
134        let mut worlds = Vec::with_capacity(configs.len());
135        for config in configs {
136            let world = LockstepWorld::new(config).map_err(BatchError::Config)?;
137            worlds.push(world);
138        }
139
140        // Validate all worlds share the same space topology.
141        // topology_eq checks TypeId, dimensions, and behavioral parameters
142        // (e.g. EdgeBehavior) so that spaces like Line1D(10, Absorb) and
143        // Line1D(10, Wrap) are correctly rejected.
144        let ref_space = worlds[0].space();
145        for (i, world) in worlds.iter().enumerate().skip(1) {
146            if !ref_space.topology_eq(world.space()) {
147                return Err(BatchError::InvalidArgument {
148                    reason: format!(
149                        "world 0 and world {i} have incompatible space topologies; \
150                         all worlds in a batch must use the same topology"
151                    ),
152                });
153            }
154        }
155
156        // Compile obs plan if spec provided.
157        let (obs_plan, obs_output_len, obs_mask_len) = match obs_spec {
158            Some(spec) => {
159                let result =
160                    ObsPlan::compile(spec, worlds[0].space()).map_err(BatchError::Observe)?;
161
162                // Validate all worlds have matching field schemas for observed fields.
163                // ObsPlan::compile only takes a Space (not a snapshot), so field
164                // existence isn't checked until execute(). Catching mismatches here
165                // prevents late observation failures after worlds have been stepped.
166                let ref_snap = worlds[0].snapshot();
167                for entry in &spec.entries {
168                    let fid = entry.field_id;
169                    let ref_len = ref_snap.read_field(fid).map(|d| d.len());
170                    for (i, world) in worlds.iter().enumerate().skip(1) {
171                        let snap = world.snapshot();
172                        let other_len = snap.read_field(fid).map(|d| d.len());
173                        if other_len != ref_len {
174                            return Err(BatchError::InvalidArgument {
175                                reason: format!(
176                                    "world {i} field {fid:?}: {} elements, \
177                                     world 0 has {} elements; \
178                                     all worlds must share the same field schema",
179                                    other_len
180                                        .map(|n| n.to_string())
181                                        .unwrap_or_else(|| "missing".into()),
182                                    ref_len
183                                        .map(|n| n.to_string())
184                                        .unwrap_or_else(|| "missing".into()),
185                                ),
186                            });
187                        }
188                    }
189                }
190
191                (Some(result.plan), result.output_len, result.mask_len)
192            }
193            None => (None, 0, 0),
194        };
195
196        Ok(BatchedEngine {
197            worlds,
198            obs_plan,
199            obs_output_len,
200            obs_mask_len,
201        })
202    }
203
204    /// Step all worlds and extract observations in one call.
205    ///
206    /// `commands` must have exactly `num_worlds()` entries.
207    /// `output` must have at least `num_worlds() * obs_output_len()` elements.
208    /// `mask` must have at least `num_worlds() * obs_mask_len()` bytes.
209    ///
210    /// Returns per-world tick IDs and metrics.
211    pub fn step_and_observe(
212        &mut self,
213        commands: &[Vec<Command>],
214        output: &mut [f32],
215        mask: &mut [u8],
216    ) -> Result<BatchResult, BatchError> {
217        // Pre-flight: validate observation preconditions before mutating
218        // world state. Without this, a late observe failure (no obs plan,
219        // buffer too small) would leave worlds stepped but observations
220        // unextracted — making the error non-atomic.
221        self.validate_observe_buffers(output, mask)?;
222
223        let result = self.step_all(commands)?;
224
225        // Observe phase: borrow worlds immutably for snapshot collection.
226        self.observe_all_inner(output, mask)?;
227
228        Ok(result)
229    }
230
231    /// Step all worlds without observation extraction.
232    pub fn step_all(&mut self, commands: &[Vec<Command>]) -> Result<BatchResult, BatchError> {
233        let n = self.worlds.len();
234        if commands.len() != n {
235            return Err(BatchError::InvalidArgument {
236                reason: format!("commands has {} entries, expected {n}", commands.len()),
237            });
238        }
239
240        let mut tick_ids = Vec::with_capacity(n);
241        let mut metrics = Vec::with_capacity(n);
242
243        for (idx, world) in self.worlds.iter_mut().enumerate() {
244            let result = world
245                .step_sync(commands[idx].clone())
246                .map_err(|e| BatchError::Step {
247                    world_index: idx,
248                    error: e,
249                })?;
250            tick_ids.push(result.snapshot.tick_id());
251            metrics.push(result.metrics);
252        }
253
254        Ok(BatchResult { tick_ids, metrics })
255    }
256
257    /// Extract observations from all worlds without stepping.
258    ///
259    /// Used after `reset_all()` to get initial observations.
260    pub fn observe_all(
261        &self,
262        output: &mut [f32],
263        mask: &mut [u8],
264    ) -> Result<Vec<ObsMetadata>, BatchError> {
265        self.observe_all_inner(output, mask)
266    }
267
268    /// Internal observation extraction shared by step_and_observe and observe_all.
269    fn observe_all_inner(
270        &self,
271        output: &mut [f32],
272        mask: &mut [u8],
273    ) -> Result<Vec<ObsMetadata>, BatchError> {
274        let plan = self.obs_plan.as_ref().ok_or(BatchError::NoObsPlan)?;
275
276        let snapshots: Vec<_> = self.worlds.iter().map(|w| w.snapshot()).collect();
277        let snap_refs: Vec<&dyn SnapshotAccess> =
278            snapshots.iter().map(|s| s as &dyn SnapshotAccess).collect();
279
280        plan.execute_batch(&snap_refs, None, output, mask)
281            .map_err(BatchError::Observe)
282    }
283
284    /// Validate that observation preconditions are met (plan exists, buffers
285    /// large enough) without performing any mutation. Called by
286    /// `step_and_observe` before `step_all` to guarantee atomicity.
287    fn validate_observe_buffers(&self, output: &[f32], mask: &[u8]) -> Result<(), BatchError> {
288        let plan = self.obs_plan.as_ref().ok_or(BatchError::NoObsPlan)?;
289        if plan.is_standard() {
290            return Err(BatchError::InvalidArgument {
291                reason: "obs spec uses agent-relative regions (AgentDisk/AgentRect), \
292                         which are unsupported in batched step_and_observe"
293                    .into(),
294            });
295        }
296        let n = self.worlds.len();
297        let expected_out = n * self.obs_output_len;
298        let expected_mask = n * self.obs_mask_len;
299        if output.len() < expected_out {
300            return Err(BatchError::InvalidArgument {
301                reason: format!("output buffer too small: {} < {expected_out}", output.len()),
302            });
303        }
304        if mask.len() < expected_mask {
305            return Err(BatchError::InvalidArgument {
306                reason: format!("mask buffer too small: {} < {expected_mask}", mask.len()),
307            });
308        }
309        Ok(())
310    }
311
312    /// Reset a single world by index.
313    pub fn reset_world(&mut self, idx: usize, seed: u64) -> Result<(), BatchError> {
314        let n = self.worlds.len();
315        let world = self.worlds.get_mut(idx).ok_or(BatchError::InvalidIndex {
316            world_index: idx,
317            num_worlds: n,
318        })?;
319        world.reset(seed).map_err(BatchError::Config)?;
320        Ok(())
321    }
322
323    /// Reset all worlds with per-world seeds.
324    pub fn reset_all(&mut self, seeds: &[u64]) -> Result<(), BatchError> {
325        let n = self.worlds.len();
326        if seeds.len() != n {
327            return Err(BatchError::InvalidArgument {
328                reason: format!("seeds has {} entries, expected {n}", seeds.len()),
329            });
330        }
331        for (idx, world) in self.worlds.iter_mut().enumerate() {
332            world.reset(seeds[idx]).map_err(BatchError::Config)?;
333        }
334        Ok(())
335    }
336
337    /// Number of worlds in the batch.
338    pub fn num_worlds(&self) -> usize {
339        self.worlds.len()
340    }
341
342    /// Per-world observation output length (f32 elements).
343    pub fn obs_output_len(&self) -> usize {
344        self.obs_output_len
345    }
346
347    /// Per-world observation mask length (bytes).
348    pub fn obs_mask_len(&self) -> usize {
349        self.obs_mask_len
350    }
351
352    /// Current tick ID of a specific world.
353    pub fn world_tick(&self, idx: usize) -> Option<TickId> {
354        self.worlds.get(idx).map(|w| w.current_tick())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use murk_core::id::FieldId;
362    use murk_core::traits::FieldReader;
363    use murk_obs::spec::{ObsDtype, ObsEntry, ObsRegion, ObsTransform};
364    use murk_space::{EdgeBehavior, Line1D, RegionSpec, Square4};
365    use murk_test_utils::ConstPropagator;
366
367    use crate::config::BackoffConfig;
368
369    fn scalar_field(name: &str) -> murk_core::FieldDef {
370        murk_core::FieldDef {
371            name: name.to_string(),
372            field_type: murk_core::FieldType::Scalar,
373            mutability: murk_core::FieldMutability::PerTick,
374            units: None,
375            bounds: None,
376            boundary_behavior: murk_core::BoundaryBehavior::Clamp,
377        }
378    }
379
380    fn make_config(seed: u64, value: f32) -> WorldConfig {
381        WorldConfig {
382            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
383            fields: vec![scalar_field("energy")],
384            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), value))],
385            dt: 0.1,
386            seed,
387            ring_buffer_size: 8,
388            max_ingress_queue: 1024,
389            tick_rate_hz: None,
390            backoff: BackoffConfig::default(),
391        }
392    }
393
394    fn make_grid_config(seed: u64, value: f32) -> WorldConfig {
395        WorldConfig {
396            space: Box::new(Square4::new(4, 4, EdgeBehavior::Absorb).unwrap()),
397            fields: vec![scalar_field("energy")],
398            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), value))],
399            dt: 0.1,
400            seed,
401            ring_buffer_size: 8,
402            max_ingress_queue: 1024,
403            tick_rate_hz: None,
404            backoff: BackoffConfig::default(),
405        }
406    }
407
408    fn obs_spec_all_field0() -> ObsSpec {
409        ObsSpec {
410            entries: vec![ObsEntry {
411                field_id: FieldId(0),
412                region: ObsRegion::Fixed(RegionSpec::All),
413                pool: None,
414                transform: ObsTransform::Identity,
415                dtype: ObsDtype::F32,
416            }],
417        }
418    }
419
420    // ── Construction tests ────────────────────────────────────
421
422    #[test]
423    fn new_single_world() {
424        let configs = vec![make_config(42, 1.0)];
425        let engine = BatchedEngine::new(configs, None).unwrap();
426        assert_eq!(engine.num_worlds(), 1);
427        assert_eq!(engine.obs_output_len(), 0);
428        assert_eq!(engine.obs_mask_len(), 0);
429    }
430
431    #[test]
432    fn new_four_worlds() {
433        let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
434        let engine = BatchedEngine::new(configs, None).unwrap();
435        assert_eq!(engine.num_worlds(), 4);
436    }
437
438    #[test]
439    fn new_zero_worlds_is_error() {
440        let result = BatchedEngine::new(vec![], None);
441        assert!(result.is_err());
442    }
443
444    #[test]
445    fn new_with_obs_spec() {
446        let configs = vec![make_config(42, 1.0)];
447        let spec = obs_spec_all_field0();
448        let engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
449        assert_eq!(engine.obs_output_len(), 10); // Line1D(10) → 10 cells
450        assert_eq!(engine.obs_mask_len(), 10);
451    }
452
453    // ── Determinism test ──────────────────────────────────────
454
455    #[test]
456    fn batch_matches_independent_worlds() {
457        let spec = obs_spec_all_field0();
458
459        // Batched: 2 worlds
460        let configs = vec![make_config(42, 42.0), make_config(99, 42.0)];
461        let mut batched = BatchedEngine::new(configs, Some(&spec)).unwrap();
462        let n = batched.num_worlds();
463        let out_len = n * batched.obs_output_len();
464        let mask_len = n * batched.obs_mask_len();
465        let mut batch_output = vec![0.0f32; out_len];
466        let mut batch_mask = vec![0u8; mask_len];
467
468        let commands = vec![vec![], vec![]];
469        batched
470            .step_and_observe(&commands, &mut batch_output, &mut batch_mask)
471            .unwrap();
472
473        // Independent: 2 separate worlds
474        let mut w0 = LockstepWorld::new(make_config(42, 42.0)).unwrap();
475        let mut w1 = LockstepWorld::new(make_config(99, 42.0)).unwrap();
476        let r0 = w0.step_sync(vec![]).unwrap();
477        let r1 = w1.step_sync(vec![]).unwrap();
478
479        let d0 = r0.snapshot.read(FieldId(0)).unwrap();
480        let d1 = r1.snapshot.read(FieldId(0)).unwrap();
481
482        // Batch output should be [world0_obs | world1_obs]
483        assert_eq!(&batch_output[..10], d0);
484        assert_eq!(&batch_output[10..20], d1);
485    }
486
487    // ── Observation correctness ───────────────────────────────
488
489    #[test]
490    fn observation_filled_with_const_value() {
491        let spec = obs_spec_all_field0();
492        let configs = vec![
493            make_config(1, 42.0),
494            make_config(2, 42.0),
495            make_config(3, 42.0),
496        ];
497        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
498
499        let commands = vec![vec![], vec![], vec![]];
500        let n = engine.num_worlds();
501        let mut output = vec![0.0f32; n * engine.obs_output_len()];
502        let mut mask = vec![0u8; n * engine.obs_mask_len()];
503        engine
504            .step_and_observe(&commands, &mut output, &mut mask)
505            .unwrap();
506
507        // All cells should be 42.0 for all worlds.
508        assert!(output.iter().all(|&v| v == 42.0));
509        assert!(mask.iter().all(|&m| m == 1));
510    }
511
512    // ── Reset tests ───────────────────────────────────────────
513
514    #[test]
515    fn reset_single_world_preserves_others() {
516        let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
517        let mut engine = BatchedEngine::new(configs, None).unwrap();
518
519        // Step all once.
520        let commands = vec![vec![]; 4];
521        engine.step_all(&commands).unwrap();
522        assert_eq!(engine.world_tick(0), Some(TickId(1)));
523        assert_eq!(engine.world_tick(3), Some(TickId(1)));
524
525        // Reset only world 0.
526        engine.reset_world(0, 999).unwrap();
527        assert_eq!(engine.world_tick(0), Some(TickId(0)));
528        assert_eq!(engine.world_tick(1), Some(TickId(1)));
529        assert_eq!(engine.world_tick(2), Some(TickId(1)));
530        assert_eq!(engine.world_tick(3), Some(TickId(1)));
531    }
532
533    #[test]
534    fn reset_all_resets_to_tick_zero() {
535        let configs: Vec<_> = (0..3).map(|i| make_config(i, 1.0)).collect();
536        let mut engine = BatchedEngine::new(configs, None).unwrap();
537
538        // Step all twice.
539        let commands = vec![vec![]; 3];
540        engine.step_all(&commands).unwrap();
541        engine.step_all(&commands).unwrap();
542
543        engine.reset_all(&[10, 20, 30]).unwrap();
544        for i in 0..3 {
545            assert_eq!(engine.world_tick(i), Some(TickId(0)));
546        }
547    }
548
549    // ── Error isolation ───────────────────────────────────────
550
551    #[test]
552    fn invalid_world_index_returns_error() {
553        let configs = vec![make_config(0, 1.0)];
554        let mut engine = BatchedEngine::new(configs, None).unwrap();
555
556        let result = engine.reset_world(5, 0);
557        assert!(matches!(result, Err(BatchError::InvalidIndex { .. })));
558    }
559
560    #[test]
561    fn wrong_command_count_returns_error() {
562        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
563        let mut engine = BatchedEngine::new(configs, None).unwrap();
564
565        let result = engine.step_all(&[vec![]]); // 1 entry for 2 worlds
566        assert!(result.is_err());
567    }
568
569    #[test]
570    fn observe_without_plan_returns_error() {
571        let configs = vec![make_config(0, 1.0)];
572        let engine = BatchedEngine::new(configs, None).unwrap();
573
574        let mut output = vec![0.0f32; 10];
575        let mut mask = vec![0u8; 10];
576        let result = engine.observe_all(&mut output, &mut mask);
577        assert!(matches!(result, Err(BatchError::NoObsPlan)));
578    }
579
580    // ── Observe after reset ───────────────────────────────────
581
582    #[test]
583    fn observe_all_after_reset() {
584        let spec = obs_spec_all_field0();
585        let configs = vec![make_config(1, 42.0), make_config(2, 42.0)];
586        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
587
588        // Step once to populate data.
589        let commands = vec![vec![], vec![]];
590        let n = engine.num_worlds();
591        let mut output = vec![0.0f32; n * engine.obs_output_len()];
592        let mut mask = vec![0u8; n * engine.obs_mask_len()];
593        engine
594            .step_and_observe(&commands, &mut output, &mut mask)
595            .unwrap();
596
597        // Reset all and observe (initial state is zeroed).
598        engine.reset_all(&[10, 20]).unwrap();
599        let meta = engine.observe_all(&mut output, &mut mask).unwrap();
600        assert_eq!(meta.len(), 2);
601        assert_eq!(meta[0].tick_id, TickId(0));
602        assert_eq!(meta[1].tick_id, TickId(0));
603    }
604
605    // ── Topology validation ──────────────────────────────────
606
607    #[test]
608    fn mixed_space_types_rejected() {
609        use murk_space::Ring1D;
610
611        // Line1D(10) and Ring1D(10): same ndim, same cell_count, different type.
612        let line_config = WorldConfig {
613            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
614            fields: vec![scalar_field("energy")],
615            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
616            dt: 0.1,
617            seed: 1,
618            ring_buffer_size: 8,
619            max_ingress_queue: 1024,
620            tick_rate_hz: None,
621            backoff: BackoffConfig::default(),
622        };
623        let ring_config = WorldConfig {
624            space: Box::new(Ring1D::new(10).unwrap()),
625            fields: vec![scalar_field("energy")],
626            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
627            dt: 0.1,
628            seed: 2,
629            ring_buffer_size: 8,
630            max_ingress_queue: 1024,
631            tick_rate_hz: None,
632            backoff: BackoffConfig::default(),
633        };
634
635        let result = BatchedEngine::new(vec![line_config, ring_config], None);
636        match result {
637            Err(e) => {
638                let msg = format!("{e}");
639                assert!(msg.contains("incompatible space topologies"), "got: {msg}");
640            }
641            Ok(_) => panic!("expected error for mixed space types"),
642        }
643    }
644
645    #[test]
646    fn mixed_edge_behaviors_rejected() {
647        // Line1D(10, Absorb) and Line1D(10, Wrap): same TypeId, ndim, cell_count,
648        // but different edge behavior — must be rejected.
649        let absorb_config = WorldConfig {
650            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
651            fields: vec![scalar_field("energy")],
652            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
653            dt: 0.1,
654            seed: 1,
655            ring_buffer_size: 8,
656            max_ingress_queue: 1024,
657            tick_rate_hz: None,
658            backoff: BackoffConfig::default(),
659        };
660        let wrap_config = WorldConfig {
661            space: Box::new(Line1D::new(10, EdgeBehavior::Wrap).unwrap()),
662            fields: vec![scalar_field("energy")],
663            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
664            dt: 0.1,
665            seed: 2,
666            ring_buffer_size: 8,
667            max_ingress_queue: 1024,
668            tick_rate_hz: None,
669            backoff: BackoffConfig::default(),
670        };
671
672        let result = BatchedEngine::new(vec![absorb_config, wrap_config], None);
673        assert!(result.is_err(), "expected error for mixed edge behaviors");
674    }
675
676    // ── Atomic step_and_observe ──────────────────────────────
677
678    #[test]
679    fn step_and_observe_no_plan_does_not_step() {
680        // Without an obs plan, step_and_observe should fail *before*
681        // advancing any world state.
682        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
683        let mut engine = BatchedEngine::new(configs, None).unwrap();
684
685        let commands = vec![vec![], vec![]];
686        let mut output = vec![0.0f32; 20];
687        let mut mask = vec![0u8; 20];
688        let result = engine.step_and_observe(&commands, &mut output, &mut mask);
689        assert!(matches!(result, Err(BatchError::NoObsPlan)));
690
691        // Worlds must still be at tick 0 — no mutation occurred.
692        assert_eq!(engine.world_tick(0), Some(TickId(0)));
693        assert_eq!(engine.world_tick(1), Some(TickId(0)));
694    }
695
696    #[test]
697    fn step_and_observe_small_buffer_does_not_step() {
698        // Buffer too small should fail before advancing world state.
699        let spec = obs_spec_all_field0();
700        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
701        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
702
703        let commands = vec![vec![], vec![]];
704        let mut output = vec![0.0f32; 5]; // need 20, only 5
705        let mut mask = vec![0u8; 20];
706        let result = engine.step_and_observe(&commands, &mut output, &mut mask);
707        assert!(result.is_err());
708
709        // Worlds must still be at tick 0.
710        assert_eq!(engine.world_tick(0), Some(TickId(0)));
711        assert_eq!(engine.world_tick(1), Some(TickId(0)));
712    }
713
714    #[test]
715    fn step_and_observe_agent_relative_plan_does_not_step() {
716        let spec = ObsSpec {
717            entries: vec![ObsEntry {
718                field_id: FieldId(0),
719                region: ObsRegion::AgentRect {
720                    half_extent: smallvec::smallvec![1, 1],
721                },
722                pool: None,
723                transform: ObsTransform::Identity,
724                dtype: ObsDtype::F32,
725            }],
726        };
727        let configs = vec![make_grid_config(0, 1.0), make_grid_config(1, 1.0)];
728        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
729        let n = engine.num_worlds();
730        let mut output = vec![0.0f32; n * engine.obs_output_len()];
731        let mut mask = vec![0u8; n * engine.obs_mask_len()];
732
733        let result = engine.step_and_observe(&[vec![], vec![]], &mut output, &mut mask);
734        match result {
735            Err(BatchError::InvalidArgument { reason }) => {
736                assert!(
737                    reason.contains("AgentDisk/AgentRect"),
738                    "unexpected reason: {reason}"
739                );
740            }
741            _ => panic!("expected InvalidArgument for agent-relative plan"),
742        }
743
744        assert_eq!(engine.world_tick(0), Some(TickId(0)));
745        assert_eq!(engine.world_tick(1), Some(TickId(0)));
746    }
747
748    // ── Field schema validation ─────────────────────────────
749
750    #[test]
751    fn mismatched_field_schemas_rejected() {
752        // World 0 has 2 fields, world 1 has only 1. Obs spec references
753        // FieldId(1) which is missing in world 1. Construction must fail.
754        let spec = ObsSpec {
755            entries: vec![
756                ObsEntry {
757                    field_id: FieldId(0),
758                    region: ObsRegion::Fixed(RegionSpec::All),
759                    pool: None,
760                    transform: ObsTransform::Identity,
761                    dtype: ObsDtype::F32,
762                },
763                ObsEntry {
764                    field_id: FieldId(1),
765                    region: ObsRegion::Fixed(RegionSpec::All),
766                    pool: None,
767                    transform: ObsTransform::Identity,
768                    dtype: ObsDtype::F32,
769                },
770            ],
771        };
772
773        // World 0: has 2 fields (FieldId(0) and FieldId(1))
774        let config_two_fields = WorldConfig {
775            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
776            fields: vec![scalar_field("energy"), scalar_field("temp")],
777            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
778            dt: 0.1,
779            seed: 1,
780            ring_buffer_size: 8,
781            max_ingress_queue: 1024,
782            tick_rate_hz: None,
783            backoff: BackoffConfig::default(),
784        };
785
786        // World 1: has only 1 field (FieldId(0)), missing FieldId(1)
787        let config_one_field = WorldConfig {
788            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
789            fields: vec![scalar_field("energy")],
790            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
791            dt: 0.1,
792            seed: 2,
793            ring_buffer_size: 8,
794            max_ingress_queue: 1024,
795            tick_rate_hz: None,
796            backoff: BackoffConfig::default(),
797        };
798
799        let result = BatchedEngine::new(vec![config_two_fields, config_one_field], Some(&spec));
800        match result {
801            Err(e) => {
802                let msg = format!("{e}");
803                assert!(
804                    msg.contains("field") && msg.contains("missing"),
805                    "error should mention missing field, got: {msg}"
806                );
807            }
808            Ok(_) => panic!("expected error for mismatched field schemas"),
809        }
810    }
811}