Skip to main content

cu29_runtime/
curuntime.rs

1//! CuRuntime is the heart of what copper is running on the robot.
2//! It is exposed to the user via the `copper_runtime` macro injecting it as a field in their application struct.
3//!
4
5use crate::config::{ComponentConfig, CuDirection, DEFAULT_KEYFRAME_INTERVAL, Node};
6use crate::config::{CuConfig, CuGraph, NodeId, RuntimeConfig};
7use crate::copperlist::{CopperList, CopperListState, CuListZeroedInit, CuListsManager};
8use crate::cutask::{BincodeAdapter, Freezable};
9use crate::monitoring::{CuMonitor, build_monitor_topology};
10use crate::resource::ResourceManager;
11use cu29_clock::{ClockProvider, CuTime, RobotClock};
12use cu29_traits::CuResult;
13use cu29_traits::WriteStream;
14use cu29_traits::{CopperListTuple, CuError};
15
16#[cfg(target_os = "none")]
17#[allow(unused_imports)]
18use cu29_log::{ANONYMOUS, CuLogEntry, CuLogLevel};
19#[cfg(target_os = "none")]
20#[allow(unused_imports)]
21use cu29_log_derive::info;
22#[cfg(target_os = "none")]
23#[allow(unused_imports)]
24use cu29_log_runtime::log;
25#[cfg(all(target_os = "none", debug_assertions))]
26#[allow(unused_imports)]
27use cu29_log_runtime::log_debug_mode;
28#[cfg(target_os = "none")]
29#[allow(unused_imports)]
30use cu29_value::to_value;
31
32use alloc::boxed::Box;
33use alloc::collections::{BTreeSet, VecDeque};
34use alloc::format;
35use alloc::string::{String, ToString};
36use alloc::vec::Vec;
37use bincode::enc::EncoderImpl;
38use bincode::enc::write::{SizeWriter, SliceWriter};
39use bincode::error::EncodeError;
40use bincode::{Decode, Encode};
41use core::fmt::Result as FmtResult;
42use core::fmt::{Debug, Formatter};
43
44#[cfg(feature = "std")]
45use cu29_log_runtime::LoggerRuntime;
46#[cfg(feature = "std")]
47use cu29_unifiedlog::UnifiedLoggerWrite;
48#[cfg(feature = "std")]
49use std::sync::{Arc, Mutex};
50
51/// Just a simple struct to hold the various bits needed to run a Copper application.
52#[cfg(feature = "std")]
53pub struct CopperContext {
54    pub unified_logger: Arc<Mutex<UnifiedLoggerWrite>>,
55    pub logger_runtime: LoggerRuntime,
56    pub clock: RobotClock,
57}
58
59/// Manages the lifecycle of the copper lists and logging.
60pub struct CopperListsManager<P: CopperListTuple + Default, const NBCL: usize> {
61    pub inner: CuListsManager<P, NBCL>,
62    /// Logger for the copper lists (messages between tasks)
63    pub logger: Option<Box<dyn WriteStream<CopperList<P>>>>,
64    /// Last encoded size returned by logger.log
65    pub last_encoded_bytes: u64,
66}
67
68impl<P: CopperListTuple + Default, const NBCL: usize> CopperListsManager<P, NBCL> {
69    pub fn end_of_processing(&mut self, culistid: u32) -> CuResult<()> {
70        let mut is_top = true;
71        let mut nb_done = 0;
72        for cl in self.inner.iter_mut() {
73            if cl.id == culistid && cl.get_state() == CopperListState::Processing {
74                cl.change_state(CopperListState::DoneProcessing);
75            }
76            if is_top && cl.get_state() == CopperListState::DoneProcessing {
77                if let Some(logger) = &mut self.logger {
78                    cl.change_state(CopperListState::BeingSerialized);
79                    logger.log(cl)?;
80                    self.last_encoded_bytes = logger.last_log_bytes().unwrap_or(0) as u64;
81                }
82                cl.change_state(CopperListState::Free);
83                nb_done += 1;
84            } else {
85                is_top = false;
86            }
87        }
88        for _ in 0..nb_done {
89            let _ = self.inner.pop();
90        }
91        Ok(())
92    }
93
94    pub fn available_copper_lists(&self) -> usize {
95        NBCL - self.inner.len()
96    }
97}
98
99/// Manages the frozen tasks state and logging.
100pub struct KeyFramesManager {
101    /// Where the serialized tasks are stored following the wave of execution of a CL.
102    inner: KeyFrame,
103
104    /// Optional override for the timestamp to stamp the next keyframe (used by deterministic replay).
105    forced_timestamp: Option<CuTime>,
106
107    /// If set, reuse this keyframe verbatim (e.g., during replay) instead of re-freezing state.
108    locked: bool,
109
110    /// Logger for the state of the tasks (frozen tasks)
111    logger: Option<Box<dyn WriteStream<KeyFrame>>>,
112
113    /// Capture a keyframe only each...
114    keyframe_interval: u32,
115
116    /// Bytes written by the last keyframe log
117    pub last_encoded_bytes: u64,
118}
119
120impl KeyFramesManager {
121    fn is_keyframe(&self, culistid: u32) -> bool {
122        self.logger.is_some() && culistid.is_multiple_of(self.keyframe_interval)
123    }
124
125    pub fn reset(&mut self, culistid: u32, clock: &RobotClock) {
126        if self.is_keyframe(culistid) {
127            // If a recorded keyframe was preloaded for this CL, keep it as-is.
128            if self.locked && self.inner.culistid == culistid {
129                return;
130            }
131            let ts = self.forced_timestamp.take().unwrap_or_else(|| clock.now());
132            self.inner.reset(culistid, ts);
133            self.locked = false;
134        }
135    }
136
137    /// Force the timestamp of the next keyframe to a given value.
138    #[cfg(feature = "std")]
139    pub fn set_forced_timestamp(&mut self, ts: CuTime) {
140        self.forced_timestamp = Some(ts);
141    }
142
143    pub fn freeze_task(&mut self, culistid: u32, task: &impl Freezable) -> CuResult<usize> {
144        if self.is_keyframe(culistid) {
145            if self.locked {
146                // We are replaying a recorded keyframe verbatim; don't mutate it.
147                return Ok(0);
148            }
149            if self.inner.culistid != culistid {
150                return Err(CuError::from(format!(
151                    "Freezing task for culistid {} but current keyframe is {}",
152                    culistid, self.inner.culistid
153                )));
154            }
155            self.inner
156                .add_frozen_task(task)
157                .map_err(|e| CuError::from(format!("Failed to serialize task: {e}")))
158        } else {
159            Ok(0)
160        }
161    }
162
163    /// Generic helper to freeze any `Freezable` state (task or bridge) into the current keyframe.
164    pub fn freeze_any(&mut self, culistid: u32, item: &impl Freezable) -> CuResult<usize> {
165        self.freeze_task(culistid, item)
166    }
167
168    pub fn end_of_processing(&mut self, culistid: u32) -> CuResult<()> {
169        if self.is_keyframe(culistid) {
170            let logger = self.logger.as_mut().unwrap();
171            logger.log(&self.inner)?;
172            self.last_encoded_bytes = logger.last_log_bytes().unwrap_or(0) as u64;
173            // Clear the lock so the next CL can rebuild normally unless re-locked.
174            self.locked = false;
175            Ok(())
176        } else {
177            // Not a keyframe for this CL; ensure we don't carry stale sizes forward.
178            self.last_encoded_bytes = 0;
179            Ok(())
180        }
181    }
182
183    /// Preload a recorded keyframe so it is logged verbatim on the matching CL.
184    #[cfg(feature = "std")]
185    pub fn lock_keyframe(&mut self, keyframe: &KeyFrame) {
186        self.inner = keyframe.clone();
187        self.forced_timestamp = Some(keyframe.timestamp);
188        self.locked = true;
189    }
190}
191
192/// This is the main structure that will be injected as a member of the Application struct.
193/// CT is the tuple of all the tasks in order of execution.
194/// CL is the type of the copper list, representing the input/output messages for all the tasks.
195pub struct CuRuntime<CT, CB, P: CopperListTuple, M: CuMonitor, const NBCL: usize> {
196    /// The base clock the runtime will be using to record time.
197    pub clock: RobotClock, // TODO: remove public at some point
198
199    /// The tuple of all the tasks in order of execution.
200    pub tasks: CT,
201
202    /// Tuple of all instantiated bridges.
203    pub bridges: CB,
204
205    /// Resource registry kept alive for tasks borrowing shared handles.
206    pub resources: ResourceManager,
207
208    /// The runtime monitoring.
209    pub monitor: M,
210
211    /// The logger for the copper lists (messages between tasks)
212    pub copperlists_manager: CopperListsManager<P, NBCL>,
213
214    /// The logger for the state of the tasks (frozen tasks)
215    pub keyframes_manager: KeyFramesManager,
216
217    /// The runtime configuration controlling the behavior of the run loop
218    pub runtime_config: RuntimeConfig,
219}
220
221/// To be able to share the clock we make the runtime a clock provider.
222impl<CT, CB, P: CopperListTuple + CuListZeroedInit + Default, M: CuMonitor, const NBCL: usize>
223    ClockProvider for CuRuntime<CT, CB, P, M, NBCL>
224{
225    fn get_clock(&self) -> RobotClock {
226        self.clock.clone()
227    }
228}
229
230/// A KeyFrame is recording a snapshot of the tasks state before a given copperlist.
231/// It is a double encapsulation: this one recording the culistid and another even in
232/// bincode in the serialized_tasks.
233#[derive(Clone, Encode, Decode)]
234pub struct KeyFrame {
235    // This is the id of the copper list that this keyframe is associated with (recorded before the copperlist).
236    pub culistid: u32,
237    // This is the timestamp when the keyframe was created, using the robot clock.
238    pub timestamp: CuTime,
239    // This is the bincode representation of the tuple of all the tasks.
240    pub serialized_tasks: Vec<u8>,
241}
242
243impl KeyFrame {
244    fn new() -> Self {
245        KeyFrame {
246            culistid: 0,
247            timestamp: CuTime::default(),
248            serialized_tasks: Vec::new(),
249        }
250    }
251
252    /// This is to be able to avoid reallocations
253    fn reset(&mut self, culistid: u32, timestamp: CuTime) {
254        self.culistid = culistid;
255        self.timestamp = timestamp;
256        self.serialized_tasks.clear();
257    }
258
259    /// We need to be able to accumulate tasks to the serialization as they are executed after the step.
260    fn add_frozen_task(&mut self, task: &impl Freezable) -> Result<usize, EncodeError> {
261        let cfg = bincode::config::standard();
262        let mut sizer = EncoderImpl::<_, _>::new(SizeWriter::default(), cfg);
263        BincodeAdapter(task).encode(&mut sizer)?;
264        let need = sizer.into_writer().bytes_written as usize;
265
266        let start = self.serialized_tasks.len();
267        self.serialized_tasks.resize(start + need, 0);
268        let mut enc =
269            EncoderImpl::<_, _>::new(SliceWriter::new(&mut self.serialized_tasks[start..]), cfg);
270        BincodeAdapter(task).encode(&mut enc)?;
271        Ok(need)
272    }
273}
274
275impl<
276    CT,
277    CB,
278    P: CopperListTuple + CuListZeroedInit + Default + 'static,
279    M: CuMonitor,
280    const NBCL: usize,
281> CuRuntime<CT, CB, P, M, NBCL>
282{
283    // FIXME(gbin): this became REALLY ugly with no-std
284    #[allow(clippy::too_many_arguments)]
285    #[cfg(feature = "std")]
286    pub fn new(
287        clock: RobotClock,
288        config: &CuConfig,
289        mission: Option<&str>,
290        resources_instanciator: impl Fn(&CuConfig) -> CuResult<ResourceManager>,
291        tasks_instanciator: impl for<'c> Fn(
292            Vec<Option<&'c ComponentConfig>>,
293            &mut ResourceManager,
294        ) -> CuResult<CT>,
295        monitor_instanciator: impl Fn(&CuConfig) -> M,
296        bridges_instanciator: impl Fn(&CuConfig, &mut ResourceManager) -> CuResult<CB>,
297        copperlists_logger: impl WriteStream<CopperList<P>> + 'static,
298        keyframes_logger: impl WriteStream<KeyFrame> + 'static,
299    ) -> CuResult<Self> {
300        let resources = resources_instanciator(config)?;
301        Self::new_with_resources(
302            clock,
303            config,
304            mission,
305            resources,
306            tasks_instanciator,
307            monitor_instanciator,
308            bridges_instanciator,
309            copperlists_logger,
310            keyframes_logger,
311        )
312    }
313
314    #[allow(clippy::too_many_arguments)]
315    #[cfg(feature = "std")]
316    pub fn new_with_resources(
317        clock: RobotClock,
318        config: &CuConfig,
319        mission: Option<&str>,
320        mut resources: ResourceManager,
321        tasks_instanciator: impl for<'c> Fn(
322            Vec<Option<&'c ComponentConfig>>,
323            &mut ResourceManager,
324        ) -> CuResult<CT>,
325        monitor_instanciator: impl Fn(&CuConfig) -> M,
326        bridges_instanciator: impl Fn(&CuConfig, &mut ResourceManager) -> CuResult<CB>,
327        copperlists_logger: impl WriteStream<CopperList<P>> + 'static,
328        keyframes_logger: impl WriteStream<KeyFrame> + 'static,
329    ) -> CuResult<Self> {
330        let graph = config.get_graph(mission)?;
331        let all_instances_configs: Vec<Option<&ComponentConfig>> = graph
332            .get_all_nodes()
333            .iter()
334            .map(|(_, node)| node.get_instance_config())
335            .collect();
336
337        let tasks = tasks_instanciator(all_instances_configs, &mut resources)?;
338        let mut monitor = monitor_instanciator(config);
339        if let Ok(topology) = build_monitor_topology(config, mission) {
340            monitor.set_topology(topology);
341        }
342        let bridges = bridges_instanciator(config, &mut resources)?;
343
344        let (copperlists_logger, keyframes_logger, keyframe_interval) = match &config.logging {
345            Some(logging_config) if logging_config.enable_task_logging => (
346                Some(Box::new(copperlists_logger) as Box<dyn WriteStream<CopperList<P>>>),
347                Some(Box::new(keyframes_logger) as Box<dyn WriteStream<KeyFrame>>),
348                logging_config.keyframe_interval.unwrap(), // it is set to a default at parsing time
349            ),
350            Some(_) => (None, None, 0), // explicit no enable logging
351            None => (
352                // default
353                Some(Box::new(copperlists_logger) as Box<dyn WriteStream<CopperList<P>>>),
354                Some(Box::new(keyframes_logger) as Box<dyn WriteStream<KeyFrame>>),
355                DEFAULT_KEYFRAME_INTERVAL,
356            ),
357        };
358
359        let copperlists_manager = CopperListsManager {
360            inner: CuListsManager::new(),
361            logger: copperlists_logger,
362            last_encoded_bytes: 0,
363        };
364        #[cfg(target_os = "none")]
365        {
366            let cl_size = core::mem::size_of::<CopperList<P>>();
367            let total_bytes = cl_size.saturating_mul(NBCL);
368            info!(
369                "CuRuntime::new: copperlists count={} cl_size={} total_bytes={}",
370                NBCL, cl_size, total_bytes
371            );
372        }
373
374        let keyframes_manager = KeyFramesManager {
375            inner: KeyFrame::new(),
376            logger: keyframes_logger,
377            keyframe_interval,
378            last_encoded_bytes: 0,
379            forced_timestamp: None,
380            locked: false,
381        };
382
383        let runtime_config = config.runtime.clone().unwrap_or_default();
384
385        let runtime = Self {
386            tasks,
387            bridges,
388            resources,
389            monitor,
390            clock,
391            copperlists_manager,
392            keyframes_manager,
393            runtime_config,
394        };
395
396        Ok(runtime)
397    }
398
399    #[allow(clippy::too_many_arguments)]
400    #[cfg(not(feature = "std"))]
401    pub fn new(
402        clock: RobotClock,
403        config: &CuConfig,
404        mission: Option<&str>,
405        resources_instanciator: impl Fn(&CuConfig) -> CuResult<ResourceManager>,
406        tasks_instanciator: impl for<'c> Fn(
407            Vec<Option<&'c ComponentConfig>>,
408            &mut ResourceManager,
409        ) -> CuResult<CT>,
410        monitor_instanciator: impl Fn(&CuConfig) -> M,
411        bridges_instanciator: impl Fn(&CuConfig, &mut ResourceManager) -> CuResult<CB>,
412        copperlists_logger: impl WriteStream<CopperList<P>> + 'static,
413        keyframes_logger: impl WriteStream<KeyFrame> + 'static,
414    ) -> CuResult<Self> {
415        #[cfg(target_os = "none")]
416        info!("CuRuntime::new: resources instanciator");
417        let resources = resources_instanciator(config)?;
418        Self::new_with_resources(
419            clock,
420            config,
421            mission,
422            resources,
423            tasks_instanciator,
424            monitor_instanciator,
425            bridges_instanciator,
426            copperlists_logger,
427            keyframes_logger,
428        )
429    }
430
431    #[allow(clippy::too_many_arguments)]
432    #[cfg(not(feature = "std"))]
433    pub fn new_with_resources(
434        clock: RobotClock,
435        config: &CuConfig,
436        mission: Option<&str>,
437        mut resources: ResourceManager,
438        tasks_instanciator: impl for<'c> Fn(
439            Vec<Option<&'c ComponentConfig>>,
440            &mut ResourceManager,
441        ) -> CuResult<CT>,
442        monitor_instanciator: impl Fn(&CuConfig) -> M,
443        bridges_instanciator: impl Fn(&CuConfig, &mut ResourceManager) -> CuResult<CB>,
444        copperlists_logger: impl WriteStream<CopperList<P>> + 'static,
445        keyframes_logger: impl WriteStream<KeyFrame> + 'static,
446    ) -> CuResult<Self> {
447        #[cfg(target_os = "none")]
448        info!("CuRuntime::new: get graph");
449        let graph = config.get_graph(mission)?;
450        #[cfg(target_os = "none")]
451        info!("CuRuntime::new: graph ok");
452        let all_instances_configs: Vec<Option<&ComponentConfig>> = graph
453            .get_all_nodes()
454            .iter()
455            .map(|(_, node)| node.get_instance_config())
456            .collect();
457
458        #[cfg(target_os = "none")]
459        info!("CuRuntime::new: tasks instanciator");
460        let tasks = tasks_instanciator(all_instances_configs, &mut resources)?;
461
462        #[cfg(target_os = "none")]
463        info!("CuRuntime::new: monitor instanciator");
464        let mut monitor = monitor_instanciator(config);
465        #[cfg(target_os = "none")]
466        info!("CuRuntime::new: monitor instanciator ok");
467        #[cfg(target_os = "none")]
468        info!("CuRuntime::new: build monitor topology");
469        if let Ok(topology) = build_monitor_topology(config, mission) {
470            #[cfg(target_os = "none")]
471            info!("CuRuntime::new: monitor topology ok");
472            monitor.set_topology(topology);
473            #[cfg(target_os = "none")]
474            info!("CuRuntime::new: monitor topology set");
475        }
476        #[cfg(target_os = "none")]
477        info!("CuRuntime::new: bridges instanciator");
478        let bridges = bridges_instanciator(config, &mut resources)?;
479
480        let (copperlists_logger, keyframes_logger, keyframe_interval) = match &config.logging {
481            Some(logging_config) if logging_config.enable_task_logging => (
482                Some(Box::new(copperlists_logger) as Box<dyn WriteStream<CopperList<P>>>),
483                Some(Box::new(keyframes_logger) as Box<dyn WriteStream<KeyFrame>>),
484                logging_config.keyframe_interval.unwrap(), // it is set to a default at parsing time
485            ),
486            Some(_) => (None, None, 0), // explicit no enable logging
487            None => (
488                // default
489                Some(Box::new(copperlists_logger) as Box<dyn WriteStream<CopperList<P>>>),
490                Some(Box::new(keyframes_logger) as Box<dyn WriteStream<KeyFrame>>),
491                DEFAULT_KEYFRAME_INTERVAL,
492            ),
493        };
494
495        let copperlists_manager = CopperListsManager {
496            inner: CuListsManager::new(),
497            logger: copperlists_logger,
498            last_encoded_bytes: 0,
499        };
500        #[cfg(target_os = "none")]
501        {
502            let cl_size = core::mem::size_of::<CopperList<P>>();
503            let total_bytes = cl_size.saturating_mul(NBCL);
504            info!(
505                "CuRuntime::new: copperlists count={} cl_size={} total_bytes={}",
506                NBCL, cl_size, total_bytes
507            );
508        }
509
510        let keyframes_manager = KeyFramesManager {
511            inner: KeyFrame::new(),
512            logger: keyframes_logger,
513            keyframe_interval,
514            last_encoded_bytes: 0,
515            forced_timestamp: None,
516            locked: false,
517        };
518
519        let runtime_config = config.runtime.clone().unwrap_or_default();
520
521        let runtime = Self {
522            tasks,
523            bridges,
524            resources,
525            monitor,
526            clock,
527            copperlists_manager,
528            keyframes_manager,
529            runtime_config,
530        };
531
532        Ok(runtime)
533    }
534}
535
536/// Copper tasks can be of 3 types:
537/// - Source: only producing output messages (usually used for drivers)
538/// - Regular: processing input messages and producing output messages, more like compute nodes.
539/// - Sink: only consuming input messages (usually used for actuators)
540#[derive(Debug, PartialEq, Eq, Clone, Copy)]
541pub enum CuTaskType {
542    Source,
543    Regular,
544    Sink,
545}
546
547#[derive(Debug, Clone)]
548pub struct CuOutputPack {
549    pub culist_index: u32,
550    pub msg_types: Vec<String>,
551}
552
553#[derive(Debug, Clone)]
554pub struct CuInputMsg {
555    pub culist_index: u32,
556    pub msg_type: String,
557    pub src_port: usize,
558    pub edge_id: usize,
559}
560
561/// This structure represents a step in the execution plan.
562pub struct CuExecutionStep {
563    /// NodeId: node id of the task to execute
564    pub node_id: NodeId,
565    /// Node: node instance
566    pub node: Node,
567    /// CuTaskType: type of the task
568    pub task_type: CuTaskType,
569
570    /// the indices in the copper list of the input messages and their types
571    pub input_msg_indices_types: Vec<CuInputMsg>,
572
573    /// the index in the copper list of the output message and its type
574    pub output_msg_pack: Option<CuOutputPack>,
575}
576
577impl Debug for CuExecutionStep {
578    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
579        f.write_str(format!("   CuExecutionStep: Node Id: {}\n", self.node_id).as_str())?;
580        f.write_str(format!("                  task_type: {:?}\n", self.node.get_type()).as_str())?;
581        f.write_str(format!("                       task: {:?}\n", self.task_type).as_str())?;
582        f.write_str(
583            format!(
584                "              input_msg_types: {:?}\n",
585                self.input_msg_indices_types
586            )
587            .as_str(),
588        )?;
589        f.write_str(format!("       output_msg_pack: {:?}\n", self.output_msg_pack).as_str())?;
590        Ok(())
591    }
592}
593
594/// This structure represents a loop in the execution plan.
595/// It is used to represent a sequence of Execution units (loop or steps) that are executed
596/// multiple times.
597/// if loop_count is None, the loop is infinite.
598pub struct CuExecutionLoop {
599    pub steps: Vec<CuExecutionUnit>,
600    pub loop_count: Option<u32>,
601}
602
603impl Debug for CuExecutionLoop {
604    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
605        f.write_str("CuExecutionLoop:\n")?;
606        for step in &self.steps {
607            match step {
608                CuExecutionUnit::Step(step) => {
609                    step.fmt(f)?;
610                }
611                CuExecutionUnit::Loop(l) => {
612                    l.fmt(f)?;
613                }
614            }
615        }
616
617        f.write_str(format!("   count: {:?}", self.loop_count).as_str())?;
618        Ok(())
619    }
620}
621
622/// This structure represents a step in the execution plan.
623#[derive(Debug)]
624pub enum CuExecutionUnit {
625    Step(CuExecutionStep),
626    Loop(CuExecutionLoop),
627}
628
629fn find_output_pack_from_nodeid(
630    node_id: NodeId,
631    steps: &Vec<CuExecutionUnit>,
632) -> Option<CuOutputPack> {
633    for step in steps {
634        match step {
635            CuExecutionUnit::Loop(loop_unit) => {
636                if let Some(output_pack) = find_output_pack_from_nodeid(node_id, &loop_unit.steps) {
637                    return Some(output_pack);
638                }
639            }
640            CuExecutionUnit::Step(step) => {
641                if step.node_id == node_id {
642                    return step.output_msg_pack.clone();
643                }
644            }
645        }
646    }
647    None
648}
649
650pub fn find_task_type_for_id(graph: &CuGraph, node_id: NodeId) -> CuTaskType {
651    if graph.incoming_neighbor_count(node_id) == 0 {
652        CuTaskType::Source
653    } else if graph.outgoing_neighbor_count(node_id) == 0 {
654        CuTaskType::Sink
655    } else {
656        CuTaskType::Regular
657    }
658}
659
660/// The connection id used here is the index of the config graph edge that equates to the wanted
661/// connection.
662fn sort_inputs_by_cnx_id(input_msg_indices_types: &mut [CuInputMsg]) {
663    input_msg_indices_types.sort_by_key(|input| input.edge_id);
664}
665
666fn collect_output_msg_types(graph: &CuGraph, node_id: NodeId) -> Vec<String> {
667    let mut edge_ids = graph.get_src_edges(node_id).unwrap_or_default();
668    edge_ids.sort();
669
670    let mut msg_types = Vec::new();
671    let mut seen = Vec::new();
672    for edge_id in edge_ids {
673        if let Some(edge) = graph.edge(edge_id) {
674            if seen.iter().any(|msg| msg == &edge.msg) {
675                continue;
676            }
677            seen.push(edge.msg.clone());
678            msg_types.push(edge.msg.clone());
679        }
680    }
681    msg_types
682}
683/// Explores a subbranch and build the partial plan out of it.
684fn plan_tasks_tree_branch(
685    graph: &CuGraph,
686    mut next_culist_output_index: u32,
687    starting_point: NodeId,
688    plan: &mut Vec<CuExecutionUnit>,
689) -> (u32, bool) {
690    #[cfg(all(feature = "std", feature = "macro_debug"))]
691    eprintln!("-- starting branch from node {starting_point}");
692
693    let mut handled = false;
694
695    for id in graph.bfs_nodes(starting_point) {
696        let node_ref = graph.get_node(id).unwrap();
697        #[cfg(all(feature = "std", feature = "macro_debug"))]
698        eprintln!("  Visiting node: {node_ref:?}");
699
700        let mut input_msg_indices_types: Vec<CuInputMsg> = Vec::new();
701        let output_msg_pack: Option<CuOutputPack>;
702        let task_type = find_task_type_for_id(graph, id);
703
704        match task_type {
705            CuTaskType::Source => {
706                #[cfg(all(feature = "std", feature = "macro_debug"))]
707                eprintln!("    → Source node, assign output index {next_culist_output_index}");
708                let msg_types = collect_output_msg_types(graph, id);
709                if msg_types.is_empty() {
710                    panic!(
711                        "Source node '{}' has no outgoing connections",
712                        node_ref.get_id()
713                    );
714                }
715                output_msg_pack = Some(CuOutputPack {
716                    culist_index: next_culist_output_index,
717                    msg_types,
718                });
719                next_culist_output_index += 1;
720            }
721            CuTaskType::Sink => {
722                let mut edge_ids = graph.get_dst_edges(id).unwrap_or_default();
723                edge_ids.sort();
724                #[cfg(all(feature = "std", feature = "macro_debug"))]
725                eprintln!("    → Sink with incoming edges: {edge_ids:?}");
726                for edge_id in edge_ids {
727                    let edge = graph
728                        .edge(edge_id)
729                        .unwrap_or_else(|| panic!("Missing edge {edge_id} for node {id}"));
730                    let pid = graph
731                        .get_node_id_by_name(edge.src.as_str())
732                        .unwrap_or_else(|| {
733                            panic!("Missing source node '{}' for edge {edge_id}", edge.src)
734                        });
735                    let output_pack = find_output_pack_from_nodeid(pid, plan);
736                    if let Some(output_pack) = output_pack {
737                        #[cfg(all(feature = "std", feature = "macro_debug"))]
738                        eprintln!("      ✓ Input from {pid} ready: {output_pack:?}");
739                        let msg_type = edge.msg.as_str();
740                        let src_port = output_pack
741                            .msg_types
742                            .iter()
743                            .position(|msg| msg == msg_type)
744                            .unwrap_or_else(|| {
745                                panic!(
746                                    "Missing output port for message type '{msg_type}' on node {pid}"
747                                )
748                            });
749                        input_msg_indices_types.push(CuInputMsg {
750                            culist_index: output_pack.culist_index,
751                            msg_type: msg_type.to_string(),
752                            src_port,
753                            edge_id,
754                        });
755                    } else {
756                        #[cfg(all(feature = "std", feature = "macro_debug"))]
757                        eprintln!("      ✗ Input from {pid} not ready, returning");
758                        return (next_culist_output_index, handled);
759                    }
760                }
761                output_msg_pack = Some(CuOutputPack {
762                    culist_index: next_culist_output_index,
763                    msg_types: Vec::from(["()".to_string()]),
764                });
765                next_culist_output_index += 1;
766            }
767            CuTaskType::Regular => {
768                let mut edge_ids = graph.get_dst_edges(id).unwrap_or_default();
769                edge_ids.sort();
770                #[cfg(all(feature = "std", feature = "macro_debug"))]
771                eprintln!("    → Regular task with incoming edges: {edge_ids:?}");
772                for edge_id in edge_ids {
773                    let edge = graph
774                        .edge(edge_id)
775                        .unwrap_or_else(|| panic!("Missing edge {edge_id} for node {id}"));
776                    let pid = graph
777                        .get_node_id_by_name(edge.src.as_str())
778                        .unwrap_or_else(|| {
779                            panic!("Missing source node '{}' for edge {edge_id}", edge.src)
780                        });
781                    let output_pack = find_output_pack_from_nodeid(pid, plan);
782                    if let Some(output_pack) = output_pack {
783                        #[cfg(all(feature = "std", feature = "macro_debug"))]
784                        eprintln!("      ✓ Input from {pid} ready: {output_pack:?}");
785                        let msg_type = edge.msg.as_str();
786                        let src_port = output_pack
787                            .msg_types
788                            .iter()
789                            .position(|msg| msg == msg_type)
790                            .unwrap_or_else(|| {
791                                panic!(
792                                    "Missing output port for message type '{msg_type}' on node {pid}"
793                                )
794                            });
795                        input_msg_indices_types.push(CuInputMsg {
796                            culist_index: output_pack.culist_index,
797                            msg_type: msg_type.to_string(),
798                            src_port,
799                            edge_id,
800                        });
801                    } else {
802                        #[cfg(all(feature = "std", feature = "macro_debug"))]
803                        eprintln!("      ✗ Input from {pid} not ready, returning");
804                        return (next_culist_output_index, handled);
805                    }
806                }
807                let msg_types = collect_output_msg_types(graph, id);
808                if msg_types.is_empty() {
809                    panic!(
810                        "Regular node '{}' has no outgoing connections",
811                        node_ref.get_id()
812                    );
813                }
814                output_msg_pack = Some(CuOutputPack {
815                    culist_index: next_culist_output_index,
816                    msg_types,
817                });
818                next_culist_output_index += 1;
819            }
820        }
821
822        sort_inputs_by_cnx_id(&mut input_msg_indices_types);
823
824        if let Some(pos) = plan
825            .iter()
826            .position(|step| matches!(step, CuExecutionUnit::Step(s) if s.node_id == id))
827        {
828            #[cfg(all(feature = "std", feature = "macro_debug"))]
829            eprintln!("    → Already in plan, modifying existing step");
830            let mut step = plan.remove(pos);
831            if let CuExecutionUnit::Step(ref mut s) = step {
832                s.input_msg_indices_types = input_msg_indices_types;
833            }
834            plan.push(step);
835        } else {
836            #[cfg(all(feature = "std", feature = "macro_debug"))]
837            eprintln!("    → New step added to plan");
838            let step = CuExecutionStep {
839                node_id: id,
840                node: node_ref.clone(),
841                task_type,
842                input_msg_indices_types,
843                output_msg_pack,
844            };
845            plan.push(CuExecutionUnit::Step(step));
846        }
847
848        handled = true;
849    }
850
851    #[cfg(all(feature = "std", feature = "macro_debug"))]
852    eprintln!("-- finished branch from node {starting_point} with handled={handled}");
853    (next_culist_output_index, handled)
854}
855
856/// This is the main heuristics to compute an execution plan at compilation time.
857/// TODO(gbin): Make that heuristic pluggable.
858pub fn compute_runtime_plan(graph: &CuGraph) -> CuResult<CuExecutionLoop> {
859    #[cfg(all(feature = "std", feature = "macro_debug"))]
860    eprintln!("[runtime plan]");
861    let mut plan = Vec::new();
862    let mut next_culist_output_index = 0u32;
863
864    let mut queue: VecDeque<NodeId> = graph
865        .node_ids()
866        .into_iter()
867        .filter(|&node_id| find_task_type_for_id(graph, node_id) == CuTaskType::Source)
868        .collect();
869
870    #[cfg(all(feature = "std", feature = "macro_debug"))]
871    eprintln!("Initial source nodes: {queue:?}");
872
873    while let Some(start_node) = queue.pop_front() {
874        #[cfg(all(feature = "std", feature = "macro_debug"))]
875        eprintln!("→ Starting BFS from source {start_node}");
876        for node_id in graph.bfs_nodes(start_node) {
877            let already_in_plan = plan
878                .iter()
879                .any(|unit| matches!(unit, CuExecutionUnit::Step(s) if s.node_id == node_id));
880            if already_in_plan {
881                #[cfg(all(feature = "std", feature = "macro_debug"))]
882                eprintln!("    → Node {node_id} already planned, skipping");
883                continue;
884            }
885
886            #[cfg(all(feature = "std", feature = "macro_debug"))]
887            eprintln!("    Planning from node {node_id}");
888            let (new_index, handled) =
889                plan_tasks_tree_branch(graph, next_culist_output_index, node_id, &mut plan);
890            next_culist_output_index = new_index;
891
892            if !handled {
893                #[cfg(all(feature = "std", feature = "macro_debug"))]
894                eprintln!("    ✗ Node {node_id} was not handled, skipping enqueue of neighbors");
895                continue;
896            }
897
898            #[cfg(all(feature = "std", feature = "macro_debug"))]
899            eprintln!("    ✓ Node {node_id} handled successfully, enqueueing neighbors");
900            for neighbor in graph.get_neighbor_ids(node_id, CuDirection::Outgoing) {
901                #[cfg(all(feature = "std", feature = "macro_debug"))]
902                eprintln!("      → Enqueueing neighbor {neighbor}");
903                queue.push_back(neighbor);
904            }
905        }
906    }
907
908    let mut planned_nodes = BTreeSet::new();
909    for unit in &plan {
910        if let CuExecutionUnit::Step(step) = unit {
911            planned_nodes.insert(step.node_id);
912        }
913    }
914
915    let mut missing = Vec::new();
916    for node_id in graph.node_ids() {
917        if !planned_nodes.contains(&node_id) {
918            if let Some(node) = graph.get_node(node_id) {
919                missing.push(node.get_id().to_string());
920            } else {
921                missing.push(format!("node_id_{node_id}"));
922            }
923        }
924    }
925
926    if !missing.is_empty() {
927        missing.sort();
928        return Err(CuError::from(format!(
929            "Execution plan could not include all nodes. Missing: {}. Check for loopback or missing source connections.",
930            missing.join(", ")
931        )));
932    }
933
934    Ok(CuExecutionLoop {
935        steps: plan,
936        loop_count: None,
937    })
938}
939
940//tests
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use crate::config::Node;
945    use crate::cutask::CuSinkTask;
946    use crate::cutask::{CuSrcTask, Freezable};
947    use crate::monitoring::NoMonitor;
948    use crate::reflect::Reflect;
949    use bincode::Encode;
950    use cu29_traits::{ErasedCuStampedData, ErasedCuStampedDataSet, MatchingTasks};
951    use serde_derive::{Deserialize, Serialize};
952
953    #[derive(Reflect)]
954    pub struct TestSource {}
955
956    impl Freezable for TestSource {}
957
958    impl CuSrcTask for TestSource {
959        type Resources<'r> = ();
960        type Output<'m> = ();
961        fn new(_config: Option<&ComponentConfig>, _resources: Self::Resources<'_>) -> CuResult<Self>
962        where
963            Self: Sized,
964        {
965            Ok(Self {})
966        }
967
968        fn process(
969            &mut self,
970            _clock: &RobotClock,
971            _empty_msg: &mut Self::Output<'_>,
972        ) -> CuResult<()> {
973            Ok(())
974        }
975    }
976
977    #[derive(Reflect)]
978    pub struct TestSink {}
979
980    impl Freezable for TestSink {}
981
982    impl CuSinkTask for TestSink {
983        type Resources<'r> = ();
984        type Input<'m> = ();
985
986        fn new(_config: Option<&ComponentConfig>, _resources: Self::Resources<'_>) -> CuResult<Self>
987        where
988            Self: Sized,
989        {
990            Ok(Self {})
991        }
992
993        fn process(&mut self, _clock: &RobotClock, _input: &Self::Input<'_>) -> CuResult<()> {
994            Ok(())
995        }
996    }
997
998    // Those should be generated by the derive macro
999    type Tasks = (TestSource, TestSink);
1000
1001    #[derive(Debug, Encode, Decode, Serialize, Deserialize, Default)]
1002    struct Msgs(());
1003
1004    impl ErasedCuStampedDataSet for Msgs {
1005        fn cumsgs(&self) -> Vec<&dyn ErasedCuStampedData> {
1006            Vec::new()
1007        }
1008    }
1009
1010    impl MatchingTasks for Msgs {
1011        fn get_all_task_ids() -> &'static [&'static str] {
1012            &[]
1013        }
1014    }
1015
1016    impl CuListZeroedInit for Msgs {
1017        fn init_zeroed(&mut self) {}
1018    }
1019
1020    #[cfg(feature = "std")]
1021    fn tasks_instanciator(
1022        all_instances_configs: Vec<Option<&ComponentConfig>>,
1023        _resources: &mut ResourceManager,
1024    ) -> CuResult<Tasks> {
1025        Ok((
1026            TestSource::new(all_instances_configs[0], ())?,
1027            TestSink::new(all_instances_configs[1], ())?,
1028        ))
1029    }
1030
1031    #[cfg(not(feature = "std"))]
1032    fn tasks_instanciator(
1033        all_instances_configs: Vec<Option<&ComponentConfig>>,
1034        _resources: &mut ResourceManager,
1035    ) -> CuResult<Tasks> {
1036        Ok((
1037            TestSource::new(all_instances_configs[0], ())?,
1038            TestSink::new(all_instances_configs[1], ())?,
1039        ))
1040    }
1041
1042    fn monitor_instanciator(_config: &CuConfig) -> NoMonitor {
1043        NoMonitor {}
1044    }
1045
1046    fn bridges_instanciator(_config: &CuConfig, _resources: &mut ResourceManager) -> CuResult<()> {
1047        Ok(())
1048    }
1049
1050    fn resources_instanciator(_config: &CuConfig) -> CuResult<ResourceManager> {
1051        Ok(ResourceManager::new(&[]))
1052    }
1053
1054    #[derive(Debug)]
1055    struct FakeWriter {}
1056
1057    impl<E: Encode> WriteStream<E> for FakeWriter {
1058        fn log(&mut self, _obj: &E) -> CuResult<()> {
1059            Ok(())
1060        }
1061    }
1062
1063    #[test]
1064    fn test_runtime_instantiation() {
1065        let mut config = CuConfig::default();
1066        let graph = config.get_graph_mut(None).unwrap();
1067        graph.add_node(Node::new("a", "TestSource")).unwrap();
1068        graph.add_node(Node::new("b", "TestSink")).unwrap();
1069        graph.connect(0, 1, "()").unwrap();
1070        let runtime = CuRuntime::<Tasks, (), Msgs, NoMonitor, 2>::new(
1071            RobotClock::default(),
1072            &config,
1073            None,
1074            resources_instanciator,
1075            tasks_instanciator,
1076            monitor_instanciator,
1077            bridges_instanciator,
1078            FakeWriter {},
1079            FakeWriter {},
1080        );
1081        assert!(runtime.is_ok());
1082    }
1083
1084    #[test]
1085    fn test_copperlists_manager_lifecycle() {
1086        let mut config = CuConfig::default();
1087        let graph = config.get_graph_mut(None).unwrap();
1088        graph.add_node(Node::new("a", "TestSource")).unwrap();
1089        graph.add_node(Node::new("b", "TestSink")).unwrap();
1090        graph.connect(0, 1, "()").unwrap();
1091
1092        let mut runtime = CuRuntime::<Tasks, (), Msgs, NoMonitor, 2>::new(
1093            RobotClock::default(),
1094            &config,
1095            None,
1096            resources_instanciator,
1097            tasks_instanciator,
1098            monitor_instanciator,
1099            bridges_instanciator,
1100            FakeWriter {},
1101            FakeWriter {},
1102        )
1103        .unwrap();
1104
1105        // Now emulates the generated runtime
1106        {
1107            let copperlists = &mut runtime.copperlists_manager;
1108            let culist0 = copperlists
1109                .inner
1110                .create()
1111                .expect("Ran out of space for copper lists");
1112            // FIXME: error handling.
1113            let id = culist0.id;
1114            assert_eq!(id, 0);
1115            culist0.change_state(CopperListState::Processing);
1116            assert_eq!(copperlists.available_copper_lists(), 1);
1117        }
1118
1119        {
1120            let copperlists = &mut runtime.copperlists_manager;
1121            let culist1 = copperlists
1122                .inner
1123                .create()
1124                .expect("Ran out of space for copper lists"); // FIXME: error handling.
1125            let id = culist1.id;
1126            assert_eq!(id, 1);
1127            culist1.change_state(CopperListState::Processing);
1128            assert_eq!(copperlists.available_copper_lists(), 0);
1129        }
1130
1131        {
1132            let copperlists = &mut runtime.copperlists_manager;
1133            let culist2 = copperlists.inner.create();
1134            assert!(culist2.is_none());
1135            assert_eq!(copperlists.available_copper_lists(), 0);
1136            // Free in order, should let the top of the stack be serialized and freed.
1137            let _ = copperlists.end_of_processing(1);
1138            assert_eq!(copperlists.available_copper_lists(), 1);
1139        }
1140
1141        // Readd a CL
1142        {
1143            let copperlists = &mut runtime.copperlists_manager;
1144            let culist2 = copperlists
1145                .inner
1146                .create()
1147                .expect("Ran out of space for copper lists"); // FIXME: error handling.
1148            let id = culist2.id;
1149            assert_eq!(id, 2);
1150            culist2.change_state(CopperListState::Processing);
1151            assert_eq!(copperlists.available_copper_lists(), 0);
1152            // Free out of order, the #0 first
1153            let _ = copperlists.end_of_processing(0);
1154            // Should not free up the top of the stack
1155            assert_eq!(copperlists.available_copper_lists(), 0);
1156
1157            // Free up the top of the stack
1158            let _ = copperlists.end_of_processing(2);
1159            // This should free up 2 CLs
1160
1161            assert_eq!(copperlists.available_copper_lists(), 2);
1162        }
1163    }
1164
1165    #[test]
1166    fn test_runtime_task_input_order() {
1167        let mut config = CuConfig::default();
1168        let graph = config.get_graph_mut(None).unwrap();
1169        let src1_id = graph.add_node(Node::new("a", "Source1")).unwrap();
1170        let src2_id = graph.add_node(Node::new("b", "Source2")).unwrap();
1171        let sink_id = graph.add_node(Node::new("c", "Sink")).unwrap();
1172
1173        assert_eq!(src1_id, 0);
1174        assert_eq!(src2_id, 1);
1175
1176        // note that the source2 connection is before the source1
1177        let src1_type = "src1_type";
1178        let src2_type = "src2_type";
1179        graph.connect(src2_id, sink_id, src2_type).unwrap();
1180        graph.connect(src1_id, sink_id, src1_type).unwrap();
1181
1182        let src1_edge_id = *graph.get_src_edges(src1_id).unwrap().first().unwrap();
1183        let src2_edge_id = *graph.get_src_edges(src2_id).unwrap().first().unwrap();
1184        // the edge id depends on the order the connection is created, not
1185        // on the node id, and that is what determines the input order
1186        assert_eq!(src1_edge_id, 1);
1187        assert_eq!(src2_edge_id, 0);
1188
1189        let runtime = compute_runtime_plan(graph).unwrap();
1190        let sink_step = runtime
1191            .steps
1192            .iter()
1193            .find_map(|step| match step {
1194                CuExecutionUnit::Step(step) if step.node_id == sink_id => Some(step),
1195                _ => None,
1196            })
1197            .unwrap();
1198
1199        // since the src2 connection was added before src1 connection, the src2 type should be
1200        // first
1201        assert_eq!(sink_step.input_msg_indices_types[0].msg_type, src2_type);
1202        assert_eq!(sink_step.input_msg_indices_types[1].msg_type, src1_type);
1203    }
1204
1205    #[test]
1206    fn test_runtime_output_ports_unique_ordered() {
1207        let mut config = CuConfig::default();
1208        let graph = config.get_graph_mut(None).unwrap();
1209        let src_id = graph.add_node(Node::new("src", "Source")).unwrap();
1210        let dst_a_id = graph.add_node(Node::new("dst_a", "SinkA")).unwrap();
1211        let dst_b_id = graph.add_node(Node::new("dst_b", "SinkB")).unwrap();
1212        let dst_a2_id = graph.add_node(Node::new("dst_a2", "SinkA2")).unwrap();
1213        let dst_c_id = graph.add_node(Node::new("dst_c", "SinkC")).unwrap();
1214
1215        graph.connect(src_id, dst_a_id, "msg::A").unwrap();
1216        graph.connect(src_id, dst_b_id, "msg::B").unwrap();
1217        graph.connect(src_id, dst_a2_id, "msg::A").unwrap();
1218        graph.connect(src_id, dst_c_id, "msg::C").unwrap();
1219
1220        let runtime = compute_runtime_plan(graph).unwrap();
1221        let src_step = runtime
1222            .steps
1223            .iter()
1224            .find_map(|step| match step {
1225                CuExecutionUnit::Step(step) if step.node_id == src_id => Some(step),
1226                _ => None,
1227            })
1228            .unwrap();
1229
1230        let output_pack = src_step.output_msg_pack.as_ref().unwrap();
1231        assert_eq!(output_pack.msg_types, vec!["msg::A", "msg::B", "msg::C"]);
1232
1233        let dst_a_step = runtime
1234            .steps
1235            .iter()
1236            .find_map(|step| match step {
1237                CuExecutionUnit::Step(step) if step.node_id == dst_a_id => Some(step),
1238                _ => None,
1239            })
1240            .unwrap();
1241        let dst_b_step = runtime
1242            .steps
1243            .iter()
1244            .find_map(|step| match step {
1245                CuExecutionUnit::Step(step) if step.node_id == dst_b_id => Some(step),
1246                _ => None,
1247            })
1248            .unwrap();
1249        let dst_a2_step = runtime
1250            .steps
1251            .iter()
1252            .find_map(|step| match step {
1253                CuExecutionUnit::Step(step) if step.node_id == dst_a2_id => Some(step),
1254                _ => None,
1255            })
1256            .unwrap();
1257        let dst_c_step = runtime
1258            .steps
1259            .iter()
1260            .find_map(|step| match step {
1261                CuExecutionUnit::Step(step) if step.node_id == dst_c_id => Some(step),
1262                _ => None,
1263            })
1264            .unwrap();
1265
1266        assert_eq!(dst_a_step.input_msg_indices_types[0].src_port, 0);
1267        assert_eq!(dst_b_step.input_msg_indices_types[0].src_port, 1);
1268        assert_eq!(dst_a2_step.input_msg_indices_types[0].src_port, 0);
1269        assert_eq!(dst_c_step.input_msg_indices_types[0].src_port, 2);
1270    }
1271
1272    #[test]
1273    fn test_runtime_output_ports_fanout_single() {
1274        let mut config = CuConfig::default();
1275        let graph = config.get_graph_mut(None).unwrap();
1276        let src_id = graph.add_node(Node::new("src", "Source")).unwrap();
1277        let dst_a_id = graph.add_node(Node::new("dst_a", "SinkA")).unwrap();
1278        let dst_b_id = graph.add_node(Node::new("dst_b", "SinkB")).unwrap();
1279
1280        graph.connect(src_id, dst_a_id, "i32").unwrap();
1281        graph.connect(src_id, dst_b_id, "i32").unwrap();
1282
1283        let runtime = compute_runtime_plan(graph).unwrap();
1284        let src_step = runtime
1285            .steps
1286            .iter()
1287            .find_map(|step| match step {
1288                CuExecutionUnit::Step(step) if step.node_id == src_id => Some(step),
1289                _ => None,
1290            })
1291            .unwrap();
1292
1293        let output_pack = src_step.output_msg_pack.as_ref().unwrap();
1294        assert_eq!(output_pack.msg_types, vec!["i32"]);
1295    }
1296
1297    #[test]
1298    fn test_runtime_plan_diamond_case1() {
1299        // more complex topology that tripped the scheduler
1300        let mut config = CuConfig::default();
1301        let graph = config.get_graph_mut(None).unwrap();
1302        let cam0_id = graph
1303            .add_node(Node::new("cam0", "tasks::IntegerSrcTask"))
1304            .unwrap();
1305        let inf0_id = graph
1306            .add_node(Node::new("inf0", "tasks::Integer2FloatTask"))
1307            .unwrap();
1308        let broadcast_id = graph
1309            .add_node(Node::new("broadcast", "tasks::MergingSinkTask"))
1310            .unwrap();
1311
1312        // case 1 order
1313        graph.connect(cam0_id, broadcast_id, "i32").unwrap();
1314        graph.connect(cam0_id, inf0_id, "i32").unwrap();
1315        graph.connect(inf0_id, broadcast_id, "f32").unwrap();
1316
1317        let edge_cam0_to_broadcast = *graph.get_src_edges(cam0_id).unwrap().first().unwrap();
1318        let edge_cam0_to_inf0 = graph.get_src_edges(cam0_id).unwrap()[1];
1319
1320        assert_eq!(edge_cam0_to_inf0, 0);
1321        assert_eq!(edge_cam0_to_broadcast, 1);
1322
1323        let runtime = compute_runtime_plan(graph).unwrap();
1324        let broadcast_step = runtime
1325            .steps
1326            .iter()
1327            .find_map(|step| match step {
1328                CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
1329                _ => None,
1330            })
1331            .unwrap();
1332
1333        assert_eq!(broadcast_step.input_msg_indices_types[0].msg_type, "i32");
1334        assert_eq!(broadcast_step.input_msg_indices_types[1].msg_type, "f32");
1335    }
1336
1337    #[test]
1338    fn test_runtime_plan_diamond_case2() {
1339        // more complex topology that tripped the scheduler variation 2
1340        let mut config = CuConfig::default();
1341        let graph = config.get_graph_mut(None).unwrap();
1342        let cam0_id = graph
1343            .add_node(Node::new("cam0", "tasks::IntegerSrcTask"))
1344            .unwrap();
1345        let inf0_id = graph
1346            .add_node(Node::new("inf0", "tasks::Integer2FloatTask"))
1347            .unwrap();
1348        let broadcast_id = graph
1349            .add_node(Node::new("broadcast", "tasks::MergingSinkTask"))
1350            .unwrap();
1351
1352        // case 2 order
1353        graph.connect(cam0_id, inf0_id, "i32").unwrap();
1354        graph.connect(cam0_id, broadcast_id, "i32").unwrap();
1355        graph.connect(inf0_id, broadcast_id, "f32").unwrap();
1356
1357        let edge_cam0_to_inf0 = *graph.get_src_edges(cam0_id).unwrap().first().unwrap();
1358        let edge_cam0_to_broadcast = graph.get_src_edges(cam0_id).unwrap()[1];
1359
1360        assert_eq!(edge_cam0_to_broadcast, 0);
1361        assert_eq!(edge_cam0_to_inf0, 1);
1362
1363        let runtime = compute_runtime_plan(graph).unwrap();
1364        let broadcast_step = runtime
1365            .steps
1366            .iter()
1367            .find_map(|step| match step {
1368                CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
1369                _ => None,
1370            })
1371            .unwrap();
1372
1373        assert_eq!(broadcast_step.input_msg_indices_types[0].msg_type, "i32");
1374        assert_eq!(broadcast_step.input_msg_indices_types[1].msg_type, "f32");
1375    }
1376}