Skip to main content

somatize_compiler/
compiler.rs

1//! Graph → ExecutionPlan compiler.
2//!
3//! Compilation phases: topological sort → parallelism detection →
4//! cache resolution → schema validation → distribution wrapping → simplification.
5
6use crate::plan::ExecutionPlan;
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::Result;
9use somatize_core::filter::{Filter, FilterMeta};
10use somatize_core::graph::{Graph, NodeId};
11use std::collections::{HashMap, HashSet};
12
13/// Compilation mode affects caching behavior.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompileMode {
16    /// Full caching: skip nodes whose outputs are cached.
17    Inference,
18    /// Cache states only: re-execute forwards for gradient flow.
19    Differentiable,
20    /// No caching at all: force re-execution of everything.
21    NoCache,
22}
23
24/// Diagnostic message emitted during compilation.
25#[derive(Debug, Clone)]
26pub struct Diagnostic {
27    pub node_id: NodeId,
28    pub level: DiagnosticLevel,
29    pub message: String,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DiagnosticLevel {
34    Warning,
35    Info,
36}
37
38/// Compiled result: the plan plus any diagnostics.
39pub struct CompileResult {
40    pub plan: ExecutionPlan,
41    pub diagnostics: Vec<Diagnostic>,
42}
43
44/// Registry that maps node IDs to their filter metadata.
45/// The compiler needs metadata (cacheable, differentiable, etc.)
46/// but doesn't need the actual filter implementations.
47pub trait FilterRegistry: Send + Sync {
48    fn meta(&self, node_id: &str) -> Option<FilterMeta>;
49    fn config_hash(&self, node_id: &str) -> Option<CacheKey>;
50}
51
52/// Simple in-memory filter registry for compilation.
53pub struct SimpleFilterRegistry {
54    entries: HashMap<String, (FilterMeta, CacheKey)>,
55}
56
57impl SimpleFilterRegistry {
58    pub fn new() -> Self {
59        Self {
60            entries: HashMap::new(),
61        }
62    }
63
64    pub fn register(&mut self, node_id: impl Into<String>, filter: &dyn Filter) {
65        let id = node_id.into();
66        self.entries
67            .insert(id, (filter.meta(), filter.config_hash()));
68    }
69
70    pub fn register_meta(
71        &mut self,
72        node_id: impl Into<String>,
73        meta: FilterMeta,
74        config_hash: CacheKey,
75    ) {
76        self.entries.insert(node_id.into(), (meta, config_hash));
77    }
78}
79
80impl Default for SimpleFilterRegistry {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl FilterRegistry for SimpleFilterRegistry {
87    fn meta(&self, node_id: &str) -> Option<FilterMeta> {
88        self.entries.get(node_id).map(|(m, _)| m.clone())
89    }
90
91    fn config_hash(&self, node_id: &str) -> Option<CacheKey> {
92        self.entries.get(node_id).map(|(_, h)| h.clone())
93    }
94}
95
96/// Compiles a Graph into an ExecutionPlan.
97pub struct Compiler<'a> {
98    graph: &'a Graph,
99    registry: &'a dyn FilterRegistry,
100    mode: CompileMode,
101    diagnostics: Vec<Diagnostic>,
102}
103
104impl<'a> Compiler<'a> {
105    pub fn new(graph: &'a Graph, registry: &'a dyn FilterRegistry, mode: CompileMode) -> Self {
106        Self {
107            graph,
108            registry,
109            mode,
110            diagnostics: Vec::new(),
111        }
112    }
113
114    /// Compile the graph into an execution plan.
115    pub fn compile(mut self, cache: Option<&dyn CacheStore>) -> Result<CompileResult> {
116        self.graph.validate()?;
117
118        let sorted = self.graph.topological_sort()?;
119
120        if sorted.is_empty() {
121            return Ok(CompileResult {
122                plan: ExecutionPlan::Empty,
123                diagnostics: self.diagnostics,
124            });
125        }
126
127        // Check gradient flow
128        self.check_gradient_flow(&sorted);
129
130        // Validate schema compatibility
131        self.validate_schemas(&sorted);
132
133        // Build the structural plan (detect parallelism)
134        let plan = self.build_plan(&sorted);
135
136        // Resolve caching if applicable
137        let plan = if let Some(cache) = cache {
138            self.resolve_cache(plan, cache, &sorted)?
139        } else {
140            plan
141        };
142
143        // Resolve distribution (wrap Remote nodes)
144        let plan = self.resolve_distribution(plan);
145
146        // Collapse consecutive differentiable nodes into Composite blocks
147        let plan = self.collapse_differentiable(plan);
148
149        let plan = plan.simplify();
150
151        Ok(CompileResult {
152            plan,
153            diagnostics: self.diagnostics,
154        })
155    }
156
157    /// Build a plan from topologically sorted nodes, detecting parallelism.
158    fn build_plan(&self, sorted: &[&str]) -> ExecutionPlan {
159        // Compute topological levels (nodes at the same level can run in parallel)
160        let levels = self.compute_levels(sorted);
161
162        let mut plan_steps: Vec<ExecutionPlan> = Vec::new();
163
164        for level in &levels {
165            if level.len() == 1 {
166                plan_steps.push(self.plan_for_node(level[0]));
167            } else {
168                let branches: Vec<ExecutionPlan> =
169                    level.iter().map(|id| self.plan_for_node(id)).collect();
170                plan_steps.push(ExecutionPlan::Parallel(branches));
171            }
172        }
173
174        if plan_steps.len() == 1 {
175            plan_steps.into_iter().next().unwrap()
176        } else {
177            ExecutionPlan::Sequence(plan_steps)
178        }
179    }
180
181    /// Generate the execution plan for a single node based on its kind.
182    fn plan_for_node(&self, node_id: &str) -> ExecutionPlan {
183        use somatize_core::graph::NodeKind;
184
185        let node = match self.graph.node(node_id) {
186            Some(n) => n,
187            None => {
188                return ExecutionPlan::Execute {
189                    node_id: node_id.to_string(),
190                };
191            }
192        };
193
194        match &node.kind {
195            NodeKind::Filter { .. } => ExecutionPlan::Execute {
196                node_id: node_id.to_string(),
197            },
198
199            NodeKind::SubGraph { graph } => {
200                // Recursively compile the inner graph
201                let inner_compiler = Compiler::new(graph, self.registry, self.mode);
202                match inner_compiler.compile(None) {
203                    Ok(result) => result.plan,
204                    Err(_) => ExecutionPlan::Execute {
205                        node_id: node_id.to_string(),
206                    },
207                }
208            }
209
210            NodeKind::Loop { max_iterations } => {
211                // The body consists of the successors of this loop node.
212                // Build a sub-plan from the successor chain.
213                let successors = self.graph.successors(node_id);
214                let body = if successors.len() == 1 {
215                    self.plan_for_node(successors[0])
216                } else if successors.len() > 1 {
217                    let branches: Vec<ExecutionPlan> =
218                        successors.iter().map(|id| self.plan_for_node(id)).collect();
219                    ExecutionPlan::Parallel(branches)
220                } else {
221                    ExecutionPlan::Empty
222                };
223                ExecutionPlan::Loop {
224                    node_id: node_id.to_string(),
225                    body: Box::new(body),
226                    max_iterations: *max_iterations,
227                }
228            }
229
230            NodeKind::Branch => {
231                // Arms come from control edges leaving this node.
232                let arms: Vec<(String, ExecutionPlan)> = self
233                    .graph
234                    .edges
235                    .iter()
236                    .filter(|e| e.source == node_id)
237                    .map(|e| {
238                        let label = e.label.clone().unwrap_or_else(|| e.target.clone());
239                        let plan = self.plan_for_node(&e.target);
240                        (label, plan)
241                    })
242                    .collect();
243                ExecutionPlan::Branch {
244                    node_id: node_id.to_string(),
245                    arms,
246                }
247            }
248
249            _ => ExecutionPlan::Execute {
250                node_id: node_id.to_string(),
251            },
252        }
253    }
254
255    /// Compute topological levels: groups of nodes that can execute concurrently.
256    /// Each node's level = max(predecessor levels) + 1.
257    fn compute_levels<'b>(&self, sorted: &[&'b str]) -> Vec<Vec<&'b str>> {
258        let mut node_level: HashMap<&str, usize> = HashMap::new();
259        let mut max_level: usize = 0;
260
261        for &node in sorted {
262            let preds = self.graph.predecessors(node);
263            let level = if preds.is_empty() {
264                0
265            } else {
266                preds
267                    .iter()
268                    .map(|p| node_level.get(p).copied().unwrap_or(0) + 1)
269                    .max()
270                    .unwrap_or(0)
271            };
272            node_level.insert(node, level);
273            if level > max_level {
274                max_level = level;
275            }
276        }
277
278        let mut levels: Vec<Vec<&str>> = vec![Vec::new(); max_level + 1];
279        for &node in sorted {
280            let level = node_level[node];
281            levels[level].push(node);
282        }
283
284        // Remove empty levels (shouldn't happen but defensive)
285        levels.retain(|l| !l.is_empty());
286        levels
287    }
288
289    /// Resolve caching: replace Execute nodes with Cached when possible.
290    /// Implements cascade invalidation.
291    fn resolve_cache(
292        &self,
293        plan: ExecutionPlan,
294        cache: &dyn CacheStore,
295        sorted: &[&str],
296    ) -> Result<ExecutionPlan> {
297        if self.mode == CompileMode::NoCache {
298            return Ok(plan);
299        }
300
301        // Compute cache keys for all nodes in topological order.
302        // A node's key depends on its config + its predecessors' keys.
303        let mut node_keys: HashMap<String, CacheKey> = HashMap::new();
304        let mut cached_nodes: HashSet<String> = HashSet::new();
305
306        for &node_id in sorted {
307            let config_hash = match self.registry.config_hash(node_id) {
308                Some(h) => h,
309                None => continue, // no filter registered, can't cache
310            };
311
312            let meta = self.registry.meta(node_id);
313            let cacheable = meta.as_ref().is_some_and(|m| m.cacheable);
314
315            // In differentiable mode, only cache states (not forward outputs)
316            // For simplicity at this stage, we skip caching in differentiable mode
317            let can_cache = cacheable && self.mode == CompileMode::Inference;
318
319            // Build the cache key from config + predecessor output keys
320            let pred_ids = self.graph.predecessors(node_id);
321            let mut key_parts: Vec<Vec<u8>> = vec![config_hash.0.to_vec()];
322            for pred in &pred_ids {
323                if let Some(pred_key) = node_keys.get(*pred) {
324                    key_parts.push(pred_key.0.to_vec());
325                } else {
326                    // Predecessor should always be processed first in topological order.
327                    // If missing, the cache key will be incomplete but won't panic.
328                    debug_assert!(
329                        false,
330                        "predecessor `{pred}` of `{node_id}` not in node_keys - \
331                         topological order may be broken"
332                    );
333                }
334            }
335            let parts_refs: Vec<&[u8]> = key_parts.iter().map(|p| p.as_slice()).collect();
336            let key = CacheKey::from_parts(&parts_refs);
337            node_keys.insert(node_id.to_string(), key.clone());
338
339            // Check if this node's output exists in cache
340            if can_cache {
341                // Only use cache if ALL predecessors are also cached (or roots).
342                // This ensures cascade invalidation: if any upstream re-executed,
343                // the key is already different (since it includes predecessor keys).
344                if cache.exists(&key)? {
345                    cached_nodes.insert(node_id.to_string());
346                }
347            }
348        }
349
350        // Replace Execute nodes with Cached where applicable
351        Ok(self.apply_cache_to_plan(plan, &cached_nodes, &node_keys))
352    }
353
354    fn apply_cache_to_plan(
355        &self,
356        plan: ExecutionPlan,
357        cached: &HashSet<String>,
358        keys: &HashMap<String, CacheKey>,
359    ) -> ExecutionPlan {
360        match plan {
361            ExecutionPlan::Execute { ref node_id } => {
362                if cached.contains(node_id)
363                    && let Some(key) = keys.get(node_id)
364                {
365                    return ExecutionPlan::Cached {
366                        node_id: node_id.clone(),
367                        key: key.clone(),
368                    };
369                }
370                plan
371            }
372            ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
373                steps
374                    .into_iter()
375                    .map(|s| self.apply_cache_to_plan(s, cached, keys))
376                    .collect(),
377            ),
378            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
379                branches
380                    .into_iter()
381                    .map(|b| self.apply_cache_to_plan(b, cached, keys))
382                    .collect(),
383            ),
384            other => other,
385        }
386    }
387
388    /// Wrap nodes with Remote distribution in ExecutionPlan::Remote.
389    fn resolve_distribution(&self, plan: ExecutionPlan) -> ExecutionPlan {
390        match plan {
391            ExecutionPlan::Execute { ref node_id } => {
392                if let Some(meta) = self.registry.meta(node_id) {
393                    match &meta.distribution {
394                        somatize_core::filter::Distribution::Remote(target) => {
395                            ExecutionPlan::Remote {
396                                node_id: node_id.clone(),
397                                target: target.clone(),
398                                plan: Box::new(plan),
399                            }
400                        }
401                        _ => plan,
402                    }
403                } else {
404                    plan
405                }
406            }
407            ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
408                steps
409                    .into_iter()
410                    .map(|s| self.resolve_distribution(s))
411                    .collect(),
412            ),
413            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
414                branches
415                    .into_iter()
416                    .map(|b| self.resolve_distribution(b))
417                    .collect(),
418            ),
419            ExecutionPlan::Composite { ref node_ids } => {
420                // If ALL nodes in the composite have a Remote target, wrap the
421                // entire composite in a single Remote (using the first node's
422                // target). Otherwise keep it local.
423                let targets: Vec<_> = node_ids
424                    .iter()
425                    .filter_map(|nid| {
426                        self.registry.meta(nid).and_then(|m| match &m.distribution {
427                            somatize_core::filter::Distribution::Remote(t) => Some(t.clone()),
428                            _ => None,
429                        })
430                    })
431                    .collect();
432
433                if targets.len() == node_ids.len() && !targets.is_empty() {
434                    let first_id = node_ids[0].clone();
435                    ExecutionPlan::Remote {
436                        node_id: first_id,
437                        target: targets.into_iter().next().unwrap(),
438                        plan: Box::new(plan),
439                    }
440                } else {
441                    plan
442                }
443            }
444            other => other,
445        }
446    }
447
448    /// Collapse consecutive differentiable Execute nodes into Composite blocks.
449    ///
450    /// A `Composite` groups nodes that should share a PyTorch autograd session.
451    /// Only groups 2+ consecutive `Execute` nodes where `meta.differentiable == true`.
452    fn collapse_differentiable(&self, plan: ExecutionPlan) -> ExecutionPlan {
453        match plan {
454            ExecutionPlan::Sequence(steps) => {
455                let mut result: Vec<ExecutionPlan> = Vec::new();
456                let mut diff_group: Vec<String> = Vec::new();
457
458                for step in steps {
459                    if let ExecutionPlan::Execute { ref node_id } = step
460                        && self
461                            .registry
462                            .meta(node_id)
463                            .map(|m| m.differentiable)
464                            .unwrap_or(false)
465                    {
466                        diff_group.push(node_id.clone());
467                        continue;
468                    }
469                    // Flush accumulated differentiable group
470                    Self::flush_diff_group(&mut diff_group, &mut result);
471                    result.push(self.collapse_differentiable(step));
472                }
473                Self::flush_diff_group(&mut diff_group, &mut result);
474
475                if result.len() == 1 {
476                    result.pop().unwrap()
477                } else {
478                    ExecutionPlan::Sequence(result)
479                }
480            }
481            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
482                branches
483                    .into_iter()
484                    .map(|b| self.collapse_differentiable(b))
485                    .collect(),
486            ),
487            ExecutionPlan::Remote {
488                node_id,
489                target,
490                plan,
491            } => ExecutionPlan::Remote {
492                node_id,
493                target,
494                plan: Box::new(self.collapse_differentiable(*plan)),
495            },
496            other => other,
497        }
498    }
499
500    fn flush_diff_group(group: &mut Vec<String>, result: &mut Vec<ExecutionPlan>) {
501        if group.len() > 1 {
502            result.push(ExecutionPlan::Composite {
503                node_ids: std::mem::take(group),
504            });
505        } else if let Some(id) = group.pop() {
506            result.push(ExecutionPlan::Execute { node_id: id });
507        }
508    }
509
510    /// Validate schema compatibility between connected filters.
511    ///
512    /// For each edge (A → B), checks that A's output_schema is compatible
513    /// with B's input_schema. Emits warnings (not errors) for mismatches,
514    /// since schemas are optional and None means "accepts anything".
515    fn validate_schemas(&mut self, sorted: &[&str]) {
516        for &node_id in sorted {
517            let input_schema = self
518                .registry
519                .meta(node_id)
520                .and_then(|m| m.input_schema.clone());
521
522            // Skip if this node accepts anything
523            let Some(expected_input) = input_schema else {
524                continue;
525            };
526
527            // Check each predecessor's output schema
528            for pred_id in self.graph.predecessors(node_id) {
529                let pred_output = self
530                    .registry
531                    .meta(pred_id)
532                    .and_then(|m| m.output_schema.clone());
533
534                let Some(actual_output) = pred_output else {
535                    continue; // predecessor output unknown, skip
536                };
537
538                if !actual_output.is_compatible_with(&expected_input) {
539                    self.diagnostics.push(Diagnostic {
540                        node_id: node_id.to_string(),
541                        level: DiagnosticLevel::Warning,
542                        message: format!(
543                            "schema mismatch: `{pred_id}` outputs {actual_output} \
544                             but `{node_id}` expects {expected_input}",
545                        ),
546                    });
547                }
548            }
549        }
550    }
551
552    /// Check gradient flow and emit warnings for each interruption.
553    ///
554    /// Gradient flow can restart after an opaque node (differentiable nodes
555    /// after an opaque one can still propagate gradients among themselves),
556    /// but gradients from before the interruption are lost.
557    fn check_gradient_flow(&mut self, sorted: &[&str]) {
558        let mut gradient_flows = true;
559
560        for &node_id in sorted {
561            if let Some(meta) = self.registry.meta(node_id) {
562                if gradient_flows && !meta.differentiable {
563                    self.diagnostics.push(Diagnostic {
564                        node_id: node_id.to_string(),
565                        level: DiagnosticLevel::Warning,
566                        message: format!(
567                            "gradient flow interrupted at `{}` ({:?}). \
568                             Gradients from upstream will not reach downstream filters \
569                             through this node.",
570                            node_id, meta.kind,
571                        ),
572                    });
573                    gradient_flows = false;
574                } else if !gradient_flows && meta.differentiable {
575                    // Gradient flow restarts: differentiable nodes after the
576                    // interruption can propagate gradients among themselves
577                    gradient_flows = true;
578                }
579            }
580        }
581    }
582}
583
584/// Convenience function: compile a graph with default settings.
585pub fn compile(
586    graph: &Graph,
587    registry: &dyn FilterRegistry,
588    mode: CompileMode,
589    cache: Option<&dyn CacheStore>,
590) -> Result<CompileResult> {
591    Compiler::new(graph, registry, mode).compile(cache)
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use somatize_core::cache::EntryMeta;
598    use somatize_core::error::SomaError;
599    use somatize_core::filter::{FilterKind, StreamMode};
600    use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
601    use somatize_core::value::Value;
602    use std::sync::Mutex;
603
604    // ── Mock cache store ──
605
606    struct MockCacheStore {
607        entries: Mutex<HashSet<CacheKey>>,
608    }
609
610    impl MockCacheStore {
611        fn new() -> Self {
612            Self {
613                entries: Mutex::new(HashSet::new()),
614            }
615        }
616
617        fn insert(&self, key: CacheKey) {
618            self.entries.lock().unwrap().insert(key);
619        }
620    }
621
622    impl CacheStore for MockCacheStore {
623        fn get(&self, _key: &CacheKey) -> Result<Option<Value>> {
624            Ok(None)
625        }
626        fn put(&self, _key: &CacheKey, _value: &Value) -> Result<()> {
627            Ok(())
628        }
629        fn exists(&self, key: &CacheKey) -> Result<bool> {
630            Ok(self.entries.lock().unwrap().contains(key))
631        }
632        fn remove(&self, _key: &CacheKey) -> Result<()> {
633            Ok(())
634        }
635        fn metadata(&self, _key: &CacheKey) -> Result<Option<EntryMeta>> {
636            Ok(None)
637        }
638    }
639
640    // ── Helpers ──
641
642    fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
643        FilterMeta {
644            name: "test".into(),
645            kind,
646            cacheable: true,
647            differentiable,
648            stream_mode: StreamMode::FixedState,
649            distribution: somatize_core::filter::Distribution::Local,
650            input_schema: None,
651            output_schema: None,
652        }
653    }
654
655    fn register_nodes(registry: &mut SimpleFilterRegistry, ids: &[&str], meta: FilterMeta) {
656        for (i, id) in ids.iter().enumerate() {
657            let hash = CacheKey::from_parts(&[id.as_bytes(), &[i as u8]]);
658            registry.register_meta(*id, meta.clone(), hash);
659        }
660    }
661
662    // ── Tests ──
663
664    #[test]
665    fn compile_empty_graph() {
666        let graph = Graph::new();
667        let registry = SimpleFilterRegistry::new();
668        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
669        assert!(matches!(result.plan, ExecutionPlan::Empty));
670    }
671
672    #[test]
673    fn compile_single_node() {
674        let mut graph = Graph::new();
675        graph.add_node(Node::new("a", "A", "F"));
676        let mut registry = SimpleFilterRegistry::new();
677        register_nodes(
678            &mut registry,
679            &["a"],
680            make_meta(FilterKind::Trainable, true),
681        );
682
683        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
684        assert!(matches!(result.plan, ExecutionPlan::Execute { .. }));
685    }
686
687    #[test]
688    fn compile_linear_pipeline_produces_sequence() {
689        let graph = linear_pipeline(vec![
690            Node::new("a", "Scaler", "F"),
691            Node::new("b", "PCA", "F"),
692            Node::new("c", "SVM", "F"),
693        ]);
694        let mut registry = SimpleFilterRegistry::new();
695        register_nodes(
696            &mut registry,
697            &["a", "b", "c"],
698            make_meta(FilterKind::Trainable, true),
699        );
700
701        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
702
703        // All 3 nodes are differentiable → collapsed into Composite
704        if let ExecutionPlan::Composite { node_ids } = &result.plan {
705            assert_eq!(node_ids, &["a", "b", "c"]);
706        } else {
707            panic!("expected Composite, got: {:?}", result.plan);
708        }
709    }
710
711    #[test]
712    fn compile_diamond_detects_parallelism() {
713        let mut graph = Graph::new();
714        graph.add_node(Node::new("root", "Root", "F"));
715        graph.add_node(Node::new("b1", "B1", "F"));
716        graph.add_node(Node::new("b2", "B2", "F"));
717        graph.add_node(Node::new("merge", "Merge", "F"));
718        graph.add_edge(Edge::data("e1", "root", "b1"));
719        graph.add_edge(Edge::data("e2", "root", "b2"));
720        graph.add_edge(Edge::data("e3", "b1", "merge"));
721        graph.add_edge(Edge::data("e4", "b2", "merge"));
722
723        let mut registry = SimpleFilterRegistry::new();
724        register_nodes(
725            &mut registry,
726            &["root", "b1", "b2", "merge"],
727            make_meta(FilterKind::Trainable, true),
728        );
729
730        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
731
732        // Should be: Sequence(Execute(root), Parallel(Execute(b1), Execute(b2)), Execute(merge))
733        if let ExecutionPlan::Sequence(steps) = &result.plan {
734            assert_eq!(steps.len(), 3);
735            assert!(matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "root"));
736            assert!(matches!(&steps[1], ExecutionPlan::Parallel(branches) if branches.len() == 2));
737            assert!(matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "merge"));
738        } else {
739            panic!("expected Sequence, got: {:?}", result.plan);
740        }
741    }
742
743    #[test]
744    fn compile_independent_roots_parallel() {
745        let mut graph = Graph::new();
746        graph.add_node(Node::new("a", "A", "F"));
747        graph.add_node(Node::new("b", "B", "F"));
748        // No edges: fully independent
749
750        let mut registry = SimpleFilterRegistry::new();
751        register_nodes(
752            &mut registry,
753            &["a", "b"],
754            make_meta(FilterKind::Trainable, true),
755        );
756
757        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
758
759        // Both at level 0 → Parallel
760        assert!(matches!(result.plan, ExecutionPlan::Parallel(_)));
761    }
762
763    #[test]
764    fn cache_resolution_replaces_cached_nodes() {
765        let graph = linear_pipeline(vec![
766            Node::new("a", "Scaler", "F"),
767            Node::new("b", "PCA", "F"),
768            Node::new("c", "SVM", "F"),
769        ]);
770
771        let mut registry = SimpleFilterRegistry::new();
772        register_nodes(
773            &mut registry,
774            &["a", "b", "c"],
775            make_meta(FilterKind::Trainable, true),
776        );
777
778        // Pre-compute the cache key for node "a" (same logic as compiler)
779        let a_config = registry.config_hash("a").unwrap();
780        let a_cache_key = CacheKey::from_parts(&[&a_config.0]);
781
782        let cache = MockCacheStore::new();
783        cache.insert(a_cache_key);
784
785        let result = compile(&graph, &registry, CompileMode::Inference, Some(&cache)).unwrap();
786
787        // "a" is cached, "b"+"c" are differentiable → Composite
788        if let ExecutionPlan::Sequence(steps) = &result.plan {
789            assert!(
790                matches!(&steps[0], ExecutionPlan::Cached { node_id, .. } if node_id == "a"),
791                "first node should be cached, got: {:?}",
792                steps[0]
793            );
794            assert!(
795                matches!(&steps[1], ExecutionPlan::Composite { node_ids } if node_ids == &["b", "c"]),
796                "b+c should be Composite, got: {:?}",
797                steps[1]
798            );
799        } else {
800            panic!("expected Sequence, got: {:?}", result.plan);
801        }
802    }
803
804    #[test]
805    fn cascade_invalidation_different_config_changes_keys() {
806        // Register with config hash "v1"
807        let mut reg1 = SimpleFilterRegistry::new();
808        reg1.register_meta(
809            "a",
810            make_meta(FilterKind::Trainable, true),
811            CacheKey::hash_data(b"scaler_v1"),
812        );
813        reg1.register_meta(
814            "b",
815            make_meta(FilterKind::Trainable, true),
816            CacheKey::hash_data(b"pca_v1"),
817        );
818
819        // Register with config hash "v2" for node "a"
820        let mut reg2 = SimpleFilterRegistry::new();
821        reg2.register_meta(
822            "a",
823            make_meta(FilterKind::Trainable, true),
824            CacheKey::hash_data(b"scaler_v2"), // changed!
825        );
826        reg2.register_meta(
827            "b",
828            make_meta(FilterKind::Trainable, true),
829            CacheKey::hash_data(b"pca_v1"), // same
830        );
831
832        // Compute keys for both configurations
833        // The plans have same structure but when cache keys are computed,
834        // changing "a" config changes "b"'s key too (cascade).
835        // We verify this by computing keys manually:
836        let a_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v1").0]);
837        let b_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v1.0]);
838
839        let a_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v2").0]);
840        let b_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v2.0]);
841
842        // a changed → a's key changed
843        assert_ne!(a_key_v1, a_key_v2);
844        // b's config didn't change but a's key is in b's key → b's key also changed
845        assert_ne!(b_key_v1, b_key_v2);
846    }
847
848    #[test]
849    fn no_cache_mode_skips_all_caching() {
850        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
851
852        let mut registry = SimpleFilterRegistry::new();
853        register_nodes(
854            &mut registry,
855            &["a", "b"],
856            make_meta(FilterKind::Trainable, true),
857        );
858
859        // Put everything in cache
860        let a_config = registry.config_hash("a").unwrap();
861        let a_key = CacheKey::from_parts(&[&a_config.0]);
862        let cache = MockCacheStore::new();
863        cache.insert(a_key);
864
865        let result = compile(&graph, &registry, CompileMode::NoCache, Some(&cache)).unwrap();
866
867        // Nothing should be cached
868        assert_eq!(result.plan.cached_count(), 0);
869    }
870
871    #[test]
872    fn differentiable_mode_skips_output_caching() {
873        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
874
875        let mut registry = SimpleFilterRegistry::new();
876        register_nodes(
877            &mut registry,
878            &["a", "b"],
879            make_meta(FilterKind::Trainable, true),
880        );
881
882        let a_config = registry.config_hash("a").unwrap();
883        let a_key = CacheKey::from_parts(&[&a_config.0]);
884        let cache = MockCacheStore::new();
885        cache.insert(a_key);
886
887        let result = compile(&graph, &registry, CompileMode::Differentiable, Some(&cache)).unwrap();
888
889        // Differentiable mode should not cache forward outputs
890        assert_eq!(result.plan.cached_count(), 0);
891    }
892
893    #[test]
894    fn gradient_flow_diagnostic_on_opaque() {
895        let graph = linear_pipeline(vec![
896            Node::new("scaler", "Scaler", "F"),
897            Node::new("tree", "DecisionTree", "F"),
898            Node::new("linear", "Linear", "F"),
899        ]);
900
901        let mut registry = SimpleFilterRegistry::new();
902        registry.register_meta(
903            "scaler",
904            make_meta(FilterKind::Trainable, true),
905            CacheKey::hash_data(b"s"),
906        );
907        registry.register_meta(
908            "tree",
909            make_meta(FilterKind::Opaque, false), // not differentiable
910            CacheKey::hash_data(b"t"),
911        );
912        registry.register_meta(
913            "linear",
914            make_meta(FilterKind::Trainable, true),
915            CacheKey::hash_data(b"l"),
916        );
917
918        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
919
920        assert_eq!(result.diagnostics.len(), 1);
921        assert_eq!(result.diagnostics[0].node_id, "tree");
922        assert_eq!(result.diagnostics[0].level, DiagnosticLevel::Warning);
923        assert!(
924            result.diagnostics[0]
925                .message
926                .contains("gradient flow interrupted")
927        );
928    }
929
930    #[test]
931    fn no_diagnostic_when_all_differentiable() {
932        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
933
934        let mut registry = SimpleFilterRegistry::new();
935        register_nodes(
936            &mut registry,
937            &["a", "b"],
938            make_meta(FilterKind::Trainable, true),
939        );
940
941        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
942        assert!(result.diagnostics.is_empty());
943    }
944
945    #[test]
946    fn compile_rejects_cycle() {
947        let mut graph = Graph::new();
948        graph.add_node(Node::new("a", "A", "F"));
949        graph.add_node(Node::new("b", "B", "F"));
950        graph.add_edge(Edge::data("e1", "a", "b"));
951        graph.add_edge(Edge::data("e2", "b", "a"));
952
953        let registry = SimpleFilterRegistry::new();
954        let result = compile(&graph, &registry, CompileMode::Inference, None);
955        assert!(matches!(result, Err(SomaError::CycleDetected)));
956    }
957
958    #[test]
959    fn plan_summary_is_accurate() {
960        let mut graph = Graph::new();
961        graph.add_node(Node::new("root", "Root", "F"));
962        graph.add_node(Node::new("b1", "B1", "F"));
963        graph.add_node(Node::new("b2", "B2", "F"));
964        graph.add_node(Node::new("end", "End", "F"));
965        graph.add_edge(Edge::data("e1", "root", "b1"));
966        graph.add_edge(Edge::data("e2", "root", "b2"));
967        graph.add_edge(Edge::data("e3", "b1", "end"));
968        graph.add_edge(Edge::data("e4", "b2", "end"));
969
970        let mut registry = SimpleFilterRegistry::new();
971        register_nodes(
972            &mut registry,
973            &["root", "b1", "b2", "end"],
974            make_meta(FilterKind::Trainable, true),
975        );
976
977        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
978        let summary = result.plan.summary();
979        assert_eq!(summary.total_nodes, 4);
980        assert_eq!(summary.parallel_branches, 2);
981    }
982
983    #[test]
984    fn distribution_wraps_remote_nodes() {
985        let graph = linear_pipeline(vec![
986            Node::new("preprocess", "Preprocess", "F"),
987            Node::new("gpu_train", "GpuTrain", "F"),
988            Node::new("evaluate", "Evaluate", "F"),
989        ]);
990
991        let mut registry = SimpleFilterRegistry::new();
992        // preprocess: local
993        registry.register_meta(
994            "preprocess",
995            make_meta(FilterKind::Trainable, true),
996            CacheKey::hash_data(b"pre"),
997        );
998        // gpu_train: remote on GPU tag
999        let mut gpu_meta = make_meta(FilterKind::Trainable, true);
1000        gpu_meta.distribution = somatize_core::filter::Distribution::Remote(
1001            somatize_core::filter::RemoteTarget::Tag("gpu".into()),
1002        );
1003        registry.register_meta("gpu_train", gpu_meta, CacheKey::hash_data(b"gpu"));
1004        // evaluate: local
1005        registry.register_meta(
1006            "evaluate",
1007            make_meta(FilterKind::Trainable, true),
1008            CacheKey::hash_data(b"eval"),
1009        );
1010
1011        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
1012
1013        // Should be: Sequence(Execute(preprocess), Remote(gpu_train, ...), Execute(evaluate))
1014        if let ExecutionPlan::Sequence(steps) = &result.plan {
1015            assert_eq!(steps.len(), 3);
1016            assert!(
1017                matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "preprocess")
1018            );
1019            assert!(
1020                matches!(&steps[1], ExecutionPlan::Remote { node_id, target, .. }
1021                    if node_id == "gpu_train"
1022                    && *target == somatize_core::filter::RemoteTarget::Tag("gpu".into())
1023                ),
1024                "expected Remote, got: {:?}",
1025                steps[1]
1026            );
1027            assert!(
1028                matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "evaluate")
1029            );
1030        } else {
1031            panic!("expected Sequence, got: {:?}", result.plan);
1032        }
1033    }
1034
1035    #[test]
1036    fn local_distribution_not_wrapped() {
1037        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
1038
1039        let mut registry = SimpleFilterRegistry::new();
1040        register_nodes(
1041            &mut registry,
1042            &["a", "b"],
1043            make_meta(FilterKind::Trainable, true),
1044        );
1045
1046        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
1047
1048        // No Remote nodes
1049        let ids = result.plan.node_ids();
1050        assert_eq!(ids.len(), 2);
1051        // Should all be Execute, no Remote wrapper
1052        if let ExecutionPlan::Sequence(steps) = &result.plan {
1053            assert!(
1054                steps
1055                    .iter()
1056                    .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
1057            );
1058        }
1059    }
1060}