1use crate::lagrange::CompositionNode;
2use noether_core::effects::Effect;
3use noether_core::stage::StageId;
4use noether_store::StageStore;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub enum ExecutionMode {
9 Inline,
10 Process,
11 Remote,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExecutionStep {
16 pub step_index: usize,
17 pub stage_id: StageId,
18 pub mode: ExecutionMode,
19 pub depends_on: Vec<usize>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CostSummary {
24 pub total_time_ms_p50: Option<u64>,
25 pub total_tokens_est: Option<u64>,
26 pub total_memory_mb_peak: Option<u64>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ExecutionPlan {
31 pub steps: Vec<ExecutionStep>,
32 pub cost: CostSummary,
33 pub parallel_groups: Vec<Vec<usize>>,
34}
35
36pub fn plan_graph(node: &CompositionNode, store: &(impl StageStore + ?Sized)) -> ExecutionPlan {
38 let mut steps = Vec::new();
39 let mut parallel_groups = Vec::new();
40 flatten_node(node, &mut steps, &mut parallel_groups, store, &[]);
41
42 let cost = estimate_cost(&steps, store);
43
44 ExecutionPlan {
45 steps,
46 cost,
47 parallel_groups,
48 }
49}
50
51fn flatten_node(
53 node: &CompositionNode,
54 steps: &mut Vec<ExecutionStep>,
55 parallel_groups: &mut Vec<Vec<usize>>,
56 store: &(impl StageStore + ?Sized),
57 depends_on: &[usize],
58) -> Vec<usize> {
59 match node {
60 CompositionNode::Stage { id, .. } => {
61 let idx = steps.len();
62 steps.push(ExecutionStep {
63 step_index: idx,
64 stage_id: id.clone(),
65 mode: ExecutionMode::Inline,
66 depends_on: depends_on.to_vec(),
67 });
68 vec![idx]
69 }
70 CompositionNode::Const { .. } => {
71 depends_on.to_vec()
74 }
75 CompositionNode::RemoteStage { .. } => {
76 depends_on.to_vec()
79 }
80 CompositionNode::Sequential { stages } => {
81 let mut prev_indices = depends_on.to_vec();
82
83 let start_step = steps.len();
84 for stage in stages {
85 prev_indices = flatten_node(stage, steps, parallel_groups, store, &prev_indices);
86 }
87 let end_step = steps.len();
88
89 let all_direct_pure_stages = stages.iter().all(|s| {
92 if let CompositionNode::Stage { id, .. } = s {
93 store
94 .get(id)
95 .ok()
96 .flatten()
97 .map(|st| st.signature.effects.contains(&Effect::Pure))
98 .unwrap_or(false)
99 } else {
100 false
101 }
102 });
103
104 if all_direct_pure_stages && stages.len() > 1 {
105 let group: Vec<usize> = (start_step..end_step).collect();
106 if group.len() > 1 {
107 parallel_groups.push(group);
108 }
109 }
110
111 prev_indices
112 }
113 CompositionNode::Parallel { branches } => {
114 let mut group = Vec::new();
115 let mut all_outputs = Vec::new();
116 for node in branches.values() {
117 let outputs = flatten_node(node, steps, parallel_groups, store, depends_on);
118 if let Some(&first) = outputs.first() {
120 group.push(first);
121 }
122 all_outputs.extend(outputs);
123 }
124 if group.len() > 1 {
125 parallel_groups.push(group);
126 }
127 all_outputs
128 }
129 CompositionNode::Branch {
130 predicate,
131 if_true,
132 if_false,
133 } => {
134 let pred_out = flatten_node(predicate, steps, parallel_groups, store, depends_on);
135 let true_out = flatten_node(if_true, steps, parallel_groups, store, &pred_out);
136 let false_out = flatten_node(if_false, steps, parallel_groups, store, &pred_out);
137 let mut combined = true_out;
138 combined.extend(false_out);
139 combined
140 }
141 CompositionNode::Fanout { source, targets } => {
142 let source_out = flatten_node(source, steps, parallel_groups, store, depends_on);
143 let mut group = Vec::new();
144 let mut all_outputs = Vec::new();
145 for target in targets {
146 let outputs = flatten_node(target, steps, parallel_groups, store, &source_out);
147 if let Some(&first) = outputs.first() {
148 group.push(first);
149 }
150 all_outputs.extend(outputs);
151 }
152 if group.len() > 1 {
153 parallel_groups.push(group);
154 }
155 all_outputs
156 }
157 CompositionNode::Merge { sources, target } => {
158 let mut all_source_outputs = Vec::new();
159 let mut group = Vec::new();
160 for src in sources {
161 let outputs = flatten_node(src, steps, parallel_groups, store, depends_on);
162 if let Some(&first) = outputs.first() {
163 group.push(first);
164 }
165 all_source_outputs.extend(outputs);
166 }
167 if group.len() > 1 {
168 parallel_groups.push(group);
169 }
170 flatten_node(target, steps, parallel_groups, store, &all_source_outputs)
171 }
172 CompositionNode::Retry { stage, .. } => {
173 flatten_node(stage, steps, parallel_groups, store, depends_on)
174 }
175 CompositionNode::Let { bindings, body } => {
176 let mut group = Vec::new();
179 let mut binding_outputs = Vec::new();
180 for node in bindings.values() {
181 let outputs = flatten_node(node, steps, parallel_groups, store, depends_on);
182 if let Some(&first) = outputs.first() {
183 group.push(first);
184 }
185 binding_outputs.extend(outputs);
186 }
187 if group.len() > 1 {
188 parallel_groups.push(group);
189 }
190 let mut body_deps = depends_on.to_vec();
193 body_deps.extend(binding_outputs);
194 flatten_node(body, steps, parallel_groups, store, &body_deps)
195 }
196 }
197}
198
199fn estimate_cost(steps: &[ExecutionStep], store: &(impl StageStore + ?Sized)) -> CostSummary {
200 let mut total_time: u64 = 0;
201 let mut total_tokens: u64 = 0;
202 let mut max_memory: u64 = 0;
203
204 for step in steps {
205 if let Ok(Some(stage)) = store.get(&step.stage_id) {
206 if let Some(t) = stage.cost.time_ms_p50 {
207 total_time += t;
208 }
209 if let Some(t) = stage.cost.tokens_est {
210 total_tokens += t;
211 }
212 if let Some(m) = stage.cost.memory_mb {
213 max_memory = max_memory.max(m);
214 }
215 }
216 }
217
218 CostSummary {
219 total_time_ms_p50: if total_time > 0 {
220 Some(total_time)
221 } else {
222 None
223 },
224 total_tokens_est: if total_tokens > 0 {
225 Some(total_tokens)
226 } else {
227 None
228 },
229 total_memory_mb_peak: if max_memory > 0 {
230 Some(max_memory)
231 } else {
232 None
233 },
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use noether_store::MemoryStore;
241 use std::collections::BTreeMap;
242
243 fn stage(id: &str) -> CompositionNode {
244 CompositionNode::Stage {
245 id: StageId(id.into()),
246 config: None,
247 }
248 }
249
250 #[test]
251 fn plan_single_stage() {
252 let store = MemoryStore::new();
253 let plan = plan_graph(&stage("a"), &store);
254 assert_eq!(plan.steps.len(), 1);
255 assert_eq!(plan.steps[0].stage_id, StageId("a".into()));
256 assert!(plan.steps[0].depends_on.is_empty());
257 }
258
259 #[test]
260 fn plan_sequential_has_dependencies() {
261 let store = MemoryStore::new();
262 let node = CompositionNode::Sequential {
263 stages: vec![stage("a"), stage("b"), stage("c")],
264 };
265 let plan = plan_graph(&node, &store);
266 assert_eq!(plan.steps.len(), 3);
267 assert!(plan.steps[0].depends_on.is_empty());
268 assert_eq!(plan.steps[1].depends_on, vec![0]);
269 assert_eq!(plan.steps[2].depends_on, vec![1]);
270 }
271
272 #[test]
273 fn plan_parallel_creates_group() {
274 let store = MemoryStore::new();
275 let node = CompositionNode::Parallel {
276 branches: BTreeMap::from([("a".into(), stage("s1")), ("b".into(), stage("s2"))]),
277 };
278 let plan = plan_graph(&node, &store);
279 assert_eq!(plan.steps.len(), 2);
280 assert_eq!(plan.parallel_groups.len(), 1);
281 assert_eq!(plan.parallel_groups[0].len(), 2);
282 }
283
284 #[test]
285 fn plan_sequential_with_parallel() {
286 let store = MemoryStore::new();
287 let node = CompositionNode::Sequential {
288 stages: vec![
289 stage("input"),
290 CompositionNode::Parallel {
291 branches: BTreeMap::from([
292 ("a".into(), stage("s1")),
293 ("b".into(), stage("s2")),
294 ]),
295 },
296 stage("output"),
297 ],
298 };
299 let plan = plan_graph(&node, &store);
300 assert_eq!(plan.steps.len(), 4); assert!(plan.steps[1].depends_on.contains(&0));
303 assert!(plan.steps[2].depends_on.contains(&0));
304 assert!(plan.steps[3].depends_on.contains(&1));
306 assert!(plan.steps[3].depends_on.contains(&2));
307 }
308}