Skip to main content

cougr_core/scheduler/
mod.rs

1mod types;
2
3pub use types::{ScheduleError, ScheduleStage, SystemConfig};
4
5use crate::simple_world::SimpleWorld;
6use crate::system::{AppSystem, SimpleSystem, SystemContext, SystemSpec, WorldSystem};
7use alloc::boxed::Box;
8use alloc::string::{String, ToString};
9use alloc::vec::Vec;
10use soroban_sdk::Env;
11
12struct SimpleSystemEntry {
13    name: String,
14    config: SystemConfig,
15    system: Box<dyn SimpleSystem>,
16}
17
18/// Scheduler for the Soroban-first `SimpleWorld` runtime.
19///
20/// Systems can be grouped into stages and ordered relative to each other using
21/// `SystemConfig::before()` and `SystemConfig::after()`. Each system receives
22/// a deferred `CommandQueue` through `SystemContext`; queued commands are
23/// applied after the system finishes.
24///
25/// # Example
26/// ```
27/// use cougr_core::scheduler::{ScheduleStage, SimpleScheduler, SystemConfig};
28/// use cougr_core::simple_world::SimpleWorld;
29/// use soroban_sdk::{symbol_short, Bytes, Env};
30///
31/// fn physics_system(world: &mut SimpleWorld, env: &Env) {
32///     let entity = world.spawn_entity();
33///     world.add_component(entity, symbol_short!("physics"), Bytes::new(env));
34/// }
35/// fn scoring_system(world: &mut SimpleWorld, env: &Env) {
36///     let entity = world.spawn_entity();
37///     world.add_component(entity, symbol_short!("scoring"), Bytes::new(env));
38/// }
39///
40/// let env = Env::default();
41/// let mut world = SimpleWorld::new(&env);
42/// let mut scheduler = SimpleScheduler::new();
43/// scheduler.add_system("physics", physics_system);
44/// scheduler.add_system_with_config(
45///     "scoring",
46///     scoring_system,
47///     SystemConfig::new()
48///         .in_stage(ScheduleStage::Update)
49///         .after("physics"),
50/// );
51/// scheduler.run_all(&mut world, &env).unwrap();
52/// assert_eq!(scheduler.system_count(), 2);
53/// ```
54pub struct SimpleScheduler {
55    systems: Vec<SimpleSystemEntry>,
56}
57
58/// Group of one or more runtime systems that can be registered together.
59pub trait SystemGroup {
60    fn register(self, scheduler: &mut SimpleScheduler);
61    fn register_in_stage(self, scheduler: &mut SimpleScheduler, stage: ScheduleStage);
62}
63
64impl<S> SystemGroup for SystemSpec<S>
65where
66    S: AppSystem + 'static,
67{
68    fn register(self, scheduler: &mut SimpleScheduler) {
69        scheduler.add_system_spec(self);
70    }
71
72    fn register_in_stage(self, scheduler: &mut SimpleScheduler, stage: ScheduleStage) {
73        scheduler.add_system_spec(self.in_stage(stage));
74    }
75}
76
77impl<A: SystemGroup, B: SystemGroup> SystemGroup for (A, B) {
78    fn register(self, scheduler: &mut SimpleScheduler) {
79        self.0.register(scheduler);
80        self.1.register(scheduler);
81    }
82
83    fn register_in_stage(self, scheduler: &mut SimpleScheduler, stage: ScheduleStage) {
84        self.0.register_in_stage(scheduler, stage);
85        self.1.register_in_stage(scheduler, stage);
86    }
87}
88
89impl<A: SystemGroup, B: SystemGroup, C: SystemGroup> SystemGroup for (A, B, C) {
90    fn register(self, scheduler: &mut SimpleScheduler) {
91        self.0.register(scheduler);
92        self.1.register(scheduler);
93        self.2.register(scheduler);
94    }
95
96    fn register_in_stage(self, scheduler: &mut SimpleScheduler, stage: ScheduleStage) {
97        self.0.register_in_stage(scheduler, stage);
98        self.1.register_in_stage(scheduler, stage);
99        self.2.register_in_stage(scheduler, stage);
100    }
101}
102
103impl SimpleScheduler {
104    pub fn new() -> Self {
105        Self {
106            systems: Vec::new(),
107        }
108    }
109
110    /// Add a world/env system to the `Update` stage.
111    pub fn add_system<F>(&mut self, name: &'static str, system: F)
112    where
113        F: FnMut(&mut SimpleWorld, &Env) + 'static,
114    {
115        self.add_system_with_config(name, system, SystemConfig::default());
116    }
117
118    /// Add a world/env system with explicit scheduling rules.
119    pub fn add_system_with_config<F>(&mut self, name: &'static str, system: F, config: SystemConfig)
120    where
121        F: FnMut(&mut SimpleWorld, &Env) + 'static,
122    {
123        self.push_entry(name, config, Box::new(WorldSystem::new(system)));
124    }
125
126    /// Add a world/env system directly to a specific stage.
127    pub fn add_system_in_stage<F>(&mut self, stage: ScheduleStage, name: &'static str, system: F)
128    where
129        F: FnMut(&mut SimpleWorld, &Env) + 'static,
130    {
131        self.add_system_with_config(name, system, SystemConfig::new().in_stage(stage));
132    }
133
134    /// Add a context-aware system to the `Update` stage.
135    pub fn add_context_system<F>(&mut self, name: &'static str, system: F)
136    where
137        F: for<'w, 'e, 'c> FnMut(&mut SystemContext<'w, 'e, 'c>) + 'static,
138    {
139        self.add_context_system_with_config(name, system, SystemConfig::default());
140    }
141
142    /// Add a context-aware system with explicit scheduling rules.
143    pub fn add_context_system_with_config<F>(
144        &mut self,
145        name: &'static str,
146        system: F,
147        config: SystemConfig,
148    ) where
149        F: for<'w, 'e, 'c> FnMut(&mut SystemContext<'w, 'e, 'c>) + 'static,
150    {
151        self.push_entry(
152            name,
153            config,
154            Box::new(crate::system::ContextSystem::new(system)),
155        );
156    }
157
158    /// Add a context-aware system directly to a specific stage.
159    pub fn add_context_system_in_stage<F>(
160        &mut self,
161        stage: ScheduleStage,
162        name: &'static str,
163        system: F,
164    ) where
165        F: for<'w, 'e, 'c> FnMut(&mut SystemContext<'w, 'e, 'c>) + 'static,
166    {
167        self.add_context_system_with_config(name, system, SystemConfig::new().in_stage(stage));
168    }
169
170    /// Add any pre-built runtime system to the default `Update` stage.
171    pub fn add_simple_system<S>(&mut self, name: &'static str, system: S)
172    where
173        S: AppSystem + 'static,
174    {
175        self.add_simple_system_with_config(name, system, SystemConfig::default());
176    }
177
178    /// Add any pre-built runtime system with explicit scheduling rules.
179    pub fn add_simple_system_with_config<S>(
180        &mut self,
181        name: &'static str,
182        system: S,
183        config: SystemConfig,
184    ) where
185        S: AppSystem + 'static,
186    {
187        self.push_entry(name, config, Box::new(system));
188    }
189
190    /// Add any pre-built runtime system directly to a specific stage.
191    pub fn add_simple_system_in_stage<S>(
192        &mut self,
193        stage: ScheduleStage,
194        name: &'static str,
195        system: S,
196    ) where
197        S: AppSystem + 'static,
198    {
199        self.add_simple_system_with_config(name, system, SystemConfig::new().in_stage(stage));
200    }
201
202    /// Add a declarative runtime system spec.
203    pub fn add_system_spec<S>(&mut self, spec: SystemSpec<S>)
204    where
205        S: AppSystem + 'static,
206    {
207        let (name, config, system) = spec.into_parts();
208        self.push_entry(name, config, Box::new(system));
209    }
210
211    /// Add one or more runtime systems using the modern declarative API.
212    pub fn add_systems<G>(&mut self, group: G)
213    where
214        G: SystemGroup,
215    {
216        group.register(self);
217    }
218
219    /// Add one or more runtime systems while forcing them into a stage.
220    pub fn add_systems_in_stage<G>(&mut self, stage: ScheduleStage, group: G)
221    where
222        G: SystemGroup,
223    {
224        group.register_in_stage(self, stage);
225    }
226
227    /// Update the scheduling config for an already-registered system.
228    pub fn configure_system(
229        &mut self,
230        name: &str,
231        config: SystemConfig,
232    ) -> Result<(), ScheduleError> {
233        for entry in &mut self.systems {
234            if entry.name == name {
235                entry.config = config;
236                return Ok(());
237            }
238        }
239
240        Err(ScheduleError::MissingDependency {
241            system: name.to_string(),
242            dependency: name.to_string(),
243        })
244    }
245
246    fn push_entry(
247        &mut self,
248        name: &'static str,
249        config: SystemConfig,
250        system: Box<dyn SimpleSystem>,
251    ) {
252        self.systems.push(SimpleSystemEntry {
253            name: name.to_string(),
254            config,
255            system,
256        });
257    }
258
259    /// Validate and execute the full schedule.
260    pub fn run_all(&mut self, world: &mut SimpleWorld, env: &Env) -> Result<(), ScheduleError> {
261        for stage in ScheduleStage::ordered() {
262            self.run_stage(stage, world, env)?;
263        }
264        Ok(())
265    }
266
267    /// Validate and execute only a single stage.
268    pub fn run_stage(
269        &mut self,
270        stage: ScheduleStage,
271        world: &mut SimpleWorld,
272        env: &Env,
273    ) -> Result<(), ScheduleError> {
274        let plan = self.execution_plan_for_stage(stage)?;
275        for index in plan {
276            let mut commands = crate::commands::CommandQueue::new();
277            let entry = &mut self.systems[index];
278            let mut context = SystemContext::new(world, env, &mut commands);
279            entry.system.run(&mut context);
280            commands.apply(world);
281        }
282        Ok(())
283    }
284
285    fn execution_plan_for_stage(&self, stage: ScheduleStage) -> Result<Vec<usize>, ScheduleError> {
286        self.validate_unique_names()?;
287        let mut stage_indexes = Vec::new();
288        for index in 0..self.systems.len() {
289            if self.systems[index].config.stage() == stage {
290                stage_indexes.push(index);
291            }
292        }
293
294        self.validate_stage_dependencies(stage, &stage_indexes)?;
295        self.topological_order(stage, &stage_indexes)
296    }
297
298    fn validate_unique_names(&self) -> Result<(), ScheduleError> {
299        for left in 0..self.systems.len() {
300            for right in (left + 1)..self.systems.len() {
301                if self.systems[left].name == self.systems[right].name {
302                    return Err(ScheduleError::DuplicateSystem(
303                        self.systems[left].name.clone(),
304                    ));
305                }
306            }
307        }
308        Ok(())
309    }
310
311    fn validate_stage_dependencies(
312        &self,
313        stage: ScheduleStage,
314        stage_indexes: &[usize],
315    ) -> Result<(), ScheduleError> {
316        for &index in stage_indexes {
317            let entry = &self.systems[index];
318
319            for dependency in entry.config.after_dependencies() {
320                let dependency_index = self.find_system_index(dependency).ok_or_else(|| {
321                    ScheduleError::MissingDependency {
322                        system: entry.name.clone(),
323                        dependency: dependency.clone(),
324                    }
325                })?;
326                let dependency_stage = self.systems[dependency_index].config.stage();
327                if dependency_stage != stage {
328                    return Err(ScheduleError::CrossStageDependency {
329                        system: entry.name.clone(),
330                        dependency: dependency.clone(),
331                        system_stage: stage,
332                        dependency_stage,
333                    });
334                }
335            }
336
337            for dependency in entry.config.before_dependencies() {
338                let dependency_index = self.find_system_index(dependency).ok_or_else(|| {
339                    ScheduleError::MissingDependency {
340                        system: entry.name.clone(),
341                        dependency: dependency.clone(),
342                    }
343                })?;
344                let dependency_stage = self.systems[dependency_index].config.stage();
345                if dependency_stage != stage {
346                    return Err(ScheduleError::CrossStageDependency {
347                        system: entry.name.clone(),
348                        dependency: dependency.clone(),
349                        system_stage: stage,
350                        dependency_stage,
351                    });
352                }
353            }
354
355            for set in entry.config.after_set_dependencies() {
356                if !self.stage_has_set(stage_indexes, set) {
357                    return Err(ScheduleError::MissingSet {
358                        system: entry.name.clone(),
359                        set: set.clone(),
360                    });
361                }
362            }
363
364            for set in entry.config.before_set_dependencies() {
365                if !self.stage_has_set(stage_indexes, set) {
366                    return Err(ScheduleError::MissingSet {
367                        system: entry.name.clone(),
368                        set: set.clone(),
369                    });
370                }
371            }
372        }
373
374        Ok(())
375    }
376
377    fn topological_order(
378        &self,
379        stage: ScheduleStage,
380        stage_indexes: &[usize],
381    ) -> Result<Vec<usize>, ScheduleError> {
382        let mut remaining = Vec::new();
383        let mut indegree = Vec::new();
384        let mut outgoing: Vec<Vec<usize>> = Vec::new();
385
386        for &system_index in stage_indexes {
387            remaining.push(system_index);
388            indegree.push(0usize);
389            outgoing.push(Vec::new());
390        }
391
392        for i in 0..stage_indexes.len() {
393            let source_index = stage_indexes[i];
394            let source = &self.systems[source_index];
395
396            for dependency in source.config.before_dependencies() {
397                let target_local = self.find_stage_local_index(stage_indexes, dependency);
398                if let Some(target_local) = target_local {
399                    outgoing[i].push(target_local);
400                    indegree[target_local] += 1;
401                }
402            }
403
404            for dependency in source.config.after_dependencies() {
405                let dependency_local = self.find_stage_local_index(stage_indexes, dependency);
406                if let Some(dependency_local) = dependency_local {
407                    outgoing[dependency_local].push(i);
408                    indegree[i] += 1;
409                }
410            }
411
412            for set in source.config.before_set_dependencies() {
413                for target_local in self.find_stage_set_members(stage_indexes, set) {
414                    if target_local != i {
415                        outgoing[i].push(target_local);
416                        indegree[target_local] += 1;
417                    }
418                }
419            }
420
421            for set in source.config.after_set_dependencies() {
422                for dependency_local in self.find_stage_set_members(stage_indexes, set) {
423                    if dependency_local != i {
424                        outgoing[dependency_local].push(i);
425                        indegree[i] += 1;
426                    }
427                }
428            }
429        }
430
431        let mut queue = Vec::new();
432        for (i, degree) in indegree.iter().enumerate() {
433            if *degree == 0 {
434                queue.push(i);
435            }
436        }
437
438        let mut ordered = Vec::new();
439        while !queue.is_empty() {
440            let local_index = queue.remove(0);
441            ordered.push(stage_indexes[local_index]);
442
443            for &target_local in outgoing[local_index].iter() {
444                indegree[target_local] -= 1;
445                if indegree[target_local] == 0 {
446                    queue.push(target_local);
447                }
448            }
449        }
450
451        if ordered.len() != stage_indexes.len() {
452            let mut names = Vec::new();
453            for &index in stage_indexes {
454                names.push(self.systems[index].name.clone());
455            }
456            return Err(ScheduleError::DependencyCycle {
457                stage,
458                systems: names,
459            });
460        }
461
462        Ok(ordered)
463    }
464
465    fn find_system_index(&self, name: &str) -> Option<usize> {
466        (0..self.systems.len()).find(|&index| self.systems[index].name == name)
467    }
468
469    fn find_stage_local_index(&self, stage_indexes: &[usize], name: &str) -> Option<usize> {
470        for (local_index, system_index) in stage_indexes.iter().enumerate() {
471            if self.systems[*system_index].name == name {
472                return Some(local_index);
473            }
474        }
475        None
476    }
477
478    fn stage_has_set(&self, stage_indexes: &[usize], set: &str) -> bool {
479        for &system_index in stage_indexes {
480            if self.systems[system_index].config.set_name() == Some(set) {
481                return true;
482            }
483        }
484        false
485    }
486
487    fn find_stage_set_members(&self, stage_indexes: &[usize], set: &str) -> Vec<usize> {
488        let mut members = Vec::new();
489        for (local_index, system_index) in stage_indexes.iter().enumerate() {
490            if self.systems[*system_index].config.set_name() == Some(set) {
491                members.push(local_index);
492            }
493        }
494        members
495    }
496
497    /// Returns the number of registered systems.
498    pub fn system_count(&self) -> usize {
499        self.systems.len()
500    }
501
502    /// Returns the system names in registration order.
503    pub fn system_names(&self) -> Vec<&str> {
504        self.systems
505            .iter()
506            .map(|entry| entry.name.as_str())
507            .collect()
508    }
509
510    /// Returns the system names assigned to a given stage in execution order.
511    pub fn stage_system_names(&self, stage: ScheduleStage) -> Result<Vec<&str>, ScheduleError> {
512        let mut names = Vec::new();
513        for index in self.execution_plan_for_stage(stage)? {
514            names.push(self.systems[index].name.as_str());
515        }
516        Ok(names)
517    }
518}
519
520impl Default for SimpleScheduler {
521    fn default() -> Self {
522        Self::new()
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use crate::system::{named_context_system, named_system};
530    use soroban_sdk::{symbol_short, Bytes, Env};
531
532    #[test]
533    fn test_simple_scheduler_empty() {
534        let mut scheduler = SimpleScheduler::new();
535        let env = Env::default();
536        let mut world = SimpleWorld::new(&env);
537        scheduler.run_all(&mut world, &env).unwrap();
538        assert_eq!(scheduler.system_count(), 0);
539    }
540
541    fn test_system_a(world: &mut SimpleWorld, env: &Env) {
542        let e1 = world.spawn_entity();
543        let data = Bytes::from_array(env, &[0xAA]);
544        world.add_component(e1, symbol_short!("sys_a"), data);
545    }
546
547    fn test_system_b(world: &mut SimpleWorld, env: &Env) {
548        let e2 = world.spawn_entity();
549        let data = Bytes::from_array(env, &[0xBB]);
550        world.add_component(e2, symbol_short!("sys_b"), data);
551    }
552
553    #[test]
554    fn test_simple_scheduler_execution_order() {
555        let mut scheduler = SimpleScheduler::new();
556        scheduler.add_system("system_a", test_system_a);
557        scheduler.add_system("system_b", test_system_b);
558        assert_eq!(scheduler.system_count(), 2);
559
560        let env = Env::default();
561        let mut world = SimpleWorld::new(&env);
562        scheduler.run_all(&mut world, &env).unwrap();
563
564        assert!(world.has_component(1, &symbol_short!("sys_a")));
565        assert!(world.has_component(2, &symbol_short!("sys_b")));
566    }
567
568    #[test]
569    fn test_simple_scheduler_names() {
570        let mut scheduler = SimpleScheduler::new();
571        scheduler.add_system("physics", test_system_a);
572        scheduler.add_system("scoring", test_system_b);
573
574        let names = scheduler.system_names();
575        assert_eq!(names, alloc::vec!["physics", "scoring"]);
576    }
577
578    #[test]
579    fn test_stage_ordering_and_dependencies() {
580        let mut scheduler = SimpleScheduler::new();
581        scheduler.add_system_with_config(
582            "cleanup",
583            test_system_b,
584            SystemConfig::new().in_stage(ScheduleStage::Cleanup),
585        );
586        scheduler.add_system_with_config(
587            "physics",
588            test_system_a,
589            SystemConfig::new()
590                .in_stage(ScheduleStage::Update)
591                .before("scoring"),
592        );
593        scheduler.add_system_with_config(
594            "scoring",
595            test_system_b,
596            SystemConfig::new()
597                .in_stage(ScheduleStage::Update)
598                .after("physics"),
599        );
600
601        assert_eq!(
602            scheduler.stage_system_names(ScheduleStage::Update).unwrap(),
603            alloc::vec!["physics", "scoring"]
604        );
605        assert_eq!(
606            scheduler
607                .stage_system_names(ScheduleStage::Cleanup)
608                .unwrap(),
609            alloc::vec!["cleanup"]
610        );
611    }
612
613    #[test]
614    fn test_context_system_applies_deferred_commands() {
615        let env = Env::default();
616        let mut world = SimpleWorld::new(&env);
617        let mut scheduler = SimpleScheduler::new();
618        scheduler.add_context_system("spawn_marker", |context| {
619            context.commands().spawn();
620        });
621
622        scheduler.run_all(&mut world, &env).unwrap();
623        assert_eq!(world.next_entity_id, 2);
624    }
625
626    #[test]
627    fn test_detects_cross_stage_dependency() {
628        let mut scheduler = SimpleScheduler::new();
629        scheduler.add_system_with_config(
630            "startup",
631            test_system_a,
632            SystemConfig::new().in_stage(ScheduleStage::Startup),
633        );
634        scheduler.add_system_with_config(
635            "update",
636            test_system_b,
637            SystemConfig::new()
638                .in_stage(ScheduleStage::Update)
639                .after("startup"),
640        );
641
642        let err = scheduler
643            .stage_system_names(ScheduleStage::Update)
644            .unwrap_err();
645        assert!(matches!(err, ScheduleError::CrossStageDependency { .. }));
646    }
647
648    #[test]
649    fn test_set_ordering() {
650        let mut scheduler = SimpleScheduler::new();
651        scheduler.add_system_with_config(
652            "apply_input",
653            test_system_a,
654            SystemConfig::new()
655                .in_stage(ScheduleStage::Update)
656                .in_set("input"),
657        );
658        scheduler.add_system_with_config(
659            "physics",
660            test_system_b,
661            SystemConfig::new()
662                .in_stage(ScheduleStage::Update)
663                .in_set("simulation")
664                .after_set("input"),
665        );
666
667        assert_eq!(
668            scheduler.stage_system_names(ScheduleStage::Update).unwrap(),
669            alloc::vec!["apply_input", "physics"]
670        );
671    }
672
673    #[test]
674    fn test_missing_set_dependency() {
675        let mut scheduler = SimpleScheduler::new();
676        scheduler.add_system_with_config(
677            "physics",
678            test_system_a,
679            SystemConfig::new()
680                .in_stage(ScheduleStage::Update)
681                .after_set("missing"),
682        );
683
684        let err = scheduler
685            .stage_system_names(ScheduleStage::Update)
686            .unwrap_err();
687        assert!(matches!(err, ScheduleError::MissingSet { .. }));
688    }
689
690    #[test]
691    fn test_grouped_system_registration() {
692        let env = Env::default();
693        let mut world = SimpleWorld::new(&env);
694        let mut scheduler = SimpleScheduler::new();
695        scheduler.add_systems((
696            named_system("spawn_a", test_system_a).in_stage(ScheduleStage::Startup),
697            named_system("spawn_b", test_system_b).in_stage(ScheduleStage::Update),
698        ));
699
700        assert_eq!(
701            scheduler
702                .stage_system_names(ScheduleStage::Startup)
703                .unwrap(),
704            alloc::vec!["spawn_a"]
705        );
706        assert_eq!(
707            scheduler.stage_system_names(ScheduleStage::Update).unwrap(),
708            alloc::vec!["spawn_b"]
709        );
710
711        scheduler.run_all(&mut world, &env).unwrap();
712        assert!(world.has_component(1, &symbol_short!("sys_a")));
713        assert!(world.has_component(2, &symbol_short!("sys_b")));
714    }
715
716    #[test]
717    fn test_simple_scheduler_execution_order_with_context_and_world_systems() {
718        let env = Env::default();
719        let mut world = SimpleWorld::new(&env);
720        let mut scheduler = SimpleScheduler::new();
721
722        scheduler.add_systems((
723            named_system("spawn_entity", |world: &mut SimpleWorld, env: &Env| {
724                let entity = world.spawn_entity();
725                world.add_component(entity, symbol_short!("seed"), Bytes::from_array(env, &[1]));
726            })
727            .in_stage(ScheduleStage::Update),
728            named_context_system("mark_seed", |context| {
729                let entities = context
730                    .world()
731                    .get_entities_with_component(&symbol_short!("seed"), context.env());
732                let env = context.env().clone();
733                for i in 0..entities.len() {
734                    let entity = entities.get(i).unwrap();
735                    context.commands().add_component(
736                        entity,
737                        symbol_short!("grown"),
738                        Bytes::from_array(&env, &[1]),
739                    );
740                }
741            })
742            .in_stage(ScheduleStage::Update)
743            .after("spawn_entity"),
744        ));
745
746        scheduler.run_all(&mut world, &env).unwrap();
747        assert!(world.has_component(1, &symbol_short!("grown")));
748    }
749}