1use serde::{Deserialize, Serialize};
7use somatize_core::cache::CacheKey;
8use somatize_core::filter::RemoteTarget;
9use somatize_core::graph::NodeId;
10use std::fmt;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18#[non_exhaustive]
19pub enum ExecutionPlan {
20 Sequence(Vec<ExecutionPlan>),
22
23 Parallel(Vec<ExecutionPlan>),
25
26 Execute { node_id: NodeId },
28
29 Cached { node_id: NodeId, key: CacheKey },
31
32 Loop {
34 node_id: NodeId,
35 body: Box<ExecutionPlan>,
36 max_iterations: Option<usize>,
37 },
38
39 Branch {
41 node_id: NodeId,
42 arms: Vec<(String, ExecutionPlan)>,
43 },
44
45 Remote {
47 node_id: NodeId,
48 target: RemoteTarget,
49 plan: Box<ExecutionPlan>,
50 },
51
52 Composite { node_ids: Vec<NodeId> },
56
57 Stream {
61 node_ids: Vec<NodeId>,
62 chunk_size: usize,
63 },
64
65 Empty,
67}
68
69impl ExecutionPlan {
70 pub fn node_count(&self) -> usize {
72 match self {
73 Self::Execute { .. } | Self::Cached { .. } => 1,
74 Self::Composite { node_ids } | Self::Stream { node_ids, .. } => node_ids.len(),
75 Self::Sequence(steps) | Self::Parallel(steps) => {
76 steps.iter().map(|s| s.node_count()).sum()
77 }
78 Self::Loop { body, .. } => 1 + body.node_count(),
79 Self::Branch { arms, .. } => {
80 1 + arms.iter().map(|(_, p)| p.node_count()).sum::<usize>()
81 }
82 Self::Remote { plan, .. } => plan.node_count(),
83 Self::Empty => 0,
84 }
85 }
86
87 pub fn cached_count(&self) -> usize {
89 match self {
90 Self::Cached { .. } => 1,
91 Self::Execute { .. } | Self::Composite { .. } | Self::Stream { .. } => 0,
92 Self::Sequence(steps) | Self::Parallel(steps) => {
93 steps.iter().map(|s| s.cached_count()).sum()
94 }
95 Self::Loop { body, .. } => body.cached_count(),
96 Self::Branch { arms, .. } => arms.iter().map(|(_, p)| p.cached_count()).sum(),
97 Self::Remote { plan, .. } => plan.cached_count(),
98 Self::Empty => 0,
99 }
100 }
101
102 pub fn parallel_branch_count(&self) -> usize {
104 match self {
105 Self::Parallel(branches) => branches.len(),
106 Self::Sequence(steps) => steps.iter().map(|s| s.parallel_branch_count()).sum(),
107 Self::Execute { .. }
108 | Self::Cached { .. }
109 | Self::Loop { .. }
110 | Self::Branch { .. }
111 | Self::Remote { .. }
112 | Self::Composite { .. }
113 | Self::Stream { .. }
114 | Self::Empty => 0,
115 }
116 }
117
118 pub fn node_ids(&self) -> Vec<&str> {
120 match self {
121 Self::Execute { node_id } | Self::Cached { node_id, .. } => vec![node_id.as_str()],
122 Self::Sequence(steps) | Self::Parallel(steps) => {
123 steps.iter().flat_map(|s| s.node_ids()).collect()
124 }
125 Self::Loop { node_id, body, .. } => {
126 let mut ids = vec![node_id.as_str()];
127 ids.extend(body.node_ids());
128 ids
129 }
130 Self::Branch { node_id, arms, .. } => {
131 let mut ids = vec![node_id.as_str()];
132 for (_, p) in arms {
133 ids.extend(p.node_ids());
134 }
135 ids
136 }
137 Self::Remote { node_id, plan, .. } => {
138 let mut ids = vec![node_id.as_str()];
139 ids.extend(plan.node_ids());
140 ids
141 }
142 Self::Composite { node_ids } | Self::Stream { node_ids, .. } => {
143 node_ids.iter().map(|s| s.as_str()).collect()
144 }
145 Self::Empty => vec![],
146 }
147 }
148
149 pub fn summary(&self) -> somatize_core::event::PlanSummary {
151 somatize_core::event::PlanSummary {
152 total_nodes: self.node_count(),
153 cached_nodes: self.cached_count(),
154 parallel_branches: self.parallel_branch_count(),
155 }
156 }
157
158 pub fn simplify(self) -> Self {
160 match self {
161 Self::Sequence(mut steps) => {
162 steps = steps.into_iter().map(|s| s.simplify()).collect();
163 steps.retain(|s| !matches!(s, Self::Empty));
164 match steps.len() {
165 0 => Self::Empty,
166 1 => steps.into_iter().next().unwrap(),
167 _ => Self::Sequence(steps),
168 }
169 }
170 Self::Parallel(mut branches) => {
171 branches = branches.into_iter().map(|b| b.simplify()).collect();
172 branches.retain(|b| !matches!(b, Self::Empty));
173 match branches.len() {
174 0 => Self::Empty,
175 1 => branches.into_iter().next().unwrap(),
176 _ => Self::Parallel(branches),
177 }
178 }
179 other => other,
180 }
181 }
182}
183
184impl ExecutionPlan {
185 pub fn to_mermaid(&self) -> String {
187 let mut out = String::from("graph TD\n");
188 let mut counter = 0;
189 self.mermaid_nodes(&mut out, &mut counter, None);
190 out
191 }
192
193 fn mermaid_nodes(&self, out: &mut String, counter: &mut usize, parent: Option<&str>) {
194 use std::fmt::Write;
195 match self {
196 Self::Execute { node_id } => {
197 let _ = writeln!(out, " {node_id}[{node_id}]");
198 if let Some(p) = parent {
199 let _ = writeln!(out, " {p} --> {node_id}");
200 }
201 }
202 Self::Cached { node_id, .. } => {
203 let _ = writeln!(out, " {node_id}[/{node_id} cached/]");
204 if let Some(p) = parent {
205 let _ = writeln!(out, " {p} --> {node_id}");
206 }
207 }
208 Self::Sequence(steps) => {
209 let mut prev = parent.map(String::from);
210 for step in steps {
211 step.mermaid_nodes(out, counter, prev.as_deref());
212 prev = step.first_node_id().map(String::from);
213 }
214 }
215 Self::Parallel(branches) => {
216 let fork_id = format!("fork_{counter}");
217 *counter += 1;
218 let _ = writeln!(out, " {fork_id}{{{{fork}}}}");
219 if let Some(p) = parent {
220 let _ = writeln!(out, " {p} --> {fork_id}");
221 }
222 for branch in branches {
223 branch.mermaid_nodes(out, counter, Some(&fork_id));
224 }
225 }
226 Self::Loop {
227 node_id,
228 body,
229 max_iterations,
230 } => {
231 let label = match max_iterations {
232 Some(n) => format!("{node_id} loop max={n}"),
233 None => format!("{node_id} loop"),
234 };
235 let _ = writeln!(out, " {node_id}(({label}))");
236 if let Some(p) = parent {
237 let _ = writeln!(out, " {p} --> {node_id}");
238 }
239 body.mermaid_nodes(out, counter, Some(node_id));
240 }
241 Self::Branch { node_id, arms } => {
242 let _ = writeln!(out, " {node_id}{{{{{node_id}}}}}");
243 if let Some(p) = parent {
244 let _ = writeln!(out, " {p} --> {node_id}");
245 }
246 for (label, plan) in arms {
247 let arm_id = format!("arm_{counter}");
248 *counter += 1;
249 let _ = writeln!(out, " {node_id} -->|{label}| {arm_id}[{label}]");
250 plan.mermaid_nodes(out, counter, Some(&arm_id));
251 }
252 }
253 Self::Remote {
254 node_id,
255 target,
256 plan,
257 } => {
258 let _ = writeln!(out, " {node_id}>{{{node_id} remote: {target:?}}}]");
259 if let Some(p) = parent {
260 let _ = writeln!(out, " {p} --> {node_id}");
261 }
262 plan.mermaid_nodes(out, counter, Some(node_id));
263 }
264 Self::Composite { node_ids } | Self::Stream { node_ids, .. } => {
265 use std::fmt::Write;
266 let stream_label = matches!(self, Self::Stream { .. });
267 let mut prev: Option<&str> = None;
268 for nid in node_ids {
269 if stream_label {
270 let _ = writeln!(out, " {nid}([{nid} stream])");
271 } else {
272 let _ = writeln!(out, " {nid}[{nid}]");
273 }
274 if let Some(p) = prev.or(parent) {
275 let _ = writeln!(out, " {p} --> {nid}");
276 }
277 prev = Some(nid);
278 }
279 }
280 Self::Empty => {}
281 }
282 }
283
284 fn first_node_id(&self) -> Option<&str> {
285 match self {
286 Self::Execute { node_id } | Self::Cached { node_id, .. } => Some(node_id),
287 Self::Sequence(steps) => steps.first().and_then(|s| s.first_node_id()),
288 Self::Parallel(_) => None,
289 Self::Loop { node_id, .. }
290 | Self::Branch { node_id, .. }
291 | Self::Remote { node_id, .. } => Some(node_id),
292 Self::Composite { node_ids } | Self::Stream { node_ids, .. } => {
293 node_ids.first().map(|s| s.as_str())
294 }
295 Self::Empty => None,
296 }
297 }
298}
299
300impl fmt::Display for ExecutionPlan {
301 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 self.fmt_indent(f, 0)
303 }
304}
305
306impl ExecutionPlan {
307 fn fmt_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
308 let pad = " ".repeat(indent);
309 match self {
310 Self::Sequence(steps) => {
311 writeln!(f, "{pad}Sequence:")?;
312 for step in steps {
313 step.fmt_indent(f, indent + 1)?;
314 }
315 Ok(())
316 }
317 Self::Parallel(branches) => {
318 writeln!(f, "{pad}Parallel:")?;
319 for branch in branches {
320 branch.fmt_indent(f, indent + 1)?;
321 }
322 Ok(())
323 }
324 Self::Execute { node_id } => writeln!(f, "{pad}Execute({node_id})"),
325 Self::Cached { node_id, key } => writeln!(f, "{pad}Cached({node_id}, {key})"),
326 Self::Loop {
327 node_id,
328 body,
329 max_iterations,
330 } => {
331 writeln!(f, "{pad}Loop({node_id}, max={max_iterations:?}):")?;
332 body.fmt_indent(f, indent + 1)
333 }
334 Self::Branch { node_id, arms } => {
335 writeln!(f, "{pad}Branch({node_id}):")?;
336 for (label, plan) in arms {
337 writeln!(f, "{pad} [{label}]:")?;
338 plan.fmt_indent(f, indent + 2)?;
339 }
340 Ok(())
341 }
342 Self::Remote {
343 node_id,
344 target,
345 plan,
346 } => {
347 writeln!(f, "{pad}Remote({node_id}, target={target:?}):")?;
348 plan.fmt_indent(f, indent + 1)
349 }
350 Self::Composite { node_ids } => {
351 let ids = node_ids
352 .iter()
353 .map(|s| s.as_str())
354 .collect::<Vec<_>>()
355 .join(" \u{2192} ");
356 writeln!(f, "{pad}Composite[{ids}]")
357 }
358 Self::Stream {
359 node_ids,
360 chunk_size,
361 } => {
362 let ids = node_ids
363 .iter()
364 .map(|s| s.as_str())
365 .collect::<Vec<_>>()
366 .join(" \u{2192} ");
367 writeln!(f, "{pad}Stream[{ids}](chunk_size={chunk_size})")
368 }
369 Self::Empty => writeln!(f, "{pad}Empty"),
370 }
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn node_count_linear() {
380 let plan = ExecutionPlan::Sequence(vec![
381 ExecutionPlan::Execute {
382 node_id: "a".into(),
383 },
384 ExecutionPlan::Execute {
385 node_id: "b".into(),
386 },
387 ExecutionPlan::Execute {
388 node_id: "c".into(),
389 },
390 ]);
391 assert_eq!(plan.node_count(), 3);
392 assert_eq!(plan.cached_count(), 0);
393 }
394
395 #[test]
396 fn cached_count() {
397 let plan = ExecutionPlan::Sequence(vec![
398 ExecutionPlan::Cached {
399 node_id: "a".into(),
400 key: CacheKey::hash_data(b"a"),
401 },
402 ExecutionPlan::Execute {
403 node_id: "b".into(),
404 },
405 ExecutionPlan::Cached {
406 node_id: "c".into(),
407 key: CacheKey::hash_data(b"c"),
408 },
409 ]);
410 assert_eq!(plan.node_count(), 3);
411 assert_eq!(plan.cached_count(), 2);
412 }
413
414 #[test]
415 fn parallel_branch_count() {
416 let plan = ExecutionPlan::Sequence(vec![
417 ExecutionPlan::Execute {
418 node_id: "a".into(),
419 },
420 ExecutionPlan::Parallel(vec![
421 ExecutionPlan::Execute {
422 node_id: "b".into(),
423 },
424 ExecutionPlan::Execute {
425 node_id: "c".into(),
426 },
427 ExecutionPlan::Execute {
428 node_id: "d".into(),
429 },
430 ]),
431 ExecutionPlan::Execute {
432 node_id: "e".into(),
433 },
434 ]);
435 assert_eq!(plan.parallel_branch_count(), 3);
436 assert_eq!(plan.node_count(), 5);
437 }
438
439 #[test]
440 fn node_ids_collected() {
441 let plan = ExecutionPlan::Sequence(vec![
442 ExecutionPlan::Cached {
443 node_id: "a".into(),
444 key: CacheKey::hash_data(b"a"),
445 },
446 ExecutionPlan::Execute {
447 node_id: "b".into(),
448 },
449 ]);
450 let ids = plan.node_ids();
451 assert_eq!(ids, vec!["a", "b"]);
452 }
453
454 #[test]
455 fn simplify_removes_empty() {
456 let plan = ExecutionPlan::Sequence(vec![
457 ExecutionPlan::Empty,
458 ExecutionPlan::Execute {
459 node_id: "a".into(),
460 },
461 ExecutionPlan::Empty,
462 ]);
463 let simplified = plan.simplify();
464 assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
465 }
466
467 #[test]
468 fn simplify_unwraps_single_element() {
469 let plan = ExecutionPlan::Sequence(vec![ExecutionPlan::Execute {
470 node_id: "a".into(),
471 }]);
472 let simplified = plan.simplify();
473 assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
474 }
475
476 #[test]
477 fn simplify_preserves_multi() {
478 let plan = ExecutionPlan::Sequence(vec![
479 ExecutionPlan::Execute {
480 node_id: "a".into(),
481 },
482 ExecutionPlan::Execute {
483 node_id: "b".into(),
484 },
485 ]);
486 let simplified = plan.simplify();
487 assert!(matches!(simplified, ExecutionPlan::Sequence(_)));
488 }
489
490 #[test]
491 fn display_format() {
492 let plan = ExecutionPlan::Sequence(vec![
493 ExecutionPlan::Execute {
494 node_id: "scaler".into(),
495 },
496 ExecutionPlan::Parallel(vec![
497 ExecutionPlan::Execute {
498 node_id: "pca".into(),
499 },
500 ExecutionPlan::Execute {
501 node_id: "umap".into(),
502 },
503 ]),
504 ExecutionPlan::Execute {
505 node_id: "svm".into(),
506 },
507 ]);
508 let output = format!("{plan}");
509 assert!(output.contains("Sequence:"));
510 assert!(output.contains("Parallel:"));
511 assert!(output.contains("Execute(scaler)"));
512 assert!(output.contains("Execute(pca)"));
513 }
514
515 #[test]
516 fn summary_values() {
517 let plan = ExecutionPlan::Sequence(vec![
518 ExecutionPlan::Cached {
519 node_id: "a".into(),
520 key: CacheKey::hash_data(b"a"),
521 },
522 ExecutionPlan::Parallel(vec![
523 ExecutionPlan::Execute {
524 node_id: "b".into(),
525 },
526 ExecutionPlan::Execute {
527 node_id: "c".into(),
528 },
529 ]),
530 ExecutionPlan::Execute {
531 node_id: "d".into(),
532 },
533 ]);
534 let summary = plan.summary();
535 assert_eq!(summary.total_nodes, 4);
536 assert_eq!(summary.cached_nodes, 1);
537 assert_eq!(summary.parallel_branches, 2);
538 }
539
540 #[test]
541 fn serde_roundtrip() {
542 let plan = ExecutionPlan::Sequence(vec![
543 ExecutionPlan::Cached {
544 node_id: "a".into(),
545 key: CacheKey::hash_data(b"test"),
546 },
547 ExecutionPlan::Execute {
548 node_id: "b".into(),
549 },
550 ]);
551 let json = serde_json::to_string(&plan).unwrap();
552 let deserialized: ExecutionPlan = serde_json::from_str(&json).unwrap();
553 assert_eq!(deserialized.node_count(), 2);
554 }
555
556 #[test]
557 fn empty_plan() {
558 let plan = ExecutionPlan::Empty;
559 assert_eq!(plan.node_count(), 0);
560 assert_eq!(plan.cached_count(), 0);
561 assert!(plan.node_ids().is_empty());
562 }
563
564 #[test]
565 fn to_mermaid_sequence() {
566 let plan = ExecutionPlan::Sequence(vec![
567 ExecutionPlan::Execute {
568 node_id: "scaler".into(),
569 },
570 ExecutionPlan::Execute {
571 node_id: "model".into(),
572 },
573 ]);
574 let m = plan.to_mermaid();
575 assert!(m.starts_with("graph TD"));
576 assert!(m.contains("scaler[scaler]"));
577 assert!(m.contains("model[model]"));
578 assert!(m.contains("scaler --> model"));
579 }
580
581 #[test]
582 fn to_mermaid_parallel() {
583 let plan = ExecutionPlan::Parallel(vec![
584 ExecutionPlan::Execute {
585 node_id: "a".into(),
586 },
587 ExecutionPlan::Execute {
588 node_id: "b".into(),
589 },
590 ]);
591 let m = plan.to_mermaid();
592 assert!(m.contains("fork_0{"));
593 assert!(m.contains("fork_0 --> a"));
594 assert!(m.contains("fork_0 --> b"));
595 }
596
597 #[test]
598 fn to_mermaid_cached() {
599 let plan = ExecutionPlan::Cached {
600 node_id: "x".into(),
601 key: CacheKey::hash_data(b"x"),
602 };
603 let m = plan.to_mermaid();
604 assert!(m.contains("x[/x cached/]"));
605 }
606}