1use crate::lagrange::CompositionNode;
2use noether_core::effects::Effect;
3use noether_core::stage::StageId;
4use noether_store::StageStore;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub enum ExecutionMode {
9 Inline,
10 Process,
11 Remote,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionStep {
16 pub step_index: usize,
17 pub stage_id: StageId,
18 pub mode: ExecutionMode,
19 pub depends_on: Vec<usize>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CostSummary {
24 pub total_time_ms_p50: Option<u64>,
25 pub total_tokens_est: Option<u64>,
26 pub total_memory_mb_peak: Option<u64>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ExecutionPlan {
31 pub steps: Vec<ExecutionStep>,
32 pub cost: CostSummary,
33 pub parallel_groups: Vec<Vec<usize>>,
34}
35
36pub fn plan_graph(node: &CompositionNode, store: &(impl StageStore + ?Sized)) -> ExecutionPlan {
38 let mut steps = Vec::new();
39 let mut parallel_groups = Vec::new();
40 flatten_node(node, &mut steps, &mut parallel_groups, store, &[]);
41
42 let cost = estimate_cost(&steps, store);
43
44 ExecutionPlan {
45 steps,
46 cost,
47 parallel_groups,
48 }
49}
50
51fn flatten_node(
53 node: &CompositionNode,
54 steps: &mut Vec<ExecutionStep>,
55 parallel_groups: &mut Vec<Vec<usize>>,
56 store: &(impl StageStore + ?Sized),
57 depends_on: &[usize],
58) -> Vec<usize> {
59 match node {
60 CompositionNode::Stage { id, .. } => {
61 let idx = steps.len();
62 steps.push(ExecutionStep {
63 step_index: idx,
64 stage_id: id.clone(),
65 mode: ExecutionMode::Inline,
66 depends_on: depends_on.to_vec(),
67 });
68 vec![idx]
69 }
70 CompositionNode::Const { .. } => {
71 depends_on.to_vec()
74 }
75 CompositionNode::RemoteStage { .. } => {
76 depends_on.to_vec()
79 }
80 CompositionNode::Sequential { stages } => {
81 let mut prev_indices = depends_on.to_vec();
82
83 let start_step = steps.len();
84 for stage in stages {
85 prev_indices = flatten_node(stage, steps, parallel_groups, store, &prev_indices);
86 }
87 let end_step = steps.len();
88
89 let all_direct_pure_stages = stages.iter().all(|s| {
92 if let CompositionNode::Stage { id, .. } = s {
93 store
94 .get(id)
95 .ok()
96 .flatten()
97 .map(|st| st.signature.effects.contains(&Effect::Pure))
98 .unwrap_or(false)
99 } else {
100 false
101 }
102 });
103
104 if all_direct_pure_stages && stages.len() > 1 {
105 let group: Vec<usize> = (start_step..end_step).collect();
106 if group.len() > 1 {
107 parallel_groups.push(group);
108 }
109 }
110
111 prev_indices
112 }
113 CompositionNode::Parallel { branches } => {
114 let mut group = Vec::new();
115 let mut all_outputs = Vec::new();
116 for node in branches.values() {
117 let outputs = flatten_node(node, steps, parallel_groups, store, depends_on);
118 if let Some(&first) = outputs.first() {
120 group.push(first);
121 }
122 all_outputs.extend(outputs);
123 }
124 if group.len() > 1 {
125 parallel_groups.push(group);
126 }
127 all_outputs
128 }
129 CompositionNode::Branch {
130 predicate,
131 if_true,
132 if_false,
133 } => {
134 let pred_out = flatten_node(predicate, steps, parallel_groups, store, depends_on);
135 let true_out = flatten_node(if_true, steps, parallel_groups, store, &pred_out);
136 let false_out = flatten_node(if_false, steps, parallel_groups, store, &pred_out);
137 let mut combined = true_out;
138 combined.extend(false_out);
139 combined
140 }
141 CompositionNode::Fanout { source, targets } => {
142 let source_out = flatten_node(source, steps, parallel_groups, store, depends_on);
143 let mut group = Vec::new();
144 let mut all_outputs = Vec::new();
145 for target in targets {
146 let outputs = flatten_node(target, steps, parallel_groups, store, &source_out);
147 if let Some(&first) = outputs.first() {
148 group.push(first);
149 }
150 all_outputs.extend(outputs);
151 }
152 if group.len() > 1 {
153 parallel_groups.push(group);
154 }
155 all_outputs
156 }
157 CompositionNode::Merge { sources, target } => {
158 let mut all_source_outputs = Vec::new();
159 let mut group = Vec::new();
160 for src in sources {
161 let outputs = flatten_node(src, steps, parallel_groups, store, depends_on);
162 if let Some(&first) = outputs.first() {
163 group.push(first);
164 }
165 all_source_outputs.extend(outputs);
166 }
167 if group.len() > 1 {
168 parallel_groups.push(group);
169 }
170 flatten_node(target, steps, parallel_groups, store, &all_source_outputs)
171 }
172 CompositionNode::Retry { stage, .. } => {
173 flatten_node(stage, steps, parallel_groups, store, depends_on)
174 }
175 CompositionNode::Let { bindings, body } => {
176 let mut group = Vec::new();
179 let mut binding_outputs = Vec::new();
180 for node in bindings.values() {
181 let outputs = flatten_node(node, steps, parallel_groups, store, depends_on);
182 if let Some(&first) = outputs.first() {
183 group.push(first);
184 }
185 binding_outputs.extend(outputs);
186 }
187 if group.len() > 1 {
188 parallel_groups.push(group);
189 }
190 let mut body_deps = depends_on.to_vec();
193 body_deps.extend(binding_outputs);
194 flatten_node(body, steps, parallel_groups, store, &body_deps)
195 }
196 }
197}
198
199fn estimate_cost(steps: &[ExecutionStep], store: &(impl StageStore + ?Sized)) -> CostSummary {
200 let mut total_time: u64 = 0;
201 let mut total_tokens: u64 = 0;
202 let mut max_memory: u64 = 0;
203
204 for step in steps {
205 if let Ok(Some(stage)) = store.get(&step.stage_id) {
206 if let Some(t) = stage.cost.time_ms_p50 {
207 total_time += t;
208 }
209 if let Some(t) = stage.cost.tokens_est {
210 total_tokens += t;
211 }
212 if let Some(m) = stage.cost.memory_mb {
213 max_memory = max_memory.max(m);
214 }
215 }
216 }
217
218 CostSummary {
219 total_time_ms_p50: if total_time > 0 {
220 Some(total_time)
221 } else {
222 None
223 },
224 total_tokens_est: if total_tokens > 0 {
225 Some(total_tokens)
226 } else {
227 None
228 },
229 total_memory_mb_peak: if max_memory > 0 {
230 Some(max_memory)
231 } else {
232 None
233 },
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use noether_store::MemoryStore;
241 use std::collections::BTreeMap;
242
243 fn stage(id: &str) -> CompositionNode {
244 CompositionNode::Stage {
245 id: StageId(id.into()),
246 pinning: crate::lagrange::Pinning::Signature,
247 config: None,
248 }
249 }
250
251 #[test]
252 fn plan_single_stage() {
253 let store = MemoryStore::new();
254 let plan = plan_graph(&stage("a"), &store);
255 assert_eq!(plan.steps.len(), 1);
256 assert_eq!(plan.steps[0].stage_id, StageId("a".into()));
257 assert!(plan.steps[0].depends_on.is_empty());
258 }
259
260 #[test]
261 fn plan_sequential_has_dependencies() {
262 let store = MemoryStore::new();
263 let node = CompositionNode::Sequential {
264 stages: vec![stage("a"), stage("b"), stage("c")],
265 };
266 let plan = plan_graph(&node, &store);
267 assert_eq!(plan.steps.len(), 3);
268 assert!(plan.steps[0].depends_on.is_empty());
269 assert_eq!(plan.steps[1].depends_on, vec![0]);
270 assert_eq!(plan.steps[2].depends_on, vec![1]);
271 }
272
273 #[test]
274 fn plan_parallel_creates_group() {
275 let store = MemoryStore::new();
276 let node = CompositionNode::Parallel {
277 branches: BTreeMap::from([("a".into(), stage("s1")), ("b".into(), stage("s2"))]),
278 };
279 let plan = plan_graph(&node, &store);
280 assert_eq!(plan.steps.len(), 2);
281 assert_eq!(plan.parallel_groups.len(), 1);
282 assert_eq!(plan.parallel_groups[0].len(), 2);
283 }
284
285 #[test]
286 fn plan_sequential_with_parallel() {
287 let store = MemoryStore::new();
288 let node = CompositionNode::Sequential {
289 stages: vec![
290 stage("input"),
291 CompositionNode::Parallel {
292 branches: BTreeMap::from([
293 ("a".into(), stage("s1")),
294 ("b".into(), stage("s2")),
295 ]),
296 },
297 stage("output"),
298 ],
299 };
300 let plan = plan_graph(&node, &store);
301 assert_eq!(plan.steps.len(), 4); assert!(plan.steps[1].depends_on.contains(&0));
304 assert!(plan.steps[2].depends_on.contains(&0));
305 assert!(plan.steps[3].depends_on.contains(&1));
307 assert!(plan.steps[3].depends_on.contains(&2));
308 }
309}