Skip to main content

somatize_compiler/
plan.rs

1//! Execution plan — the compiled representation of a pipeline.
2//!
3//! Variants: Sequence, Parallel, Execute, Cached, Loop, Branch, Remote, Stream, Empty.
4//! Plans are data-free (no filter implementations) and serializable.
5
6use serde::{Deserialize, Serialize};
7use somatize_core::cache::CacheKey;
8use somatize_core::filter::RemoteTarget;
9use somatize_core::graph::NodeId;
10use std::fmt;
11
12/// A compiled execution plan produced by the compiler.
13///
14/// This is a recursive tree that the runtime walks to execute a pipeline.
15/// The compiler resolves caching, parallelism, and distribution before
16/// the runtime sees the plan.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[non_exhaustive]
19pub enum ExecutionPlan {
20    /// Execute steps sequentially, one after another.
21    Sequence(Vec<ExecutionPlan>),
22
23    /// Execute branches concurrently (fork-join).
24    Parallel(Vec<ExecutionPlan>),
25
26    /// Execute a single filter node.
27    Execute { node_id: NodeId },
28
29    /// Load result from cache (resolved at compile time).
30    Cached { node_id: NodeId, key: CacheKey },
31
32    /// Iterate: execute body for each item in a collection.
33    Loop {
34        node_id: NodeId,
35        body: Box<ExecutionPlan>,
36        max_iterations: Option<usize>,
37    },
38
39    /// Conditional branching: evaluate condition, pick an arm.
40    Branch {
41        node_id: NodeId,
42        arms: Vec<(String, ExecutionPlan)>,
43    },
44
45    /// Execute a sub-plan on a remote worker.
46    Remote {
47        node_id: NodeId,
48        target: RemoteTarget,
49        plan: Box<ExecutionPlan>,
50    },
51
52    /// Execute multiple differentiable nodes as a single block.
53    /// The executor passes tensors directly between filters (no Value conversion),
54    /// preserving PyTorch autograd for gradient flow.
55    Composite { node_ids: Vec<NodeId> },
56
57    /// Streaming execution: process input in chunks through a filter chain.
58    /// Each filter's StreamMode (FixedState/Evolving/Barrier) defines its
59    /// per-chunk contract. Results flow progressively — no full materialization.
60    Stream {
61        node_ids: Vec<NodeId>,
62        chunk_size: usize,
63    },
64
65    /// No-op: nothing to execute (e.g. empty graph).
66    Empty,
67}
68
69impl ExecutionPlan {
70    /// Count total nodes in the plan (Execute + Cached).
71    pub fn node_count(&self) -> usize {
72        match self {
73            Self::Execute { .. } | Self::Cached { .. } => 1,
74            Self::Composite { node_ids } | Self::Stream { node_ids, .. } => node_ids.len(),
75            Self::Sequence(steps) | Self::Parallel(steps) => {
76                steps.iter().map(|s| s.node_count()).sum()
77            }
78            Self::Loop { body, .. } => 1 + body.node_count(),
79            Self::Branch { arms, .. } => {
80                1 + arms.iter().map(|(_, p)| p.node_count()).sum::<usize>()
81            }
82            Self::Remote { plan, .. } => plan.node_count(),
83            Self::Empty => 0,
84        }
85    }
86
87    /// Count cached nodes in the plan.
88    pub fn cached_count(&self) -> usize {
89        match self {
90            Self::Cached { .. } => 1,
91            Self::Execute { .. } | Self::Composite { .. } | Self::Stream { .. } => 0,
92            Self::Sequence(steps) | Self::Parallel(steps) => {
93                steps.iter().map(|s| s.cached_count()).sum()
94            }
95            Self::Loop { body, .. } => body.cached_count(),
96            Self::Branch { arms, .. } => arms.iter().map(|(_, p)| p.cached_count()).sum(),
97            Self::Remote { plan, .. } => plan.cached_count(),
98            Self::Empty => 0,
99        }
100    }
101
102    /// Count parallel branches at the top level of the plan.
103    pub fn parallel_branch_count(&self) -> usize {
104        match self {
105            Self::Parallel(branches) => branches.len(),
106            Self::Sequence(steps) => steps.iter().map(|s| s.parallel_branch_count()).sum(),
107            Self::Execute { .. }
108            | Self::Cached { .. }
109            | Self::Loop { .. }
110            | Self::Branch { .. }
111            | Self::Remote { .. }
112            | Self::Composite { .. }
113            | Self::Stream { .. }
114            | Self::Empty => 0,
115        }
116    }
117
118    /// Collect all node IDs referenced in the plan.
119    pub fn node_ids(&self) -> Vec<&str> {
120        match self {
121            Self::Execute { node_id } | Self::Cached { node_id, .. } => vec![node_id.as_str()],
122            Self::Sequence(steps) | Self::Parallel(steps) => {
123                steps.iter().flat_map(|s| s.node_ids()).collect()
124            }
125            Self::Loop { node_id, body, .. } => {
126                let mut ids = vec![node_id.as_str()];
127                ids.extend(body.node_ids());
128                ids
129            }
130            Self::Branch { node_id, arms, .. } => {
131                let mut ids = vec![node_id.as_str()];
132                for (_, p) in arms {
133                    ids.extend(p.node_ids());
134                }
135                ids
136            }
137            Self::Remote { node_id, plan, .. } => {
138                let mut ids = vec![node_id.as_str()];
139                ids.extend(plan.node_ids());
140                ids
141            }
142            Self::Composite { node_ids } | Self::Stream { node_ids, .. } => {
143                node_ids.iter().map(|s| s.as_str()).collect()
144            }
145            Self::Empty => vec![],
146        }
147    }
148
149    /// Create a PlanSummary for event payloads.
150    pub fn summary(&self) -> somatize_core::event::PlanSummary {
151        somatize_core::event::PlanSummary {
152            total_nodes: self.node_count(),
153            cached_nodes: self.cached_count(),
154            parallel_branches: self.parallel_branch_count(),
155        }
156    }
157
158    /// Flatten unnecessary nesting (e.g. Sequence of one element).
159    pub fn simplify(self) -> Self {
160        match self {
161            Self::Sequence(mut steps) => {
162                steps = steps.into_iter().map(|s| s.simplify()).collect();
163                steps.retain(|s| !matches!(s, Self::Empty));
164                match steps.len() {
165                    0 => Self::Empty,
166                    1 => steps.into_iter().next().unwrap(),
167                    _ => Self::Sequence(steps),
168                }
169            }
170            Self::Parallel(mut branches) => {
171                branches = branches.into_iter().map(|b| b.simplify()).collect();
172                branches.retain(|b| !matches!(b, Self::Empty));
173                match branches.len() {
174                    0 => Self::Empty,
175                    1 => branches.into_iter().next().unwrap(),
176                    _ => Self::Parallel(branches),
177                }
178            }
179            other => other,
180        }
181    }
182}
183
184impl ExecutionPlan {
185    /// Render the execution plan as a Mermaid flowchart.
186    pub fn to_mermaid(&self) -> String {
187        let mut out = String::from("graph TD\n");
188        let mut counter = 0;
189        self.mermaid_nodes(&mut out, &mut counter, None);
190        out
191    }
192
193    fn mermaid_nodes(&self, out: &mut String, counter: &mut usize, parent: Option<&str>) {
194        use std::fmt::Write;
195        match self {
196            Self::Execute { node_id } => {
197                let _ = writeln!(out, "    {node_id}[{node_id}]");
198                if let Some(p) = parent {
199                    let _ = writeln!(out, "    {p} --> {node_id}");
200                }
201            }
202            Self::Cached { node_id, .. } => {
203                let _ = writeln!(out, "    {node_id}[/{node_id} cached/]");
204                if let Some(p) = parent {
205                    let _ = writeln!(out, "    {p} --> {node_id}");
206                }
207            }
208            Self::Sequence(steps) => {
209                let mut prev = parent.map(String::from);
210                for step in steps {
211                    step.mermaid_nodes(out, counter, prev.as_deref());
212                    prev = step.first_node_id().map(String::from);
213                }
214            }
215            Self::Parallel(branches) => {
216                let fork_id = format!("fork_{counter}");
217                *counter += 1;
218                let _ = writeln!(out, "    {fork_id}{{{{fork}}}}");
219                if let Some(p) = parent {
220                    let _ = writeln!(out, "    {p} --> {fork_id}");
221                }
222                for branch in branches {
223                    branch.mermaid_nodes(out, counter, Some(&fork_id));
224                }
225            }
226            Self::Loop {
227                node_id,
228                body,
229                max_iterations,
230            } => {
231                let label = match max_iterations {
232                    Some(n) => format!("{node_id} loop max={n}"),
233                    None => format!("{node_id} loop"),
234                };
235                let _ = writeln!(out, "    {node_id}(({label}))");
236                if let Some(p) = parent {
237                    let _ = writeln!(out, "    {p} --> {node_id}");
238                }
239                body.mermaid_nodes(out, counter, Some(node_id));
240            }
241            Self::Branch { node_id, arms } => {
242                let _ = writeln!(out, "    {node_id}{{{{{node_id}}}}}");
243                if let Some(p) = parent {
244                    let _ = writeln!(out, "    {p} --> {node_id}");
245                }
246                for (label, plan) in arms {
247                    let arm_id = format!("arm_{counter}");
248                    *counter += 1;
249                    let _ = writeln!(out, "    {node_id} -->|{label}| {arm_id}[{label}]");
250                    plan.mermaid_nodes(out, counter, Some(&arm_id));
251                }
252            }
253            Self::Remote {
254                node_id,
255                target,
256                plan,
257            } => {
258                let _ = writeln!(out, "    {node_id}>{{{node_id} remote: {target:?}}}]");
259                if let Some(p) = parent {
260                    let _ = writeln!(out, "    {p} --> {node_id}");
261                }
262                plan.mermaid_nodes(out, counter, Some(node_id));
263            }
264            Self::Composite { node_ids } | Self::Stream { node_ids, .. } => {
265                use std::fmt::Write;
266                let stream_label = matches!(self, Self::Stream { .. });
267                let mut prev: Option<&str> = None;
268                for nid in node_ids {
269                    if stream_label {
270                        let _ = writeln!(out, "    {nid}([{nid} stream])");
271                    } else {
272                        let _ = writeln!(out, "    {nid}[{nid}]");
273                    }
274                    if let Some(p) = prev.or(parent) {
275                        let _ = writeln!(out, "    {p} --> {nid}");
276                    }
277                    prev = Some(nid);
278                }
279            }
280            Self::Empty => {}
281        }
282    }
283
284    fn first_node_id(&self) -> Option<&str> {
285        match self {
286            Self::Execute { node_id } | Self::Cached { node_id, .. } => Some(node_id),
287            Self::Sequence(steps) => steps.first().and_then(|s| s.first_node_id()),
288            Self::Parallel(_) => None,
289            Self::Loop { node_id, .. }
290            | Self::Branch { node_id, .. }
291            | Self::Remote { node_id, .. } => Some(node_id),
292            Self::Composite { node_ids } | Self::Stream { node_ids, .. } => {
293                node_ids.first().map(|s| s.as_str())
294            }
295            Self::Empty => None,
296        }
297    }
298}
299
300impl fmt::Display for ExecutionPlan {
301    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302        self.fmt_indent(f, 0)
303    }
304}
305
306impl ExecutionPlan {
307    fn fmt_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
308        let pad = "  ".repeat(indent);
309        match self {
310            Self::Sequence(steps) => {
311                writeln!(f, "{pad}Sequence:")?;
312                for step in steps {
313                    step.fmt_indent(f, indent + 1)?;
314                }
315                Ok(())
316            }
317            Self::Parallel(branches) => {
318                writeln!(f, "{pad}Parallel:")?;
319                for branch in branches {
320                    branch.fmt_indent(f, indent + 1)?;
321                }
322                Ok(())
323            }
324            Self::Execute { node_id } => writeln!(f, "{pad}Execute({node_id})"),
325            Self::Cached { node_id, key } => writeln!(f, "{pad}Cached({node_id}, {key})"),
326            Self::Loop {
327                node_id,
328                body,
329                max_iterations,
330            } => {
331                writeln!(f, "{pad}Loop({node_id}, max={max_iterations:?}):")?;
332                body.fmt_indent(f, indent + 1)
333            }
334            Self::Branch { node_id, arms } => {
335                writeln!(f, "{pad}Branch({node_id}):")?;
336                for (label, plan) in arms {
337                    writeln!(f, "{pad}  [{label}]:")?;
338                    plan.fmt_indent(f, indent + 2)?;
339                }
340                Ok(())
341            }
342            Self::Remote {
343                node_id,
344                target,
345                plan,
346            } => {
347                writeln!(f, "{pad}Remote({node_id}, target={target:?}):")?;
348                plan.fmt_indent(f, indent + 1)
349            }
350            Self::Composite { node_ids } => {
351                let ids = node_ids
352                    .iter()
353                    .map(|s| s.as_str())
354                    .collect::<Vec<_>>()
355                    .join(" \u{2192} ");
356                writeln!(f, "{pad}Composite[{ids}]")
357            }
358            Self::Stream {
359                node_ids,
360                chunk_size,
361            } => {
362                let ids = node_ids
363                    .iter()
364                    .map(|s| s.as_str())
365                    .collect::<Vec<_>>()
366                    .join(" \u{2192} ");
367                writeln!(f, "{pad}Stream[{ids}](chunk_size={chunk_size})")
368            }
369            Self::Empty => writeln!(f, "{pad}Empty"),
370        }
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn node_count_linear() {
380        let plan = ExecutionPlan::Sequence(vec![
381            ExecutionPlan::Execute {
382                node_id: "a".into(),
383            },
384            ExecutionPlan::Execute {
385                node_id: "b".into(),
386            },
387            ExecutionPlan::Execute {
388                node_id: "c".into(),
389            },
390        ]);
391        assert_eq!(plan.node_count(), 3);
392        assert_eq!(plan.cached_count(), 0);
393    }
394
395    #[test]
396    fn cached_count() {
397        let plan = ExecutionPlan::Sequence(vec![
398            ExecutionPlan::Cached {
399                node_id: "a".into(),
400                key: CacheKey::hash_data(b"a"),
401            },
402            ExecutionPlan::Execute {
403                node_id: "b".into(),
404            },
405            ExecutionPlan::Cached {
406                node_id: "c".into(),
407                key: CacheKey::hash_data(b"c"),
408            },
409        ]);
410        assert_eq!(plan.node_count(), 3);
411        assert_eq!(plan.cached_count(), 2);
412    }
413
414    #[test]
415    fn parallel_branch_count() {
416        let plan = ExecutionPlan::Sequence(vec![
417            ExecutionPlan::Execute {
418                node_id: "a".into(),
419            },
420            ExecutionPlan::Parallel(vec![
421                ExecutionPlan::Execute {
422                    node_id: "b".into(),
423                },
424                ExecutionPlan::Execute {
425                    node_id: "c".into(),
426                },
427                ExecutionPlan::Execute {
428                    node_id: "d".into(),
429                },
430            ]),
431            ExecutionPlan::Execute {
432                node_id: "e".into(),
433            },
434        ]);
435        assert_eq!(plan.parallel_branch_count(), 3);
436        assert_eq!(plan.node_count(), 5);
437    }
438
439    #[test]
440    fn node_ids_collected() {
441        let plan = ExecutionPlan::Sequence(vec![
442            ExecutionPlan::Cached {
443                node_id: "a".into(),
444                key: CacheKey::hash_data(b"a"),
445            },
446            ExecutionPlan::Execute {
447                node_id: "b".into(),
448            },
449        ]);
450        let ids = plan.node_ids();
451        assert_eq!(ids, vec!["a", "b"]);
452    }
453
454    #[test]
455    fn simplify_removes_empty() {
456        let plan = ExecutionPlan::Sequence(vec![
457            ExecutionPlan::Empty,
458            ExecutionPlan::Execute {
459                node_id: "a".into(),
460            },
461            ExecutionPlan::Empty,
462        ]);
463        let simplified = plan.simplify();
464        assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
465    }
466
467    #[test]
468    fn simplify_unwraps_single_element() {
469        let plan = ExecutionPlan::Sequence(vec![ExecutionPlan::Execute {
470            node_id: "a".into(),
471        }]);
472        let simplified = plan.simplify();
473        assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
474    }
475
476    #[test]
477    fn simplify_preserves_multi() {
478        let plan = ExecutionPlan::Sequence(vec![
479            ExecutionPlan::Execute {
480                node_id: "a".into(),
481            },
482            ExecutionPlan::Execute {
483                node_id: "b".into(),
484            },
485        ]);
486        let simplified = plan.simplify();
487        assert!(matches!(simplified, ExecutionPlan::Sequence(_)));
488    }
489
490    #[test]
491    fn display_format() {
492        let plan = ExecutionPlan::Sequence(vec![
493            ExecutionPlan::Execute {
494                node_id: "scaler".into(),
495            },
496            ExecutionPlan::Parallel(vec![
497                ExecutionPlan::Execute {
498                    node_id: "pca".into(),
499                },
500                ExecutionPlan::Execute {
501                    node_id: "umap".into(),
502                },
503            ]),
504            ExecutionPlan::Execute {
505                node_id: "svm".into(),
506            },
507        ]);
508        let output = format!("{plan}");
509        assert!(output.contains("Sequence:"));
510        assert!(output.contains("Parallel:"));
511        assert!(output.contains("Execute(scaler)"));
512        assert!(output.contains("Execute(pca)"));
513    }
514
515    #[test]
516    fn summary_values() {
517        let plan = ExecutionPlan::Sequence(vec![
518            ExecutionPlan::Cached {
519                node_id: "a".into(),
520                key: CacheKey::hash_data(b"a"),
521            },
522            ExecutionPlan::Parallel(vec![
523                ExecutionPlan::Execute {
524                    node_id: "b".into(),
525                },
526                ExecutionPlan::Execute {
527                    node_id: "c".into(),
528                },
529            ]),
530            ExecutionPlan::Execute {
531                node_id: "d".into(),
532            },
533        ]);
534        let summary = plan.summary();
535        assert_eq!(summary.total_nodes, 4);
536        assert_eq!(summary.cached_nodes, 1);
537        assert_eq!(summary.parallel_branches, 2);
538    }
539
540    #[test]
541    fn serde_roundtrip() {
542        let plan = ExecutionPlan::Sequence(vec![
543            ExecutionPlan::Cached {
544                node_id: "a".into(),
545                key: CacheKey::hash_data(b"test"),
546            },
547            ExecutionPlan::Execute {
548                node_id: "b".into(),
549            },
550        ]);
551        let json = serde_json::to_string(&plan).unwrap();
552        let deserialized: ExecutionPlan = serde_json::from_str(&json).unwrap();
553        assert_eq!(deserialized.node_count(), 2);
554    }
555
556    #[test]
557    fn empty_plan() {
558        let plan = ExecutionPlan::Empty;
559        assert_eq!(plan.node_count(), 0);
560        assert_eq!(plan.cached_count(), 0);
561        assert!(plan.node_ids().is_empty());
562    }
563
564    #[test]
565    fn to_mermaid_sequence() {
566        let plan = ExecutionPlan::Sequence(vec![
567            ExecutionPlan::Execute {
568                node_id: "scaler".into(),
569            },
570            ExecutionPlan::Execute {
571                node_id: "model".into(),
572            },
573        ]);
574        let m = plan.to_mermaid();
575        assert!(m.starts_with("graph TD"));
576        assert!(m.contains("scaler[scaler]"));
577        assert!(m.contains("model[model]"));
578        assert!(m.contains("scaler --> model"));
579    }
580
581    #[test]
582    fn to_mermaid_parallel() {
583        let plan = ExecutionPlan::Parallel(vec![
584            ExecutionPlan::Execute {
585                node_id: "a".into(),
586            },
587            ExecutionPlan::Execute {
588                node_id: "b".into(),
589            },
590        ]);
591        let m = plan.to_mermaid();
592        assert!(m.contains("fork_0{"));
593        assert!(m.contains("fork_0 --> a"));
594        assert!(m.contains("fork_0 --> b"));
595    }
596
597    #[test]
598    fn to_mermaid_cached() {
599        let plan = ExecutionPlan::Cached {
600            node_id: "x".into(),
601            key: CacheKey::hash_data(b"x"),
602        };
603        let m = plan.to_mermaid();
604        assert!(m.contains("x[/x cached/]"));
605    }
606}