use somatize_compiler::{CompileMode, SimpleFilterRegistry, compile};
use somatize_core::cache::{CacheKey, CacheStore, EntryMeta};
use somatize_core::error::Result;
use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
use somatize_core::graph::{Edge, Graph, Node, linear_pipeline};
use somatize_core::schema::{DataType, Schema};
use somatize_core::value::Value;
use std::collections::HashSet;
use std::sync::Mutex;
fn make_meta(kind: FilterKind, differentiable: bool) -> FilterMeta {
FilterMeta {
name: "test".into(),
kind,
cacheable: true,
differentiable,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: None,
output_schema: None,
}
}
struct MockCache {
entries: Mutex<HashSet<CacheKey>>,
}
impl MockCache {
fn new() -> Self {
Self {
entries: Mutex::new(HashSet::new()),
}
}
fn insert(&self, key: CacheKey) {
self.entries.lock().unwrap().insert(key);
}
}
impl CacheStore for MockCache {
fn get(&self, _: &CacheKey) -> Result<Option<Value>> {
Ok(None)
}
fn put(&self, _: &CacheKey, _: &Value) -> Result<()> {
Ok(())
}
fn exists(&self, key: &CacheKey) -> Result<bool> {
Ok(self.entries.lock().unwrap().contains(key))
}
fn remove(&self, _: &CacheKey) -> Result<()> {
Ok(())
}
fn metadata(&self, _: &CacheKey) -> Result<Option<EntryMeta>> {
Ok(None)
}
}
#[test]
fn gradient_multiple_interruptions() {
let graph = linear_pipeline(vec![
Node::new("d1", "D1", "F"),
Node::new("o1", "O1", "F"),
Node::new("d2", "D2", "F"),
Node::new("o2", "O2", "F"),
Node::new("d3", "D3", "F"),
]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"d1",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"d1"),
);
reg.register_meta(
"o1",
make_meta(FilterKind::Opaque, false),
CacheKey::hash_data(b"o1"),
);
reg.register_meta(
"d2",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"d2"),
);
reg.register_meta(
"o2",
make_meta(FilterKind::Opaque, false),
CacheKey::hash_data(b"o2"),
);
reg.register_meta(
"d3",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"d3"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
assert_eq!(
result.diagnostics.len(),
2,
"expected 2 gradient warnings, got: {:?}",
result.diagnostics
);
assert_eq!(result.diagnostics[0].node_id, "o1");
assert_eq!(result.diagnostics[1].node_id, "o2");
}
#[test]
fn gradient_all_opaque_single_warning() {
let graph = linear_pipeline(vec![
Node::new("o1", "O1", "F"),
Node::new("o2", "O2", "F"),
Node::new("o3", "O3", "F"),
]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"o1",
make_meta(FilterKind::Opaque, false),
CacheKey::hash_data(b"o1"),
);
reg.register_meta(
"o2",
make_meta(FilterKind::Opaque, false),
CacheKey::hash_data(b"o2"),
);
reg.register_meta(
"o3",
make_meta(FilterKind::Opaque, false),
CacheKey::hash_data(b"o3"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
assert_eq!(result.diagnostics.len(), 1);
assert_eq!(result.diagnostics[0].node_id, "o1");
}
#[test]
fn cache_diamond_cascade() {
let mut graph = Graph::new();
graph.add_node(Node::new("root", "Root", "F"));
graph.add_node(Node::new("b1", "B1", "F"));
graph.add_node(Node::new("b2", "B2", "F"));
graph.add_node(Node::new("merge", "Merge", "F"));
graph.add_edge(Edge::data("e1", "root", "b1"));
graph.add_edge(Edge::data("e2", "root", "b2"));
graph.add_edge(Edge::data("e3", "b1", "merge"));
graph.add_edge(Edge::data("e4", "b2", "merge"));
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"root",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"root"),
);
reg.register_meta(
"b1",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"b1"),
);
reg.register_meta(
"b2",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"b2"),
);
reg.register_meta(
"merge",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"merge"),
);
let cache = MockCache::new();
let root_key = CacheKey::from_parts(&[&CacheKey::hash_data(b"root").0]);
cache.insert(root_key.clone());
let b1_key = CacheKey::from_parts(&[&CacheKey::hash_data(b"b1").0, &root_key.0]);
cache.insert(b1_key.clone());
let b2_key = CacheKey::from_parts(&[&CacheKey::hash_data(b"b2").0, &root_key.0]);
cache.insert(b2_key.clone());
let merge_key = CacheKey::from_parts(&[&CacheKey::hash_data(b"merge").0, &b1_key.0, &b2_key.0]);
cache.insert(merge_key);
let result = compile(&graph, ®, CompileMode::Inference, Some(&cache)).unwrap();
assert_eq!(
result.plan.cached_count(),
4,
"all 4 nodes should be cached"
);
}
#[test]
fn compile_with_unregistered_node() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"a"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
assert_eq!(result.plan.node_count(), 2);
}
#[test]
fn compile_deep_chain() {
let nodes: Vec<Node> = (0..20)
.map(|i| Node::new(format!("n{i}"), format!("N{i}"), "F"))
.collect();
let graph = linear_pipeline(nodes);
let mut reg = SimpleFilterRegistry::new();
for i in 0..20 {
reg.register_meta(
format!("n{i}"),
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(format!("config_{i}").as_bytes()),
);
}
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
assert_eq!(result.plan.node_count(), 20);
}
#[test]
fn all_compile_modes() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"a"),
);
reg.register_meta(
"b",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"b"),
);
let cache = MockCache::new();
let a_key = CacheKey::from_parts(&[&CacheKey::hash_data(b"a").0]);
cache.insert(a_key);
let r1 = compile(&graph, ®, CompileMode::Inference, Some(&cache)).unwrap();
assert_eq!(r1.plan.cached_count(), 1);
let r2 = compile(&graph, ®, CompileMode::Differentiable, Some(&cache)).unwrap();
assert_eq!(r2.plan.cached_count(), 0);
let r3 = compile(&graph, ®, CompileMode::NoCache, Some(&cache)).unwrap();
assert_eq!(r3.plan.cached_count(), 0);
}
fn meta_with_schemas(output: Option<Schema>, input: Option<Schema>) -> FilterMeta {
FilterMeta {
name: "typed".into(),
kind: FilterKind::Trainable,
cacheable: true,
differentiable: true,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: input,
output_schema: output,
}
}
#[test]
fn schema_compatible_no_warnings() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
meta_with_schemas(Some(Schema::vector(DataType::Float64, 128)), None),
CacheKey::hash_data(b"a"),
);
reg.register_meta(
"b",
meta_with_schemas(None, Some(Schema::vector(DataType::Float64, 128))),
CacheKey::hash_data(b"b"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
let schema_warnings: Vec<_> = result
.diagnostics
.iter()
.filter(|d| d.message.contains("schema mismatch"))
.collect();
assert!(schema_warnings.is_empty(), "should have no schema warnings");
}
#[test]
fn schema_incompatible_dtype_warns() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
meta_with_schemas(Some(Schema::vector(DataType::Float64, 128)), None),
CacheKey::hash_data(b"a"),
);
reg.register_meta(
"b",
meta_with_schemas(None, Some(Schema::vector(DataType::Int64, 128))),
CacheKey::hash_data(b"b"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
let schema_warnings: Vec<_> = result
.diagnostics
.iter()
.filter(|d| d.message.contains("schema mismatch"))
.collect();
assert_eq!(schema_warnings.len(), 1);
assert!(schema_warnings[0].message.contains("f64"));
assert!(schema_warnings[0].message.contains("i64"));
}
#[test]
fn schema_incompatible_shape_warns() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
meta_with_schemas(Some(Schema::vector(DataType::Float64, 128)), None),
CacheKey::hash_data(b"a"),
);
reg.register_meta(
"b",
meta_with_schemas(None, Some(Schema::vector(DataType::Float64, 256))),
CacheKey::hash_data(b"b"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
let schema_warnings: Vec<_> = result
.diagnostics
.iter()
.filter(|d| d.message.contains("schema mismatch"))
.collect();
assert_eq!(schema_warnings.len(), 1);
}
#[test]
fn schema_dynamic_compatible_with_fixed() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
meta_with_schemas(Some(Schema::batched(DataType::Float64, &[128])), None),
CacheKey::hash_data(b"a"),
);
reg.register_meta(
"b",
meta_with_schemas(None, Some(Schema::matrix(DataType::Float64, 32, 128))),
CacheKey::hash_data(b"b"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
let schema_warnings: Vec<_> = result
.diagnostics
.iter()
.filter(|d| d.message.contains("schema mismatch"))
.collect();
assert!(schema_warnings.is_empty());
}
#[test]
fn schema_none_skips_validation() {
let graph = linear_pipeline(vec![Node::new("a", "A", "F"), Node::new("b", "B", "F")]);
let mut reg = SimpleFilterRegistry::new();
reg.register_meta(
"a",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"a"),
);
reg.register_meta(
"b",
make_meta(FilterKind::Trainable, true),
CacheKey::hash_data(b"b"),
);
let result = compile(&graph, ®, CompileMode::Inference, None).unwrap();
assert!(result.diagnostics.is_empty());
}