use somatize_compiler::{CompileMode, ExecutionPlan};
use somatize_core::cache::CacheKey;
use somatize_core::error::{Result, SomaError};
use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
use somatize_core::graph::{Edge, Graph, Node};
use somatize_core::store::{DataStore, LocalDataStore};
use somatize_core::value::Value;
use somatize_runtime::cache::MemoryCache;
use somatize_runtime::executor::{Context, RemoteExecutor};
use somatize_runtime::filter_library::FilterLibrary;
use somatize_runtime::graph_session::GraphSession;
use somatize_runtime::*;
use std::sync::Arc;
struct Doubler;
impl Filter for Doubler {
fn config_hash(&self) -> CacheKey {
CacheKey::from_parts(&[b"Doubler"])
}
fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
Ok(Value::Empty)
}
fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
match x {
Value::Tensor { values, shape } => Ok(Value::tensor(
values.iter().map(|v| v * 2.0).collect(),
shape.clone(),
)),
_ => Ok(x.clone()),
}
}
fn meta(&self) -> FilterMeta {
FilterMeta {
name: "Doubler".into(),
kind: FilterKind::Stateless,
cacheable: true,
differentiable: true,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: None,
output_schema: None,
}
}
}
struct MeanFilter;
impl Filter for MeanFilter {
fn config_hash(&self) -> CacheKey {
CacheKey::from_parts(&[b"Mean"])
}
fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
let (data, _) = x
.as_tensor()
.ok_or(SomaError::Other("need tensor".into()))?;
let mean = data.iter().sum::<f64>() / data.len() as f64;
Ok(Value::json(serde_json::json!({"mean": mean})))
}
fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
let (data, shape) = x
.as_tensor()
.ok_or(SomaError::Other("need tensor".into()))?;
let mean = state
.as_json()
.and_then(|j| j["mean"].as_f64())
.unwrap_or(0.0);
Ok(Value::tensor(
data.iter().map(|v| v - mean).collect(),
shape.to_vec(),
))
}
fn meta(&self) -> FilterMeta {
FilterMeta {
name: "Mean".into(),
kind: FilterKind::Trainable,
cacheable: true,
differentiable: true,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: None,
output_schema: None,
}
}
}
struct BranchCondition;
impl Filter for BranchCondition {
fn config_hash(&self) -> CacheKey {
CacheKey::from_parts(&[b"BranchCond"])
}
fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
Ok(Value::Empty)
}
fn forward(&self, _x: &Value, _state: &Value) -> Result<Value> {
Ok(Value::json(serde_json::json!("left")))
}
fn meta(&self) -> FilterMeta {
FilterMeta {
name: "BranchCond".into(),
kind: FilterKind::Stateless,
cacheable: false,
differentiable: false,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: None,
output_schema: None,
}
}
}
struct StopFilter {
call_count: std::sync::atomic::AtomicUsize,
stop_at: usize,
}
impl Filter for StopFilter {
fn config_hash(&self) -> CacheKey {
CacheKey::from_parts(&[b"Stop"])
}
fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
Ok(Value::Empty)
}
fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
let count = self
.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count >= self.stop_at {
Ok(Value::json(serde_json::json!({"done": true})))
} else {
Ok(x.clone())
}
}
fn meta(&self) -> FilterMeta {
FilterMeta {
name: "Stop".into(),
kind: FilterKind::Stateless,
cacheable: false,
differentiable: false,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: None,
output_schema: None,
}
}
}
fn linear_graph(ids: &[&str]) -> Graph {
let mut g = Graph::new();
for &id in ids {
g.nodes.push(Node::new(id, id, id));
}
for (i, pair) in ids.windows(2).enumerate() {
g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
}
g
}
#[test]
fn session_run_returns_all_outputs() {
let graph = linear_graph(&["doubler", "doubler2"]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
lib.register("doubler2", Box::new(Doubler));
let mut session = GraphSession::new(graph, lib);
let result = session.run(CompileMode::NoCache);
assert!(result.is_ok());
}
#[test]
fn session_forward_after_fit() {
let graph = linear_graph(&["mean", "doubler"]);
let mut lib = FilterLibrary::new();
lib.register("mean", Box::new(MeanFilter));
lib.register("doubler", Box::new(Doubler));
let mut session = GraphSession::new(graph, lib);
let train = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
session.fit(&train, None).unwrap();
assert!(session.is_fitted());
let test = Value::tensor(vec![20.0], vec![1]);
let result = session.forward(&test);
let _ = result;
}
#[test]
fn session_compile_all_modes() {
let graph = linear_graph(&["doubler"]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let session = GraphSession::new(graph, lib);
let r1 = session.compile(CompileMode::Inference).unwrap();
assert!(r1.plan.node_count() > 0);
let r2 = session.compile(CompileMode::Differentiable).unwrap();
assert!(r2.plan.node_count() > 0);
let r3 = session.compile(CompileMode::NoCache).unwrap();
assert!(r3.plan.node_count() > 0);
}
#[test]
fn session_with_data_store() {
let graph = linear_graph(&["doubler"]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let tmp = tempfile::tempdir().unwrap();
let store = Arc::new(LocalDataStore::new(tmp.path()));
let session = GraphSession::new(graph, lib).with_data_store(store);
let result = session.compile(CompileMode::NoCache);
assert!(result.is_ok());
}
#[test]
fn session_with_remote_executor() {
struct DummyRemote;
impl RemoteExecutor for DummyRemote {
fn execute_remote(
&self,
_node_id: &str,
_target: &somatize_core::filter::RemoteTarget,
_input: Option<&Value>,
) -> Result<Value> {
Ok(Value::tensor(vec![42.0], vec![1]))
}
}
let graph = linear_graph(&["doubler"]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let session = GraphSession::new(graph, lib).with_remote_executor(Arc::new(DummyRemote));
let result = session.compile(CompileMode::NoCache);
assert!(result.is_ok());
}
#[test]
fn session_graph_and_library_accessors() {
let graph = linear_graph(&["a"]);
let mut lib = FilterLibrary::new();
lib.register("a", Box::new(Doubler));
let mut session = GraphSession::new(graph, lib);
assert_eq!(session.graph().nodes.len(), 1);
assert!(session.library().get("a").is_some());
session.library_mut().register("b", Box::new(Doubler));
assert!(session.library().get("b").is_some());
}
#[test]
fn session_persist_and_load_states() {
let graph = linear_graph(&["mean"]);
let mut lib = FilterLibrary::new();
lib.register("mean", Box::new(MeanFilter));
let tmp = tempfile::tempdir().unwrap();
let store: Arc<dyn DataStore> = Arc::new(LocalDataStore::new(tmp.path()));
let mut session = GraphSession::new(graph.clone(), lib).with_data_store(store.clone());
let train = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
session.fit(&train, None).unwrap();
let data_ref = session.persist_states().unwrap();
let mut lib2 = FilterLibrary::new();
lib2.register("mean", Box::new(MeanFilter));
let mut session2 = GraphSession::new(graph, lib2).with_data_store(store);
session2.load_states(&data_ref).unwrap();
assert!(session2.is_fitted());
}
#[test]
fn session_persist_without_datastore_errors() {
let graph = linear_graph(&["mean"]);
let mut lib = FilterLibrary::new();
lib.register("mean", Box::new(MeanFilter));
let session = GraphSession::new(graph, lib);
let result = session.persist_states();
assert!(result.is_err()); }
#[test]
fn executor_loop_terminates_on_done() {
let bus = Arc::new(EventBus::new(64));
let cache = MemoryCache::default();
let mut ctx = Context::new(bus, "loop_test");
ctx.set("input", Value::tensor(vec![1.0], vec![1]));
ctx.graph_info
.set_predecessors("stopper", vec!["input".into()]);
let mut lib = FilterLibrary::new();
lib.register(
"stopper",
Box::new(StopFilter {
call_count: std::sync::atomic::AtomicUsize::new(0),
stop_at: 3,
}),
);
let plan = ExecutionPlan::Loop {
node_id: "loop".into(),
body: Box::new(ExecutionPlan::Execute {
node_id: "stopper".into(),
}),
max_iterations: Some(100),
};
execute(&plan, &mut ctx, &lib, &cache).unwrap();
}
#[test]
fn executor_branch_selects_arm() {
let bus = Arc::new(EventBus::new(64));
let cache = MemoryCache::default();
let mut ctx = Context::new(bus, "branch_test");
ctx.set("input", Value::tensor(vec![1.0], vec![1]));
ctx.graph_info
.set_predecessors("cond", vec!["input".into()]);
ctx.graph_info
.set_predecessors("left_doubler", vec!["cond".into()]);
ctx.graph_info
.set_predecessors("right_doubler", vec!["cond".into()]);
let mut lib = FilterLibrary::new();
lib.register("cond", Box::new(BranchCondition));
lib.register("left_doubler", Box::new(Doubler));
lib.register("right_doubler", Box::new(Doubler));
let plan = ExecutionPlan::Branch {
node_id: "cond".into(),
arms: vec![
(
"left".into(),
ExecutionPlan::Execute {
node_id: "left_doubler".into(),
},
),
(
"right".into(),
ExecutionPlan::Execute {
node_id: "right_doubler".into(),
},
),
],
};
execute(&plan, &mut ctx, &lib, &cache).unwrap();
assert!(ctx.get("left_doubler").is_some(), "left arm should execute");
}
#[test]
fn executor_remote_falls_back_to_local() {
let bus = Arc::new(EventBus::new(64));
let cache = MemoryCache::default();
let mut ctx = Context::new(bus, "remote_test");
ctx.set("input", Value::tensor(vec![5.0], vec![1]));
ctx.graph_info
.set_predecessors("doubler", vec!["input".into()]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let plan = ExecutionPlan::Remote {
node_id: "doubler".into(),
target: somatize_core::filter::RemoteTarget::Tag("gpu".into()),
plan: Box::new(ExecutionPlan::Execute {
node_id: "doubler".into(),
}),
};
execute(&plan, &mut ctx, &lib, &cache).unwrap();
let result = ctx.get("doubler").unwrap();
let (data, _) = result.as_tensor().unwrap();
assert_eq!(data, &[10.0]); }
#[test]
fn executor_remote_with_executor() {
struct TestRemote;
impl RemoteExecutor for TestRemote {
fn execute_remote(
&self,
_node_id: &str,
_target: &somatize_core::filter::RemoteTarget,
input: Option<&Value>,
) -> Result<Value> {
match input {
Some(Value::Tensor { values, shape }) => Ok(Value::tensor(
values.iter().map(|v| v * 2.0).collect(),
shape.clone(),
)),
_ => Ok(Value::tensor(vec![99.0], vec![1])),
}
}
}
let bus = Arc::new(EventBus::new(64));
let cache = MemoryCache::default();
let mut ctx = Context::new(bus, "remote_exec_test").with_remote_executor(Arc::new(TestRemote));
ctx.set("input", Value::tensor(vec![7.0], vec![1]));
ctx.graph_info
.set_predecessors("remote_node", vec!["input".into()]);
let mut lib = FilterLibrary::new();
lib.register("remote_node", Box::new(Doubler));
let plan = ExecutionPlan::Remote {
node_id: "remote_node".into(),
target: somatize_core::filter::RemoteTarget::Tag("gpu".into()),
plan: Box::new(ExecutionPlan::Execute {
node_id: "remote_node".into(),
}),
};
execute(&plan, &mut ctx, &lib, &cache).unwrap();
let result = ctx.get("remote_node").unwrap();
let (data, _) = result.as_tensor().unwrap();
assert_eq!(data, &[14.0]); }
#[test]
fn executor_spills_large_values_to_datastore() {
let bus = Arc::new(EventBus::new(64));
let cache = MemoryCache::default();
let tmp = tempfile::tempdir().unwrap();
let store: Arc<dyn DataStore> = Arc::new(LocalDataStore::new(tmp.path()));
let large_input = Value::tensor(vec![1.0; 100], vec![100]);
let mut ctx = Context::new(bus, "spill_test")
.with_data_store(store)
.with_spill_threshold(100);
ctx.set("input", large_input);
ctx.graph_info
.set_predecessors("doubler", vec!["input".into()]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let plan = ExecutionPlan::Execute {
node_id: "doubler".into(),
};
execute(&plan, &mut ctx, &lib, &cache).unwrap();
let vv = ctx.get_virtual("doubler");
assert!(vv.is_some());
}
#[test]
fn graph_run_free_function() {
let graph = linear_graph(&["doubler"]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let cache = MemoryCache::default();
let result = somatize_runtime::graph_run(&graph, &lib, CompileMode::NoCache, &cache);
assert!(result.is_ok());
}
#[test]
fn graph_fit_free_function_trainable() {
let graph = linear_graph(&["mean"]);
let mut lib = FilterLibrary::new();
lib.register("mean", Box::new(MeanFilter));
let cache = MemoryCache::default();
let x = Value::tensor(vec![10.0, 20.0], vec![2]);
let outputs = somatize_runtime::graph_fit(&graph, &lib, &x, None, &cache).unwrap();
assert!(outputs.contains_key("mean"));
let (data, _) = outputs["mean"].as_tensor().unwrap();
assert_eq!(data, &[-5.0, 5.0]);
}
#[test]
fn graph_predict_free_function() {
let graph = linear_graph(&["doubler"]);
let mut lib = FilterLibrary::new();
lib.register("doubler", Box::new(Doubler));
let cache = MemoryCache::default();
let x = Value::tensor(vec![3.0], vec![1]);
let result = somatize_runtime::graph_predict(&graph, &lib, &x, &cache);
let _ = result;
}