Skip to main content

somatize_compiler/
compiler.rs

1//! Graph → ExecutionPlan compiler.
2//!
3//! Compilation phases: topological sort → parallelism detection →
4//! cache resolution → schema validation → distribution wrapping → simplification.
5
6use crate::plan::ExecutionPlan;
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::Result;
9use somatize_core::filter::{Filter, FilterMeta};
10use somatize_core::graph::{Graph, NodeId};
11use std::collections::{HashMap, HashSet};
12
13/// Compilation mode affects caching behavior.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompileMode {
16    /// Full caching: skip nodes whose outputs are cached.
17    Inference,
18    /// Cache states only: re-execute forwards for gradient flow.
19    Differentiable,
20    /// No caching at all: force re-execution of everything.
21    NoCache,
22}
23
24/// Diagnostic message emitted during compilation.
25#[derive(Debug, Clone)]
26pub struct Diagnostic {
27    pub node_id: NodeId,
28    pub level: DiagnosticLevel,
29    pub message: String,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DiagnosticLevel {
34    Warning,
35    Info,
36}
37
38/// Compiled result: the plan plus any diagnostics.
39pub struct CompileResult {
40    pub plan: ExecutionPlan,
41    pub diagnostics: Vec<Diagnostic>,
42}
43
44/// Registry that maps node IDs to their filter metadata.
45/// The compiler needs metadata (cacheable, differentiable, etc.)
46/// but doesn't need the actual filter implementations.
47pub trait FilterRegistry: Send + Sync {
48    fn meta(&self, node_id: &str) -> Option<FilterMeta>;
49    fn config_hash(&self, node_id: &str) -> Option<CacheKey>;
50}
51
52/// Simple in-memory filter registry for compilation.
53pub struct SimpleFilterRegistry {
54    entries: HashMap<String, (FilterMeta, CacheKey)>,
55}
56
57impl SimpleFilterRegistry {
58    pub fn new() -> Self {
59        Self {
60            entries: HashMap::new(),
61        }
62    }
63
64    pub fn register(&mut self, node_id: impl Into<String>, filter: &dyn Filter) {
65        let id = node_id.into();
66        self.entries
67            .insert(id, (filter.meta(), filter.config_hash()));
68    }
69
70    pub fn register_meta(
71        &mut self,
72        node_id: impl Into<String>,
73        meta: FilterMeta,
74        config_hash: CacheKey,
75    ) {
76        self.entries.insert(node_id.into(), (meta, config_hash));
77    }
78}
79
80impl Default for SimpleFilterRegistry {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl FilterRegistry for SimpleFilterRegistry {
87    fn meta(&self, node_id: &str) -> Option<FilterMeta> {
88        self.entries.get(node_id).map(|(m, _)| m.clone())
89    }
90
91    fn config_hash(&self, node_id: &str) -> Option<CacheKey> {
92        self.entries.get(node_id).map(|(_, h)| h.clone())
93    }
94}
95
96/// Compiles a Graph into an ExecutionPlan.
97pub struct Compiler<'a> {
98    graph: &'a Graph,
99    registry: &'a dyn FilterRegistry,
100    mode: CompileMode,
101    diagnostics: Vec<Diagnostic>,
102}
103
104impl<'a> Compiler<'a> {
105    pub fn new(graph: &'a Graph, registry: &'a dyn FilterRegistry, mode: CompileMode) -> Self {
106        Self {
107            graph,
108            registry,
109            mode,
110            diagnostics: Vec::new(),
111        }
112    }
113
114    /// Compile the graph into an execution plan.
115    pub fn compile(mut self, cache: Option<&dyn CacheStore>) -> Result<CompileResult> {
116        self.graph.validate()?;
117
118        let sorted = self.graph.topological_sort()?;
119
120        if sorted.is_empty() {
121            return Ok(CompileResult {
122                plan: ExecutionPlan::Empty,
123                diagnostics: self.diagnostics,
124            });
125        }
126
127        // Check gradient flow
128        self.check_gradient_flow(&sorted);
129
130        // Validate schema compatibility
131        self.validate_schemas(&sorted);
132
133        // Build the structural plan (detect parallelism)
134        let plan = self.build_plan(&sorted);
135
136        // Resolve caching if applicable
137        let plan = if let Some(cache) = cache {
138            self.resolve_cache(plan, cache, &sorted)?
139        } else {
140            plan
141        };
142
143        // Resolve distribution (wrap Remote nodes)
144        let plan = self.resolve_distribution(plan);
145
146        // Collapse consecutive differentiable nodes into Composite blocks
147        let plan = self.collapse_differentiable(plan);
148
149        let plan = plan.simplify();
150
151        Ok(CompileResult {
152            plan,
153            diagnostics: self.diagnostics,
154        })
155    }
156
157    /// Build a plan from topologically sorted nodes, detecting parallelism.
158    fn build_plan(&self, sorted: &[&str]) -> ExecutionPlan {
159        // Compute topological levels (nodes at the same level can run in parallel)
160        let levels = self.compute_levels(sorted);
161
162        let mut plan_steps: Vec<ExecutionPlan> = Vec::new();
163
164        for level in &levels {
165            if level.len() == 1 {
166                plan_steps.push(self.plan_for_node(level[0]));
167            } else {
168                let branches: Vec<ExecutionPlan> =
169                    level.iter().map(|id| self.plan_for_node(id)).collect();
170                plan_steps.push(ExecutionPlan::Parallel(branches));
171            }
172        }
173
174        if plan_steps.len() == 1 {
175            plan_steps.into_iter().next().unwrap()
176        } else {
177            ExecutionPlan::Sequence(plan_steps)
178        }
179    }
180
181    /// Generate the execution plan for a single node based on its kind.
182    fn plan_for_node(&self, node_id: &str) -> ExecutionPlan {
183        use somatize_core::graph::NodeKind;
184
185        let node = match self.graph.node(node_id) {
186            Some(n) => n,
187            None => {
188                return ExecutionPlan::Execute {
189                    node_id: node_id.to_string(),
190                };
191            }
192        };
193
194        match &node.kind {
195            NodeKind::Filter { .. } => ExecutionPlan::Execute {
196                node_id: node_id.to_string(),
197            },
198
199            NodeKind::SubGraph { graph } => {
200                // Recursively compile the inner graph
201                let inner_compiler = Compiler::new(graph, self.registry, self.mode);
202                match inner_compiler.compile(None) {
203                    Ok(result) => result.plan,
204                    Err(_) => ExecutionPlan::Execute {
205                        node_id: node_id.to_string(),
206                    },
207                }
208            }
209
210            NodeKind::Loop { max_iterations } => {
211                // The body consists of the successors of this loop node.
212                // Build a sub-plan from the successor chain.
213                let successors = self.graph.successors(node_id);
214                let body = if successors.len() == 1 {
215                    self.plan_for_node(successors[0])
216                } else if successors.len() > 1 {
217                    let branches: Vec<ExecutionPlan> =
218                        successors.iter().map(|id| self.plan_for_node(id)).collect();
219                    ExecutionPlan::Parallel(branches)
220                } else {
221                    ExecutionPlan::Empty
222                };
223                ExecutionPlan::Loop {
224                    node_id: node_id.to_string(),
225                    body: Box::new(body),
226                    max_iterations: *max_iterations,
227                }
228            }
229
230            NodeKind::Branch => {
231                // Arms come from control edges leaving this node.
232                let arms: Vec<(String, ExecutionPlan)> = self
233                    .graph
234                    .edges
235                    .iter()
236                    .filter(|e| e.source == node_id)
237                    .map(|e| {
238                        let label = e.label.clone().unwrap_or_else(|| e.target.clone());
239                        let plan = self.plan_for_node(&e.target);
240                        (label, plan)
241                    })
242                    .collect();
243                ExecutionPlan::Branch {
244                    node_id: node_id.to_string(),
245                    arms,
246                }
247            }
248
249            _ => ExecutionPlan::Execute {
250                node_id: node_id.to_string(),
251            },
252        }
253    }
254
255    /// Compute topological levels: groups of nodes that can execute concurrently.
256    /// Each node's level = max(predecessor levels) + 1.
257    fn compute_levels<'b>(&self, sorted: &[&'b str]) -> Vec<Vec<&'b str>> {
258        let mut node_level: HashMap<&str, usize> = HashMap::new();
259        let mut max_level: usize = 0;
260
261        for &node in sorted {
262            let preds = self.graph.predecessors(node);
263            let level = if preds.is_empty() {
264                0
265            } else {
266                preds
267                    .iter()
268                    .map(|p| node_level.get(p).copied().unwrap_or(0) + 1)
269                    .max()
270                    .unwrap_or(0)
271            };
272            node_level.insert(node, level);
273            if level > max_level {
274                max_level = level;
275            }
276        }
277
278        let mut levels: Vec<Vec<&str>> = vec![Vec::new(); max_level + 1];
279        for &node in sorted {
280            let level = node_level[node];
281            levels[level].push(node);
282        }
283
284        // Remove empty levels (shouldn't happen but defensive)
285        levels.retain(|l| !l.is_empty());
286        levels
287    }
288
289    /// Resolve caching: replace Execute nodes with Cached when possible.
290    /// Implements cascade invalidation.
291    fn resolve_cache(
292        &self,
293        plan: ExecutionPlan,
294        cache: &dyn CacheStore,
295        sorted: &[&str],
296    ) -> Result<ExecutionPlan> {
297        if self.mode == CompileMode::NoCache {
298            return Ok(plan);
299        }
300
301        // Compute cache keys for all nodes in topological order.
302        // A node's key depends on its config + its predecessors' keys.
303        let mut node_keys: HashMap<String, CacheKey> = HashMap::new();
304        let mut cached_nodes: HashSet<String> = HashSet::new();
305
306        for &node_id in sorted {
307            let config_hash = match self.registry.config_hash(node_id) {
308                Some(h) => h,
309                None => continue, // no filter registered, can't cache
310            };
311
312            let meta = self.registry.meta(node_id);
313            let cacheable = meta.as_ref().is_some_and(|m| m.cacheable);
314
315            // In differentiable mode, only cache states (not forward outputs)
316            // For simplicity at this stage, we skip caching in differentiable mode
317            let can_cache = cacheable && self.mode == CompileMode::Inference;
318
319            // Build the cache key from config + predecessor output keys
320            let pred_ids = self.graph.predecessors(node_id);
321            let mut key_parts: Vec<Vec<u8>> = vec![config_hash.0.to_vec()];
322            for pred in &pred_ids {
323                if let Some(pred_key) = node_keys.get(*pred) {
324                    key_parts.push(pred_key.0.to_vec());
325                } else {
326                    // Predecessor should always be processed first in topological order.
327                    // If missing, the cache key will be incomplete but won't panic.
328                    debug_assert!(
329                        false,
330                        "predecessor `{pred}` of `{node_id}` not in node_keys - \
331                         topological order may be broken"
332                    );
333                }
334            }
335            let parts_refs: Vec<&[u8]> = key_parts.iter().map(|p| p.as_slice()).collect();
336            let key = CacheKey::from_parts(&parts_refs);
337            node_keys.insert(node_id.to_string(), key.clone());
338
339            // Check if this node's output exists in cache
340            if can_cache {
341                // Only use cache if ALL predecessors are also cached (or roots).
342                // This ensures cascade invalidation: if any upstream re-executed,
343                // the key is already different (since it includes predecessor keys).
344                if cache.exists(&key)? {
345                    cached_nodes.insert(node_id.to_string());
346                }
347            }
348        }
349
350        // Replace Execute nodes with Cached where applicable
351        Ok(self.apply_cache_to_plan(plan, &cached_nodes, &node_keys))
352    }
353
354    fn apply_cache_to_plan(
355        &self,
356        plan: ExecutionPlan,
357        cached: &HashSet<String>,
358        keys: &HashMap<String, CacheKey>,
359    ) -> ExecutionPlan {
360        match plan {
361            ExecutionPlan::Execute { ref node_id } => {
362                if cached.contains(node_id)
363                    && let Some(key) = keys.get(node_id)
364                {
365                    return ExecutionPlan::Cached {
366                        node_id: node_id.clone(),
367                        key: key.clone(),
368                    };
369                }
370                plan
371            }
372            ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
373                steps
374                    .into_iter()
375                    .map(|s| self.apply_cache_to_plan(s, cached, keys))
376                    .collect(),
377            ),
378            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
379                branches
380                    .into_iter()
381                    .map(|b| self.apply_cache_to_plan(b, cached, keys))
382                    .collect(),
383            ),
384            other => other,
385        }
386    }
387
388    /// Wrap nodes with Remote distribution in ExecutionPlan::Remote.
389    fn resolve_distribution(&self, plan: ExecutionPlan) -> ExecutionPlan {
390        match plan {
391            ExecutionPlan::Execute { ref node_id } => {
392                if let Some(meta) = self.registry.meta(node_id) {
393                    match &meta.distribution {
394                        somatize_core::filter::Distribution::Remote(target) => {
395                            ExecutionPlan::Remote {
396                                node_id: node_id.clone(),
397                                target: target.clone(),
398                                plan: Box::new(plan),
399                            }
400                        }
401                        _ => plan,
402                    }
403                } else {
404                    plan
405                }
406            }
407            ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
408                steps
409                    .into_iter()
410                    .map(|s| self.resolve_distribution(s))
411                    .collect(),
412            ),
413            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
414                branches
415                    .into_iter()
416                    .map(|b| self.resolve_distribution(b))
417                    .collect(),
418            ),
419            ExecutionPlan::Composite { ref node_ids } => {
420                // If ALL nodes in the composite have a Remote target, wrap the
421                // entire composite in a single Remote (using the first node's
422                // target). Otherwise keep it local.
423                let targets: Vec<_> = node_ids
424                    .iter()
425                    .filter_map(|nid| {
426                        self.registry.meta(nid).and_then(|m| match &m.distribution {
427                            somatize_core::filter::Distribution::Remote(t) => Some(t.clone()),
428                            _ => None,
429                        })
430                    })
431                    .collect();
432
433                if targets.len() == node_ids.len() && !targets.is_empty() {
434                    let first_id = node_ids[0].clone();
435                    ExecutionPlan::Remote {
436                        node_id: first_id,
437                        target: targets.into_iter().next().unwrap(),
438                        plan: Box::new(plan),
439                    }
440                } else {
441                    plan
442                }
443            }
444            other => other,
445        }
446    }
447
448    /// Collapse consecutive differentiable Execute nodes into Composite blocks.
449    ///
450    /// A `Composite` groups nodes that should share a PyTorch autograd session.
451    /// Only groups 2+ consecutive `Execute` nodes where `meta.differentiable == true`.
452    fn collapse_differentiable(&self, plan: ExecutionPlan) -> ExecutionPlan {
453        match plan {
454            ExecutionPlan::Sequence(steps) => {
455                let mut result: Vec<ExecutionPlan> = Vec::new();
456                let mut diff_group: Vec<String> = Vec::new();
457
458                for step in steps {
459                    if let ExecutionPlan::Execute { ref node_id } = step
460                        && self
461                            .registry
462                            .meta(node_id)
463                            .map(|m| m.differentiable)
464                            .unwrap_or(false)
465                    {
466                        diff_group.push(node_id.clone());
467                        continue;
468                    }
469                    // Flush accumulated differentiable group
470                    Self::flush_diff_group(&mut diff_group, &mut result);
471                    result.push(self.collapse_differentiable(step));
472                }
473                Self::flush_diff_group(&mut diff_group, &mut result);
474
475                if result.len() == 1 {
476                    result.pop().unwrap()
477                } else {
478                    ExecutionPlan::Sequence(result)
479                }
480            }
481            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
482                branches
483                    .into_iter()
484                    .map(|b| self.collapse_differentiable(b))
485                    .collect(),
486            ),
487            ExecutionPlan::Remote {
488                node_id,
489                target,
490                plan,
491            } => ExecutionPlan::Remote {
492                node_id,
493                target,
494                plan: Box::new(self.collapse_differentiable(*plan)),
495            },
496            other => other,
497        }
498    }
499
500    fn flush_diff_group(group: &mut Vec<String>, result: &mut Vec<ExecutionPlan>) {
501        if group.len() > 1 {
502            result.push(ExecutionPlan::Composite {
503                node_ids: std::mem::take(group),
504            });
505        } else if let Some(id) = group.pop() {
506            result.push(ExecutionPlan::Execute { node_id: id });
507        }
508    }
509
510    /// Validate schema compatibility between connected filters.
511    ///
512    /// For each edge (A → B), checks that A's output_schema is compatible
513    /// with B's input_schema. Emits warnings (not errors) for mismatches,
514    /// since schemas are optional and None means "accepts anything".
515    fn validate_schemas(&mut self, sorted: &[&str]) {
516        for &node_id in sorted {
517            let input_schema = self
518                .registry
519                .meta(node_id)
520                .and_then(|m| m.input_schema.clone());
521
522            // Skip if this node accepts anything
523            let Some(expected_input) = input_schema else {
524                continue;
525            };
526
527            // Check each predecessor's output schema
528            for pred_id in self.graph.predecessors(node_id) {
529                let pred_output = self
530                    .registry
531                    .meta(pred_id)
532                    .and_then(|m| m.output_schema.clone());
533
534                let Some(actual_output) = pred_output else {
535                    continue; // predecessor output unknown, skip
536                };
537
538                if !actual_output.is_compatible_with(&expected_input) {
539                    self.diagnostics.push(Diagnostic {
540                        node_id: node_id.to_string(),
541                        level: DiagnosticLevel::Warning,
542                        message: format!(
543                            "schema mismatch: `{pred_id}` outputs {actual_output} \
544                             but `{node_id}` expects {expected_input}",
545                        ),
546                    });
547                }
548            }
549        }
550    }
551
552    /// Check gradient flow and emit warnings for each interruption.
553    ///
554    /// Gradient flow can restart after an opaque node (differentiable nodes
555    /// after an opaque one can still propagate gradients among themselves),
556    /// but gradients from before the interruption are lost.
557    fn check_gradient_flow(&mut self, sorted: &[&str]) {
558        let mut gradient_flows = true;
559
560        for &node_id in sorted {
561            if let Some(meta) = self.registry.meta(node_id) {
562                if gradient_flows && !meta.differentiable {
563                    self.diagnostics.push(Diagnostic {
564                        node_id: node_id.to_string(),
565                        level: DiagnosticLevel::Warning,
566                        message: format!(
567                            "gradient flow interrupted at `{}` ({:?}). \
568                             Gradients from upstream will not reach downstream filters \
569                             through this node.",
570                            node_id, meta.kind,
571                        ),
572                    });
573                    gradient_flows = false;
574                } else if !gradient_flows && meta.differentiable {
575                    // Gradient flow restarts: differentiable nodes after the
576                    // interruption can propagate gradients among themselves
577                    gradient_flows = true;
578                }
579            }
580        }
581    }
582}
583
584/// Convenience function: compile a graph with default settings.
585pub fn compile(
586    graph: &Graph,
587    registry: &dyn FilterRegistry,
588    mode: CompileMode,
589    cache: Option<&dyn CacheStore>,
590) -> Result<CompileResult> {
591    Compiler::new(graph, registry, mode).compile(cache)
592}
593
594/// Compile a graph for streaming execution.
595///
596/// Produces an `ExecutionPlan::Stream` wrapping the topologically sorted
597/// node chain. The runtime will chunk input and process through a
598/// `StreamExecutor` that respects each filter's `StreamMode`.
599pub fn compile_stream(
600    graph: &Graph,
601    _registry: &dyn FilterRegistry,
602    chunk_size: usize,
603) -> Result<CompileResult> {
604    graph.validate()?;
605    let sorted = graph.topological_sort()?;
606
607    if sorted.is_empty() {
608        return Ok(CompileResult {
609            plan: ExecutionPlan::Empty,
610            diagnostics: Vec::new(),
611        });
612    }
613
614    let node_ids: Vec<NodeId> = sorted.into_iter().map(|s| s.to_string()).collect();
615    let plan = ExecutionPlan::Stream {
616        node_ids,
617        chunk_size,
618    };
619
620    Ok(CompileResult {
621        plan,
622        diagnostics: Vec::new(),
623    })
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629    use somatize_core::cache::EntryMeta;
630    use somatize_core::error::SomaError;
631    use somatize_core::filter::{FilterKind, StreamMode};
632    use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
633    use somatize_core::value::Value;
634    use std::sync::Mutex;
635
636    // ── Mock cache store ──
637
638    struct MockCacheStore {
639        entries: Mutex<HashSet<CacheKey>>,
640    }
641
642    impl MockCacheStore {
643        fn new() -> Self {
644            Self {
645                entries: Mutex::new(HashSet::new()),
646            }
647        }
648
649        fn insert(&self, key: CacheKey) {
650            self.entries.lock().unwrap().insert(key);
651        }
652    }
653
654    impl CacheStore for MockCacheStore {
655        fn get(&self, _key: &CacheKey) -> Result<Option<Value>> {
656            Ok(None)
657        }
658        fn put(&self, _key: &CacheKey, _value: &Value) -> Result<()> {
659            Ok(())
660        }
661        fn exists(&self, key: &CacheKey) -> Result<bool> {
662            Ok(self.entries.lock().unwrap().contains(key))
663        }
664        fn remove(&self, _key: &CacheKey) -> Result<()> {
665            Ok(())
666        }
667        fn metadata(&self, _key: &CacheKey) -> Result<Option<EntryMeta>> {
668            Ok(None)
669        }
670    }
671
672    // ── Helpers ──
673
674    fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
675        FilterMeta {
676            name: "test".into(),
677            kind,
678            cacheable: true,
679            differentiable,
680            stream_mode: StreamMode::FixedState,
681            distribution: somatize_core::filter::Distribution::Local,
682            input_schema: None,
683            output_schema: None,
684        }
685    }
686
687    fn register_nodes(registry: &mut SimpleFilterRegistry, ids: &[&str], meta: FilterMeta) {
688        for (i, id) in ids.iter().enumerate() {
689            let hash = CacheKey::from_parts(&[id.as_bytes(), &[i as u8]]);
690            registry.register_meta(*id, meta.clone(), hash);
691        }
692    }
693
694    // ── Tests ──
695
696    #[test]
697    fn compile_empty_graph() {
698        let graph = Graph::new();
699        let registry = SimpleFilterRegistry::new();
700        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
701        assert!(matches!(result.plan, ExecutionPlan::Empty));
702    }
703
704    #[test]
705    fn compile_single_node() {
706        let mut graph = Graph::new();
707        graph.add_node(Node::new("a", "A", "F"));
708        let mut registry = SimpleFilterRegistry::new();
709        register_nodes(
710            &mut registry,
711            &["a"],
712            make_meta(FilterKind::Trainable, true),
713        );
714
715        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
716        assert!(matches!(result.plan, ExecutionPlan::Execute { .. }));
717    }
718
719    #[test]
720    fn compile_linear_pipeline_produces_sequence() {
721        let graph = linear_pipeline(vec![
722            Node::new("a", "Scaler", "F"),
723            Node::new("b", "PCA", "F"),
724            Node::new("c", "SVM", "F"),
725        ]);
726        let mut registry = SimpleFilterRegistry::new();
727        register_nodes(
728            &mut registry,
729            &["a", "b", "c"],
730            make_meta(FilterKind::Trainable, true),
731        );
732
733        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
734
735        // All 3 nodes are differentiable → collapsed into Composite
736        if let ExecutionPlan::Composite { node_ids } = &result.plan {
737            assert_eq!(node_ids, &["a", "b", "c"]);
738        } else {
739            panic!("expected Composite, got: {:?}", result.plan);
740        }
741    }
742
743    #[test]
744    fn compile_diamond_detects_parallelism() {
745        let mut graph = Graph::new();
746        graph.add_node(Node::new("root", "Root", "F"));
747        graph.add_node(Node::new("b1", "B1", "F"));
748        graph.add_node(Node::new("b2", "B2", "F"));
749        graph.add_node(Node::new("merge", "Merge", "F"));
750        graph.add_edge(Edge::data("e1", "root", "b1"));
751        graph.add_edge(Edge::data("e2", "root", "b2"));
752        graph.add_edge(Edge::data("e3", "b1", "merge"));
753        graph.add_edge(Edge::data("e4", "b2", "merge"));
754
755        let mut registry = SimpleFilterRegistry::new();
756        register_nodes(
757            &mut registry,
758            &["root", "b1", "b2", "merge"],
759            make_meta(FilterKind::Trainable, true),
760        );
761
762        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
763
764        // Should be: Sequence(Execute(root), Parallel(Execute(b1), Execute(b2)), Execute(merge))
765        if let ExecutionPlan::Sequence(steps) = &result.plan {
766            assert_eq!(steps.len(), 3);
767            assert!(matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "root"));
768            assert!(matches!(&steps[1], ExecutionPlan::Parallel(branches) if branches.len() == 2));
769            assert!(matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "merge"));
770        } else {
771            panic!("expected Sequence, got: {:?}", result.plan);
772        }
773    }
774
775    #[test]
776    fn compile_independent_roots_parallel() {
777        let mut graph = Graph::new();
778        graph.add_node(Node::new("a", "A", "F"));
779        graph.add_node(Node::new("b", "B", "F"));
780        // No edges: fully independent
781
782        let mut registry = SimpleFilterRegistry::new();
783        register_nodes(
784            &mut registry,
785            &["a", "b"],
786            make_meta(FilterKind::Trainable, true),
787        );
788
789        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
790
791        // Both at level 0 → Parallel
792        assert!(matches!(result.plan, ExecutionPlan::Parallel(_)));
793    }
794
795    #[test]
796    fn cache_resolution_replaces_cached_nodes() {
797        let graph = linear_pipeline(vec![
798            Node::new("a", "Scaler", "F"),
799            Node::new("b", "PCA", "F"),
800            Node::new("c", "SVM", "F"),
801        ]);
802
803        let mut registry = SimpleFilterRegistry::new();
804        register_nodes(
805            &mut registry,
806            &["a", "b", "c"],
807            make_meta(FilterKind::Trainable, true),
808        );
809
810        // Pre-compute the cache key for node "a" (same logic as compiler)
811        let a_config = registry.config_hash("a").unwrap();
812        let a_cache_key = CacheKey::from_parts(&[&a_config.0]);
813
814        let cache = MockCacheStore::new();
815        cache.insert(a_cache_key);
816
817        let result = compile(&graph, &registry, CompileMode::Inference, Some(&cache)).unwrap();
818
819        // "a" is cached, "b"+"c" are differentiable → Composite
820        if let ExecutionPlan::Sequence(steps) = &result.plan {
821            assert!(
822                matches!(&steps[0], ExecutionPlan::Cached { node_id, .. } if node_id == "a"),
823                "first node should be cached, got: {:?}",
824                steps[0]
825            );
826            assert!(
827                matches!(&steps[1], ExecutionPlan::Composite { node_ids } if node_ids == &["b", "c"]),
828                "b+c should be Composite, got: {:?}",
829                steps[1]
830            );
831        } else {
832            panic!("expected Sequence, got: {:?}", result.plan);
833        }
834    }
835
836    #[test]
837    fn cascade_invalidation_different_config_changes_keys() {
838        // Register with config hash "v1"
839        let mut reg1 = SimpleFilterRegistry::new();
840        reg1.register_meta(
841            "a",
842            make_meta(FilterKind::Trainable, true),
843            CacheKey::hash_data(b"scaler_v1"),
844        );
845        reg1.register_meta(
846            "b",
847            make_meta(FilterKind::Trainable, true),
848            CacheKey::hash_data(b"pca_v1"),
849        );
850
851        // Register with config hash "v2" for node "a"
852        let mut reg2 = SimpleFilterRegistry::new();
853        reg2.register_meta(
854            "a",
855            make_meta(FilterKind::Trainable, true),
856            CacheKey::hash_data(b"scaler_v2"), // changed!
857        );
858        reg2.register_meta(
859            "b",
860            make_meta(FilterKind::Trainable, true),
861            CacheKey::hash_data(b"pca_v1"), // same
862        );
863
864        // Compute keys for both configurations
865        // The plans have same structure but when cache keys are computed,
866        // changing "a" config changes "b"'s key too (cascade).
867        // We verify this by computing keys manually:
868        let a_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v1").0]);
869        let b_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v1.0]);
870
871        let a_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v2").0]);
872        let b_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v2.0]);
873
874        // a changed → a's key changed
875        assert_ne!(a_key_v1, a_key_v2);
876        // b's config didn't change but a's key is in b's key → b's key also changed
877        assert_ne!(b_key_v1, b_key_v2);
878    }
879
880    #[test]
881    fn no_cache_mode_skips_all_caching() {
882        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
883
884        let mut registry = SimpleFilterRegistry::new();
885        register_nodes(
886            &mut registry,
887            &["a", "b"],
888            make_meta(FilterKind::Trainable, true),
889        );
890
891        // Put everything in cache
892        let a_config = registry.config_hash("a").unwrap();
893        let a_key = CacheKey::from_parts(&[&a_config.0]);
894        let cache = MockCacheStore::new();
895        cache.insert(a_key);
896
897        let result = compile(&graph, &registry, CompileMode::NoCache, Some(&cache)).unwrap();
898
899        // Nothing should be cached
900        assert_eq!(result.plan.cached_count(), 0);
901    }
902
903    #[test]
904    fn differentiable_mode_skips_output_caching() {
905        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
906
907        let mut registry = SimpleFilterRegistry::new();
908        register_nodes(
909            &mut registry,
910            &["a", "b"],
911            make_meta(FilterKind::Trainable, true),
912        );
913
914        let a_config = registry.config_hash("a").unwrap();
915        let a_key = CacheKey::from_parts(&[&a_config.0]);
916        let cache = MockCacheStore::new();
917        cache.insert(a_key);
918
919        let result = compile(&graph, &registry, CompileMode::Differentiable, Some(&cache)).unwrap();
920
921        // Differentiable mode should not cache forward outputs
922        assert_eq!(result.plan.cached_count(), 0);
923    }
924
925    #[test]
926    fn gradient_flow_diagnostic_on_opaque() {
927        let graph = linear_pipeline(vec![
928            Node::new("scaler", "Scaler", "F"),
929            Node::new("tree", "DecisionTree", "F"),
930            Node::new("linear", "Linear", "F"),
931        ]);
932
933        let mut registry = SimpleFilterRegistry::new();
934        registry.register_meta(
935            "scaler",
936            make_meta(FilterKind::Trainable, true),
937            CacheKey::hash_data(b"s"),
938        );
939        registry.register_meta(
940            "tree",
941            make_meta(FilterKind::Opaque, false), // not differentiable
942            CacheKey::hash_data(b"t"),
943        );
944        registry.register_meta(
945            "linear",
946            make_meta(FilterKind::Trainable, true),
947            CacheKey::hash_data(b"l"),
948        );
949
950        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
951
952        assert_eq!(result.diagnostics.len(), 1);
953        assert_eq!(result.diagnostics[0].node_id, "tree");
954        assert_eq!(result.diagnostics[0].level, DiagnosticLevel::Warning);
955        assert!(
956            result.diagnostics[0]
957                .message
958                .contains("gradient flow interrupted")
959        );
960    }
961
962    #[test]
963    fn no_diagnostic_when_all_differentiable() {
964        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
965
966        let mut registry = SimpleFilterRegistry::new();
967        register_nodes(
968            &mut registry,
969            &["a", "b"],
970            make_meta(FilterKind::Trainable, true),
971        );
972
973        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
974        assert!(result.diagnostics.is_empty());
975    }
976
977    #[test]
978    fn compile_rejects_cycle() {
979        let mut graph = Graph::new();
980        graph.add_node(Node::new("a", "A", "F"));
981        graph.add_node(Node::new("b", "B", "F"));
982        graph.add_edge(Edge::data("e1", "a", "b"));
983        graph.add_edge(Edge::data("e2", "b", "a"));
984
985        let registry = SimpleFilterRegistry::new();
986        let result = compile(&graph, &registry, CompileMode::Inference, None);
987        assert!(matches!(result, Err(SomaError::CycleDetected)));
988    }
989
990    #[test]
991    fn plan_summary_is_accurate() {
992        let mut graph = Graph::new();
993        graph.add_node(Node::new("root", "Root", "F"));
994        graph.add_node(Node::new("b1", "B1", "F"));
995        graph.add_node(Node::new("b2", "B2", "F"));
996        graph.add_node(Node::new("end", "End", "F"));
997        graph.add_edge(Edge::data("e1", "root", "b1"));
998        graph.add_edge(Edge::data("e2", "root", "b2"));
999        graph.add_edge(Edge::data("e3", "b1", "end"));
1000        graph.add_edge(Edge::data("e4", "b2", "end"));
1001
1002        let mut registry = SimpleFilterRegistry::new();
1003        register_nodes(
1004            &mut registry,
1005            &["root", "b1", "b2", "end"],
1006            make_meta(FilterKind::Trainable, true),
1007        );
1008
1009        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
1010        let summary = result.plan.summary();
1011        assert_eq!(summary.total_nodes, 4);
1012        assert_eq!(summary.parallel_branches, 2);
1013    }
1014
1015    #[test]
1016    fn distribution_wraps_remote_nodes() {
1017        let graph = linear_pipeline(vec![
1018            Node::new("preprocess", "Preprocess", "F"),
1019            Node::new("gpu_train", "GpuTrain", "F"),
1020            Node::new("evaluate", "Evaluate", "F"),
1021        ]);
1022
1023        let mut registry = SimpleFilterRegistry::new();
1024        // preprocess: local
1025        registry.register_meta(
1026            "preprocess",
1027            make_meta(FilterKind::Trainable, true),
1028            CacheKey::hash_data(b"pre"),
1029        );
1030        // gpu_train: remote on GPU tag
1031        let mut gpu_meta = make_meta(FilterKind::Trainable, true);
1032        gpu_meta.distribution = somatize_core::filter::Distribution::Remote(
1033            somatize_core::filter::RemoteTarget::Tag("gpu".into()),
1034        );
1035        registry.register_meta("gpu_train", gpu_meta, CacheKey::hash_data(b"gpu"));
1036        // evaluate: local
1037        registry.register_meta(
1038            "evaluate",
1039            make_meta(FilterKind::Trainable, true),
1040            CacheKey::hash_data(b"eval"),
1041        );
1042
1043        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
1044
1045        // Should be: Sequence(Execute(preprocess), Remote(gpu_train, ...), Execute(evaluate))
1046        if let ExecutionPlan::Sequence(steps) = &result.plan {
1047            assert_eq!(steps.len(), 3);
1048            assert!(
1049                matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "preprocess")
1050            );
1051            assert!(
1052                matches!(&steps[1], ExecutionPlan::Remote { node_id, target, .. }
1053                    if node_id == "gpu_train"
1054                    && *target == somatize_core::filter::RemoteTarget::Tag("gpu".into())
1055                ),
1056                "expected Remote, got: {:?}",
1057                steps[1]
1058            );
1059            assert!(
1060                matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "evaluate")
1061            );
1062        } else {
1063            panic!("expected Sequence, got: {:?}", result.plan);
1064        }
1065    }
1066
1067    #[test]
1068    fn local_distribution_not_wrapped() {
1069        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
1070
1071        let mut registry = SimpleFilterRegistry::new();
1072        register_nodes(
1073            &mut registry,
1074            &["a", "b"],
1075            make_meta(FilterKind::Trainable, true),
1076        );
1077
1078        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
1079
1080        // No Remote nodes
1081        let ids = result.plan.node_ids();
1082        assert_eq!(ids.len(), 2);
1083        // Should all be Execute, no Remote wrapper
1084        if let ExecutionPlan::Sequence(steps) = &result.plan {
1085            assert!(
1086                steps
1087                    .iter()
1088                    .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
1089            );
1090        }
1091    }
1092}