Skip to main content

noether_engine/
planner.rs

1use crate::lagrange::CompositionNode;
2use noether_core::effects::Effect;
3use noether_core::stage::StageId;
4use noether_store::StageStore;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub enum ExecutionMode {
9    Inline,
10    Process,
11    Remote,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionStep {
16    pub step_index: usize,
17    pub stage_id: StageId,
18    pub mode: ExecutionMode,
19    pub depends_on: Vec<usize>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CostSummary {
24    pub total_time_ms_p50: Option<u64>,
25    pub total_tokens_est: Option<u64>,
26    pub total_memory_mb_peak: Option<u64>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ExecutionPlan {
31    pub steps: Vec<ExecutionStep>,
32    pub cost: CostSummary,
33    pub parallel_groups: Vec<Vec<usize>>,
34}
35
36/// Flatten a composition AST into a linear execution plan.
37pub fn plan_graph(node: &CompositionNode, store: &(impl StageStore + ?Sized)) -> ExecutionPlan {
38    let mut steps = Vec::new();
39    let mut parallel_groups = Vec::new();
40    flatten_node(node, &mut steps, &mut parallel_groups, store, &[]);
41
42    let cost = estimate_cost(&steps, store);
43
44    ExecutionPlan {
45        steps,
46        cost,
47        parallel_groups,
48    }
49}
50
51/// Returns the indices of the steps added for this node.
52fn flatten_node(
53    node: &CompositionNode,
54    steps: &mut Vec<ExecutionStep>,
55    parallel_groups: &mut Vec<Vec<usize>>,
56    store: &(impl StageStore + ?Sized),
57    depends_on: &[usize],
58) -> Vec<usize> {
59    match node {
60        CompositionNode::Stage { id, .. } => {
61            let idx = steps.len();
62            steps.push(ExecutionStep {
63                step_index: idx,
64                stage_id: id.clone(),
65                mode: ExecutionMode::Inline,
66                depends_on: depends_on.to_vec(),
67            });
68            vec![idx]
69        }
70        CompositionNode::Const { .. } => {
71            // Const nodes produce no execution step — they are resolved inline
72            // in the runner without touching the store.
73            depends_on.to_vec()
74        }
75        CompositionNode::RemoteStage { .. } => {
76            // RemoteStage nodes produce no local execution step.
77            // Native runner handles these via reqwest; browser runtime via fetch().
78            depends_on.to_vec()
79        }
80        CompositionNode::Sequential { stages } => {
81            let mut prev_indices = depends_on.to_vec();
82
83            let start_step = steps.len();
84            for stage in stages {
85                prev_indices = flatten_node(stage, steps, parallel_groups, store, &prev_indices);
86            }
87            let end_step = steps.len();
88
89            // After flattening, check whether ALL direct children are Stage nodes
90            // and all are Pure. If so, add them as a parallel group hint.
91            let all_direct_pure_stages = stages.iter().all(|s| {
92                if let CompositionNode::Stage { id, .. } = s {
93                    store
94                        .get(id)
95                        .ok()
96                        .flatten()
97                        .map(|st| st.signature.effects.contains(&Effect::Pure))
98                        .unwrap_or(false)
99                } else {
100                    false
101                }
102            });
103
104            if all_direct_pure_stages && stages.len() > 1 {
105                let group: Vec<usize> = (start_step..end_step).collect();
106                if group.len() > 1 {
107                    parallel_groups.push(group);
108                }
109            }
110
111            prev_indices
112        }
113        CompositionNode::Parallel { branches } => {
114            let mut group = Vec::new();
115            let mut all_outputs = Vec::new();
116            for node in branches.values() {
117                let outputs = flatten_node(node, steps, parallel_groups, store, depends_on);
118                // The first step of each branch is in the parallel group
119                if let Some(&first) = outputs.first() {
120                    group.push(first);
121                }
122                all_outputs.extend(outputs);
123            }
124            if group.len() > 1 {
125                parallel_groups.push(group);
126            }
127            all_outputs
128        }
129        CompositionNode::Branch {
130            predicate,
131            if_true,
132            if_false,
133        } => {
134            let pred_out = flatten_node(predicate, steps, parallel_groups, store, depends_on);
135            let true_out = flatten_node(if_true, steps, parallel_groups, store, &pred_out);
136            let false_out = flatten_node(if_false, steps, parallel_groups, store, &pred_out);
137            let mut combined = true_out;
138            combined.extend(false_out);
139            combined
140        }
141        CompositionNode::Fanout { source, targets } => {
142            let source_out = flatten_node(source, steps, parallel_groups, store, depends_on);
143            let mut group = Vec::new();
144            let mut all_outputs = Vec::new();
145            for target in targets {
146                let outputs = flatten_node(target, steps, parallel_groups, store, &source_out);
147                if let Some(&first) = outputs.first() {
148                    group.push(first);
149                }
150                all_outputs.extend(outputs);
151            }
152            if group.len() > 1 {
153                parallel_groups.push(group);
154            }
155            all_outputs
156        }
157        CompositionNode::Merge { sources, target } => {
158            let mut all_source_outputs = Vec::new();
159            let mut group = Vec::new();
160            for src in sources {
161                let outputs = flatten_node(src, steps, parallel_groups, store, depends_on);
162                if let Some(&first) = outputs.first() {
163                    group.push(first);
164                }
165                all_source_outputs.extend(outputs);
166            }
167            if group.len() > 1 {
168                parallel_groups.push(group);
169            }
170            flatten_node(target, steps, parallel_groups, store, &all_source_outputs)
171        }
172        CompositionNode::Retry { stage, .. } => {
173            flatten_node(stage, steps, parallel_groups, store, depends_on)
174        }
175        CompositionNode::Let { bindings, body } => {
176            // Bindings run concurrently against the outer input; body then
177            // sequentially after every binding completes.
178            let mut group = Vec::new();
179            let mut binding_outputs = Vec::new();
180            for node in bindings.values() {
181                let outputs = flatten_node(node, steps, parallel_groups, store, depends_on);
182                if let Some(&first) = outputs.first() {
183                    group.push(first);
184                }
185                binding_outputs.extend(outputs);
186            }
187            if group.len() > 1 {
188                parallel_groups.push(group);
189            }
190            // The body sees an input that depends on every binding output and
191            // the outer input.
192            let mut body_deps = depends_on.to_vec();
193            body_deps.extend(binding_outputs);
194            flatten_node(body, steps, parallel_groups, store, &body_deps)
195        }
196    }
197}
198
199fn estimate_cost(steps: &[ExecutionStep], store: &(impl StageStore + ?Sized)) -> CostSummary {
200    let mut total_time: u64 = 0;
201    let mut total_tokens: u64 = 0;
202    let mut max_memory: u64 = 0;
203
204    for step in steps {
205        if let Ok(Some(stage)) = store.get(&step.stage_id) {
206            if let Some(t) = stage.cost.time_ms_p50 {
207                total_time += t;
208            }
209            if let Some(t) = stage.cost.tokens_est {
210                total_tokens += t;
211            }
212            if let Some(m) = stage.cost.memory_mb {
213                max_memory = max_memory.max(m);
214            }
215        }
216    }
217
218    CostSummary {
219        total_time_ms_p50: if total_time > 0 {
220            Some(total_time)
221        } else {
222            None
223        },
224        total_tokens_est: if total_tokens > 0 {
225            Some(total_tokens)
226        } else {
227            None
228        },
229        total_memory_mb_peak: if max_memory > 0 {
230            Some(max_memory)
231        } else {
232            None
233        },
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use noether_store::MemoryStore;
241    use std::collections::BTreeMap;
242
243    fn stage(id: &str) -> CompositionNode {
244        CompositionNode::Stage {
245            id: StageId(id.into()),
246            pinning: crate::lagrange::Pinning::Signature,
247            config: None,
248        }
249    }
250
251    #[test]
252    fn plan_single_stage() {
253        let store = MemoryStore::new();
254        let plan = plan_graph(&stage("a"), &store);
255        assert_eq!(plan.steps.len(), 1);
256        assert_eq!(plan.steps[0].stage_id, StageId("a".into()));
257        assert!(plan.steps[0].depends_on.is_empty());
258    }
259
260    #[test]
261    fn plan_sequential_has_dependencies() {
262        let store = MemoryStore::new();
263        let node = CompositionNode::Sequential {
264            stages: vec![stage("a"), stage("b"), stage("c")],
265        };
266        let plan = plan_graph(&node, &store);
267        assert_eq!(plan.steps.len(), 3);
268        assert!(plan.steps[0].depends_on.is_empty());
269        assert_eq!(plan.steps[1].depends_on, vec![0]);
270        assert_eq!(plan.steps[2].depends_on, vec![1]);
271    }
272
273    #[test]
274    fn plan_parallel_creates_group() {
275        let store = MemoryStore::new();
276        let node = CompositionNode::Parallel {
277            branches: BTreeMap::from([("a".into(), stage("s1")), ("b".into(), stage("s2"))]),
278        };
279        let plan = plan_graph(&node, &store);
280        assert_eq!(plan.steps.len(), 2);
281        assert_eq!(plan.parallel_groups.len(), 1);
282        assert_eq!(plan.parallel_groups[0].len(), 2);
283    }
284
285    #[test]
286    fn plan_sequential_with_parallel() {
287        let store = MemoryStore::new();
288        let node = CompositionNode::Sequential {
289            stages: vec![
290                stage("input"),
291                CompositionNode::Parallel {
292                    branches: BTreeMap::from([
293                        ("a".into(), stage("s1")),
294                        ("b".into(), stage("s2")),
295                    ]),
296                },
297                stage("output"),
298            ],
299        };
300        let plan = plan_graph(&node, &store);
301        assert_eq!(plan.steps.len(), 4); // input, s1, s2, output
302                                         // s1 and s2 depend on input (step 0)
303        assert!(plan.steps[1].depends_on.contains(&0));
304        assert!(plan.steps[2].depends_on.contains(&0));
305        // output depends on both s1 and s2
306        assert!(plan.steps[3].depends_on.contains(&1));
307        assert!(plan.steps[3].depends_on.contains(&2));
308    }
309}