1use crate::plan::ExecutionPlan;
7use somatize_core::cache::{CacheKey, CacheStore};
8use somatize_core::error::Result;
9use somatize_core::filter::{Filter, FilterMeta};
10use somatize_core::graph::{Graph, NodeId};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompileMode {
16 Inference,
18 Differentiable,
20 NoCache,
22}
23
24#[derive(Debug, Clone)]
26pub struct Diagnostic {
27 pub node_id: NodeId,
28 pub level: DiagnosticLevel,
29 pub message: String,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DiagnosticLevel {
34 Warning,
35 Info,
36}
37
38pub struct CompileResult {
40 pub plan: ExecutionPlan,
41 pub diagnostics: Vec<Diagnostic>,
42}
43
44pub trait FilterRegistry: Send + Sync {
48 fn meta(&self, node_id: &str) -> Option<FilterMeta>;
49 fn config_hash(&self, node_id: &str) -> Option<CacheKey>;
50}
51
52pub struct SimpleFilterRegistry {
54 entries: HashMap<String, (FilterMeta, CacheKey)>,
55}
56
57impl SimpleFilterRegistry {
58 pub fn new() -> Self {
59 Self {
60 entries: HashMap::new(),
61 }
62 }
63
64 pub fn register(&mut self, node_id: impl Into<String>, filter: &dyn Filter) {
65 let id = node_id.into();
66 self.entries
67 .insert(id, (filter.meta(), filter.config_hash()));
68 }
69
70 pub fn register_meta(
71 &mut self,
72 node_id: impl Into<String>,
73 meta: FilterMeta,
74 config_hash: CacheKey,
75 ) {
76 self.entries.insert(node_id.into(), (meta, config_hash));
77 }
78}
79
80impl Default for SimpleFilterRegistry {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl FilterRegistry for SimpleFilterRegistry {
87 fn meta(&self, node_id: &str) -> Option<FilterMeta> {
88 self.entries.get(node_id).map(|(m, _)| m.clone())
89 }
90
91 fn config_hash(&self, node_id: &str) -> Option<CacheKey> {
92 self.entries.get(node_id).map(|(_, h)| h.clone())
93 }
94}
95
96pub struct Compiler<'a> {
98 graph: &'a Graph,
99 registry: &'a dyn FilterRegistry,
100 mode: CompileMode,
101 diagnostics: Vec<Diagnostic>,
102}
103
104impl<'a> Compiler<'a> {
105 pub fn new(graph: &'a Graph, registry: &'a dyn FilterRegistry, mode: CompileMode) -> Self {
106 Self {
107 graph,
108 registry,
109 mode,
110 diagnostics: Vec::new(),
111 }
112 }
113
114 pub fn compile(mut self, cache: Option<&dyn CacheStore>) -> Result<CompileResult> {
116 self.graph.validate()?;
117
118 let sorted = self.graph.topological_sort()?;
119
120 if sorted.is_empty() {
121 return Ok(CompileResult {
122 plan: ExecutionPlan::Empty,
123 diagnostics: self.diagnostics,
124 });
125 }
126
127 self.check_gradient_flow(&sorted);
129
130 self.validate_schemas(&sorted);
132
133 let plan = self.build_plan(&sorted);
135
136 let plan = if let Some(cache) = cache {
138 self.resolve_cache(plan, cache, &sorted)?
139 } else {
140 plan
141 };
142
143 let plan = self.resolve_distribution(plan);
145
146 let plan = self.collapse_differentiable(plan);
148
149 let plan = plan.simplify();
150
151 Ok(CompileResult {
152 plan,
153 diagnostics: self.diagnostics,
154 })
155 }
156
157 fn build_plan(&self, sorted: &[&str]) -> ExecutionPlan {
159 let levels = self.compute_levels(sorted);
161
162 let mut plan_steps: Vec<ExecutionPlan> = Vec::new();
163
164 for level in &levels {
165 if level.len() == 1 {
166 plan_steps.push(self.plan_for_node(level[0]));
167 } else {
168 let branches: Vec<ExecutionPlan> =
169 level.iter().map(|id| self.plan_for_node(id)).collect();
170 plan_steps.push(ExecutionPlan::Parallel(branches));
171 }
172 }
173
174 if plan_steps.len() == 1 {
175 plan_steps.into_iter().next().unwrap()
176 } else {
177 ExecutionPlan::Sequence(plan_steps)
178 }
179 }
180
181 fn plan_for_node(&self, node_id: &str) -> ExecutionPlan {
183 use somatize_core::graph::NodeKind;
184
185 let node = match self.graph.node(node_id) {
186 Some(n) => n,
187 None => {
188 return ExecutionPlan::Execute {
189 node_id: node_id.to_string(),
190 };
191 }
192 };
193
194 match &node.kind {
195 NodeKind::Filter { .. } => ExecutionPlan::Execute {
196 node_id: node_id.to_string(),
197 },
198
199 NodeKind::SubGraph { graph } => {
200 let inner_compiler = Compiler::new(graph, self.registry, self.mode);
202 match inner_compiler.compile(None) {
203 Ok(result) => result.plan,
204 Err(_) => ExecutionPlan::Execute {
205 node_id: node_id.to_string(),
206 },
207 }
208 }
209
210 NodeKind::Loop { max_iterations } => {
211 let successors = self.graph.successors(node_id);
214 let body = if successors.len() == 1 {
215 self.plan_for_node(successors[0])
216 } else if successors.len() > 1 {
217 let branches: Vec<ExecutionPlan> =
218 successors.iter().map(|id| self.plan_for_node(id)).collect();
219 ExecutionPlan::Parallel(branches)
220 } else {
221 ExecutionPlan::Empty
222 };
223 ExecutionPlan::Loop {
224 node_id: node_id.to_string(),
225 body: Box::new(body),
226 max_iterations: *max_iterations,
227 }
228 }
229
230 NodeKind::Branch => {
231 let arms: Vec<(String, ExecutionPlan)> = self
233 .graph
234 .edges
235 .iter()
236 .filter(|e| e.source == node_id)
237 .map(|e| {
238 let label = e.label.clone().unwrap_or_else(|| e.target.clone());
239 let plan = self.plan_for_node(&e.target);
240 (label, plan)
241 })
242 .collect();
243 ExecutionPlan::Branch {
244 node_id: node_id.to_string(),
245 arms,
246 }
247 }
248
249 _ => ExecutionPlan::Execute {
250 node_id: node_id.to_string(),
251 },
252 }
253 }
254
255 fn compute_levels<'b>(&self, sorted: &[&'b str]) -> Vec<Vec<&'b str>> {
258 let mut node_level: HashMap<&str, usize> = HashMap::new();
259 let mut max_level: usize = 0;
260
261 for &node in sorted {
262 let preds = self.graph.predecessors(node);
263 let level = if preds.is_empty() {
264 0
265 } else {
266 preds
267 .iter()
268 .map(|p| node_level.get(p).copied().unwrap_or(0) + 1)
269 .max()
270 .unwrap_or(0)
271 };
272 node_level.insert(node, level);
273 if level > max_level {
274 max_level = level;
275 }
276 }
277
278 let mut levels: Vec<Vec<&str>> = vec![Vec::new(); max_level + 1];
279 for &node in sorted {
280 let level = node_level[node];
281 levels[level].push(node);
282 }
283
284 levels.retain(|l| !l.is_empty());
286 levels
287 }
288
289 fn resolve_cache(
292 &self,
293 plan: ExecutionPlan,
294 cache: &dyn CacheStore,
295 sorted: &[&str],
296 ) -> Result<ExecutionPlan> {
297 if self.mode == CompileMode::NoCache {
298 return Ok(plan);
299 }
300
301 let mut node_keys: HashMap<String, CacheKey> = HashMap::new();
304 let mut cached_nodes: HashSet<String> = HashSet::new();
305
306 for &node_id in sorted {
307 let config_hash = match self.registry.config_hash(node_id) {
308 Some(h) => h,
309 None => continue, };
311
312 let meta = self.registry.meta(node_id);
313 let cacheable = meta.as_ref().is_some_and(|m| m.cacheable);
314
315 let can_cache = cacheable && self.mode == CompileMode::Inference;
318
319 let pred_ids = self.graph.predecessors(node_id);
321 let mut key_parts: Vec<Vec<u8>> = vec![config_hash.0.to_vec()];
322 for pred in &pred_ids {
323 if let Some(pred_key) = node_keys.get(*pred) {
324 key_parts.push(pred_key.0.to_vec());
325 } else {
326 debug_assert!(
329 false,
330 "predecessor `{pred}` of `{node_id}` not in node_keys - \
331 topological order may be broken"
332 );
333 }
334 }
335 let parts_refs: Vec<&[u8]> = key_parts.iter().map(|p| p.as_slice()).collect();
336 let key = CacheKey::from_parts(&parts_refs);
337 node_keys.insert(node_id.to_string(), key.clone());
338
339 if can_cache {
341 if cache.exists(&key)? {
345 cached_nodes.insert(node_id.to_string());
346 }
347 }
348 }
349
350 Ok(self.apply_cache_to_plan(plan, &cached_nodes, &node_keys))
352 }
353
354 fn apply_cache_to_plan(
355 &self,
356 plan: ExecutionPlan,
357 cached: &HashSet<String>,
358 keys: &HashMap<String, CacheKey>,
359 ) -> ExecutionPlan {
360 match plan {
361 ExecutionPlan::Execute { ref node_id } => {
362 if cached.contains(node_id)
363 && let Some(key) = keys.get(node_id)
364 {
365 return ExecutionPlan::Cached {
366 node_id: node_id.clone(),
367 key: key.clone(),
368 };
369 }
370 plan
371 }
372 ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
373 steps
374 .into_iter()
375 .map(|s| self.apply_cache_to_plan(s, cached, keys))
376 .collect(),
377 ),
378 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
379 branches
380 .into_iter()
381 .map(|b| self.apply_cache_to_plan(b, cached, keys))
382 .collect(),
383 ),
384 other => other,
385 }
386 }
387
388 fn resolve_distribution(&self, plan: ExecutionPlan) -> ExecutionPlan {
390 match plan {
391 ExecutionPlan::Execute { ref node_id } => {
392 if let Some(meta) = self.registry.meta(node_id) {
393 match &meta.distribution {
394 somatize_core::filter::Distribution::Remote(target) => {
395 ExecutionPlan::Remote {
396 node_id: node_id.clone(),
397 target: target.clone(),
398 plan: Box::new(plan),
399 }
400 }
401 _ => plan,
402 }
403 } else {
404 plan
405 }
406 }
407 ExecutionPlan::Sequence(steps) => ExecutionPlan::Sequence(
408 steps
409 .into_iter()
410 .map(|s| self.resolve_distribution(s))
411 .collect(),
412 ),
413 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
414 branches
415 .into_iter()
416 .map(|b| self.resolve_distribution(b))
417 .collect(),
418 ),
419 ExecutionPlan::Composite { ref node_ids } => {
420 let targets: Vec<_> = node_ids
424 .iter()
425 .filter_map(|nid| {
426 self.registry.meta(nid).and_then(|m| match &m.distribution {
427 somatize_core::filter::Distribution::Remote(t) => Some(t.clone()),
428 _ => None,
429 })
430 })
431 .collect();
432
433 if targets.len() == node_ids.len() && !targets.is_empty() {
434 let first_id = node_ids[0].clone();
435 ExecutionPlan::Remote {
436 node_id: first_id,
437 target: targets.into_iter().next().unwrap(),
438 plan: Box::new(plan),
439 }
440 } else {
441 plan
442 }
443 }
444 other => other,
445 }
446 }
447
448 fn collapse_differentiable(&self, plan: ExecutionPlan) -> ExecutionPlan {
453 match plan {
454 ExecutionPlan::Sequence(steps) => {
455 let mut result: Vec<ExecutionPlan> = Vec::new();
456 let mut diff_group: Vec<String> = Vec::new();
457
458 for step in steps {
459 if let ExecutionPlan::Execute { ref node_id } = step
460 && self
461 .registry
462 .meta(node_id)
463 .map(|m| m.differentiable)
464 .unwrap_or(false)
465 {
466 diff_group.push(node_id.clone());
467 continue;
468 }
469 Self::flush_diff_group(&mut diff_group, &mut result);
471 result.push(self.collapse_differentiable(step));
472 }
473 Self::flush_diff_group(&mut diff_group, &mut result);
474
475 if result.len() == 1 {
476 result.pop().unwrap()
477 } else {
478 ExecutionPlan::Sequence(result)
479 }
480 }
481 ExecutionPlan::Parallel(branches) => ExecutionPlan::Parallel(
482 branches
483 .into_iter()
484 .map(|b| self.collapse_differentiable(b))
485 .collect(),
486 ),
487 ExecutionPlan::Remote {
488 node_id,
489 target,
490 plan,
491 } => ExecutionPlan::Remote {
492 node_id,
493 target,
494 plan: Box::new(self.collapse_differentiable(*plan)),
495 },
496 other => other,
497 }
498 }
499
500 fn flush_diff_group(group: &mut Vec<String>, result: &mut Vec<ExecutionPlan>) {
501 if group.len() > 1 {
502 result.push(ExecutionPlan::Composite {
503 node_ids: std::mem::take(group),
504 });
505 } else if let Some(id) = group.pop() {
506 result.push(ExecutionPlan::Execute { node_id: id });
507 }
508 }
509
510 fn validate_schemas(&mut self, sorted: &[&str]) {
516 for &node_id in sorted {
517 let input_schema = self
518 .registry
519 .meta(node_id)
520 .and_then(|m| m.input_schema.clone());
521
522 let Some(expected_input) = input_schema else {
524 continue;
525 };
526
527 for pred_id in self.graph.predecessors(node_id) {
529 let pred_output = self
530 .registry
531 .meta(pred_id)
532 .and_then(|m| m.output_schema.clone());
533
534 let Some(actual_output) = pred_output else {
535 continue; };
537
538 if !actual_output.is_compatible_with(&expected_input) {
539 self.diagnostics.push(Diagnostic {
540 node_id: node_id.to_string(),
541 level: DiagnosticLevel::Warning,
542 message: format!(
543 "schema mismatch: `{pred_id}` outputs {actual_output} \
544 but `{node_id}` expects {expected_input}",
545 ),
546 });
547 }
548 }
549 }
550 }
551
552 fn check_gradient_flow(&mut self, sorted: &[&str]) {
558 let mut gradient_flows = true;
559
560 for &node_id in sorted {
561 if let Some(meta) = self.registry.meta(node_id) {
562 if gradient_flows && !meta.differentiable {
563 self.diagnostics.push(Diagnostic {
564 node_id: node_id.to_string(),
565 level: DiagnosticLevel::Warning,
566 message: format!(
567 "gradient flow interrupted at `{}` ({:?}). \
568 Gradients from upstream will not reach downstream filters \
569 through this node.",
570 node_id, meta.kind,
571 ),
572 });
573 gradient_flows = false;
574 } else if !gradient_flows && meta.differentiable {
575 gradient_flows = true;
578 }
579 }
580 }
581 }
582}
583
584pub fn compile(
586 graph: &Graph,
587 registry: &dyn FilterRegistry,
588 mode: CompileMode,
589 cache: Option<&dyn CacheStore>,
590) -> Result<CompileResult> {
591 Compiler::new(graph, registry, mode).compile(cache)
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use somatize_core::cache::EntryMeta;
598 use somatize_core::error::SomaError;
599 use somatize_core::filter::{FilterKind, StreamMode};
600 use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
601 use somatize_core::value::Value;
602 use std::sync::Mutex;
603
604 struct MockCacheStore {
607 entries: Mutex<HashSet<CacheKey>>,
608 }
609
610 impl MockCacheStore {
611 fn new() -> Self {
612 Self {
613 entries: Mutex::new(HashSet::new()),
614 }
615 }
616
617 fn insert(&self, key: CacheKey) {
618 self.entries.lock().unwrap().insert(key);
619 }
620 }
621
622 impl CacheStore for MockCacheStore {
623 fn get(&self, _key: &CacheKey) -> Result<Option<Value>> {
624 Ok(None)
625 }
626 fn put(&self, _key: &CacheKey, _value: &Value) -> Result<()> {
627 Ok(())
628 }
629 fn exists(&self, key: &CacheKey) -> Result<bool> {
630 Ok(self.entries.lock().unwrap().contains(key))
631 }
632 fn remove(&self, _key: &CacheKey) -> Result<()> {
633 Ok(())
634 }
635 fn metadata(&self, _key: &CacheKey) -> Result<Option<EntryMeta>> {
636 Ok(None)
637 }
638 }
639
640 fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
643 FilterMeta {
644 name: "test".into(),
645 kind,
646 cacheable: true,
647 differentiable,
648 stream_mode: StreamMode::FixedState,
649 distribution: somatize_core::filter::Distribution::Local,
650 input_schema: None,
651 output_schema: None,
652 }
653 }
654
655 fn register_nodes(registry: &mut SimpleFilterRegistry, ids: &[&str], meta: FilterMeta) {
656 for (i, id) in ids.iter().enumerate() {
657 let hash = CacheKey::from_parts(&[id.as_bytes(), &[i as u8]]);
658 registry.register_meta(*id, meta.clone(), hash);
659 }
660 }
661
662 #[test]
665 fn compile_empty_graph() {
666 let graph = Graph::new();
667 let registry = SimpleFilterRegistry::new();
668 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
669 assert!(matches!(result.plan, ExecutionPlan::Empty));
670 }
671
672 #[test]
673 fn compile_single_node() {
674 let mut graph = Graph::new();
675 graph.add_node(Node::new("a", "A", "F"));
676 let mut registry = SimpleFilterRegistry::new();
677 register_nodes(
678 &mut registry,
679 &["a"],
680 make_meta(FilterKind::Trainable, true),
681 );
682
683 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
684 assert!(matches!(result.plan, ExecutionPlan::Execute { .. }));
685 }
686
687 #[test]
688 fn compile_linear_pipeline_produces_sequence() {
689 let graph = linear_pipeline(vec![
690 Node::new("a", "Scaler", "F"),
691 Node::new("b", "PCA", "F"),
692 Node::new("c", "SVM", "F"),
693 ]);
694 let mut registry = SimpleFilterRegistry::new();
695 register_nodes(
696 &mut registry,
697 &["a", "b", "c"],
698 make_meta(FilterKind::Trainable, true),
699 );
700
701 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
702
703 if let ExecutionPlan::Composite { node_ids } = &result.plan {
705 assert_eq!(node_ids, &["a", "b", "c"]);
706 } else {
707 panic!("expected Composite, got: {:?}", result.plan);
708 }
709 }
710
711 #[test]
712 fn compile_diamond_detects_parallelism() {
713 let mut graph = Graph::new();
714 graph.add_node(Node::new("root", "Root", "F"));
715 graph.add_node(Node::new("b1", "B1", "F"));
716 graph.add_node(Node::new("b2", "B2", "F"));
717 graph.add_node(Node::new("merge", "Merge", "F"));
718 graph.add_edge(Edge::data("e1", "root", "b1"));
719 graph.add_edge(Edge::data("e2", "root", "b2"));
720 graph.add_edge(Edge::data("e3", "b1", "merge"));
721 graph.add_edge(Edge::data("e4", "b2", "merge"));
722
723 let mut registry = SimpleFilterRegistry::new();
724 register_nodes(
725 &mut registry,
726 &["root", "b1", "b2", "merge"],
727 make_meta(FilterKind::Trainable, true),
728 );
729
730 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
731
732 if let ExecutionPlan::Sequence(steps) = &result.plan {
734 assert_eq!(steps.len(), 3);
735 assert!(matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "root"));
736 assert!(matches!(&steps[1], ExecutionPlan::Parallel(branches) if branches.len() == 2));
737 assert!(matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "merge"));
738 } else {
739 panic!("expected Sequence, got: {:?}", result.plan);
740 }
741 }
742
743 #[test]
744 fn compile_independent_roots_parallel() {
745 let mut graph = Graph::new();
746 graph.add_node(Node::new("a", "A", "F"));
747 graph.add_node(Node::new("b", "B", "F"));
748 let mut registry = SimpleFilterRegistry::new();
751 register_nodes(
752 &mut registry,
753 &["a", "b"],
754 make_meta(FilterKind::Trainable, true),
755 );
756
757 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
758
759 assert!(matches!(result.plan, ExecutionPlan::Parallel(_)));
761 }
762
763 #[test]
764 fn cache_resolution_replaces_cached_nodes() {
765 let graph = linear_pipeline(vec![
766 Node::new("a", "Scaler", "F"),
767 Node::new("b", "PCA", "F"),
768 Node::new("c", "SVM", "F"),
769 ]);
770
771 let mut registry = SimpleFilterRegistry::new();
772 register_nodes(
773 &mut registry,
774 &["a", "b", "c"],
775 make_meta(FilterKind::Trainable, true),
776 );
777
778 let a_config = registry.config_hash("a").unwrap();
780 let a_cache_key = CacheKey::from_parts(&[&a_config.0]);
781
782 let cache = MockCacheStore::new();
783 cache.insert(a_cache_key);
784
785 let result = compile(&graph, ®istry, CompileMode::Inference, Some(&cache)).unwrap();
786
787 if let ExecutionPlan::Sequence(steps) = &result.plan {
789 assert!(
790 matches!(&steps[0], ExecutionPlan::Cached { node_id, .. } if node_id == "a"),
791 "first node should be cached, got: {:?}",
792 steps[0]
793 );
794 assert!(
795 matches!(&steps[1], ExecutionPlan::Composite { node_ids } if node_ids == &["b", "c"]),
796 "b+c should be Composite, got: {:?}",
797 steps[1]
798 );
799 } else {
800 panic!("expected Sequence, got: {:?}", result.plan);
801 }
802 }
803
804 #[test]
805 fn cascade_invalidation_different_config_changes_keys() {
806 let mut reg1 = SimpleFilterRegistry::new();
808 reg1.register_meta(
809 "a",
810 make_meta(FilterKind::Trainable, true),
811 CacheKey::hash_data(b"scaler_v1"),
812 );
813 reg1.register_meta(
814 "b",
815 make_meta(FilterKind::Trainable, true),
816 CacheKey::hash_data(b"pca_v1"),
817 );
818
819 let mut reg2 = SimpleFilterRegistry::new();
821 reg2.register_meta(
822 "a",
823 make_meta(FilterKind::Trainable, true),
824 CacheKey::hash_data(b"scaler_v2"), );
826 reg2.register_meta(
827 "b",
828 make_meta(FilterKind::Trainable, true),
829 CacheKey::hash_data(b"pca_v1"), );
831
832 let a_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v1").0]);
837 let b_key_v1 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v1.0]);
838
839 let a_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"scaler_v2").0]);
840 let b_key_v2 = CacheKey::from_parts(&[&CacheKey::hash_data(b"pca_v1").0, &a_key_v2.0]);
841
842 assert_ne!(a_key_v1, a_key_v2);
844 assert_ne!(b_key_v1, b_key_v2);
846 }
847
848 #[test]
849 fn no_cache_mode_skips_all_caching() {
850 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
851
852 let mut registry = SimpleFilterRegistry::new();
853 register_nodes(
854 &mut registry,
855 &["a", "b"],
856 make_meta(FilterKind::Trainable, true),
857 );
858
859 let a_config = registry.config_hash("a").unwrap();
861 let a_key = CacheKey::from_parts(&[&a_config.0]);
862 let cache = MockCacheStore::new();
863 cache.insert(a_key);
864
865 let result = compile(&graph, ®istry, CompileMode::NoCache, Some(&cache)).unwrap();
866
867 assert_eq!(result.plan.cached_count(), 0);
869 }
870
871 #[test]
872 fn differentiable_mode_skips_output_caching() {
873 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
874
875 let mut registry = SimpleFilterRegistry::new();
876 register_nodes(
877 &mut registry,
878 &["a", "b"],
879 make_meta(FilterKind::Trainable, true),
880 );
881
882 let a_config = registry.config_hash("a").unwrap();
883 let a_key = CacheKey::from_parts(&[&a_config.0]);
884 let cache = MockCacheStore::new();
885 cache.insert(a_key);
886
887 let result = compile(&graph, ®istry, CompileMode::Differentiable, Some(&cache)).unwrap();
888
889 assert_eq!(result.plan.cached_count(), 0);
891 }
892
893 #[test]
894 fn gradient_flow_diagnostic_on_opaque() {
895 let graph = linear_pipeline(vec![
896 Node::new("scaler", "Scaler", "F"),
897 Node::new("tree", "DecisionTree", "F"),
898 Node::new("linear", "Linear", "F"),
899 ]);
900
901 let mut registry = SimpleFilterRegistry::new();
902 registry.register_meta(
903 "scaler",
904 make_meta(FilterKind::Trainable, true),
905 CacheKey::hash_data(b"s"),
906 );
907 registry.register_meta(
908 "tree",
909 make_meta(FilterKind::Opaque, false), CacheKey::hash_data(b"t"),
911 );
912 registry.register_meta(
913 "linear",
914 make_meta(FilterKind::Trainable, true),
915 CacheKey::hash_data(b"l"),
916 );
917
918 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
919
920 assert_eq!(result.diagnostics.len(), 1);
921 assert_eq!(result.diagnostics[0].node_id, "tree");
922 assert_eq!(result.diagnostics[0].level, DiagnosticLevel::Warning);
923 assert!(
924 result.diagnostics[0]
925 .message
926 .contains("gradient flow interrupted")
927 );
928 }
929
930 #[test]
931 fn no_diagnostic_when_all_differentiable() {
932 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
933
934 let mut registry = SimpleFilterRegistry::new();
935 register_nodes(
936 &mut registry,
937 &["a", "b"],
938 make_meta(FilterKind::Trainable, true),
939 );
940
941 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
942 assert!(result.diagnostics.is_empty());
943 }
944
945 #[test]
946 fn compile_rejects_cycle() {
947 let mut graph = Graph::new();
948 graph.add_node(Node::new("a", "A", "F"));
949 graph.add_node(Node::new("b", "B", "F"));
950 graph.add_edge(Edge::data("e1", "a", "b"));
951 graph.add_edge(Edge::data("e2", "b", "a"));
952
953 let registry = SimpleFilterRegistry::new();
954 let result = compile(&graph, ®istry, CompileMode::Inference, None);
955 assert!(matches!(result, Err(SomaError::CycleDetected)));
956 }
957
958 #[test]
959 fn plan_summary_is_accurate() {
960 let mut graph = Graph::new();
961 graph.add_node(Node::new("root", "Root", "F"));
962 graph.add_node(Node::new("b1", "B1", "F"));
963 graph.add_node(Node::new("b2", "B2", "F"));
964 graph.add_node(Node::new("end", "End", "F"));
965 graph.add_edge(Edge::data("e1", "root", "b1"));
966 graph.add_edge(Edge::data("e2", "root", "b2"));
967 graph.add_edge(Edge::data("e3", "b1", "end"));
968 graph.add_edge(Edge::data("e4", "b2", "end"));
969
970 let mut registry = SimpleFilterRegistry::new();
971 register_nodes(
972 &mut registry,
973 &["root", "b1", "b2", "end"],
974 make_meta(FilterKind::Trainable, true),
975 );
976
977 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
978 let summary = result.plan.summary();
979 assert_eq!(summary.total_nodes, 4);
980 assert_eq!(summary.parallel_branches, 2);
981 }
982
983 #[test]
984 fn distribution_wraps_remote_nodes() {
985 let graph = linear_pipeline(vec![
986 Node::new("preprocess", "Preprocess", "F"),
987 Node::new("gpu_train", "GpuTrain", "F"),
988 Node::new("evaluate", "Evaluate", "F"),
989 ]);
990
991 let mut registry = SimpleFilterRegistry::new();
992 registry.register_meta(
994 "preprocess",
995 make_meta(FilterKind::Trainable, true),
996 CacheKey::hash_data(b"pre"),
997 );
998 let mut gpu_meta = make_meta(FilterKind::Trainable, true);
1000 gpu_meta.distribution = somatize_core::filter::Distribution::Remote(
1001 somatize_core::filter::RemoteTarget::Tag("gpu".into()),
1002 );
1003 registry.register_meta("gpu_train", gpu_meta, CacheKey::hash_data(b"gpu"));
1004 registry.register_meta(
1006 "evaluate",
1007 make_meta(FilterKind::Trainable, true),
1008 CacheKey::hash_data(b"eval"),
1009 );
1010
1011 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
1012
1013 if let ExecutionPlan::Sequence(steps) = &result.plan {
1015 assert_eq!(steps.len(), 3);
1016 assert!(
1017 matches!(&steps[0], ExecutionPlan::Execute { node_id } if node_id == "preprocess")
1018 );
1019 assert!(
1020 matches!(&steps[1], ExecutionPlan::Remote { node_id, target, .. }
1021 if node_id == "gpu_train"
1022 && *target == somatize_core::filter::RemoteTarget::Tag("gpu".into())
1023 ),
1024 "expected Remote, got: {:?}",
1025 steps[1]
1026 );
1027 assert!(
1028 matches!(&steps[2], ExecutionPlan::Execute { node_id } if node_id == "evaluate")
1029 );
1030 } else {
1031 panic!("expected Sequence, got: {:?}", result.plan);
1032 }
1033 }
1034
1035 #[test]
1036 fn local_distribution_not_wrapped() {
1037 let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
1038
1039 let mut registry = SimpleFilterRegistry::new();
1040 register_nodes(
1041 &mut registry,
1042 &["a", "b"],
1043 make_meta(FilterKind::Trainable, true),
1044 );
1045
1046 let result = compile(&graph, ®istry, CompileMode::Inference, None).unwrap();
1047
1048 let ids = result.plan.node_ids();
1050 assert_eq!(ids.len(), 2);
1051 if let ExecutionPlan::Sequence(steps) = &result.plan {
1053 assert!(
1054 steps
1055 .iter()
1056 .all(|s| matches!(s, ExecutionPlan::Execute { .. }))
1057 );
1058 }
1059 }
1060}