Skip to main content

murk_engine/
batched.rs

1//! Batched simulation engine for vectorized RL training.
2//!
3//! [`BatchedEngine`] owns N [`LockstepWorld`]s and steps them all in a
4//! single call, eliminating per-world FFI overhead. Observation extraction
5//! uses [`ObsPlan::execute_batch()`] to fill a contiguous output buffer
6//! across all worlds.
7//!
8//! # Design
9//!
10//! The hot path is `step_and_observe`: step all worlds sequentially, then
11//! extract observations in batch. The GIL is released once at the Python
12//! layer, covering the entire operation. This reduces 2N GIL cycles (the
13//! current `MurkVecEnv` approach) to exactly 1.
14//!
15//! Parallelism (rayon) is deferred to v2. The GIL elimination alone is
16//! the dominant win; adding `par_iter_mut` later is a 3-line change.
17
18use murk_core::command::Command;
19use murk_core::error::ObsError;
20use murk_core::id::TickId;
21use murk_core::traits::SnapshotAccess;
22use murk_obs::metadata::ObsMetadata;
23use murk_obs::plan::ObsPlan;
24use murk_obs::spec::ObsSpec;
25
26use crate::config::{ConfigError, WorldConfig};
27use crate::lockstep::LockstepWorld;
28use crate::metrics::StepMetrics;
29use crate::tick::TickError;
30
31// ── Error type ──────────────────────────────────────────────────
32
33/// Error from a batched operation, annotated with the failing world index.
34#[derive(Debug, PartialEq)]
35pub enum BatchError {
36    /// A world's `step_sync()` failed.
37    Step {
38        /// Index of the world that failed (0-based).
39        world_index: usize,
40        /// The underlying tick error.
41        error: TickError,
42    },
43    /// Observation extraction failed.
44    Observe(ObsError),
45    /// Configuration error during construction or reset.
46    Config(ConfigError),
47    /// World index out of bounds.
48    InvalidIndex {
49        /// The requested index.
50        world_index: usize,
51        /// Total number of worlds.
52        num_worlds: usize,
53    },
54    /// No observation plan was compiled (called observe without obs_spec).
55    NoObsPlan,
56    /// Batch-level argument validation failed.
57    InvalidArgument {
58        /// Human-readable description of what's wrong.
59        reason: String,
60    },
61}
62
63impl std::fmt::Display for BatchError {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            BatchError::Step { world_index, error } => {
67                write!(f, "world {world_index}: step failed: {error:?}")
68            }
69            BatchError::Observe(e) => write!(f, "observe failed: {e:?}"),
70            BatchError::Config(e) => write!(f, "config error: {e:?}"),
71            BatchError::InvalidIndex {
72                world_index,
73                num_worlds,
74            } => write!(
75                f,
76                "world index {world_index} out of range (num_worlds={num_worlds})"
77            ),
78            BatchError::NoObsPlan => write!(f, "no observation plan compiled"),
79            BatchError::InvalidArgument { reason } => {
80                write!(f, "invalid argument: {reason}")
81            }
82        }
83    }
84}
85
86impl std::error::Error for BatchError {}
87
88// ── Result type ─────────────────────────────────────────────────
89
90/// Result of stepping a batch of worlds.
91pub struct BatchResult {
92    /// Per-world tick IDs after stepping.
93    pub tick_ids: Vec<TickId>,
94    /// Per-world step metrics.
95    pub metrics: Vec<StepMetrics>,
96}
97
98// ── BatchedEngine ───────────────────────────────────────────────
99
100/// Batched simulation engine owning N lockstep worlds.
101///
102/// Created from N [`WorldConfig`]s with an optional [`ObsSpec`].
103/// All worlds must share the same space topology (validated at
104/// construction).
105///
106/// The primary interface is [`step_and_observe()`](Self::step_and_observe):
107/// step all worlds, then extract observations into a contiguous buffer
108/// using [`ObsPlan::execute_batch()`].
109pub struct BatchedEngine {
110    worlds: Vec<LockstepWorld>,
111    obs_plan: Option<ObsPlan>,
112    obs_output_len: usize,
113    obs_mask_len: usize,
114}
115
116impl BatchedEngine {
117    /// Create a batched engine from N world configs.
118    ///
119    /// If `obs_spec` is provided, compiles an [`ObsPlan`] from the first
120    /// world's space. All worlds must have the same `cell_count`
121    /// (defensive check).
122    ///
123    /// # Errors
124    ///
125    /// Returns [`BatchError::Config`] if any world fails to construct,
126    /// or [`BatchError::Observe`] if the obs plan fails to compile.
127    pub fn new(configs: Vec<WorldConfig>, obs_spec: Option<&ObsSpec>) -> Result<Self, BatchError> {
128        if configs.is_empty() {
129            return Err(BatchError::InvalidArgument {
130                reason: "BatchedEngine requires at least one world config".into(),
131            });
132        }
133
134        let mut worlds = Vec::with_capacity(configs.len());
135        for config in configs {
136            let world = LockstepWorld::new(config).map_err(BatchError::Config)?;
137            worlds.push(world);
138        }
139
140        // Validate all worlds share the same space topology.
141        // topology_eq checks TypeId, dimensions, and behavioral parameters
142        // (e.g. EdgeBehavior) so that spaces like Line1D(10, Absorb) and
143        // Line1D(10, Wrap) are correctly rejected.
144        let ref_space = worlds[0].space();
145        for (i, world) in worlds.iter().enumerate().skip(1) {
146            if !ref_space.topology_eq(world.space()) {
147                return Err(BatchError::InvalidArgument {
148                    reason: format!(
149                        "world 0 and world {i} have incompatible space topologies; \
150                         all worlds in a batch must use the same topology"
151                    ),
152                });
153            }
154        }
155
156        // Compile obs plan if spec provided.
157        let (obs_plan, obs_output_len, obs_mask_len) = match obs_spec {
158            Some(spec) => {
159                let result =
160                    ObsPlan::compile(spec, worlds[0].space()).map_err(BatchError::Observe)?;
161
162                // Validate all worlds have matching field schemas for observed fields.
163                // ObsPlan::compile only takes a Space (not a snapshot), so field
164                // existence isn't checked until execute(). Catching mismatches here
165                // prevents late observation failures after worlds have been stepped.
166                let ref_snap = worlds[0].snapshot();
167                for entry in &spec.entries {
168                    let fid = entry.field_id;
169                    let ref_len = ref_snap.read_field(fid).map(|d| d.len());
170                    for (i, world) in worlds.iter().enumerate().skip(1) {
171                        let snap = world.snapshot();
172                        let other_len = snap.read_field(fid).map(|d| d.len());
173                        if other_len != ref_len {
174                            return Err(BatchError::InvalidArgument {
175                                reason: format!(
176                                    "world {i} field {fid:?}: {} elements, \
177                                     world 0 has {} elements; \
178                                     all worlds must share the same field schema",
179                                    other_len
180                                        .map(|n| n.to_string())
181                                        .unwrap_or_else(|| "missing".into()),
182                                    ref_len
183                                        .map(|n| n.to_string())
184                                        .unwrap_or_else(|| "missing".into()),
185                                ),
186                            });
187                        }
188                    }
189                }
190
191                (Some(result.plan), result.output_len, result.mask_len)
192            }
193            None => (None, 0, 0),
194        };
195
196        Ok(BatchedEngine {
197            worlds,
198            obs_plan,
199            obs_output_len,
200            obs_mask_len,
201        })
202    }
203
204    /// Step all worlds and extract observations in one call.
205    ///
206    /// `commands` must have exactly `num_worlds()` entries.
207    /// `output` must have at least `num_worlds() * obs_output_len()` elements.
208    /// `mask` must have at least `num_worlds() * obs_mask_len()` bytes.
209    ///
210    /// Returns per-world tick IDs and metrics.
211    pub fn step_and_observe(
212        &mut self,
213        commands: &[Vec<Command>],
214        output: &mut [f32],
215        mask: &mut [u8],
216    ) -> Result<BatchResult, BatchError> {
217        // Pre-flight: validate observation preconditions before mutating
218        // world state. Without this, a late observe failure (no obs plan,
219        // buffer too small) would leave worlds stepped but observations
220        // unextracted — making the error non-atomic.
221        self.validate_observe_buffers(output, mask)?;
222
223        let result = self.step_all(commands)?;
224
225        // Observe phase: borrow worlds immutably for snapshot collection.
226        self.observe_all_inner(output, mask)?;
227
228        Ok(result)
229    }
230
231    /// Step all worlds without observation extraction.
232    pub fn step_all(&mut self, commands: &[Vec<Command>]) -> Result<BatchResult, BatchError> {
233        let n = self.worlds.len();
234        if commands.len() != n {
235            return Err(BatchError::InvalidArgument {
236                reason: format!("commands has {} entries, expected {n}", commands.len()),
237            });
238        }
239
240        let mut tick_ids = Vec::with_capacity(n);
241        let mut metrics = Vec::with_capacity(n);
242
243        for (idx, world) in self.worlds.iter_mut().enumerate() {
244            let result = world
245                .step_sync(commands[idx].clone())
246                .map_err(|e| BatchError::Step {
247                    world_index: idx,
248                    error: e,
249                })?;
250            tick_ids.push(result.snapshot.tick_id());
251            metrics.push(result.metrics);
252        }
253
254        Ok(BatchResult { tick_ids, metrics })
255    }
256
257    /// Extract observations from all worlds without stepping.
258    ///
259    /// Used after `reset_all()` to get initial observations.
260    pub fn observe_all(
261        &self,
262        output: &mut [f32],
263        mask: &mut [u8],
264    ) -> Result<Vec<ObsMetadata>, BatchError> {
265        self.observe_all_inner(output, mask)
266    }
267
268    /// Internal observation extraction shared by step_and_observe and observe_all.
269    fn observe_all_inner(
270        &self,
271        output: &mut [f32],
272        mask: &mut [u8],
273    ) -> Result<Vec<ObsMetadata>, BatchError> {
274        let plan = self.obs_plan.as_ref().ok_or(BatchError::NoObsPlan)?;
275
276        let snapshots: Vec<_> = self.worlds.iter().map(|w| w.snapshot()).collect();
277        let snap_refs: Vec<&dyn SnapshotAccess> =
278            snapshots.iter().map(|s| s as &dyn SnapshotAccess).collect();
279
280        plan.execute_batch(&snap_refs, None, output, mask)
281            .map_err(BatchError::Observe)
282    }
283
284    /// Validate that observation preconditions are met (plan exists, buffers
285    /// large enough) without performing any mutation. Called by
286    /// `step_and_observe` before `step_all` to guarantee atomicity.
287    fn validate_observe_buffers(&self, output: &[f32], mask: &[u8]) -> Result<(), BatchError> {
288        if self.obs_plan.is_none() {
289            return Err(BatchError::NoObsPlan);
290        }
291        let n = self.worlds.len();
292        let expected_out = n * self.obs_output_len;
293        let expected_mask = n * self.obs_mask_len;
294        if output.len() < expected_out {
295            return Err(BatchError::InvalidArgument {
296                reason: format!("output buffer too small: {} < {expected_out}", output.len()),
297            });
298        }
299        if mask.len() < expected_mask {
300            return Err(BatchError::InvalidArgument {
301                reason: format!("mask buffer too small: {} < {expected_mask}", mask.len()),
302            });
303        }
304        Ok(())
305    }
306
307    /// Reset a single world by index.
308    pub fn reset_world(&mut self, idx: usize, seed: u64) -> Result<(), BatchError> {
309        let n = self.worlds.len();
310        let world = self.worlds.get_mut(idx).ok_or(BatchError::InvalidIndex {
311            world_index: idx,
312            num_worlds: n,
313        })?;
314        world.reset(seed).map_err(BatchError::Config)?;
315        Ok(())
316    }
317
318    /// Reset all worlds with per-world seeds.
319    pub fn reset_all(&mut self, seeds: &[u64]) -> Result<(), BatchError> {
320        let n = self.worlds.len();
321        if seeds.len() != n {
322            return Err(BatchError::InvalidArgument {
323                reason: format!("seeds has {} entries, expected {n}", seeds.len()),
324            });
325        }
326        for (idx, world) in self.worlds.iter_mut().enumerate() {
327            world.reset(seeds[idx]).map_err(BatchError::Config)?;
328        }
329        Ok(())
330    }
331
332    /// Number of worlds in the batch.
333    pub fn num_worlds(&self) -> usize {
334        self.worlds.len()
335    }
336
337    /// Per-world observation output length (f32 elements).
338    pub fn obs_output_len(&self) -> usize {
339        self.obs_output_len
340    }
341
342    /// Per-world observation mask length (bytes).
343    pub fn obs_mask_len(&self) -> usize {
344        self.obs_mask_len
345    }
346
347    /// Current tick ID of a specific world.
348    pub fn world_tick(&self, idx: usize) -> Option<TickId> {
349        self.worlds.get(idx).map(|w| w.current_tick())
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use murk_core::id::FieldId;
357    use murk_core::traits::FieldReader;
358    use murk_obs::spec::{ObsDtype, ObsEntry, ObsRegion, ObsTransform};
359    use murk_space::{EdgeBehavior, Line1D, RegionSpec};
360    use murk_test_utils::ConstPropagator;
361
362    use crate::config::BackoffConfig;
363
364    fn scalar_field(name: &str) -> murk_core::FieldDef {
365        murk_core::FieldDef {
366            name: name.to_string(),
367            field_type: murk_core::FieldType::Scalar,
368            mutability: murk_core::FieldMutability::PerTick,
369            units: None,
370            bounds: None,
371            boundary_behavior: murk_core::BoundaryBehavior::Clamp,
372        }
373    }
374
375    fn make_config(seed: u64, value: f32) -> WorldConfig {
376        WorldConfig {
377            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
378            fields: vec![scalar_field("energy")],
379            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), value))],
380            dt: 0.1,
381            seed,
382            ring_buffer_size: 8,
383            max_ingress_queue: 1024,
384            tick_rate_hz: None,
385            backoff: BackoffConfig::default(),
386        }
387    }
388
389    fn obs_spec_all_field0() -> ObsSpec {
390        ObsSpec {
391            entries: vec![ObsEntry {
392                field_id: FieldId(0),
393                region: ObsRegion::Fixed(RegionSpec::All),
394                pool: None,
395                transform: ObsTransform::Identity,
396                dtype: ObsDtype::F32,
397            }],
398        }
399    }
400
401    // ── Construction tests ────────────────────────────────────
402
403    #[test]
404    fn new_single_world() {
405        let configs = vec![make_config(42, 1.0)];
406        let engine = BatchedEngine::new(configs, None).unwrap();
407        assert_eq!(engine.num_worlds(), 1);
408        assert_eq!(engine.obs_output_len(), 0);
409        assert_eq!(engine.obs_mask_len(), 0);
410    }
411
412    #[test]
413    fn new_four_worlds() {
414        let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
415        let engine = BatchedEngine::new(configs, None).unwrap();
416        assert_eq!(engine.num_worlds(), 4);
417    }
418
419    #[test]
420    fn new_zero_worlds_is_error() {
421        let result = BatchedEngine::new(vec![], None);
422        assert!(result.is_err());
423    }
424
425    #[test]
426    fn new_with_obs_spec() {
427        let configs = vec![make_config(42, 1.0)];
428        let spec = obs_spec_all_field0();
429        let engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
430        assert_eq!(engine.obs_output_len(), 10); // Line1D(10) → 10 cells
431        assert_eq!(engine.obs_mask_len(), 10);
432    }
433
434    // ── Determinism test ──────────────────────────────────────
435
436    #[test]
437    fn batch_matches_independent_worlds() {
438        let spec = obs_spec_all_field0();
439
440        // Batched: 2 worlds
441        let configs = vec![make_config(42, 42.0), make_config(99, 42.0)];
442        let mut batched = BatchedEngine::new(configs, Some(&spec)).unwrap();
443        let n = batched.num_worlds();
444        let out_len = n * batched.obs_output_len();
445        let mask_len = n * batched.obs_mask_len();
446        let mut batch_output = vec![0.0f32; out_len];
447        let mut batch_mask = vec![0u8; mask_len];
448
449        let commands = vec![vec![], vec![]];
450        batched
451            .step_and_observe(&commands, &mut batch_output, &mut batch_mask)
452            .unwrap();
453
454        // Independent: 2 separate worlds
455        let mut w0 = LockstepWorld::new(make_config(42, 42.0)).unwrap();
456        let mut w1 = LockstepWorld::new(make_config(99, 42.0)).unwrap();
457        let r0 = w0.step_sync(vec![]).unwrap();
458        let r1 = w1.step_sync(vec![]).unwrap();
459
460        let d0 = r0.snapshot.read(FieldId(0)).unwrap();
461        let d1 = r1.snapshot.read(FieldId(0)).unwrap();
462
463        // Batch output should be [world0_obs | world1_obs]
464        assert_eq!(&batch_output[..10], d0);
465        assert_eq!(&batch_output[10..20], d1);
466    }
467
468    // ── Observation correctness ───────────────────────────────
469
470    #[test]
471    fn observation_filled_with_const_value() {
472        let spec = obs_spec_all_field0();
473        let configs = vec![
474            make_config(1, 42.0),
475            make_config(2, 42.0),
476            make_config(3, 42.0),
477        ];
478        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
479
480        let commands = vec![vec![], vec![], vec![]];
481        let n = engine.num_worlds();
482        let mut output = vec![0.0f32; n * engine.obs_output_len()];
483        let mut mask = vec![0u8; n * engine.obs_mask_len()];
484        engine
485            .step_and_observe(&commands, &mut output, &mut mask)
486            .unwrap();
487
488        // All cells should be 42.0 for all worlds.
489        assert!(output.iter().all(|&v| v == 42.0));
490        assert!(mask.iter().all(|&m| m == 1));
491    }
492
493    // ── Reset tests ───────────────────────────────────────────
494
495    #[test]
496    fn reset_single_world_preserves_others() {
497        let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
498        let mut engine = BatchedEngine::new(configs, None).unwrap();
499
500        // Step all once.
501        let commands = vec![vec![]; 4];
502        engine.step_all(&commands).unwrap();
503        assert_eq!(engine.world_tick(0), Some(TickId(1)));
504        assert_eq!(engine.world_tick(3), Some(TickId(1)));
505
506        // Reset only world 0.
507        engine.reset_world(0, 999).unwrap();
508        assert_eq!(engine.world_tick(0), Some(TickId(0)));
509        assert_eq!(engine.world_tick(1), Some(TickId(1)));
510        assert_eq!(engine.world_tick(2), Some(TickId(1)));
511        assert_eq!(engine.world_tick(3), Some(TickId(1)));
512    }
513
514    #[test]
515    fn reset_all_resets_to_tick_zero() {
516        let configs: Vec<_> = (0..3).map(|i| make_config(i, 1.0)).collect();
517        let mut engine = BatchedEngine::new(configs, None).unwrap();
518
519        // Step all twice.
520        let commands = vec![vec![]; 3];
521        engine.step_all(&commands).unwrap();
522        engine.step_all(&commands).unwrap();
523
524        engine.reset_all(&[10, 20, 30]).unwrap();
525        for i in 0..3 {
526            assert_eq!(engine.world_tick(i), Some(TickId(0)));
527        }
528    }
529
530    // ── Error isolation ───────────────────────────────────────
531
532    #[test]
533    fn invalid_world_index_returns_error() {
534        let configs = vec![make_config(0, 1.0)];
535        let mut engine = BatchedEngine::new(configs, None).unwrap();
536
537        let result = engine.reset_world(5, 0);
538        assert!(matches!(result, Err(BatchError::InvalidIndex { .. })));
539    }
540
541    #[test]
542    fn wrong_command_count_returns_error() {
543        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
544        let mut engine = BatchedEngine::new(configs, None).unwrap();
545
546        let result = engine.step_all(&[vec![]]); // 1 entry for 2 worlds
547        assert!(result.is_err());
548    }
549
550    #[test]
551    fn observe_without_plan_returns_error() {
552        let configs = vec![make_config(0, 1.0)];
553        let engine = BatchedEngine::new(configs, None).unwrap();
554
555        let mut output = vec![0.0f32; 10];
556        let mut mask = vec![0u8; 10];
557        let result = engine.observe_all(&mut output, &mut mask);
558        assert!(matches!(result, Err(BatchError::NoObsPlan)));
559    }
560
561    // ── Observe after reset ───────────────────────────────────
562
563    #[test]
564    fn observe_all_after_reset() {
565        let spec = obs_spec_all_field0();
566        let configs = vec![make_config(1, 42.0), make_config(2, 42.0)];
567        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
568
569        // Step once to populate data.
570        let commands = vec![vec![], vec![]];
571        let n = engine.num_worlds();
572        let mut output = vec![0.0f32; n * engine.obs_output_len()];
573        let mut mask = vec![0u8; n * engine.obs_mask_len()];
574        engine
575            .step_and_observe(&commands, &mut output, &mut mask)
576            .unwrap();
577
578        // Reset all and observe (initial state is zeroed).
579        engine.reset_all(&[10, 20]).unwrap();
580        let meta = engine.observe_all(&mut output, &mut mask).unwrap();
581        assert_eq!(meta.len(), 2);
582        assert_eq!(meta[0].tick_id, TickId(0));
583        assert_eq!(meta[1].tick_id, TickId(0));
584    }
585
586    // ── Topology validation ──────────────────────────────────
587
588    #[test]
589    fn mixed_space_types_rejected() {
590        use murk_space::Ring1D;
591
592        // Line1D(10) and Ring1D(10): same ndim, same cell_count, different type.
593        let line_config = WorldConfig {
594            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
595            fields: vec![scalar_field("energy")],
596            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
597            dt: 0.1,
598            seed: 1,
599            ring_buffer_size: 8,
600            max_ingress_queue: 1024,
601            tick_rate_hz: None,
602            backoff: BackoffConfig::default(),
603        };
604        let ring_config = WorldConfig {
605            space: Box::new(Ring1D::new(10).unwrap()),
606            fields: vec![scalar_field("energy")],
607            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
608            dt: 0.1,
609            seed: 2,
610            ring_buffer_size: 8,
611            max_ingress_queue: 1024,
612            tick_rate_hz: None,
613            backoff: BackoffConfig::default(),
614        };
615
616        let result = BatchedEngine::new(vec![line_config, ring_config], None);
617        match result {
618            Err(e) => {
619                let msg = format!("{e}");
620                assert!(msg.contains("incompatible space topologies"), "got: {msg}");
621            }
622            Ok(_) => panic!("expected error for mixed space types"),
623        }
624    }
625
626    #[test]
627    fn mixed_edge_behaviors_rejected() {
628        // Line1D(10, Absorb) and Line1D(10, Wrap): same TypeId, ndim, cell_count,
629        // but different edge behavior — must be rejected.
630        let absorb_config = WorldConfig {
631            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
632            fields: vec![scalar_field("energy")],
633            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
634            dt: 0.1,
635            seed: 1,
636            ring_buffer_size: 8,
637            max_ingress_queue: 1024,
638            tick_rate_hz: None,
639            backoff: BackoffConfig::default(),
640        };
641        let wrap_config = WorldConfig {
642            space: Box::new(Line1D::new(10, EdgeBehavior::Wrap).unwrap()),
643            fields: vec![scalar_field("energy")],
644            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
645            dt: 0.1,
646            seed: 2,
647            ring_buffer_size: 8,
648            max_ingress_queue: 1024,
649            tick_rate_hz: None,
650            backoff: BackoffConfig::default(),
651        };
652
653        let result = BatchedEngine::new(vec![absorb_config, wrap_config], None);
654        assert!(result.is_err(), "expected error for mixed edge behaviors");
655    }
656
657    // ── Atomic step_and_observe ──────────────────────────────
658
659    #[test]
660    fn step_and_observe_no_plan_does_not_step() {
661        // Without an obs plan, step_and_observe should fail *before*
662        // advancing any world state.
663        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
664        let mut engine = BatchedEngine::new(configs, None).unwrap();
665
666        let commands = vec![vec![], vec![]];
667        let mut output = vec![0.0f32; 20];
668        let mut mask = vec![0u8; 20];
669        let result = engine.step_and_observe(&commands, &mut output, &mut mask);
670        assert!(matches!(result, Err(BatchError::NoObsPlan)));
671
672        // Worlds must still be at tick 0 — no mutation occurred.
673        assert_eq!(engine.world_tick(0), Some(TickId(0)));
674        assert_eq!(engine.world_tick(1), Some(TickId(0)));
675    }
676
677    #[test]
678    fn step_and_observe_small_buffer_does_not_step() {
679        // Buffer too small should fail before advancing world state.
680        let spec = obs_spec_all_field0();
681        let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
682        let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
683
684        let commands = vec![vec![], vec![]];
685        let mut output = vec![0.0f32; 5]; // need 20, only 5
686        let mut mask = vec![0u8; 20];
687        let result = engine.step_and_observe(&commands, &mut output, &mut mask);
688        assert!(result.is_err());
689
690        // Worlds must still be at tick 0.
691        assert_eq!(engine.world_tick(0), Some(TickId(0)));
692        assert_eq!(engine.world_tick(1), Some(TickId(0)));
693    }
694
695    // ── Field schema validation ─────────────────────────────
696
697    #[test]
698    fn mismatched_field_schemas_rejected() {
699        // World 0 has 2 fields, world 1 has only 1. Obs spec references
700        // FieldId(1) which is missing in world 1. Construction must fail.
701        let spec = ObsSpec {
702            entries: vec![
703                ObsEntry {
704                    field_id: FieldId(0),
705                    region: ObsRegion::Fixed(RegionSpec::All),
706                    pool: None,
707                    transform: ObsTransform::Identity,
708                    dtype: ObsDtype::F32,
709                },
710                ObsEntry {
711                    field_id: FieldId(1),
712                    region: ObsRegion::Fixed(RegionSpec::All),
713                    pool: None,
714                    transform: ObsTransform::Identity,
715                    dtype: ObsDtype::F32,
716                },
717            ],
718        };
719
720        // World 0: has 2 fields (FieldId(0) and FieldId(1))
721        let config_two_fields = WorldConfig {
722            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
723            fields: vec![scalar_field("energy"), scalar_field("temp")],
724            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
725            dt: 0.1,
726            seed: 1,
727            ring_buffer_size: 8,
728            max_ingress_queue: 1024,
729            tick_rate_hz: None,
730            backoff: BackoffConfig::default(),
731        };
732
733        // World 1: has only 1 field (FieldId(0)), missing FieldId(1)
734        let config_one_field = WorldConfig {
735            space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
736            fields: vec![scalar_field("energy")],
737            propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
738            dt: 0.1,
739            seed: 2,
740            ring_buffer_size: 8,
741            max_ingress_queue: 1024,
742            tick_rate_hz: None,
743            backoff: BackoffConfig::default(),
744        };
745
746        let result = BatchedEngine::new(vec![config_two_fields, config_one_field], Some(&spec));
747        match result {
748            Err(e) => {
749                let msg = format!("{e}");
750                assert!(
751                    msg.contains("field") && msg.contains("missing"),
752                    "error should mention missing field, got: {msg}"
753                );
754            }
755            Ok(_) => panic!("expected error for mismatched field schemas"),
756        }
757    }
758}