1use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25 predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36 self.predecessors.insert(node_id.into(), preds);
37 }
38
39 pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41 let mut info = Self::new();
42 for node in &graph.nodes {
43 let preds: Vec<String> = graph
44 .predecessors(&node.id)
45 .into_iter()
46 .map(|s| s.to_string())
47 .collect();
48 info.set_predecessors(node.id.clone(), preds);
49 }
50 info
51 }
52
53 pub fn for_linear(node_ids: &[&str]) -> Self {
55 let mut info = Self::new();
56 for (i, &id) in node_ids.iter().enumerate() {
57 let preds = if i > 0 {
58 vec![node_ids[i - 1].to_string()]
59 } else {
60 vec![]
61 };
62 info.set_predecessors(id, preds);
63 }
64 info
65 }
66
67 pub fn predecessors(&self, node_id: &str) -> &[String] {
69 self.predecessors
70 .get(node_id)
71 .map(|v| v.as_slice())
72 .unwrap_or(&[])
73 }
74}
75
76pub struct Context {
82 pub store: HashMap<String, VirtualValue>,
84 pub event_bus: Arc<EventBus>,
86 pub run_id: String,
88 pub execution_order: Vec<String>,
90 pub graph_info: GraphInfo,
92 pub transport: Option<Arc<dyn crate::runner::Transport>>,
94 pub data_store: Option<Arc<dyn DataStore>>,
96 pub spill_threshold: usize,
99}
100
101impl Context {
102 pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
103 Self {
104 store: HashMap::new(),
105 event_bus,
106 run_id: run_id.into(),
107 execution_order: Vec::new(),
108 graph_info: GraphInfo::new(),
109 transport: None,
110 data_store: None,
111 spill_threshold: 0,
112 }
113 }
114
115 pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
116 self.graph_info = info;
117 self
118 }
119
120 pub fn with_transport(mut self, transport: Arc<dyn crate::runner::Transport>) -> Self {
121 self.transport = Some(transport);
122 self
123 }
124
125 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
126 self.data_store = Some(store);
127 self
128 }
129
130 pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
134 self.spill_threshold = bytes;
135 self
136 }
137
138 fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
141 if self.spill_threshold > 0
142 && let Some(store) = &self.data_store
143 {
144 let size = value.size() * 8; if size >= self.spill_threshold {
146 let key = somatize_core::cache::CacheKey::from_parts(&[
147 self.run_id.as_bytes(),
148 node_id.as_bytes(),
149 ]);
150 let vv_for_schema = VirtualValue::materialized(value.clone());
151 let schema = vv_for_schema.schema().clone();
152 if let Ok(_data_ref) = store.put(&key, &value) {
153 tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
154 return VirtualValue::cached(key, schema);
155 }
156 }
157 }
158 VirtualValue::materialized(value)
159 }
160
161 pub fn get(&self, node_id: &str) -> Option<&Value> {
163 self.store.get(node_id).and_then(|vv| vv.as_value())
164 }
165
166 pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
168 self.store.get(node_id)
169 }
170
171 pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
173 let id = node_id.into();
174 self.execution_order.push(id.clone());
175 self.store.insert(id, VirtualValue::materialized(value));
176 }
177
178 pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
180 let id = node_id.into();
181 self.execution_order.push(id.clone());
182 self.store.insert(id, vv);
183 }
184
185 fn snapshot(&self) -> Self {
186 Self {
187 store: self.store.clone(),
188 event_bus: self.event_bus.clone(),
189 run_id: self.run_id.clone(),
190 execution_order: self.execution_order.clone(),
191 graph_info: self.graph_info.clone(),
192 transport: self.transport.clone(),
193 data_store: self.data_store.clone(),
194 spill_threshold: self.spill_threshold,
195 }
196 }
197}
198
199pub trait Executable {
203 fn execute(
204 &self,
205 ctx: &mut Context,
206 filters: &FilterLibrary,
207 cache: &dyn CacheStore,
208 ) -> Result<()>;
209}
210
211impl Executable for ExecutionPlan {
212 fn execute(
213 &self,
214 ctx: &mut Context,
215 filters: &FilterLibrary,
216 cache: &dyn CacheStore,
217 ) -> Result<()> {
218 match self {
219 ExecutionPlan::Empty => Ok(()),
220
221 ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
222
223 ExecutionPlan::Cached { node_id, key } => {
224 let start = Instant::now();
225 let value = cache.get(key)?.ok_or_else(|| {
226 SomaError::Cache(format!(
227 "expected cached value for node `{node_id}` not found"
228 ))
229 })?;
230 ctx.set(node_id.clone(), value);
231 ctx.event_bus.emit(Event::NodeCacheHit {
232 run_id: ctx.run_id.clone(),
233 node_id: node_id.clone(),
234 key: key.clone(),
235 tier: somatize_core::cache::CacheTier::Memory,
236 load_time: start.elapsed(),
237 });
238 Ok(())
239 }
240
241 ExecutionPlan::Sequence(steps) => {
242 for step in steps {
243 step.execute(ctx, filters, cache)?;
244 }
245 Ok(())
246 }
247
248 ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
249
250 ExecutionPlan::Loop {
251 node_id,
252 body,
253 max_iterations,
254 } => {
255 let max = max_iterations.unwrap_or(100);
256 for i in 0..max {
257 body.execute(ctx, filters, cache)?;
258
259 let should_stop = ctx
262 .execution_order
263 .last()
264 .and_then(|last_id| ctx.get(last_id))
265 .map(|v| match v {
266 Value::Json(j) => {
267 j.as_bool() == Some(true)
268 || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
269 || j.get("done").and_then(|d| d.as_bool()) == Some(true)
270 }
271 Value::Empty => true,
272 _ => false,
273 })
274 .unwrap_or(false);
275
276 if should_stop {
277 ctx.event_bus.emit(Event::NodeCompleted {
278 run_id: ctx.run_id.clone(),
279 node_id: node_id.clone(),
280 duration: std::time::Duration::ZERO,
281 output_summary: format!("Loop terminated at iteration {}", i + 1),
282 });
283 break;
284 }
285 }
286 Ok(())
287 }
288
289 ExecutionPlan::Branch { node_id, arms } => {
290 execute_node(node_id, ctx, filters, cache)?;
292
293 let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
295
296 let selected_arm = match &condition {
298 Value::Json(j) => {
299 let selector = j
301 .as_str()
302 .map(String::from)
303 .or_else(|| j.as_bool().map(|b| b.to_string()))
304 .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
305 .unwrap_or_else(|| "true".to_string());
306
307 arms.iter()
308 .find(|(label, _)| label == &selector)
309 .or_else(|| {
310 arms.iter()
311 .find(|(label, _)| label == "default" || label == "else")
312 })
313 .or_else(|| arms.first())
314 }
315 _ => arms.first(),
316 };
317
318 if let Some((label, plan)) = selected_arm {
319 ctx.event_bus.emit(Event::NodeCompleted {
320 run_id: ctx.run_id.clone(),
321 node_id: node_id.clone(),
322 duration: std::time::Duration::ZERO,
323 output_summary: format!("Branch selected: {label}"),
324 });
325 plan.execute(ctx, filters, cache)?;
326 }
327 Ok(())
328 }
329
330 ExecutionPlan::Remote {
331 node_id,
332 target: _,
333 plan,
334 } => {
335 if let Some(transport) = &ctx.transport {
336 let input = ctx
338 .graph_info
339 .predecessors(node_id)
340 .first()
341 .and_then(|pred| ctx.get(pred));
342
343 let result = transport.execute_node(node_id, input)?;
344 ctx.set(node_id.clone(), result);
345 Ok(())
346 } else {
347 plan.execute(ctx, filters, cache)
349 }
350 }
351
352 ExecutionPlan::Composite { node_ids } => {
353 for nid in node_ids {
356 execute_node(nid, ctx, filters, cache)?;
357 }
358 Ok(())
359 }
360
361 ExecutionPlan::Stream {
362 node_ids,
363 chunk_size,
364 } => execute_stream(node_ids, *chunk_size, ctx, filters, cache),
365
366 _ => {
367 tracing::warn!("Unhandled ExecutionPlan variant");
368 Ok(())
369 }
370 }
371 }
372}
373
374pub fn execute(
376 plan: &ExecutionPlan,
377 ctx: &mut Context,
378 filters: &FilterLibrary,
379 cache: &dyn CacheStore,
380) -> Result<()> {
381 plan.execute(ctx, filters, cache)
382}
383
384fn execute_node(
386 node_id: &str,
387 ctx: &mut Context,
388 filters: &FilterLibrary,
389 _cache: &dyn CacheStore,
390) -> Result<()> {
391 let start = Instant::now();
392
393 let filter = filters
394 .get(node_id)
395 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
396
397 ctx.event_bus.emit(Event::NodeStarted {
398 run_id: ctx.run_id.clone(),
399 node_id: node_id.to_string(),
400 kind: filter.meta().kind,
401 });
402
403 let _span = tracing::info_span!("execute_node", %node_id).entered();
404
405 let input = resolve_input(node_id, ctx);
406 let state = filters.get_state(node_id);
410 let state_ref: &Value = state.as_deref().unwrap_or(&Value::Empty);
411
412 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
414 filter.forward(&input, state_ref)
415 }));
416
417 let result = match result {
418 Ok(inner) => inner,
419 Err(panic) => {
420 let msg = panic
421 .downcast_ref::<String>()
422 .map(|s| s.as_str())
423 .or_else(|| panic.downcast_ref::<&str>().copied())
424 .unwrap_or("unknown panic");
425 tracing::error!(node_id, "filter panicked: {msg}");
426 Err(SomaError::Execution {
427 node_id: node_id.to_string(),
428 message: format!("filter panicked: {msg}"),
429 })
430 }
431 };
432
433 match result {
434 Ok(output) => {
435 let duration = start.elapsed();
436 let summary = format!("{output}");
437 let vv = ctx.maybe_spill(node_id, output);
438 ctx.set_virtual(node_id, vv);
439 ctx.event_bus.emit(Event::NodeCompleted {
440 run_id: ctx.run_id.clone(),
441 node_id: node_id.to_string(),
442 duration,
443 output_summary: summary,
444 });
445 Ok(())
446 }
447 Err(e) => {
448 tracing::error!(node_id, error = %e, "node execution failed");
449 ctx.event_bus.emit(Event::NodeFailed {
450 run_id: ctx.run_id.clone(),
451 node_id: node_id.to_string(),
452 error: e.to_string(),
453 });
454 Err(e)
455 }
456 }
457}
458
459fn execute_parallel(
464 branches: &[ExecutionPlan],
465 ctx: &mut Context,
466 filters: &FilterLibrary,
467 cache: &dyn CacheStore,
468) -> Result<()> {
469 let snapshot_keys: Arc<std::collections::HashSet<String>> =
470 Arc::new(ctx.store.keys().cloned().collect());
471
472 let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
474 let handles: Vec<_> = branches
475 .iter()
476 .map(|branch| {
477 let mut branch_ctx = ctx.snapshot();
478 let keys = snapshot_keys.clone();
479 s.spawn(move || {
480 execute(branch, &mut branch_ctx, filters, cache)?;
481 let new_entries: Vec<(String, VirtualValue)> = branch_ctx
482 .store
483 .into_iter()
484 .filter(|(k, _)| !keys.contains(k))
485 .collect();
486 Ok(new_entries)
487 })
488 })
489 .collect();
490
491 handles.into_iter().map(|h| h.join().unwrap()).collect()
492 });
493
494 for result in results {
496 let entries = result?;
497 for (key, vv) in entries {
498 ctx.set_virtual(key, vv);
499 }
500 }
501
502 Ok(())
503}
504
505fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
507 match vv {
508 VirtualValue::Materialized { value, .. } => Some(value.clone()),
509 VirtualValue::Cached { key, .. } => {
510 if let Some(store) = data_store {
512 let data_ref = somatize_core::store::DataRef::Cached {
513 cache_key: key.clone(),
514 };
515 store.get(&data_ref).ok()
516 } else {
517 None
518 }
519 }
520 _ => None,
521 }
522}
523
524pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
527 let preds = ctx.graph_info.predecessors(node_id);
528
529 let resolve_node = |id: &str| -> Option<Value> {
530 ctx.store
531 .get(id)
532 .and_then(|vv| resolve_value(vv, &ctx.data_store))
533 };
534
535 match preds.len() {
536 0 => ctx
537 .execution_order
538 .last()
539 .and_then(|id| resolve_node(id))
540 .unwrap_or(Value::Empty),
541 1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
542 _ => {
543 let mut merged = serde_json::Map::new();
544 for pred_id in preds {
545 if let Some(val) = resolve_node(pred_id) {
546 let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
547 merged.insert(pred_id.clone(), json_val);
548 }
549 }
550 Value::json(serde_json::Value::Object(merged))
551 }
552 }
553}
554
555fn execute_stream(
557 node_ids: &[String],
558 chunk_size: usize,
559 ctx: &mut Context,
560 filters: &FilterLibrary,
561 cache: &dyn CacheStore,
562) -> Result<()> {
563 use crate::executors::{FittedFilter, StreamExecutor};
564
565 let start = Instant::now();
566
567 let fitted: Vec<FittedFilter> = node_ids
569 .iter()
570 .map(|nid| {
571 let filter = filters
572 .get(nid)
573 .ok_or_else(|| SomaError::NodeNotFound(nid.clone()))?;
574 let state = filters
575 .get_state(nid)
576 .unwrap_or_else(|| Arc::new(Value::Empty));
577 Ok(FittedFilter {
578 name: nid.clone(),
579 filter,
580 state,
581 })
582 })
583 .collect::<Result<_>>()?;
584
585 let first_id = node_ids
587 .first()
588 .ok_or_else(|| SomaError::Other("stream plan has no nodes".into()))?;
589 let input = resolve_input(first_id, ctx);
590
591 let chunks = chunk_value(&input, chunk_size);
593
594 let mut executor = StreamExecutor::new(fitted);
596 if let Some(c) = cache_as_arc(cache) {
597 executor = executor.with_cache(c);
598 }
599
600 let last_id = node_ids.last().unwrap();
601
602 let mut all_data: Vec<f64> = Vec::new();
607 let mut result_shape: Option<Vec<usize>> = None;
608 let mut non_tensor_output: Option<Value> = None;
609
610 let mut append_output = |output: Value| {
611 match &output {
612 Value::Tensor { values, shape } => {
613 if result_shape.is_none() {
614 result_shape = Some(shape.clone());
615 }
616 all_data.extend_from_slice(values.as_slice());
617 }
619 _ => {
620 non_tensor_output = Some(output);
621 }
622 }
623 };
624
625 for (i, chunk) in chunks.into_iter().enumerate() {
626 ctx.event_bus.emit(Event::NodeStarted {
627 run_id: ctx.run_id.clone(),
628 node_id: format!("{last_id}#chunk_{i}"),
629 kind: somatize_core::filter::FilterKind::Stateless,
630 });
631 if let Some(output) = executor.process_chunk(chunk)? {
632 append_output(output);
633 }
634 }
635
636 if let Some(flushed) = executor.flush()? {
638 append_output(flushed);
639 }
640
641 let result = if let Some(mut shape) = result_shape {
643 let row_size: usize = shape.iter().skip(1).product::<usize>().max(1);
645 shape[0] = all_data.len() / row_size;
646 Value::tensor(all_data, shape)
647 } else {
648 non_tensor_output.unwrap_or(Value::Empty)
649 };
650
651 let duration = start.elapsed();
652 ctx.set(last_id.clone(), result);
653 ctx.event_bus.emit(Event::NodeCompleted {
654 run_id: ctx.run_id.clone(),
655 node_id: last_id.clone(),
656 duration,
657 output_summary: format!(
658 "stream: {} chunks through {} filters",
659 executor.chunks_processed(),
660 node_ids.len()
661 ),
662 });
663
664 Ok(())
665}
666
667fn chunk_value(x: &Value, chunk_size: usize) -> Vec<Value> {
669 match x {
670 Value::Tensor { values, shape } if !values.is_empty() && chunk_size > 0 => {
671 let row_size = if shape.len() > 1 {
672 shape[1..].iter().product()
673 } else {
674 1
675 };
676 let n_rows = shape[0];
677 let mut chunks = Vec::new();
678 for start in (0..n_rows).step_by(chunk_size) {
679 let end = (start + chunk_size).min(n_rows);
680 let flat_start = start * row_size;
681 let flat_end = end * row_size;
682 let chunk_vals = values[flat_start..flat_end].to_vec();
683 let mut chunk_shape = shape.clone();
684 chunk_shape[0] = end - start;
685 chunks.push(Value::tensor(chunk_vals, chunk_shape));
686 }
687 chunks
688 }
689 _ => vec![x.clone()],
690 }
691}
692
693fn cache_as_arc(_cache: &dyn CacheStore) -> Option<Arc<dyn CacheStore>> {
696 None
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use crate::cache::MemoryCache;
707 use somatize_core::cache::CacheKey;
708 use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
709
710 struct DoublerFilter;
711
712 impl Filter for DoublerFilter {
713 fn config_hash(&self) -> CacheKey {
714 CacheKey::from_parts(&[b"Doubler"])
715 }
716 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
717 Ok(Value::Empty)
718 }
719 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
720 match x {
721 Value::Tensor { values, shape } => {
722 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
723 Ok(Value::tensor(doubled, shape.clone()))
724 }
725 _ => Ok(x.clone()),
726 }
727 }
728 fn meta(&self) -> FilterMeta {
729 FilterMeta {
730 name: "Doubler".into(),
731 kind: FilterKind::Stateless,
732 cacheable: true,
733 differentiable: true,
734 stream_mode: StreamMode::FixedState,
735 distribution: somatize_core::filter::Distribution::Local,
736 input_schema: None,
737 output_schema: None,
738 }
739 }
740
741 fn as_any(&self) -> &dyn std::any::Any {
742 self
743 }
744 }
745
746 struct AdderFilter {
747 amount: f64,
748 }
749
750 impl Filter for AdderFilter {
751 fn config_hash(&self) -> CacheKey {
752 CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
753 }
754 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
755 Ok(Value::Empty)
756 }
757 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
758 match x {
759 Value::Tensor { values, shape } => {
760 let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
761 Ok(Value::tensor(added, shape.clone()))
762 }
763 _ => Ok(x.clone()),
764 }
765 }
766 fn meta(&self) -> FilterMeta {
767 FilterMeta {
768 name: "Adder".into(),
769 kind: FilterKind::Stateless,
770 cacheable: true,
771 differentiable: true,
772 stream_mode: StreamMode::FixedState,
773 distribution: somatize_core::filter::Distribution::Local,
774 input_schema: None,
775 output_schema: None,
776 }
777 }
778
779 fn as_any(&self) -> &dyn std::any::Any {
780 self
781 }
782 }
783
784 struct SlowFilter {
786 id: String,
787 delay_ms: u64,
788 }
789
790 impl Filter for SlowFilter {
791 fn config_hash(&self) -> CacheKey {
792 CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
793 }
794 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
795 Ok(Value::Empty)
796 }
797 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
798 std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
799 Ok(x.clone())
800 }
801 fn meta(&self) -> FilterMeta {
802 FilterMeta {
803 name: format!("Slow_{}", self.id),
804 kind: FilterKind::Stateless,
805 cacheable: false,
806 differentiable: true,
807 stream_mode: StreamMode::FixedState,
808 distribution: somatize_core::filter::Distribution::Local,
809 input_schema: None,
810 output_schema: None,
811 }
812 }
813
814 fn as_any(&self) -> &dyn std::any::Any {
815 self
816 }
817 }
818
819 fn setup() -> (Arc<EventBus>, MemoryCache) {
820 (Arc::new(EventBus::new(64)), MemoryCache::default())
821 }
822
823 #[test]
824 fn execute_single_node() {
825 let (bus, cache) = setup();
826 let mut ctx = Context::new(bus, "run_1");
827 ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
828 ctx.graph_info
829 .set_predecessors("doubler", vec!["input".into()]);
830
831 let mut filters = FilterLibrary::new();
832 filters.register("doubler", Box::new(DoublerFilter));
833
834 let plan = ExecutionPlan::Execute {
835 node_id: "doubler".into(),
836 };
837
838 execute(&plan, &mut ctx, &filters, &cache).unwrap();
839
840 let result = ctx.get("doubler").unwrap();
841 let (data, _) = result.as_tensor().unwrap();
842 assert_eq!(data, &[2.0, 4.0, 6.0]);
843 }
844
845 #[test]
846 fn execute_sequence_with_graph_info() {
847 let (bus, cache) = setup();
848 let mut ctx = Context::new(bus, "run_1");
849 ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
850
851 let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
852 ctx.graph_info = graph_info;
853
854 let mut filters = FilterLibrary::new();
855 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
856 filters.register("double", Box::new(DoublerFilter));
857
858 let plan = ExecutionPlan::Sequence(vec![
859 ExecutionPlan::Execute {
860 node_id: "add".into(),
861 },
862 ExecutionPlan::Execute {
863 node_id: "double".into(),
864 },
865 ]);
866
867 execute(&plan, &mut ctx, &filters, &cache).unwrap();
868
869 let result = ctx.get("double").unwrap();
870 let (data, _) = result.as_tensor().unwrap();
871 assert_eq!(data, &[22.0, 24.0]);
872 }
873
874 #[test]
875 fn execute_cached_node() {
876 let (bus, cache) = setup();
877 let key = CacheKey::hash_data(b"cached_result");
878 let cached_value = Value::tensor(vec![99.0], vec![1]);
879 cache.put(&key, &cached_value).unwrap();
880
881 let mut ctx = Context::new(bus, "run_1");
882 let filters = FilterLibrary::new();
883
884 let plan = ExecutionPlan::Cached {
885 node_id: "cached_node".into(),
886 key,
887 };
888
889 execute(&plan, &mut ctx, &filters, &cache).unwrap();
890 assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
891 }
892
893 #[test]
894 fn execute_emits_events() {
895 let bus = Arc::new(EventBus::new(64));
896 let cache = MemoryCache::default();
897 let mut rx = bus.subscribe();
898
899 let mut ctx = Context::new(bus, "run_1");
900 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
901 ctx.graph_info
902 .set_predecessors("double", vec!["input".into()]);
903
904 let mut filters = FilterLibrary::new();
905 filters.register("double", Box::new(DoublerFilter));
906
907 execute(
908 &ExecutionPlan::Execute {
909 node_id: "double".into(),
910 },
911 &mut ctx,
912 &filters,
913 &cache,
914 )
915 .unwrap();
916
917 let e1 = rx.try_recv().unwrap();
918 assert!(matches!(e1, Event::NodeStarted { .. }));
919 let e2 = rx.try_recv().unwrap();
920 assert!(matches!(e2, Event::NodeCompleted { .. }));
921 }
922
923 #[test]
924 fn execute_missing_filter_errors() {
925 let (bus, cache) = setup();
926 let mut ctx = Context::new(bus, "run_1");
927 let filters = FilterLibrary::new();
928
929 let result = execute(
930 &ExecutionPlan::Execute {
931 node_id: "nonexistent".into(),
932 },
933 &mut ctx,
934 &filters,
935 &cache,
936 );
937 assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
938 }
939
940 #[test]
941 fn execute_empty_plan() {
942 let (bus, cache) = setup();
943 let mut ctx = Context::new(bus, "run_1");
944 let filters = FilterLibrary::new();
945 execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
946 }
947
948 #[test]
949 fn execute_parallel_branches_merge_outputs() {
950 let (bus, cache) = setup();
951 let mut ctx = Context::new(bus, "run_1");
952 ctx.set("input", Value::tensor(vec![5.0], vec![1]));
953 ctx.graph_info
954 .set_predecessors("double", vec!["input".into()]);
955 ctx.graph_info.set_predecessors("add", vec!["input".into()]);
956
957 let mut filters = FilterLibrary::new();
958 filters.register("double", Box::new(DoublerFilter));
959 filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
960
961 let plan = ExecutionPlan::Parallel(vec![
962 ExecutionPlan::Execute {
963 node_id: "double".into(),
964 },
965 ExecutionPlan::Execute {
966 node_id: "add".into(),
967 },
968 ]);
969
970 execute(&plan, &mut ctx, &filters, &cache).unwrap();
971
972 let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
973 assert_eq!(double_out, &[10.0]);
974 let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
975 assert_eq!(add_out, &[105.0]);
976 }
977
978 #[test]
979 fn parallel_branches_run_concurrently() {
980 let (bus, cache) = setup();
981 let mut ctx = Context::new(bus, "run_1");
982 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
983 ctx.graph_info
984 .set_predecessors("slow_a", vec!["input".into()]);
985 ctx.graph_info
986 .set_predecessors("slow_b", vec!["input".into()]);
987
988 let mut filters = FilterLibrary::new();
989 filters.register(
990 "slow_a",
991 Box::new(SlowFilter {
992 id: "a".into(),
993 delay_ms: 50,
994 }),
995 );
996 filters.register(
997 "slow_b",
998 Box::new(SlowFilter {
999 id: "b".into(),
1000 delay_ms: 50,
1001 }),
1002 );
1003
1004 let plan = ExecutionPlan::Parallel(vec![
1005 ExecutionPlan::Execute {
1006 node_id: "slow_a".into(),
1007 },
1008 ExecutionPlan::Execute {
1009 node_id: "slow_b".into(),
1010 },
1011 ]);
1012
1013 let start = Instant::now();
1014 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1015 let elapsed = start.elapsed();
1016
1017 assert!(
1020 elapsed.as_millis() < 90,
1021 "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
1022 elapsed.as_millis()
1023 );
1024
1025 assert!(ctx.get("slow_a").is_some());
1026 assert!(ctx.get("slow_b").is_some());
1027 }
1028
1029 #[test]
1030 fn resolve_input_single_predecessor() {
1031 let bus = Arc::new(EventBus::new(8));
1032 let mut ctx = Context::new(bus, "r");
1033 ctx.set("A", Value::tensor(vec![42.0], vec![1]));
1034 ctx.graph_info.set_predecessors("B", vec!["A".into()]);
1035
1036 let input = resolve_input("B", &ctx);
1037 let (data, _) = input.as_tensor().unwrap();
1038 assert_eq!(data, &[42.0]);
1039 }
1040
1041 #[test]
1042 fn resolve_input_multiple_predecessors() {
1043 let bus = Arc::new(EventBus::new(8));
1044 let mut ctx = Context::new(bus, "r");
1045 ctx.set("A", Value::tensor(vec![1.0], vec![1]));
1046 ctx.set("B", Value::tensor(vec![2.0], vec![1]));
1047 ctx.graph_info
1048 .set_predecessors("C", vec!["A".into(), "B".into()]);
1049
1050 let input = resolve_input("C", &ctx);
1051 let json = input.as_json().unwrap();
1052 assert!(json.get("A").is_some());
1053 assert!(json.get("B").is_some());
1054 }
1055
1056 #[test]
1057 fn resolve_input_no_predecessors_fallback() {
1058 let bus = Arc::new(EventBus::new(8));
1059 let mut ctx = Context::new(bus, "r");
1060 ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
1061
1062 let input = resolve_input("root", &ctx);
1063 let (data, _) = input.as_tensor().unwrap();
1064 assert_eq!(data, &[7.0]);
1065 }
1066
1067 #[test]
1068 fn graph_info_from_linear() {
1069 let info = GraphInfo::for_linear(&["a", "b", "c"]);
1070 assert!(info.predecessors("a").is_empty());
1071 assert_eq!(info.predecessors("b"), &["a"]);
1072 assert_eq!(info.predecessors("c"), &["b"]);
1073 }
1074
1075 #[test]
1076 fn execute_stream_chunks_input() {
1077 let (bus, cache) = setup();
1078 let mut ctx = Context::new(bus, "run_stream");
1079 ctx.set(
1081 "__input__",
1082 Value::tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]),
1083 );
1084 ctx.graph_info
1085 .set_predecessors("double", vec!["__input__".into()]);
1086
1087 let mut filters = FilterLibrary::new();
1088 filters.register("double", Box::new(DoublerFilter));
1089
1090 let plan = ExecutionPlan::Stream {
1091 node_ids: vec!["double".into()],
1092 chunk_size: 2,
1093 };
1094
1095 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1096
1097 let result = ctx.get("double").unwrap();
1098 let (data, shape) = result.as_tensor().unwrap();
1099 assert_eq!(data, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
1100 assert_eq!(shape, &[6]);
1101 }
1102
1103 #[test]
1104 fn execute_stream_chain() {
1105 let (bus, cache) = setup();
1106 let mut ctx = Context::new(bus, "run_stream_chain");
1107 ctx.set(
1108 "__input__",
1109 Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]),
1110 );
1111 ctx.graph_info
1112 .set_predecessors("double", vec!["__input__".into()]);
1113 ctx.graph_info
1114 .set_predecessors("add", vec!["double".into()]);
1115
1116 let mut filters = FilterLibrary::new();
1117 filters.register("double", Box::new(DoublerFilter));
1118 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
1119
1120 let plan = ExecutionPlan::Stream {
1121 node_ids: vec!["double".into(), "add".into()],
1122 chunk_size: 2,
1123 };
1124
1125 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1126
1127 let result = ctx.get("add").unwrap();
1129 let (data, shape) = result.as_tensor().unwrap();
1130 assert_eq!(data, &[12.0, 14.0, 16.0, 18.0]);
1131 assert_eq!(shape, &[4]);
1132 }
1133
1134 #[test]
1135 fn execute_stream_single_chunk() {
1136 let (bus, cache) = setup();
1137 let mut ctx = Context::new(bus, "run_stream_single");
1138 ctx.set("__input__", Value::tensor(vec![5.0, 10.0], vec![2]));
1139 ctx.graph_info
1140 .set_predecessors("double", vec!["__input__".into()]);
1141
1142 let mut filters = FilterLibrary::new();
1143 filters.register("double", Box::new(DoublerFilter));
1144
1145 let plan = ExecutionPlan::Stream {
1147 node_ids: vec!["double".into()],
1148 chunk_size: 1000,
1149 };
1150
1151 execute(&plan, &mut ctx, &filters, &cache).unwrap();
1152
1153 let result = ctx.get("double").unwrap();
1154 let (data, _) = result.as_tensor().unwrap();
1155 assert_eq!(data, &[10.0, 20.0]);
1156 }
1157}