1use crate::cache::MemoryCache;
8use crate::event_bus::EventBus;
9use crate::executor::{self, Context, GraphInfo, RemoteExecutor};
10use crate::filter_library::FilterLibrary;
11use somatize_compiler::{CompileMode, CompileResult, compile};
12use somatize_core::cache::{CacheKey, CacheStore};
13use somatize_core::error::{Result, SomaError};
14use somatize_core::event::Event;
15use somatize_core::filter::FilterKind;
16use somatize_core::graph::Graph;
17use somatize_core::store::{DataRef, DataStore};
18use somatize_core::util::timestamp_id;
19use somatize_core::value::Value;
20use std::collections::HashMap;
21use std::sync::Arc;
22
23pub struct GraphSession {
35 graph: Graph,
36 library: FilterLibrary,
37 cache: Arc<dyn CacheStore>,
38 event_bus: Arc<EventBus>,
39 data_store: Option<Arc<dyn DataStore>>,
40 remote_executor: Option<Arc<dyn RemoteExecutor>>,
41 fitted: bool,
42}
43
44impl GraphSession {
45 pub fn new(graph: Graph, library: FilterLibrary) -> Self {
46 Self {
47 graph,
48 library,
49 cache: Arc::new(MemoryCache::default()),
50 event_bus: Arc::new(EventBus::new(256)),
51 data_store: None,
52 remote_executor: None,
53 fitted: false,
54 }
55 }
56
57 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
58 self.cache = cache;
59 self
60 }
61
62 pub fn with_event_bus(mut self, bus: Arc<EventBus>) -> Self {
63 self.event_bus = bus;
64 self
65 }
66
67 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
68 self.data_store = Some(store);
69 self
70 }
71
72 pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
73 self.remote_executor = Some(executor);
74 self
75 }
76
77 pub fn compile(&self, mode: CompileMode) -> Result<CompileResult> {
81 compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))
82 }
83
84 pub fn run(&mut self, mode: CompileMode) -> Result<HashMap<String, Value>> {
86 let CompileResult { plan, diagnostics } =
87 compile(&self.graph, &self.library, mode, Some(self.cache.as_ref()))?;
88
89 for diag in &diagnostics {
90 tracing::warn!("compile diagnostic: {:?}", diag);
91 }
92
93 let graph_info = GraphInfo::from_graph(&self.graph);
94 let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_run"))
95 .with_graph_info(graph_info);
96
97 if let Some(store) = &self.data_store {
98 ctx = ctx.with_data_store(store.clone());
99 }
100 if let Some(remote) = &self.remote_executor {
101 ctx = ctx.with_remote_executor(remote.clone());
102 }
103
104 executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
105
106 Ok(ctx
107 .store
108 .into_iter()
109 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
110 .collect())
111 }
112
113 pub fn fit(&mut self, x: &Value, y: Option<&Value>) -> Result<HashMap<String, Value>> {
121 self.graph.validate()?;
122 let sorted = self.graph.topological_sort()?;
123 let graph_info = GraphInfo::from_graph(&self.graph);
124
125 let run_id = timestamp_id("graph_fit");
126 let mut outputs: HashMap<String, Value> = HashMap::new();
127
128 let roots = self.graph.roots();
130 for root_id in &roots {
131 outputs.insert(format!("__input_{root_id}"), x.clone());
132 }
133
134 for node_id in &sorted {
135 let filter = self
136 .library
137 .get(node_id)
138 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
139
140 self.event_bus.emit(Event::NodeStarted {
141 run_id: run_id.clone(),
142 node_id: node_id.to_string(),
143 kind: filter.meta().kind,
144 });
145
146 let preds = graph_info.predecessors(node_id);
148 let input = match preds.len() {
149 0 => x.clone(),
150 1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
151 _ => {
152 let mut merged = serde_json::Map::new();
153 for pred_id in preds {
154 if let Some(val) = outputs.get(pred_id.as_str()) {
155 let json_val =
156 serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
157 merged.insert(pred_id.clone(), json_val);
158 }
159 }
160 Value::Json(serde_json::Value::Object(merged))
161 }
162 };
163
164 let meta = filter.meta();
165 let start = std::time::Instant::now();
166
167 let (state, output) = if meta.kind == FilterKind::Trainable {
169 let data_hash =
170 CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
171 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
172
173 let state = if let Some(cached) = self.cache.get(&state_key)? {
174 cached
175 } else {
176 let s = filter.fit(&input, y)?;
177 self.cache.put(&state_key, &s)?;
178 s
179 };
180
181 let output = filter.forward(&input, &state)?;
182 self.library.set_state(node_id.to_string(), state.clone());
184 (state, output)
185 } else {
186 let output = filter.forward(&input, &Value::Empty)?;
187 (Value::Empty, output)
188 };
189
190 let _ = state; self.event_bus.emit(Event::NodeCompleted {
193 run_id: run_id.clone(),
194 node_id: node_id.to_string(),
195 duration: start.elapsed(),
196 output_summary: format!("{output}"),
197 });
198
199 outputs.insert(node_id.to_string(), output);
200 }
201
202 self.fitted = true;
203 Ok(outputs)
204 }
205
206 pub fn forward(&self, x: &Value) -> Result<Value> {
211 let CompileResult { plan, .. } = compile(
212 &self.graph,
213 &self.library,
214 CompileMode::Inference,
215 Some(self.cache.as_ref()),
216 )?;
217
218 let graph_info = GraphInfo::from_graph(&self.graph);
219 let mut ctx = Context::new(self.event_bus.clone(), timestamp_id("graph_forward"))
220 .with_graph_info(graph_info);
221
222 if let Some(store) = &self.data_store {
223 ctx = ctx.with_data_store(store.clone());
224 }
225 if let Some(remote) = &self.remote_executor {
226 ctx = ctx.with_remote_executor(remote.clone());
227 }
228
229 let roots = self.graph.roots();
231 if roots.len() == 1 {
232 ctx.set(format!("__input_{}", roots[0]), x.clone());
233 }
234 ctx.set("__input__", x.clone());
235
236 executor::execute(&plan, &mut ctx, &self.library, self.cache.as_ref())?;
237
238 let leaves = self.graph.leaves();
240 let mut extract = |id: &str| -> Option<Value> {
241 ctx.store.remove(id).and_then(|vv| vv.as_value().cloned())
242 };
243
244 if let Some(leaf_id) = leaves.first() {
245 extract(leaf_id).ok_or_else(|| {
246 SomaError::Other(format!("leaf node '{leaf_id}' produced no output"))
247 })
248 } else {
249 ctx.execution_order
250 .last()
251 .and_then(|id| extract(id))
252 .ok_or_else(|| SomaError::Other("no output produced".into()))
253 }
254 }
255
256 pub fn forward_batched(&self, data_ref: &DataRef, batch_size: usize) -> Result<Value> {
258 let store = self
259 .data_store
260 .as_ref()
261 .ok_or_else(|| SomaError::Execution {
262 node_id: "session".into(),
263 message: "forward_batched requires a data store (use with_data_store)".into(),
264 })?;
265
266 let meta = store.meta(data_ref)?;
267 let total_rows = meta.total_rows;
268 if total_rows == 0 {
269 return Ok(Value::Empty);
270 }
271
272 let mut all_values: Vec<f64> = Vec::new();
273 let mut result_shape: Option<Vec<usize>> = None;
274 let mut rows_processed = 0;
275
276 while rows_processed < total_rows {
277 let batch_len = batch_size.min(total_rows - rows_processed);
278 let batch = store.get_rows(data_ref, rows_processed, batch_len)?;
279 let output = self.forward(&batch)?;
280
281 if let Value::Tensor { values, shape } = &output {
282 if result_shape.is_none() {
283 result_shape = Some(shape.clone());
284 }
285 all_values.extend_from_slice(values);
286 } else {
287 return Ok(output);
288 }
289
290 rows_processed += batch_len;
291 }
292
293 match result_shape {
294 Some(mut shape) => {
295 shape[0] = total_rows;
296 Ok(Value::tensor(all_values, shape))
297 }
298 None => Ok(Value::Empty),
299 }
300 }
301
302 pub fn persist_states(&self) -> Result<DataRef> {
306 let store = self
307 .data_store
308 .as_ref()
309 .ok_or_else(|| SomaError::Execution {
310 node_id: "session".into(),
311 message: "persist_states requires a data store".into(),
312 })?;
313
314 let sorted = self.graph.topological_sort()?;
315 let mut states_map = serde_json::Map::new();
316 for node_id in &sorted {
317 if let Some(state) = self.library.get_state(node_id) {
318 let json = serde_json::to_value(state)
319 .map_err(|e| SomaError::Other(format!("state serialize: {e}")))?;
320 states_map.insert(node_id.to_string(), json);
321 }
322 }
323
324 let states_value = Value::Json(serde_json::Value::Object(states_map));
325 let key = CacheKey::from_parts(&[b"graph_states", self.graph_config_hash().as_bytes()]);
326 store.put(&key, &states_value)
327 }
328
329 pub fn load_states(&mut self, data_ref: &DataRef) -> Result<()> {
331 let store = self
332 .data_store
333 .as_ref()
334 .ok_or_else(|| SomaError::Execution {
335 node_id: "session".into(),
336 message: "load_states requires a data store".into(),
337 })?;
338
339 let states_value = store.get(data_ref)?;
340 let states_json = states_value
341 .as_json()
342 .ok_or_else(|| SomaError::Other("persisted states must be JSON".into()))?;
343 let obj = states_json
344 .as_object()
345 .ok_or_else(|| SomaError::Other("persisted states must be a JSON object".into()))?;
346
347 for (node_id, json_val) in obj {
348 let value: Value = serde_json::from_value(json_val.clone())
349 .map_err(|e| SomaError::Other(format!("state deserialize: {e}")))?;
350 self.library.set_state(node_id.clone(), value);
351 }
352
353 self.fitted = true;
354 Ok(())
355 }
356
357 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
361 self.event_bus.subscribe()
362 }
363
364 pub fn event_bus(&self) -> &Arc<EventBus> {
366 &self.event_bus
367 }
368
369 pub fn is_fitted(&self) -> bool {
371 self.fitted
372 }
373
374 pub fn graph(&self) -> &Graph {
376 &self.graph
377 }
378
379 pub fn library(&self) -> &FilterLibrary {
381 &self.library
382 }
383
384 pub fn library_mut(&mut self) -> &mut FilterLibrary {
386 &mut self.library
387 }
388
389 fn graph_config_hash(&self) -> String {
392 let node_ids: Vec<&str> = self.graph.nodes.iter().map(|n| n.id.as_str()).collect();
393 node_ids.join(",")
394 }
395}
396
397pub fn graph_run(
401 graph: &Graph,
402 library: &FilterLibrary,
403 mode: CompileMode,
404 cache: &dyn CacheStore,
405) -> Result<HashMap<String, Value>> {
406 let CompileResult { plan, diagnostics } = compile(graph, library, mode, Some(cache))?;
407
408 for diag in &diagnostics {
409 tracing::warn!("compile diagnostic: {:?}", diag);
410 }
411
412 let bus = Arc::new(EventBus::new(256));
413 let graph_info = GraphInfo::from_graph(graph);
414
415 let mut ctx = Context::new(bus, timestamp_id("graph_run")).with_graph_info(graph_info);
416
417 executor::execute(&plan, &mut ctx, library, cache)?;
418
419 Ok(ctx
420 .store
421 .into_iter()
422 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
423 .collect())
424}
425
426pub fn graph_fit(
428 graph: &Graph,
429 library: &FilterLibrary,
430 x: &Value,
431 y: Option<&Value>,
432 cache: &dyn CacheStore,
433) -> Result<HashMap<String, Value>> {
434 graph.validate()?;
435 let sorted = graph.topological_sort()?;
436 let graph_info = GraphInfo::from_graph(graph);
437
438 let bus = Arc::new(EventBus::new(256));
439 let run_id = timestamp_id("graph_fit");
440
441 let mut outputs: HashMap<String, Value> = HashMap::new();
442
443 let roots = graph.roots();
444 for root_id in &roots {
445 outputs.insert(format!("__input_{root_id}"), x.clone());
446 }
447
448 for node_id in &sorted {
449 let filter = library
450 .get(node_id)
451 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
452
453 bus.emit(Event::NodeStarted {
454 run_id: run_id.clone(),
455 node_id: node_id.to_string(),
456 kind: filter.meta().kind,
457 });
458
459 let preds = graph_info.predecessors(node_id);
460 let input = match preds.len() {
461 0 => x.clone(),
462 1 => outputs.get(&preds[0]).cloned().unwrap_or_else(|| x.clone()),
463 _ => {
464 let mut merged = serde_json::Map::new();
465 for pred_id in preds {
466 if let Some(val) = outputs.get(pred_id.as_str()) {
467 let json_val = serde_json::to_value(val).unwrap_or(serde_json::Value::Null);
468 merged.insert(pred_id.clone(), json_val);
469 }
470 }
471 Value::Json(serde_json::Value::Object(merged))
472 }
473 };
474
475 let meta = filter.meta();
476 let start = std::time::Instant::now();
477
478 let (state, output) = if meta.kind == FilterKind::Trainable {
479 let data_hash = CacheKey::hash_data(&serde_json::to_vec(&input).unwrap_or_default());
480 let state_key = CacheKey::for_state(&filter.config_hash(), &data_hash);
481
482 let state = if let Some(cached) = cache.get(&state_key)? {
483 cached
484 } else {
485 let s = filter.fit(&input, y)?;
486 cache.put(&state_key, &s)?;
487 s
488 };
489
490 let output = filter.forward(&input, &state)?;
491 (state, output)
492 } else {
493 let output = filter.forward(&input, &Value::Empty)?;
494 (Value::Empty, output)
495 };
496
497 let _ = state;
498
499 bus.emit(Event::NodeCompleted {
500 run_id: run_id.clone(),
501 node_id: node_id.to_string(),
502 duration: start.elapsed(),
503 output_summary: format!("{output}"),
504 });
505
506 outputs.insert(node_id.to_string(), output);
507 }
508
509 Ok(outputs)
510}
511
512pub fn graph_predict(
514 graph: &Graph,
515 library: &FilterLibrary,
516 x: &Value,
517 cache: &dyn CacheStore,
518) -> Result<Value> {
519 let CompileResult { plan, .. } = compile(graph, library, CompileMode::Inference, Some(cache))?;
520
521 let bus = Arc::new(EventBus::new(256));
522 let graph_info = GraphInfo::from_graph(graph);
523 let mut ctx = Context::new(bus, timestamp_id("graph_predict")).with_graph_info(graph_info);
524
525 let roots = graph.roots();
526 if roots.len() == 1 {
527 ctx.set(format!("__input_{}", roots[0]), x.clone());
528 }
529 ctx.set("__input__", x.clone());
530
531 executor::execute(&plan, &mut ctx, library, cache)?;
532
533 let leaves = graph.leaves();
534 let mut extract =
535 |id: &str| -> Option<Value> { ctx.store.remove(id).and_then(|vv| vv.as_value().cloned()) };
536
537 if let Some(leaf_id) = leaves.first() {
538 extract(leaf_id)
539 .ok_or_else(|| SomaError::Other(format!("leaf node '{leaf_id}' produced no output")))
540 } else {
541 ctx.execution_order
542 .last()
543 .and_then(|id| extract(id))
544 .ok_or_else(|| SomaError::Other("no output produced".into()))
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use crate::cache::MemoryCache;
552 use somatize_compiler::FilterRegistry;
553 use somatize_core::cache::CacheKey;
554 use somatize_core::error::Result;
555 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
556 use somatize_core::graph::{Edge, Node};
557
558 struct DoublerFilter;
561 impl somatize_core::filter::Filter for DoublerFilter {
562 fn config_hash(&self) -> CacheKey {
563 CacheKey::from_parts(&[b"Doubler"])
564 }
565 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
566 Ok(Value::Empty)
567 }
568 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
569 let (data, shape) = x
570 .as_tensor()
571 .ok_or(SomaError::Other("need tensor".into()))?;
572 Ok(Value::tensor(
573 data.iter().map(|v| v * 2.0).collect(),
574 shape.to_vec(),
575 ))
576 }
577 fn meta(&self) -> FilterMeta {
578 FilterMeta {
579 name: "Doubler".into(),
580 kind: FilterKind::Stateless,
581 cacheable: true,
582 differentiable: true,
583 stream_mode: StreamMode::FixedState,
584 distribution: somatize_core::filter::Distribution::Local,
585 input_schema: None,
586 output_schema: None,
587 }
588 }
589 }
590
591 struct AdderFilter(f64);
592 impl somatize_core::filter::Filter for AdderFilter {
593 fn config_hash(&self) -> CacheKey {
594 CacheKey::from_parts(&[b"Adder", &self.0.to_le_bytes()])
595 }
596 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
597 Ok(Value::Empty)
598 }
599 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
600 let (data, shape) = x
601 .as_tensor()
602 .ok_or(SomaError::Other("need tensor".into()))?;
603 Ok(Value::tensor(
604 data.iter().map(|v| v + self.0).collect(),
605 shape.to_vec(),
606 ))
607 }
608 fn meta(&self) -> FilterMeta {
609 FilterMeta {
610 name: "Adder".into(),
611 kind: FilterKind::Stateless,
612 cacheable: true,
613 differentiable: true,
614 stream_mode: StreamMode::FixedState,
615 distribution: somatize_core::filter::Distribution::Local,
616 input_schema: None,
617 output_schema: None,
618 }
619 }
620 }
621
622 struct MeanFilter;
623 impl somatize_core::filter::Filter for MeanFilter {
624 fn config_hash(&self) -> CacheKey {
625 CacheKey::from_parts(&[b"Mean"])
626 }
627 fn fit(&self, x: &Value, _y: Option<&Value>) -> Result<Value> {
628 let (data, _) = x
629 .as_tensor()
630 .ok_or(SomaError::Other("need tensor".into()))?;
631 let mean = data.iter().sum::<f64>() / data.len() as f64;
632 Ok(Value::json(serde_json::json!({ "mean": mean })))
633 }
634 fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
635 let (data, shape) = x
636 .as_tensor()
637 .ok_or(SomaError::Other("need tensor".into()))?;
638 let mean = state
639 .as_json()
640 .and_then(|j| j["mean"].as_f64())
641 .unwrap_or(0.0);
642 Ok(Value::tensor(
643 data.iter().map(|v| v - mean).collect(),
644 shape.to_vec(),
645 ))
646 }
647 fn meta(&self) -> FilterMeta {
648 FilterMeta {
649 name: "Mean".into(),
650 kind: FilterKind::Trainable,
651 cacheable: true,
652 differentiable: true,
653 stream_mode: StreamMode::FixedState,
654 distribution: somatize_core::filter::Distribution::Local,
655 input_schema: None,
656 output_schema: None,
657 }
658 }
659 }
660
661 fn linear_graph(ids: &[&str]) -> Graph {
662 let mut g = Graph::new();
663 for &id in ids {
664 g.nodes.push(Node::new(id, id, id));
665 }
666 for (i, pair) in ids.windows(2).enumerate() {
667 g.edges.push(Edge::data(format!("e{i}"), pair[0], pair[1]));
668 }
669 g
670 }
671
672 #[test]
675 fn session_run_linear() {
676 let graph = linear_graph(&["double", "add"]);
677 let mut lib = FilterLibrary::new();
678 lib.register("double", Box::new(DoublerFilter));
679 lib.register("add", Box::new(AdderFilter(10.0)));
680
681 let mut session = GraphSession::new(graph, lib);
682
683 let cache = MemoryCache::default();
684 session = session.with_cache(Arc::new(cache));
685
686 let CompileResult { plan, .. } = session.compile(CompileMode::NoCache).unwrap();
688 let bus = Arc::new(EventBus::new(64));
689 let mut ctx =
690 Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(session.graph()));
691 ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
692 executor::execute(&plan, &mut ctx, session.library(), &MemoryCache::default()).unwrap();
693
694 let outputs: HashMap<String, Value> = ctx
695 .store
696 .into_iter()
697 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
698 .collect();
699
700 let result = outputs.get("add").unwrap();
701 let (data, _) = result.as_tensor().unwrap();
702 assert_eq!(data, &[12.0, 14.0, 16.0]);
703 }
704
705 #[test]
706 fn session_fit_and_forward() {
707 let graph = linear_graph(&["mean", "double"]);
708 let mut lib = FilterLibrary::new();
709 lib.register("mean", Box::new(MeanFilter));
710 lib.register("double", Box::new(DoublerFilter));
711
712 let mut session = GraphSession::new(graph, lib);
713
714 let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
715 let outputs = session.fit(&x, None).unwrap();
716
717 let result = outputs.get("double").unwrap();
720 let (data, _) = result.as_tensor().unwrap();
721 assert_eq!(data, &[-20.0, 0.0, 20.0]);
722
723 assert!(session.is_fitted());
724 }
725
726 #[test]
727 fn session_compile_diagnostics() {
728 let graph = linear_graph(&["double"]);
729 let mut lib = FilterLibrary::new();
730 lib.register("double", Box::new(DoublerFilter));
731
732 let session = GraphSession::new(graph, lib);
733 let result = session.compile(CompileMode::NoCache).unwrap();
734 assert!(result.plan.node_count() > 0);
735 }
736
737 #[test]
740 fn graph_run_linear() {
741 let graph = linear_graph(&["double", "add"]);
742 let mut lib = FilterLibrary::new();
743 lib.register("double", Box::new(DoublerFilter));
744 lib.register("add", Box::new(AdderFilter(10.0)));
745
746 let cache = MemoryCache::default();
747
748 let outputs = {
749 let CompileResult { plan, .. } =
750 compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
751 let bus = Arc::new(EventBus::new(64));
752 let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
753 ctx.set("__input__", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
754 executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
755 ctx.store
756 .into_iter()
757 .filter_map(|(k, vv)| vv.as_value().cloned().map(|v| (k, v)))
758 .collect::<HashMap<String, Value>>()
759 };
760
761 let result = outputs.get("add").unwrap();
762 let (data, _) = result.as_tensor().unwrap();
763 assert_eq!(data, &[12.0, 14.0, 16.0]);
764 }
765
766 #[test]
767 fn graph_run_diamond() {
768 let mut graph = Graph::new();
769 graph.nodes.push(Node::new("double", "Double", "double"));
770 graph.nodes.push(Node::new("add", "Add", "add"));
771 graph.nodes.push(Node::new("merge", "Merge", "merge"));
772 graph.edges.push(Edge::data("e1", "double", "merge"));
773 graph.edges.push(Edge::data("e2", "add", "merge"));
774
775 let mut lib = FilterLibrary::new();
776 lib.register("double", Box::new(DoublerFilter));
777 lib.register("add", Box::new(AdderFilter(100.0)));
778
779 struct MergeFilter;
780 impl somatize_core::filter::Filter for MergeFilter {
781 fn config_hash(&self) -> CacheKey {
782 CacheKey::from_parts(&[b"Merge"])
783 }
784 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
785 Ok(Value::Empty)
786 }
787 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
788 Ok(x.clone())
789 }
790 fn meta(&self) -> FilterMeta {
791 FilterMeta {
792 name: "Merge".into(),
793 kind: FilterKind::Stateless,
794 cacheable: true,
795 differentiable: false,
796 stream_mode: StreamMode::FixedState,
797 distribution: somatize_core::filter::Distribution::Local,
798 input_schema: None,
799 output_schema: None,
800 }
801 }
802 }
803 lib.register("merge", Box::new(MergeFilter));
804
805 let cache = MemoryCache::default();
806 let CompileResult { plan, .. } = compile(&graph, &lib, CompileMode::NoCache, None).unwrap();
807
808 let bus = Arc::new(EventBus::new(64));
809 let mut ctx = Context::new(bus, "test").with_graph_info(GraphInfo::from_graph(&graph));
810 ctx.set("__input__", Value::tensor(vec![5.0], vec![1]));
811 executor::execute(&plan, &mut ctx, &lib, &cache).unwrap();
812
813 let merge_output = ctx.get("merge").unwrap();
814 assert!(
815 merge_output.as_json().is_some(),
816 "merge should receive JSON from multiple predecessors"
817 );
818 }
819
820 #[test]
821 fn graph_fit_trainable() {
822 let graph = linear_graph(&["mean", "double"]);
823 let mut lib = FilterLibrary::new();
824 lib.register("mean", Box::new(MeanFilter));
825 lib.register("double", Box::new(DoublerFilter));
826
827 let cache = MemoryCache::default();
828 let x = Value::tensor(vec![10.0, 20.0, 30.0], vec![3]);
829
830 let outputs = graph_fit(&graph, &lib, &x, None, &cache).unwrap();
831
832 let result = outputs.get("double").unwrap();
833 let (data, _) = result.as_tensor().unwrap();
834 assert_eq!(data, &[-20.0, 0.0, 20.0]);
835
836 assert!(!cache.is_empty());
837 }
838
839 #[test]
840 fn filter_library_registry_compat() {
841 let mut lib = FilterLibrary::new();
842 lib.register("a", Box::new(DoublerFilter));
843
844 let registry: &dyn FilterRegistry = &lib;
845 assert!(registry.meta("a").is_some());
846 assert_eq!(registry.meta("a").unwrap().name, "Doubler");
847 assert!(registry.config_hash("a").is_some());
848 assert!(registry.meta("b").is_none());
849 }
850}