1use murk_core::command::Command;
19use murk_core::error::ObsError;
20use murk_core::id::TickId;
21use murk_core::traits::SnapshotAccess;
22use murk_obs::metadata::ObsMetadata;
23use murk_obs::plan::ObsPlan;
24use murk_obs::spec::ObsSpec;
25
26use crate::config::{ConfigError, WorldConfig};
27use crate::lockstep::LockstepWorld;
28use crate::metrics::StepMetrics;
29use crate::tick::TickError;
30
31#[derive(Debug, PartialEq)]
35pub enum BatchError {
36 Step {
38 world_index: usize,
40 error: TickError,
42 },
43 Observe(ObsError),
45 Config(ConfigError),
47 InvalidIndex {
49 world_index: usize,
51 num_worlds: usize,
53 },
54 NoObsPlan,
56 InvalidArgument {
58 reason: String,
60 },
61}
62
63impl std::fmt::Display for BatchError {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 BatchError::Step { world_index, error } => {
67 write!(f, "world {world_index}: step failed: {error:?}")
68 }
69 BatchError::Observe(e) => write!(f, "observe failed: {e:?}"),
70 BatchError::Config(e) => write!(f, "config error: {e:?}"),
71 BatchError::InvalidIndex {
72 world_index,
73 num_worlds,
74 } => write!(
75 f,
76 "world index {world_index} out of range (num_worlds={num_worlds})"
77 ),
78 BatchError::NoObsPlan => write!(f, "no observation plan compiled"),
79 BatchError::InvalidArgument { reason } => {
80 write!(f, "invalid argument: {reason}")
81 }
82 }
83 }
84}
85
86impl std::error::Error for BatchError {}
87
88pub struct BatchResult {
92 pub tick_ids: Vec<TickId>,
94 pub metrics: Vec<StepMetrics>,
96}
97
98pub struct BatchedEngine {
110 worlds: Vec<LockstepWorld>,
111 obs_plan: Option<ObsPlan>,
112 obs_output_len: usize,
113 obs_mask_len: usize,
114}
115
116impl BatchedEngine {
117 pub fn new(configs: Vec<WorldConfig>, obs_spec: Option<&ObsSpec>) -> Result<Self, BatchError> {
128 if configs.is_empty() {
129 return Err(BatchError::InvalidArgument {
130 reason: "BatchedEngine requires at least one world config".into(),
131 });
132 }
133
134 let mut worlds = Vec::with_capacity(configs.len());
135 for config in configs {
136 let world = LockstepWorld::new(config).map_err(BatchError::Config)?;
137 worlds.push(world);
138 }
139
140 let ref_space = worlds[0].space();
145 for (i, world) in worlds.iter().enumerate().skip(1) {
146 if !ref_space.topology_eq(world.space()) {
147 return Err(BatchError::InvalidArgument {
148 reason: format!(
149 "world 0 and world {i} have incompatible space topologies; \
150 all worlds in a batch must use the same topology"
151 ),
152 });
153 }
154 }
155
156 let (obs_plan, obs_output_len, obs_mask_len) = match obs_spec {
158 Some(spec) => {
159 let result =
160 ObsPlan::compile(spec, worlds[0].space()).map_err(BatchError::Observe)?;
161
162 let ref_snap = worlds[0].snapshot();
167 for entry in &spec.entries {
168 let fid = entry.field_id;
169 let ref_len = ref_snap.read_field(fid).map(|d| d.len());
170 for (i, world) in worlds.iter().enumerate().skip(1) {
171 let snap = world.snapshot();
172 let other_len = snap.read_field(fid).map(|d| d.len());
173 if other_len != ref_len {
174 return Err(BatchError::InvalidArgument {
175 reason: format!(
176 "world {i} field {fid:?}: {} elements, \
177 world 0 has {} elements; \
178 all worlds must share the same field schema",
179 other_len
180 .map(|n| n.to_string())
181 .unwrap_or_else(|| "missing".into()),
182 ref_len
183 .map(|n| n.to_string())
184 .unwrap_or_else(|| "missing".into()),
185 ),
186 });
187 }
188 }
189 }
190
191 (Some(result.plan), result.output_len, result.mask_len)
192 }
193 None => (None, 0, 0),
194 };
195
196 Ok(BatchedEngine {
197 worlds,
198 obs_plan,
199 obs_output_len,
200 obs_mask_len,
201 })
202 }
203
204 pub fn step_and_observe(
212 &mut self,
213 commands: &[Vec<Command>],
214 output: &mut [f32],
215 mask: &mut [u8],
216 ) -> Result<BatchResult, BatchError> {
217 self.validate_observe_buffers(output, mask)?;
222
223 let result = self.step_all(commands)?;
224
225 self.observe_all_inner(output, mask)?;
227
228 Ok(result)
229 }
230
231 pub fn step_all(&mut self, commands: &[Vec<Command>]) -> Result<BatchResult, BatchError> {
233 let n = self.worlds.len();
234 if commands.len() != n {
235 return Err(BatchError::InvalidArgument {
236 reason: format!("commands has {} entries, expected {n}", commands.len()),
237 });
238 }
239
240 let mut tick_ids = Vec::with_capacity(n);
241 let mut metrics = Vec::with_capacity(n);
242
243 for (idx, world) in self.worlds.iter_mut().enumerate() {
244 let result = world
245 .step_sync(commands[idx].clone())
246 .map_err(|e| BatchError::Step {
247 world_index: idx,
248 error: e,
249 })?;
250 tick_ids.push(result.snapshot.tick_id());
251 metrics.push(result.metrics);
252 }
253
254 Ok(BatchResult { tick_ids, metrics })
255 }
256
257 pub fn observe_all(
261 &self,
262 output: &mut [f32],
263 mask: &mut [u8],
264 ) -> Result<Vec<ObsMetadata>, BatchError> {
265 self.observe_all_inner(output, mask)
266 }
267
268 fn observe_all_inner(
270 &self,
271 output: &mut [f32],
272 mask: &mut [u8],
273 ) -> Result<Vec<ObsMetadata>, BatchError> {
274 let plan = self.obs_plan.as_ref().ok_or(BatchError::NoObsPlan)?;
275
276 let snapshots: Vec<_> = self.worlds.iter().map(|w| w.snapshot()).collect();
277 let snap_refs: Vec<&dyn SnapshotAccess> =
278 snapshots.iter().map(|s| s as &dyn SnapshotAccess).collect();
279
280 plan.execute_batch(&snap_refs, None, output, mask)
281 .map_err(BatchError::Observe)
282 }
283
284 fn validate_observe_buffers(&self, output: &[f32], mask: &[u8]) -> Result<(), BatchError> {
288 if self.obs_plan.is_none() {
289 return Err(BatchError::NoObsPlan);
290 }
291 let n = self.worlds.len();
292 let expected_out = n * self.obs_output_len;
293 let expected_mask = n * self.obs_mask_len;
294 if output.len() < expected_out {
295 return Err(BatchError::InvalidArgument {
296 reason: format!("output buffer too small: {} < {expected_out}", output.len()),
297 });
298 }
299 if mask.len() < expected_mask {
300 return Err(BatchError::InvalidArgument {
301 reason: format!("mask buffer too small: {} < {expected_mask}", mask.len()),
302 });
303 }
304 Ok(())
305 }
306
307 pub fn reset_world(&mut self, idx: usize, seed: u64) -> Result<(), BatchError> {
309 let n = self.worlds.len();
310 let world = self.worlds.get_mut(idx).ok_or(BatchError::InvalidIndex {
311 world_index: idx,
312 num_worlds: n,
313 })?;
314 world.reset(seed).map_err(BatchError::Config)?;
315 Ok(())
316 }
317
318 pub fn reset_all(&mut self, seeds: &[u64]) -> Result<(), BatchError> {
320 let n = self.worlds.len();
321 if seeds.len() != n {
322 return Err(BatchError::InvalidArgument {
323 reason: format!("seeds has {} entries, expected {n}", seeds.len()),
324 });
325 }
326 for (idx, world) in self.worlds.iter_mut().enumerate() {
327 world.reset(seeds[idx]).map_err(BatchError::Config)?;
328 }
329 Ok(())
330 }
331
332 pub fn num_worlds(&self) -> usize {
334 self.worlds.len()
335 }
336
337 pub fn obs_output_len(&self) -> usize {
339 self.obs_output_len
340 }
341
342 pub fn obs_mask_len(&self) -> usize {
344 self.obs_mask_len
345 }
346
347 pub fn world_tick(&self, idx: usize) -> Option<TickId> {
349 self.worlds.get(idx).map(|w| w.current_tick())
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use murk_core::id::FieldId;
357 use murk_core::traits::FieldReader;
358 use murk_obs::spec::{ObsDtype, ObsEntry, ObsRegion, ObsTransform};
359 use murk_space::{EdgeBehavior, Line1D, RegionSpec};
360 use murk_test_utils::ConstPropagator;
361
362 use crate::config::BackoffConfig;
363
364 fn scalar_field(name: &str) -> murk_core::FieldDef {
365 murk_core::FieldDef {
366 name: name.to_string(),
367 field_type: murk_core::FieldType::Scalar,
368 mutability: murk_core::FieldMutability::PerTick,
369 units: None,
370 bounds: None,
371 boundary_behavior: murk_core::BoundaryBehavior::Clamp,
372 }
373 }
374
375 fn make_config(seed: u64, value: f32) -> WorldConfig {
376 WorldConfig {
377 space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
378 fields: vec![scalar_field("energy")],
379 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), value))],
380 dt: 0.1,
381 seed,
382 ring_buffer_size: 8,
383 max_ingress_queue: 1024,
384 tick_rate_hz: None,
385 backoff: BackoffConfig::default(),
386 }
387 }
388
389 fn obs_spec_all_field0() -> ObsSpec {
390 ObsSpec {
391 entries: vec![ObsEntry {
392 field_id: FieldId(0),
393 region: ObsRegion::Fixed(RegionSpec::All),
394 pool: None,
395 transform: ObsTransform::Identity,
396 dtype: ObsDtype::F32,
397 }],
398 }
399 }
400
401 #[test]
404 fn new_single_world() {
405 let configs = vec![make_config(42, 1.0)];
406 let engine = BatchedEngine::new(configs, None).unwrap();
407 assert_eq!(engine.num_worlds(), 1);
408 assert_eq!(engine.obs_output_len(), 0);
409 assert_eq!(engine.obs_mask_len(), 0);
410 }
411
412 #[test]
413 fn new_four_worlds() {
414 let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
415 let engine = BatchedEngine::new(configs, None).unwrap();
416 assert_eq!(engine.num_worlds(), 4);
417 }
418
419 #[test]
420 fn new_zero_worlds_is_error() {
421 let result = BatchedEngine::new(vec![], None);
422 assert!(result.is_err());
423 }
424
425 #[test]
426 fn new_with_obs_spec() {
427 let configs = vec![make_config(42, 1.0)];
428 let spec = obs_spec_all_field0();
429 let engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
430 assert_eq!(engine.obs_output_len(), 10); assert_eq!(engine.obs_mask_len(), 10);
432 }
433
434 #[test]
437 fn batch_matches_independent_worlds() {
438 let spec = obs_spec_all_field0();
439
440 let configs = vec![make_config(42, 42.0), make_config(99, 42.0)];
442 let mut batched = BatchedEngine::new(configs, Some(&spec)).unwrap();
443 let n = batched.num_worlds();
444 let out_len = n * batched.obs_output_len();
445 let mask_len = n * batched.obs_mask_len();
446 let mut batch_output = vec![0.0f32; out_len];
447 let mut batch_mask = vec![0u8; mask_len];
448
449 let commands = vec![vec![], vec![]];
450 batched
451 .step_and_observe(&commands, &mut batch_output, &mut batch_mask)
452 .unwrap();
453
454 let mut w0 = LockstepWorld::new(make_config(42, 42.0)).unwrap();
456 let mut w1 = LockstepWorld::new(make_config(99, 42.0)).unwrap();
457 let r0 = w0.step_sync(vec![]).unwrap();
458 let r1 = w1.step_sync(vec![]).unwrap();
459
460 let d0 = r0.snapshot.read(FieldId(0)).unwrap();
461 let d1 = r1.snapshot.read(FieldId(0)).unwrap();
462
463 assert_eq!(&batch_output[..10], d0);
465 assert_eq!(&batch_output[10..20], d1);
466 }
467
468 #[test]
471 fn observation_filled_with_const_value() {
472 let spec = obs_spec_all_field0();
473 let configs = vec![
474 make_config(1, 42.0),
475 make_config(2, 42.0),
476 make_config(3, 42.0),
477 ];
478 let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
479
480 let commands = vec![vec![], vec![], vec![]];
481 let n = engine.num_worlds();
482 let mut output = vec![0.0f32; n * engine.obs_output_len()];
483 let mut mask = vec![0u8; n * engine.obs_mask_len()];
484 engine
485 .step_and_observe(&commands, &mut output, &mut mask)
486 .unwrap();
487
488 assert!(output.iter().all(|&v| v == 42.0));
490 assert!(mask.iter().all(|&m| m == 1));
491 }
492
493 #[test]
496 fn reset_single_world_preserves_others() {
497 let configs: Vec<_> = (0..4).map(|i| make_config(i, 1.0)).collect();
498 let mut engine = BatchedEngine::new(configs, None).unwrap();
499
500 let commands = vec![vec![]; 4];
502 engine.step_all(&commands).unwrap();
503 assert_eq!(engine.world_tick(0), Some(TickId(1)));
504 assert_eq!(engine.world_tick(3), Some(TickId(1)));
505
506 engine.reset_world(0, 999).unwrap();
508 assert_eq!(engine.world_tick(0), Some(TickId(0)));
509 assert_eq!(engine.world_tick(1), Some(TickId(1)));
510 assert_eq!(engine.world_tick(2), Some(TickId(1)));
511 assert_eq!(engine.world_tick(3), Some(TickId(1)));
512 }
513
514 #[test]
515 fn reset_all_resets_to_tick_zero() {
516 let configs: Vec<_> = (0..3).map(|i| make_config(i, 1.0)).collect();
517 let mut engine = BatchedEngine::new(configs, None).unwrap();
518
519 let commands = vec![vec![]; 3];
521 engine.step_all(&commands).unwrap();
522 engine.step_all(&commands).unwrap();
523
524 engine.reset_all(&[10, 20, 30]).unwrap();
525 for i in 0..3 {
526 assert_eq!(engine.world_tick(i), Some(TickId(0)));
527 }
528 }
529
530 #[test]
533 fn invalid_world_index_returns_error() {
534 let configs = vec![make_config(0, 1.0)];
535 let mut engine = BatchedEngine::new(configs, None).unwrap();
536
537 let result = engine.reset_world(5, 0);
538 assert!(matches!(result, Err(BatchError::InvalidIndex { .. })));
539 }
540
541 #[test]
542 fn wrong_command_count_returns_error() {
543 let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
544 let mut engine = BatchedEngine::new(configs, None).unwrap();
545
546 let result = engine.step_all(&[vec![]]); assert!(result.is_err());
548 }
549
550 #[test]
551 fn observe_without_plan_returns_error() {
552 let configs = vec![make_config(0, 1.0)];
553 let engine = BatchedEngine::new(configs, None).unwrap();
554
555 let mut output = vec![0.0f32; 10];
556 let mut mask = vec![0u8; 10];
557 let result = engine.observe_all(&mut output, &mut mask);
558 assert!(matches!(result, Err(BatchError::NoObsPlan)));
559 }
560
561 #[test]
564 fn observe_all_after_reset() {
565 let spec = obs_spec_all_field0();
566 let configs = vec![make_config(1, 42.0), make_config(2, 42.0)];
567 let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
568
569 let commands = vec![vec![], vec![]];
571 let n = engine.num_worlds();
572 let mut output = vec![0.0f32; n * engine.obs_output_len()];
573 let mut mask = vec![0u8; n * engine.obs_mask_len()];
574 engine
575 .step_and_observe(&commands, &mut output, &mut mask)
576 .unwrap();
577
578 engine.reset_all(&[10, 20]).unwrap();
580 let meta = engine.observe_all(&mut output, &mut mask).unwrap();
581 assert_eq!(meta.len(), 2);
582 assert_eq!(meta[0].tick_id, TickId(0));
583 assert_eq!(meta[1].tick_id, TickId(0));
584 }
585
586 #[test]
589 fn mixed_space_types_rejected() {
590 use murk_space::Ring1D;
591
592 let line_config = WorldConfig {
594 space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
595 fields: vec![scalar_field("energy")],
596 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
597 dt: 0.1,
598 seed: 1,
599 ring_buffer_size: 8,
600 max_ingress_queue: 1024,
601 tick_rate_hz: None,
602 backoff: BackoffConfig::default(),
603 };
604 let ring_config = WorldConfig {
605 space: Box::new(Ring1D::new(10).unwrap()),
606 fields: vec![scalar_field("energy")],
607 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
608 dt: 0.1,
609 seed: 2,
610 ring_buffer_size: 8,
611 max_ingress_queue: 1024,
612 tick_rate_hz: None,
613 backoff: BackoffConfig::default(),
614 };
615
616 let result = BatchedEngine::new(vec![line_config, ring_config], None);
617 match result {
618 Err(e) => {
619 let msg = format!("{e}");
620 assert!(msg.contains("incompatible space topologies"), "got: {msg}");
621 }
622 Ok(_) => panic!("expected error for mixed space types"),
623 }
624 }
625
626 #[test]
627 fn mixed_edge_behaviors_rejected() {
628 let absorb_config = WorldConfig {
631 space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
632 fields: vec![scalar_field("energy")],
633 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
634 dt: 0.1,
635 seed: 1,
636 ring_buffer_size: 8,
637 max_ingress_queue: 1024,
638 tick_rate_hz: None,
639 backoff: BackoffConfig::default(),
640 };
641 let wrap_config = WorldConfig {
642 space: Box::new(Line1D::new(10, EdgeBehavior::Wrap).unwrap()),
643 fields: vec![scalar_field("energy")],
644 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
645 dt: 0.1,
646 seed: 2,
647 ring_buffer_size: 8,
648 max_ingress_queue: 1024,
649 tick_rate_hz: None,
650 backoff: BackoffConfig::default(),
651 };
652
653 let result = BatchedEngine::new(vec![absorb_config, wrap_config], None);
654 assert!(result.is_err(), "expected error for mixed edge behaviors");
655 }
656
657 #[test]
660 fn step_and_observe_no_plan_does_not_step() {
661 let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
664 let mut engine = BatchedEngine::new(configs, None).unwrap();
665
666 let commands = vec![vec![], vec![]];
667 let mut output = vec![0.0f32; 20];
668 let mut mask = vec![0u8; 20];
669 let result = engine.step_and_observe(&commands, &mut output, &mut mask);
670 assert!(matches!(result, Err(BatchError::NoObsPlan)));
671
672 assert_eq!(engine.world_tick(0), Some(TickId(0)));
674 assert_eq!(engine.world_tick(1), Some(TickId(0)));
675 }
676
677 #[test]
678 fn step_and_observe_small_buffer_does_not_step() {
679 let spec = obs_spec_all_field0();
681 let configs = vec![make_config(0, 1.0), make_config(1, 1.0)];
682 let mut engine = BatchedEngine::new(configs, Some(&spec)).unwrap();
683
684 let commands = vec![vec![], vec![]];
685 let mut output = vec![0.0f32; 5]; let mut mask = vec![0u8; 20];
687 let result = engine.step_and_observe(&commands, &mut output, &mut mask);
688 assert!(result.is_err());
689
690 assert_eq!(engine.world_tick(0), Some(TickId(0)));
692 assert_eq!(engine.world_tick(1), Some(TickId(0)));
693 }
694
695 #[test]
698 fn mismatched_field_schemas_rejected() {
699 let spec = ObsSpec {
702 entries: vec![
703 ObsEntry {
704 field_id: FieldId(0),
705 region: ObsRegion::Fixed(RegionSpec::All),
706 pool: None,
707 transform: ObsTransform::Identity,
708 dtype: ObsDtype::F32,
709 },
710 ObsEntry {
711 field_id: FieldId(1),
712 region: ObsRegion::Fixed(RegionSpec::All),
713 pool: None,
714 transform: ObsTransform::Identity,
715 dtype: ObsDtype::F32,
716 },
717 ],
718 };
719
720 let config_two_fields = WorldConfig {
722 space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
723 fields: vec![scalar_field("energy"), scalar_field("temp")],
724 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
725 dt: 0.1,
726 seed: 1,
727 ring_buffer_size: 8,
728 max_ingress_queue: 1024,
729 tick_rate_hz: None,
730 backoff: BackoffConfig::default(),
731 };
732
733 let config_one_field = WorldConfig {
735 space: Box::new(Line1D::new(10, EdgeBehavior::Absorb).unwrap()),
736 fields: vec![scalar_field("energy")],
737 propagators: vec![Box::new(ConstPropagator::new("const", FieldId(0), 1.0))],
738 dt: 0.1,
739 seed: 2,
740 ring_buffer_size: 8,
741 max_ingress_queue: 1024,
742 tick_rate_hz: None,
743 backoff: BackoffConfig::default(),
744 };
745
746 let result = BatchedEngine::new(vec![config_two_fields, config_one_field], Some(&spec));
747 match result {
748 Err(e) => {
749 let msg = format!("{e}");
750 assert!(
751 msg.contains("field") && msg.contains("missing"),
752 "error should mention missing field, got: {msg}"
753 );
754 }
755 Ok(_) => panic!("expected error for mismatched field schemas"),
756 }
757 }
758}