Skip to main content

somatize_compiler/
compiler.rs

1//! Graph → ExecutionPlan compiler.
2//!
3//! Compilation phases: topological sort → parallelism detection →
4//! cache resolution → schema validation → distribution wrapping → simplification.
5
6use crate::plan::ExecutionPlan;
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::Result;
9use somatize_core::filter::{Filter, FilterMeta};
10use somatize_core::graph::{Graph, NodeId};
11use std::collections::{HashMap, HashSet};
12
13/// Compilation mode affects caching behavior.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompileMode {
16    /// Full caching: skip nodes whose outputs are cached.
17    Inference,
18    /// Cache states only: re-execute forwards for gradient flow.
19    Differentiable,
20    /// No caching at all: force re-execution of everything.
21    NoCache,
22}
23
24/// Diagnostic message emitted during compilation.
25#[derive(Debug, Clone)]
26pub struct Diagnostic {
27    pub node_id: NodeId,
28    pub level: DiagnosticLevel,
29    pub message: String,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DiagnosticLevel {
34    Warning,
35    Info,
36}
37
38/// Compiled result: the plan plus any diagnostics.
39pub struct CompileResult {
40    pub plan: ExecutionPlan,
41    pub diagnostics: Vec<Diagnostic>,
42}
43
44/// Registry that maps node IDs to their filter metadata.
45/// The compiler needs metadata (cacheable, differentiable, etc.)
46/// but doesn't need the actual filter implementations.
47pub trait FilterRegistry: Send + Sync {
48    fn meta(&self, node_id: &str) -> Option<FilterMeta>;
49    fn config_hash(&self, node_id: &str) -> Option<CacheKey>;
50}
51
52/// Simple in-memory filter registry for compilation.
53pub struct SimpleFilterRegistry {
54    entries: HashMap<String, (FilterMeta, CacheKey)>,
55}
56
57impl SimpleFilterRegistry {
58    pub fn new() -> Self {
59        Self {
60            entries: HashMap::new(),
61        }
62    }
63
64    pub fn register(&mut self, node_id: impl Into<String>, filter: &dyn Filter) {
65        let id = node_id.into();
66        self.entries
67            .insert(id, (filter.meta(), filter.config_hash()));
68    }
69
70    pub fn register_meta(
71        &mut self,
72        node_id: impl Into<String>,
73        meta: FilterMeta,
74        config_hash: CacheKey,
75    ) {
76        self.entries.insert(node_id.into(), (meta, config_hash));
77    }
78}
79
80impl Default for SimpleFilterRegistry {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl FilterRegistry for SimpleFilterRegistry {
87    fn meta(&self, node_id: &str) -> Option<FilterMeta> {
88        self.entries.get(node_id).map(|(m, _)| m.clone())
89    }
90
91    fn config_hash(&self, node_id: &str) -> Option<CacheKey> {
92        self.entries.get(node_id).map(|(_, h)| h.clone())
93    }
94}
95
96/// Compiles a Graph into an ExecutionPlan.
97pub struct Compiler<'a> {
98    graph: &'a Graph,
99    registry: &'a dyn FilterRegistry,
100    mode: CompileMode,
101    diagnostics: Vec<Diagnostic>,
102}
103
104impl<'a> Compiler<'a> {
105    pub fn new(graph: &'a Graph, registry: &'a dyn FilterRegistry, mode: CompileMode) -> Self {
106        Self {
107            graph,
108            registry,
109            mode,
110            diagnostics: Vec::new(),
111        }
112    }
113
114    /// Compile the graph into an execution plan.
115    pub fn compile(mut self, cache: Option<&dyn CacheStore>) -> Result<CompileResult> {
116        self.graph.validate()?;
117
118        let sorted = self.graph.topological_sort()?;
119
120        if sorted.is_empty() {
121            return Ok(CompileResult {
122                plan: ExecutionPlan::Empty,
123                diagnostics: self.diagnostics,
124            });
125        }
126
127        // Check gradient flow
128        self.check_gradient_flow(&sorted);
129
130        // Validate schema compatibility
131        self.validate_schemas(&sorted);
132
133        // Build the structural plan (detect parallelism)
134        let plan = self.build_plan(&sorted);
135
136        // Resolve caching if applicable
137        let plan = if let Some(cache) = cache {
138            self.resolve_cache(plan, cache, &sorted)?
139        } else {
140            plan
141        };
142
143        // Resolve distribution (wrap Remote nodes)
144        let plan = self.resolve_distribution(plan);
145
146        let plan = plan.simplify();
147
148        Ok(CompileResult {
149            plan,
150            diagnostics: self.diagnostics,
151        })
152    }
153
154    /// Build a plan from topologically sorted nodes, detecting parallelism.
155    fn build_plan(&self, sorted: &[&str]) -> ExecutionPlan {
156        // Compute topological levels (nodes at the same level can run in parallel)
157        let levels = self.compute_levels(sorted);
158
159        let mut plan_steps: Vec<ExecutionPlan> = Vec::new();
160
161        for level in &levels {
162            if level.len() == 1 {
163                plan_steps.push(self.plan_for_node(level[0]));
164            } else {
165                let branches: Vec<ExecutionPlan> =
166                    level.iter().map(|id| self.plan_for_node(id)).collect();
167                plan_steps.push(ExecutionPlan::Parallel(branches));
168            }
169        }
170
171        if plan_steps.len() == 1 {
172            plan_steps.into_iter().next().unwrap()
173        } else {
174            ExecutionPlan::Sequence(plan_steps)
175        }
176    }
177
178    /// Generate the execution plan for a single node based on its kind.
179    fn plan_for_node(&self, node_id: &str) -> ExecutionPlan {
180        use somatize_core::graph::NodeKind;
181
182        let node = match self.graph.node(node_id) {
183            Some(n) => n,
184            None => {
185                return ExecutionPlan::Execute {
186                    node_id: node_id.to_string(),
187                };
188            }
189        };
190
191        match &node.kind {
192            NodeKind::Filter { .. } => ExecutionPlan::Execute {
193                node_id: node_id.to_string(),
194            },
195
196            NodeKind::SubGraph { graph } => {
197                // Recursively compile the inner graph
198                let inner_compiler = Compiler::new(graph, self.registry, self.mode);
199                match inner_compiler.compile(None) {
200                    Ok(result) => result.plan,
201                    Err(_) => ExecutionPlan::Execute {
202                        node_id: node_id.to_string(),
203                    },
204                }
205            }
206
207            NodeKind::Loop { max_iterations } => {
208                // The body consists of the successors of this loop node.
209                // Build a sub-plan from the successor chain.
210                let successors = self.graph.successors(node_id);
211                let body = if successors.len() == 1 {
212                    self.plan_for_node(successors[0])
213                } else if successors.len() > 1 {
214                    let branches: Vec<ExecutionPlan> =
215                        successors.iter().map(|id| self.plan_for_node(id)).collect();
216                    ExecutionPlan::Parallel(branches)
217                } else {
218                    ExecutionPlan::Empty
219                };
220                ExecutionPlan::Loop {
221                    node_id: node_id.to_string(),
222                    body: Box::new(body),
223                    max_iterations: *max_iterations,
224                }
225            }
226
227            NodeKind::Branch => {
228                // Arms come from control edges leaving this node.
229                let arms: Vec<(String, ExecutionPlan)> = self
230                    .graph
231                    .edges
232                    .iter()
233                    .filter(|e| e.source == node_id)
234                    .map(|e| {
235                        let label = e.label.clone().unwrap_or_else(|| e.target.clone());
236                        let plan = self.plan_for_node(&e.target);
237                        (label, plan)
238                    })
239                    .collect();
240                ExecutionPlan::Branch {
241                    node_id: node_id.to_string(),
242                    arms,
243                }
244            }
245
246            _ => ExecutionPlan::Execute {
247                node_id: node_id.to_string(),
248            },
249        }
250    }
251
252    /// Compute topological levels: groups of nodes that can execute concurrently.
253    /// Each node's level = max(predecessor levels) + 1.
254    fn compute_levels<'b>(&self, sorted: &[&'b str]) -> Vec<Vec<&'b str>> {
255        let mut node_level: HashMap<&str, usize> = HashMap::new();
256        let mut max_level: usize = 0;
257
258        for &node in sorted {
259            let preds = self.graph.predecessors(node);
260            let level = if preds.is_empty() {
261                0
262            } else {
263                preds
264                    .iter()
265                    .map(|p| node_level.get(p).copied().unwrap_or(0) + 1)
266                    .max()
267                    .unwrap_or(0)
268            };
269            node_level.insert(node, level);
270            if level > max_level {
271                max_level = level;
272            }
273        }
274
275        let mut levels: Vec<Vec<&str>> = vec![Vec::new(); max_level + 1];
276        for &node in sorted {
277            let level = node_level[node];
278            levels[level].push(node);
279        }
280
281        // Remove empty levels (shouldn't happen but defensive)
282        levels.retain(|l| !l.is_empty());
283        levels
284    }
285
286    /// Resolve caching: replace Execute nodes with Cached when possible.
287    /// Implements cascade invalidation.
288    fn resolve_cache(
289        &self,
290        plan: ExecutionPlan,
291        cache: &dyn CacheStore,
292        sorted: &[&str],
293    ) -> Result<ExecutionPlan> {
294        if self.mode == CompileMode::NoCache {
295            return Ok(plan);
296        }
297
298        // Compute cache keys for all nodes in topological order.
299        // A node's key depends on its config + its predecessors' keys.
300        let mut node_keys: HashMap<String, CacheKey> = HashMap::new();
301        let mut cached_nodes: HashSet<String> = HashSet::new();
302
303        for &node_id in sorted {
304            let config_hash = match self.registry.config_hash(node_id) {
305                Some(h) => h,
306                None => continue, // no filter registered, can't cache
307            };
308
309            let meta = self.registry.meta(node_id);
310            let cacheable = meta.as_ref().is_some_and(|m| m.cacheable);
311
312            // In differentiable mode, only cache states (not forward outputs)
313            // For simplicity at this stage, we skip caching in differentiable mode
314            let can_cache = cacheable && self.mode == CompileMode::Inference;
315
316            // Build the cache key from config + predecessor output keys
317            let pred_ids = self.graph.predecessors(node_id);
318            let mut key_parts: Vec<Vec<u8>> = vec![config_hash.0.to_vec()];
319            for pred in &pred_ids {
320                if let Some(pred_key) = node_keys.get(*pred) {
321                    key_parts.push(pred_key.0.to_vec());
322                } else {
323                    // Predecessor should always be processed first in topological order.
324                    // If missing, the cache key will be incomplete but won't panic.
325                    debug_assert!(
326                        false,
327                        "predecessor `{pred}` of `{node_id}` not in node_keys - \
328                         topological order may be broken"
329                    );
330                }
331            }
332            let parts_refs: Vec<&[u8]> = key_parts.iter().map(|p| p.as_slice()).collect();
333            let key = CacheKey::from_parts(&parts_refs);
334            node_keys.insert(node_id.to_string(), key.clone());
335
336            // Check if this node's output exists in cache
337            if can_cache {
338                // Only use cache if ALL predecessors are also cached (or roots).
339                // This ensures cascade invalidation: if any upstream re-executed,
340                // the key is already different (since it includes predecessor keys).
341                if cache.exists(&key)? {
342                    cached_nodes.insert(node_id.to_string());
343                }
344            }
345        }
346
347        // Replace Execute nodes with Cached where applicable
348        Ok(self.apply_cache_to_plan(plan, &cached_nodes, &node_keys))
349    }
350
351    fn apply_cache_to_plan(
352        &self,
353        plan: ExecutionPlan,
354        cached: &HashSet<String>,
355        keys: &HashMap<String, CacheKey>,
356    ) -> ExecutionPlan {
357        match plan {
358            ExecutionPlan::Execute { ref node_id } => {
359                if cached.contains(node_id)
360                    && let Some(key) = keys.get(node_id)
361                {
362                    return ExecutionPlan::Cached {
363                        node_id: node_id.clone(),
364                        key: key.clone(),
365                    };
366                }
367                plan
368            }
369            ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
370                steps
371                    .into_iter()
372                    .map(|s| self.apply_cache_to_plan(s, cached, keys))
373                    .collect(),
374            ),
375            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
376                branches
377                    .into_iter()
378                    .map(|b| self.apply_cache_to_plan(b, cached, keys))
379                    .collect(),
380            ),
381            other => other,
382        }
383    }
384
385    /// Wrap nodes with Remote distribution in ExecutionPlan::Remote.
386    fn resolve_distribution(&self, plan: ExecutionPlan) -> ExecutionPlan {
387        match plan {
388            ExecutionPlan::Execute { ref node_id } => {
389                if let Some(meta) = self.registry.meta(node_id) {
390                    match &meta.distribution {
391                        somatize_core::filter::Distribution::Remote(target) => {
392                            ExecutionPlan::Remote {
393                                node_id: node_id.clone(),
394                                target: target.clone(),
395                                plan: Box::new(plan),
396                            }
397                        }
398                        _ => plan,
399                    }
400                } else {
401                    plan
402                }
403            }
404            ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
405                steps
406                    .into_iter()
407                    .map(|s| self.resolve_distribution(s))
408                    .collect(),
409            ),
410            ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
411                branches
412                    .into_iter()
413                    .map(|b| self.resolve_distribution(b))
414                    .collect(),
415            ),
416            other => other,
417        }
418    }
419
420    /// Validate schema compatibility between connected filters.
421    ///
422    /// For each edge (A → B), checks that A's output_schema is compatible
423    /// with B's input_schema. Emits warnings (not errors) for mismatches,
424    /// since schemas are optional and None means "accepts anything".
425    fn validate_schemas(&mut self, sorted: &[&str]) {
426        for &node_id in sorted {
427            let input_schema = self
428                .registry
429                .meta(node_id)
430                .and_then(|m| m.input_schema.clone());
431
432            // Skip if this node accepts anything
433            let Some(expected_input) = input_schema else {
434                continue;
435            };
436
437            // Check each predecessor's output schema
438            for pred_id in self.graph.predecessors(node_id) {
439                let pred_output = self
440                    .registry
441                    .meta(pred_id)
442                    .and_then(|m| m.output_schema.clone());
443
444                let Some(actual_output) = pred_output else {
445                    continue; // predecessor output unknown, skip
446                };
447
448                if !actual_output.is_compatible_with(&expected_input) {
449                    self.diagnostics.push(Diagnostic {
450                        node_id: node_id.to_string(),
451                        level: DiagnosticLevel::Warning,
452                        message: format!(
453                            "schema mismatch: `{pred_id}` outputs {actual_output} \
454                             but `{node_id}` expects {expected_input}",
455                        ),
456                    });
457                }
458            }
459        }
460    }
461
462    /// Check gradient flow and emit warnings for each interruption.
463    ///
464    /// Gradient flow can restart after an opaque node (differentiable nodes
465    /// after an opaque one can still propagate gradients among themselves),
466    /// but gradients from before the interruption are lost.
467    fn check_gradient_flow(&mut self, sorted: &[&str]) {
468        let mut gradient_flows = true;
469
470        for &node_id in sorted {
471            if let Some(meta) = self.registry.meta(node_id) {
472                if gradient_flows && !meta.differentiable {
473                    self.diagnostics.push(Diagnostic {
474                        node_id: node_id.to_string(),
475                        level: DiagnosticLevel::Warning,
476                        message: format!(
477                            "gradient flow interrupted at `{}` ({:?}). \
478                             Gradients from upstream will not reach downstream filters \
479                             through this node.",
480                            node_id, meta.kind,
481                        ),
482                    });
483                    gradient_flows = false;
484                } else if !gradient_flows && meta.differentiable {
485                    // Gradient flow restarts: differentiable nodes after the
486                    // interruption can propagate gradients among themselves
487                    gradient_flows = true;
488                }
489            }
490        }
491    }
492}
493
494/// Convenience function: compile a graph with default settings.
495pub fn compile(
496    graph: &Graph,
497    registry: &dyn FilterRegistry,
498    mode: CompileMode,
499    cache: Option<&dyn CacheStore>,
500) -> Result<CompileResult> {
501    Compiler::new(graph, registry, mode).compile(cache)
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use somatize_core::cache::EntryMeta;
508    use somatize_core::error::SomaError;
509    use somatize_core::filter::{FilterKind, StreamMode};
510    use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
511    use somatize_core::value::Value;
512    use std::sync::Mutex;
513
514    // ── Mock cache store ──
515
516    struct MockCacheStore {
517        entries: Mutex<HashSet<CacheKey>>,
518    }
519
520    impl MockCacheStore {
521        fn new() -> Self {
522            Self {
523                entries: Mutex::new(HashSet::new()),
524            }
525        }
526
527        fn insert(&self, key: CacheKey) {
528            self.entries.lock().unwrap().insert(key);
529        }
530    }
531
532    impl CacheStore for MockCacheStore {
533        fn get(&self, _key: &CacheKey) -> Result<Option<Value>> {
534            Ok(None)
535        }
536        fn put(&self, _key: &CacheKey, _value: &Value) -> Result<()> {
537            Ok(())
538        }
539        fn exists(&self, key: &CacheKey) -> Result<bool> {
540            Ok(self.entries.lock().unwrap().contains(key))
541        }
542        fn remove(&self, _key: &CacheKey) -> Result<()> {
543            Ok(())
544        }
545        fn metadata(&self, _key: &CacheKey) -> Result<Option<EntryMeta>> {
546            Ok(None)
547        }
548    }
549
550    // ── Helpers ──
551
552    fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
553        FilterMeta {
554            name: "test".into(),
555            kind,
556            cacheable: true,
557            differentiable,
558            stream_mode: StreamMode::FixedState,
559            distribution: somatize_core::filter::Distribution::Local,
560            input_schema: None,
561            output_schema: None,
562        }
563    }
564
565    fn register_nodes(registry: &mut SimpleFilterRegistry, ids: &[&str], meta: FilterMeta) {
566        for (i, id) in ids.iter().enumerate() {
567            let hash = CacheKey::from_parts(&[id.as_bytes(), &[i as u8]]);
568            registry.register_meta(*id, meta.clone(), hash);
569        }
570    }
571
572    // ── Tests ──
573
574    #[test]
575    fn compile_empty_graph() {
576        let graph = Graph::new();
577        let registry = SimpleFilterRegistry::new();
578        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
579        assert!(matches!(result.plan, ExecutionPlan::Empty));
580    }
581
582    #[test]
583    fn compile_single_node() {
584        let mut graph = Graph::new();
585        graph.add_node(Node::new("a", "A", "F"));
586        let mut registry = SimpleFilterRegistry::new();
587        register_nodes(
588            &mut registry,
589            &["a"],
590            make_meta(FilterKind::Trainable, true),
591        );
592
593        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
594        assert!(matches!(result.plan, ExecutionPlan::Execute { .. }));
595    }
596
597    #[test]
598    fn compile_linear_pipeline_produces_sequence() {
599        let graph = linear_pipeline(vec![
600            Node::new("a", "Scaler", "F"),
601            Node::new("b", "PCA", "F"),
602            Node::new("c", "SVM", "F"),
603        ]);
604        let mut registry = SimpleFilterRegistry::new();
605        register_nodes(
606            &mut registry,
607            &["a", "b", "c"],
608            make_meta(FilterKind::Trainable, true),
609        );
610
611        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
612
613        if let ExecutionPlan::Sequence(steps) = &result.plan {
614            assert_eq!(steps.len(), 3);
615            assert!(
616                steps
617                    .iter()
618                    .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
619            );
620        } else {
621            panic!("expected Sequence, got: {:?}", result.plan);
622        }
623    }
624
625    #[test]
626    fn compile_diamond_detects_parallelism() {
627        let mut graph = Graph::new();
628        graph.add_node(Node::new("root", "Root", "F"));
629        graph.add_node(Node::new("b1", "B1", "F"));
630        graph.add_node(Node::new("b2", "B2", "F"));
631        graph.add_node(Node::new("merge", "Merge", "F"));
632        graph.add_edge(Edge::data("e1", "root", "b1"));
633        graph.add_edge(Edge::data("e2", "root", "b2"));
634        graph.add_edge(Edge::data("e3", "b1", "merge"));
635        graph.add_edge(Edge::data("e4", "b2", "merge"));
636
637        let mut registry = SimpleFilterRegistry::new();
638        register_nodes(
639            &mut registry,
640            &["root", "b1", "b2", "merge"],
641            make_meta(FilterKind::Trainable, true),
642        );
643
644        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
645
646        // Should be: Sequence(Execute(root), Parallel(Execute(b1), Execute(b2)), Execute(merge))
647        if let ExecutionPlan::Sequence(steps) = &result.plan {
648            assert_eq!(steps.len(), 3);
649            assert!(matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "root"));
650            assert!(matches!(&steps[1], ExecutionPlan::Parallel(branches) if branches.len() == 2));
651            assert!(matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "merge"));
652        } else {
653            panic!("expected Sequence, got: {:?}", result.plan);
654        }
655    }
656
657    #[test]
658    fn compile_independent_roots_parallel() {
659        let mut graph = Graph::new();
660        graph.add_node(Node::new("a", "A", "F"));
661        graph.add_node(Node::new("b", "B", "F"));
662        // No edges: fully independent
663
664        let mut registry = SimpleFilterRegistry::new();
665        register_nodes(
666            &mut registry,
667            &["a", "b"],
668            make_meta(FilterKind::Trainable, true),
669        );
670
671        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
672
673        // Both at level 0 → Parallel
674        assert!(matches!(result.plan, ExecutionPlan::Parallel(_)));
675    }
676
677    #[test]
678    fn cache_resolution_replaces_cached_nodes() {
679        let graph = linear_pipeline(vec![
680            Node::new("a", "Scaler", "F"),
681            Node::new("b", "PCA", "F"),
682            Node::new("c", "SVM", "F"),
683        ]);
684
685        let mut registry = SimpleFilterRegistry::new();
686        register_nodes(
687            &mut registry,
688            &["a", "b", "c"],
689            make_meta(FilterKind::Trainable, true),
690        );
691
692        // Pre-compute the cache key for node "a" (same logic as compiler)
693        let a_config = registry.config_hash("a").unwrap();
694        let a_cache_key = CacheKey::from_parts(&[&a_config.0]);
695
696        let cache = MockCacheStore::new();
697        cache.insert(a_cache_key);
698
699        let result = compile(&graph, &registry, CompileMode::Inference, Some(&cache)).unwrap();
700
701        if let ExecutionPlan::Sequence(steps) = &result.plan {
702            assert!(
703                matches!(&steps[0], ExecutionPlan::Cached { node_id, .. } if node_id == "a"),
704                "first node should be cached, got: {:?}",
705                steps[0]
706            );
707            assert!(matches!(&steps[1], ExecutionPlan::Execute { .. }));
708            assert!(matches!(&steps[2], ExecutionPlan::Execute { .. }));
709        } else {
710            panic!("expected Sequence, got: {:?}", result.plan);
711        }
712    }
713
714    #[test]
715    fn cascade_invalidation_different_config_changes_keys() {
716        // Register with config hash "v1"
717        let mut reg1 = SimpleFilterRegistry::new();
718        reg1.register_meta(
719            "a",
720            make_meta(FilterKind::Trainable, true),
721            CacheKey::hash_data(b"scaler_v1"),
722        );
723        reg1.register_meta(
724            "b",
725            make_meta(FilterKind::Trainable, true),
726            CacheKey::hash_data(b"pca_v1"),
727        );
728
729        // Register with config hash "v2" for node "a"
730        let mut reg2 = SimpleFilterRegistry::new();
731        reg2.register_meta(
732            "a",
733            make_meta(FilterKind::Trainable, true),
734            CacheKey::hash_data(b"scaler_v2"), // changed!
735        );
736        reg2.register_meta(
737            "b",
738            make_meta(FilterKind::Trainable, true),
739            CacheKey::hash_data(b"pca_v1"), // same
740        );
741
742        // Compute keys for both configurations
743        // The plans have same structure but when cache keys are computed,
744        // changing "a" config changes "b"'s key too (cascade).
745        // We verify this by computing keys manually:
746        let a_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v1").0]);
747        let b_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v1.0]);
748
749        let a_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v2").0]);
750        let b_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v2.0]);
751
752        // a changed → a's key changed
753        assert_ne!(a_key_v1, a_key_v2);
754        // b's config didn't change but a's key is in b's key → b's key also changed
755        assert_ne!(b_key_v1, b_key_v2);
756    }
757
758    #[test]
759    fn no_cache_mode_skips_all_caching() {
760        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
761
762        let mut registry = SimpleFilterRegistry::new();
763        register_nodes(
764            &mut registry,
765            &["a", "b"],
766            make_meta(FilterKind::Trainable, true),
767        );
768
769        // Put everything in cache
770        let a_config = registry.config_hash("a").unwrap();
771        let a_key = CacheKey::from_parts(&[&a_config.0]);
772        let cache = MockCacheStore::new();
773        cache.insert(a_key);
774
775        let result = compile(&graph, &registry, CompileMode::NoCache, Some(&cache)).unwrap();
776
777        // Nothing should be cached
778        assert_eq!(result.plan.cached_count(), 0);
779    }
780
781    #[test]
782    fn differentiable_mode_skips_output_caching() {
783        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
784
785        let mut registry = SimpleFilterRegistry::new();
786        register_nodes(
787            &mut registry,
788            &["a", "b"],
789            make_meta(FilterKind::Trainable, true),
790        );
791
792        let a_config = registry.config_hash("a").unwrap();
793        let a_key = CacheKey::from_parts(&[&a_config.0]);
794        let cache = MockCacheStore::new();
795        cache.insert(a_key);
796
797        let result = compile(&graph, &registry, CompileMode::Differentiable, Some(&cache)).unwrap();
798
799        // Differentiable mode should not cache forward outputs
800        assert_eq!(result.plan.cached_count(), 0);
801    }
802
803    #[test]
804    fn gradient_flow_diagnostic_on_opaque() {
805        let graph = linear_pipeline(vec![
806            Node::new("scaler", "Scaler", "F"),
807            Node::new("tree", "DecisionTree", "F"),
808            Node::new("linear", "Linear", "F"),
809        ]);
810
811        let mut registry = SimpleFilterRegistry::new();
812        registry.register_meta(
813            "scaler",
814            make_meta(FilterKind::Trainable, true),
815            CacheKey::hash_data(b"s"),
816        );
817        registry.register_meta(
818            "tree",
819            make_meta(FilterKind::Opaque, false), // not differentiable
820            CacheKey::hash_data(b"t"),
821        );
822        registry.register_meta(
823            "linear",
824            make_meta(FilterKind::Trainable, true),
825            CacheKey::hash_data(b"l"),
826        );
827
828        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
829
830        assert_eq!(result.diagnostics.len(), 1);
831        assert_eq!(result.diagnostics[0].node_id, "tree");
832        assert_eq!(result.diagnostics[0].level, DiagnosticLevel::Warning);
833        assert!(
834            result.diagnostics[0]
835                .message
836                .contains("gradient flow interrupted")
837        );
838    }
839
840    #[test]
841    fn no_diagnostic_when_all_differentiable() {
842        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
843
844        let mut registry = SimpleFilterRegistry::new();
845        register_nodes(
846            &mut registry,
847            &["a", "b"],
848            make_meta(FilterKind::Trainable, true),
849        );
850
851        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
852        assert!(result.diagnostics.is_empty());
853    }
854
855    #[test]
856    fn compile_rejects_cycle() {
857        let mut graph = Graph::new();
858        graph.add_node(Node::new("a", "A", "F"));
859        graph.add_node(Node::new("b", "B", "F"));
860        graph.add_edge(Edge::data("e1", "a", "b"));
861        graph.add_edge(Edge::data("e2", "b", "a"));
862
863        let registry = SimpleFilterRegistry::new();
864        let result = compile(&graph, &registry, CompileMode::Inference, None);
865        assert!(matches!(result, Err(SomaError::CycleDetected)));
866    }
867
868    #[test]
869    fn plan_summary_is_accurate() {
870        let mut graph = Graph::new();
871        graph.add_node(Node::new("root", "Root", "F"));
872        graph.add_node(Node::new("b1", "B1", "F"));
873        graph.add_node(Node::new("b2", "B2", "F"));
874        graph.add_node(Node::new("end", "End", "F"));
875        graph.add_edge(Edge::data("e1", "root", "b1"));
876        graph.add_edge(Edge::data("e2", "root", "b2"));
877        graph.add_edge(Edge::data("e3", "b1", "end"));
878        graph.add_edge(Edge::data("e4", "b2", "end"));
879
880        let mut registry = SimpleFilterRegistry::new();
881        register_nodes(
882            &mut registry,
883            &["root", "b1", "b2", "end"],
884            make_meta(FilterKind::Trainable, true),
885        );
886
887        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
888        let summary = result.plan.summary();
889        assert_eq!(summary.total_nodes, 4);
890        assert_eq!(summary.parallel_branches, 2);
891    }
892
893    #[test]
894    fn distribution_wraps_remote_nodes() {
895        let graph = linear_pipeline(vec![
896            Node::new("preprocess", "Preprocess", "F"),
897            Node::new("gpu_train", "GpuTrain", "F"),
898            Node::new("evaluate", "Evaluate", "F"),
899        ]);
900
901        let mut registry = SimpleFilterRegistry::new();
902        // preprocess: local
903        registry.register_meta(
904            "preprocess",
905            make_meta(FilterKind::Trainable, true),
906            CacheKey::hash_data(b"pre"),
907        );
908        // gpu_train: remote on GPU tag
909        let mut gpu_meta = make_meta(FilterKind::Trainable, true);
910        gpu_meta.distribution = somatize_core::filter::Distribution::Remote(
911            somatize_core::filter::RemoteTarget::Tag("gpu".into()),
912        );
913        registry.register_meta("gpu_train", gpu_meta, CacheKey::hash_data(b"gpu"));
914        // evaluate: local
915        registry.register_meta(
916            "evaluate",
917            make_meta(FilterKind::Trainable, true),
918            CacheKey::hash_data(b"eval"),
919        );
920
921        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
922
923        // Should be: Sequence(Execute(preprocess), Remote(gpu_train, ...), Execute(evaluate))
924        if let ExecutionPlan::Sequence(steps) = &result.plan {
925            assert_eq!(steps.len(), 3);
926            assert!(
927                matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "preprocess")
928            );
929            assert!(
930                matches!(&steps[1], ExecutionPlan::Remote { node_id, target, .. }
931                    if node_id == "gpu_train"
932                    && *target == somatize_core::filter::RemoteTarget::Tag("gpu".into())
933                ),
934                "expected Remote, got: {:?}",
935                steps[1]
936            );
937            assert!(
938                matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "evaluate")
939            );
940        } else {
941            panic!("expected Sequence, got: {:?}", result.plan);
942        }
943    }
944
945    #[test]
946    fn local_distribution_not_wrapped() {
947        let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
948
949        let mut registry = SimpleFilterRegistry::new();
950        register_nodes(
951            &mut registry,
952            &["a", "b"],
953            make_meta(FilterKind::Trainable, true),
954        );
955
956        let result = compile(&graph, &registry, CompileMode::Inference, None).unwrap();
957
958        // No Remote nodes
959        let ids = result.plan.node_ids();
960        assert_eq!(ids.len(), 2);
961        // Should all be Execute, no Remote wrapper
962        if let ExecutionPlan::Sequence(steps) = &result.plan {
963            assert!(
964                steps
965                    .iter()
966                    .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
967            );
968        }
969    }
970}