1use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25 predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36 self.predecessors.insert(node_id.into(), preds);
37 }
38
39 pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41 let mut info = Self::new();
42 for node in &graph.nodes {
43 let preds: Vec<String> = graph
44 .predecessors(&node.id)
45 .into_iter()
46 .map(|s| s.to_string())
47 .collect();
48 info.set_predecessors(node.id.clone(), preds);
49 }
50 info
51 }
52
53 pub fn for_linear(node_ids: &[&str]) -> Self {
55 let mut info = Self::new();
56 for (i, &id) in node_ids.iter().enumerate() {
57 let preds = if i > 0 {
58 vec![node_ids[i - 1].to_string()]
59 } else {
60 vec![]
61 };
62 info.set_predecessors(id, preds);
63 }
64 info
65 }
66
67 pub fn predecessors(&self, node_id: &str) -> &[String] {
69 self.predecessors
70 .get(node_id)
71 .map(|v| v.as_slice())
72 .unwrap_or(&[])
73 }
74}
75
76pub trait RemoteExecutor: Send + Sync {
82 fn execute_remote(
84 &self,
85 node_id: &str,
86 target: &somatize_core::filter::RemoteTarget,
87 input: Option<&Value>,
88 ) -> Result<Value>;
89}
90
91pub struct Context {
97 pub store: HashMap<String, VirtualValue>,
99 pub event_bus: Arc<EventBus>,
101 pub run_id: String,
103 pub execution_order: Vec<String>,
105 pub graph_info: GraphInfo,
107 pub remote_executor: Option<Arc<dyn RemoteExecutor>>,
109 pub data_store: Option<Arc<dyn DataStore>>,
111 pub spill_threshold: usize,
114}
115
116impl Context {
117 pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
118 Self {
119 store: HashMap::new(),
120 event_bus,
121 run_id: run_id.into(),
122 execution_order: Vec::new(),
123 graph_info: GraphInfo::new(),
124 remote_executor: None,
125 data_store: None,
126 spill_threshold: 0,
127 }
128 }
129
130 pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
131 self.graph_info = info;
132 self
133 }
134
135 pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
136 self.remote_executor = Some(executor);
137 self
138 }
139
140 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
141 self.data_store = Some(store);
142 self
143 }
144
145 pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
149 self.spill_threshold = bytes;
150 self
151 }
152
153 fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
156 if self.spill_threshold > 0
157 && let Some(store) = &self.data_store
158 {
159 let size = value.size() * 8; if size >= self.spill_threshold {
161 let key = somatize_core::cache::CacheKey::from_parts(&[
162 self.run_id.as_bytes(),
163 node_id.as_bytes(),
164 ]);
165 let vv_for_schema = VirtualValue::materialized(value.clone());
166 let schema = vv_for_schema.schema().clone();
167 if let Ok(_data_ref) = store.put(&key, &value) {
168 tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
169 return VirtualValue::cached(key, schema);
170 }
171 }
172 }
173 VirtualValue::materialized(value)
174 }
175
176 pub fn get(&self, node_id: &str) -> Option<&Value> {
178 self.store.get(node_id).and_then(|vv| vv.as_value())
179 }
180
181 pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
183 self.store.get(node_id)
184 }
185
186 pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
188 let id = node_id.into();
189 self.execution_order.push(id.clone());
190 self.store.insert(id, VirtualValue::materialized(value));
191 }
192
193 pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
195 let id = node_id.into();
196 self.execution_order.push(id.clone());
197 self.store.insert(id, vv);
198 }
199
200 fn snapshot(&self) -> Self {
201 Self {
202 store: self.store.clone(),
203 event_bus: self.event_bus.clone(),
204 run_id: self.run_id.clone(),
205 execution_order: self.execution_order.clone(),
206 graph_info: self.graph_info.clone(),
207 remote_executor: self.remote_executor.clone(),
208 data_store: self.data_store.clone(),
209 spill_threshold: self.spill_threshold,
210 }
211 }
212}
213
214pub fn execute(
217 plan: &ExecutionPlan,
218 ctx: &mut Context,
219 filters: &FilterLibrary,
220 cache: &dyn CacheStore,
221) -> Result<()> {
222 match plan {
223 ExecutionPlan::Empty => Ok(()),
224
225 ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
226
227 ExecutionPlan::Cached { node_id, key } => {
228 let start = Instant::now();
229 let value = cache.get(key)?.ok_or_else(|| {
230 SomaError::Cache(format!(
231 "expected cached value for node `{node_id}` not found"
232 ))
233 })?;
234 ctx.set(node_id.clone(), value);
235 ctx.event_bus.emit(Event::NodeCacheHit {
236 run_id: ctx.run_id.clone(),
237 node_id: node_id.clone(),
238 key: key.clone(),
239 tier: somatize_core::cache::CacheTier::Memory,
240 load_time: start.elapsed(),
241 });
242 Ok(())
243 }
244
245 ExecutionPlan::Sequence(steps) => {
246 for step in steps {
247 execute(step, ctx, filters, cache)?;
248 }
249 Ok(())
250 }
251
252 ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
253
254 ExecutionPlan::Loop {
255 node_id,
256 body,
257 max_iterations,
258 } => {
259 let max = max_iterations.unwrap_or(100);
260 for i in 0..max {
261 execute(body, ctx, filters, cache)?;
262
263 let should_stop = ctx
266 .execution_order
267 .last()
268 .and_then(|last_id| ctx.get(last_id))
269 .map(|v| match v {
270 Value::Json(j) => {
271 j.as_bool() == Some(true)
272 || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
273 || j.get("done").and_then(|d| d.as_bool()) == Some(true)
274 }
275 Value::Empty => true,
276 _ => false,
277 })
278 .unwrap_or(false);
279
280 if should_stop {
281 ctx.event_bus.emit(Event::NodeCompleted {
282 run_id: ctx.run_id.clone(),
283 node_id: node_id.clone(),
284 duration: std::time::Duration::ZERO,
285 output_summary: format!("Loop terminated at iteration {}", i + 1),
286 });
287 break;
288 }
289 }
290 Ok(())
291 }
292
293 ExecutionPlan::Branch { node_id, arms } => {
294 execute_node(node_id, ctx, filters, cache)?;
296
297 let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
299
300 let selected_arm = match &condition {
302 Value::Json(j) => {
303 let selector = j
305 .as_str()
306 .map(String::from)
307 .or_else(|| j.as_bool().map(|b| b.to_string()))
308 .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
309 .unwrap_or_else(|| "true".to_string());
310
311 arms.iter()
312 .find(|(label, _)| label == &selector)
313 .or_else(|| {
314 arms.iter()
315 .find(|(label, _)| label == "default" || label == "else")
316 })
317 .or_else(|| arms.first())
318 }
319 _ => arms.first(),
320 };
321
322 if let Some((label, plan)) = selected_arm {
323 ctx.event_bus.emit(Event::NodeCompleted {
324 run_id: ctx.run_id.clone(),
325 node_id: node_id.clone(),
326 duration: std::time::Duration::ZERO,
327 output_summary: format!("Branch selected: {label}"),
328 });
329 execute(plan, ctx, filters, cache)?;
330 }
331 Ok(())
332 }
333
334 ExecutionPlan::Remote {
335 node_id,
336 target,
337 plan,
338 } => {
339 if let Some(remote) = &ctx.remote_executor {
340 let input = ctx
342 .graph_info
343 .predecessors(node_id)
344 .first()
345 .and_then(|pred| ctx.get(pred));
346
347 let result = remote.execute_remote(node_id, target, input)?;
348 ctx.set(node_id.clone(), result);
349 ctx.execution_order.push(node_id.clone());
350 Ok(())
351 } else {
352 execute(plan, ctx, filters, cache)
354 }
355 }
356
357 _ => {
358 tracing::warn!("Unhandled ExecutionPlan variant");
359 Ok(())
360 }
361 }
362}
363
364fn execute_node(
366 node_id: &str,
367 ctx: &mut Context,
368 filters: &FilterLibrary,
369 _cache: &dyn CacheStore,
370) -> Result<()> {
371 let start = Instant::now();
372
373 let filter = filters
374 .get(node_id)
375 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
376
377 ctx.event_bus.emit(Event::NodeStarted {
378 run_id: ctx.run_id.clone(),
379 node_id: node_id.to_string(),
380 kind: filter.meta().kind,
381 });
382
383 let input = resolve_input(node_id, ctx);
384 let state = filters.get_state(node_id).cloned().unwrap_or(Value::Empty);
385 let result = filter.forward(&input, &state);
386
387 match result {
388 Ok(output) => {
389 let duration = start.elapsed();
390 let summary = format!("{output}");
391 let vv = ctx.maybe_spill(node_id, output);
392 ctx.set_virtual(node_id, vv);
393 ctx.event_bus.emit(Event::NodeCompleted {
394 run_id: ctx.run_id.clone(),
395 node_id: node_id.to_string(),
396 duration,
397 output_summary: summary,
398 });
399 Ok(())
400 }
401 Err(e) => {
402 ctx.event_bus.emit(Event::NodeFailed {
403 run_id: ctx.run_id.clone(),
404 node_id: node_id.to_string(),
405 error: e.to_string(),
406 });
407 Err(e)
408 }
409 }
410}
411
412fn execute_parallel(
417 branches: &[ExecutionPlan],
418 ctx: &mut Context,
419 filters: &FilterLibrary,
420 cache: &dyn CacheStore,
421) -> Result<()> {
422 let snapshot_keys: Arc<std::collections::HashSet<String>> =
423 Arc::new(ctx.store.keys().cloned().collect());
424
425 let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
427 let handles: Vec<_> = branches
428 .iter()
429 .map(|branch| {
430 let mut branch_ctx = ctx.snapshot();
431 let keys = snapshot_keys.clone();
432 s.spawn(move || {
433 execute(branch, &mut branch_ctx, filters, cache)?;
434 let new_entries: Vec<(String, VirtualValue)> = branch_ctx
435 .store
436 .into_iter()
437 .filter(|(k, _)| !keys.contains(k))
438 .collect();
439 Ok(new_entries)
440 })
441 })
442 .collect();
443
444 handles.into_iter().map(|h| h.join().unwrap()).collect()
445 });
446
447 for result in results {
449 let entries = result?;
450 for (key, vv) in entries {
451 ctx.set_virtual(key, vv);
452 }
453 }
454
455 Ok(())
456}
457
458fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
460 match vv {
461 VirtualValue::Materialized { value, .. } => Some(value.clone()),
462 VirtualValue::Cached { key, .. } => {
463 if let Some(store) = data_store {
465 let data_ref = somatize_core::store::DataRef::Cached {
466 cache_key: key.clone(),
467 };
468 store.get(&data_ref).ok()
469 } else {
470 None
471 }
472 }
473 _ => None,
474 }
475}
476
477pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
480 let preds = ctx.graph_info.predecessors(node_id);
481
482 let resolve_node = |id: &str| -> Option<Value> {
483 ctx.store
484 .get(id)
485 .and_then(|vv| resolve_value(vv, &ctx.data_store))
486 };
487
488 match preds.len() {
489 0 => ctx
490 .execution_order
491 .last()
492 .and_then(|id| resolve_node(id))
493 .unwrap_or(Value::Empty),
494 1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
495 _ => {
496 let mut merged = serde_json::Map::new();
497 for pred_id in preds {
498 if let Some(val) = resolve_node(pred_id) {
499 let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
500 merged.insert(pred_id.clone(), json_val);
501 }
502 }
503 Value::Json(serde_json::Value::Object(merged))
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511 use crate::cache::MemoryCache;
512 use somatize_core::cache::CacheKey;
513 use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
514
515 struct DoublerFilter;
516
517 impl Filter for DoublerFilter {
518 fn config_hash(&self) -> CacheKey {
519 CacheKey::from_parts(&[b"Doubler"])
520 }
521 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
522 Ok(Value::Empty)
523 }
524 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
525 match x {
526 Value::Tensor { values, shape } => {
527 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
528 Ok(Value::tensor(doubled, shape.clone()))
529 }
530 _ => Ok(x.clone()),
531 }
532 }
533 fn meta(&self) -> FilterMeta {
534 FilterMeta {
535 name: "Doubler".into(),
536 kind: FilterKind::Stateless,
537 cacheable: true,
538 differentiable: true,
539 stream_mode: StreamMode::FixedState,
540 distribution: somatize_core::filter::Distribution::Local,
541 input_schema: None,
542 output_schema: None,
543 }
544 }
545 }
546
547 struct AdderFilter {
548 amount: f64,
549 }
550
551 impl Filter for AdderFilter {
552 fn config_hash(&self) -> CacheKey {
553 CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
554 }
555 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
556 Ok(Value::Empty)
557 }
558 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
559 match x {
560 Value::Tensor { values, shape } => {
561 let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
562 Ok(Value::tensor(added, shape.clone()))
563 }
564 _ => Ok(x.clone()),
565 }
566 }
567 fn meta(&self) -> FilterMeta {
568 FilterMeta {
569 name: "Adder".into(),
570 kind: FilterKind::Stateless,
571 cacheable: true,
572 differentiable: true,
573 stream_mode: StreamMode::FixedState,
574 distribution: somatize_core::filter::Distribution::Local,
575 input_schema: None,
576 output_schema: None,
577 }
578 }
579 }
580
581 struct SlowFilter {
583 id: String,
584 delay_ms: u64,
585 }
586
587 impl Filter for SlowFilter {
588 fn config_hash(&self) -> CacheKey {
589 CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
590 }
591 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
592 Ok(Value::Empty)
593 }
594 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
595 std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
596 Ok(x.clone())
597 }
598 fn meta(&self) -> FilterMeta {
599 FilterMeta {
600 name: format!("Slow_{}", self.id),
601 kind: FilterKind::Stateless,
602 cacheable: false,
603 differentiable: true,
604 stream_mode: StreamMode::FixedState,
605 distribution: somatize_core::filter::Distribution::Local,
606 input_schema: None,
607 output_schema: None,
608 }
609 }
610 }
611
612 fn setup() -> (Arc<EventBus>, MemoryCache) {
613 (Arc::new(EventBus::new(64)), MemoryCache::default())
614 }
615
616 #[test]
617 fn execute_single_node() {
618 let (bus, cache) = setup();
619 let mut ctx = Context::new(bus, "run_1");
620 ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
621 ctx.graph_info
622 .set_predecessors("doubler", vec!["input".into()]);
623
624 let mut filters = FilterLibrary::new();
625 filters.register("doubler", Box::new(DoublerFilter));
626
627 let plan = ExecutionPlan::Execute {
628 node_id: "doubler".into(),
629 };
630
631 execute(&plan, &mut ctx, &filters, &cache).unwrap();
632
633 let result = ctx.get("doubler").unwrap();
634 let (data, _) = result.as_tensor().unwrap();
635 assert_eq!(data, &[2.0, 4.0, 6.0]);
636 }
637
638 #[test]
639 fn execute_sequence_with_graph_info() {
640 let (bus, cache) = setup();
641 let mut ctx = Context::new(bus, "run_1");
642 ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
643
644 let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
645 ctx.graph_info = graph_info;
646
647 let mut filters = FilterLibrary::new();
648 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
649 filters.register("double", Box::new(DoublerFilter));
650
651 let plan = ExecutionPlan::Sequence(vec![
652 ExecutionPlan::Execute {
653 node_id: "add".into(),
654 },
655 ExecutionPlan::Execute {
656 node_id: "double".into(),
657 },
658 ]);
659
660 execute(&plan, &mut ctx, &filters, &cache).unwrap();
661
662 let result = ctx.get("double").unwrap();
663 let (data, _) = result.as_tensor().unwrap();
664 assert_eq!(data, &[22.0, 24.0]);
665 }
666
667 #[test]
668 fn execute_cached_node() {
669 let (bus, cache) = setup();
670 let key = CacheKey::hash_data(b"cached_result");
671 let cached_value = Value::tensor(vec![99.0], vec![1]);
672 cache.put(&key, &cached_value).unwrap();
673
674 let mut ctx = Context::new(bus, "run_1");
675 let filters = FilterLibrary::new();
676
677 let plan = ExecutionPlan::Cached {
678 node_id: "cached_node".into(),
679 key,
680 };
681
682 execute(&plan, &mut ctx, &filters, &cache).unwrap();
683 assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
684 }
685
686 #[test]
687 fn execute_emits_events() {
688 let bus = Arc::new(EventBus::new(64));
689 let cache = MemoryCache::default();
690 let mut rx = bus.subscribe();
691
692 let mut ctx = Context::new(bus, "run_1");
693 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
694 ctx.graph_info
695 .set_predecessors("double", vec!["input".into()]);
696
697 let mut filters = FilterLibrary::new();
698 filters.register("double", Box::new(DoublerFilter));
699
700 execute(
701 &ExecutionPlan::Execute {
702 node_id: "double".into(),
703 },
704 &mut ctx,
705 &filters,
706 &cache,
707 )
708 .unwrap();
709
710 let e1 = rx.try_recv().unwrap();
711 assert!(matches!(e1, Event::NodeStarted { .. }));
712 let e2 = rx.try_recv().unwrap();
713 assert!(matches!(e2, Event::NodeCompleted { .. }));
714 }
715
716 #[test]
717 fn execute_missing_filter_errors() {
718 let (bus, cache) = setup();
719 let mut ctx = Context::new(bus, "run_1");
720 let filters = FilterLibrary::new();
721
722 let result = execute(
723 &ExecutionPlan::Execute {
724 node_id: "nonexistent".into(),
725 },
726 &mut ctx,
727 &filters,
728 &cache,
729 );
730 assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
731 }
732
733 #[test]
734 fn execute_empty_plan() {
735 let (bus, cache) = setup();
736 let mut ctx = Context::new(bus, "run_1");
737 let filters = FilterLibrary::new();
738 execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
739 }
740
741 #[test]
742 fn execute_parallel_branches_merge_outputs() {
743 let (bus, cache) = setup();
744 let mut ctx = Context::new(bus, "run_1");
745 ctx.set("input", Value::tensor(vec![5.0], vec![1]));
746 ctx.graph_info
747 .set_predecessors("double", vec!["input".into()]);
748 ctx.graph_info.set_predecessors("add", vec!["input".into()]);
749
750 let mut filters = FilterLibrary::new();
751 filters.register("double", Box::new(DoublerFilter));
752 filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
753
754 let plan = ExecutionPlan::Parallel(vec![
755 ExecutionPlan::Execute {
756 node_id: "double".into(),
757 },
758 ExecutionPlan::Execute {
759 node_id: "add".into(),
760 },
761 ]);
762
763 execute(&plan, &mut ctx, &filters, &cache).unwrap();
764
765 let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
766 assert_eq!(double_out, &[10.0]);
767 let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
768 assert_eq!(add_out, &[105.0]);
769 }
770
771 #[test]
772 fn parallel_branches_run_concurrently() {
773 let (bus, cache) = setup();
774 let mut ctx = Context::new(bus, "run_1");
775 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
776 ctx.graph_info
777 .set_predecessors("slow_a", vec!["input".into()]);
778 ctx.graph_info
779 .set_predecessors("slow_b", vec!["input".into()]);
780
781 let mut filters = FilterLibrary::new();
782 filters.register(
783 "slow_a",
784 Box::new(SlowFilter {
785 id: "a".into(),
786 delay_ms: 50,
787 }),
788 );
789 filters.register(
790 "slow_b",
791 Box::new(SlowFilter {
792 id: "b".into(),
793 delay_ms: 50,
794 }),
795 );
796
797 let plan = ExecutionPlan::Parallel(vec![
798 ExecutionPlan::Execute {
799 node_id: "slow_a".into(),
800 },
801 ExecutionPlan::Execute {
802 node_id: "slow_b".into(),
803 },
804 ]);
805
806 let start = Instant::now();
807 execute(&plan, &mut ctx, &filters, &cache).unwrap();
808 let elapsed = start.elapsed();
809
810 assert!(
813 elapsed.as_millis() < 90,
814 "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
815 elapsed.as_millis()
816 );
817
818 assert!(ctx.get("slow_a").is_some());
819 assert!(ctx.get("slow_b").is_some());
820 }
821
822 #[test]
823 fn resolve_input_single_predecessor() {
824 let bus = Arc::new(EventBus::new(8));
825 let mut ctx = Context::new(bus, "r");
826 ctx.set("A", Value::tensor(vec![42.0], vec![1]));
827 ctx.graph_info.set_predecessors("B", vec!["A".into()]);
828
829 let input = resolve_input("B", &ctx);
830 let (data, _) = input.as_tensor().unwrap();
831 assert_eq!(data, &[42.0]);
832 }
833
834 #[test]
835 fn resolve_input_multiple_predecessors() {
836 let bus = Arc::new(EventBus::new(8));
837 let mut ctx = Context::new(bus, "r");
838 ctx.set("A", Value::tensor(vec![1.0], vec![1]));
839 ctx.set("B", Value::tensor(vec![2.0], vec![1]));
840 ctx.graph_info
841 .set_predecessors("C", vec!["A".into(), "B".into()]);
842
843 let input = resolve_input("C", &ctx);
844 let json = input.as_json().unwrap();
845 assert!(json.get("A").is_some());
846 assert!(json.get("B").is_some());
847 }
848
849 #[test]
850 fn resolve_input_no_predecessors_fallback() {
851 let bus = Arc::new(EventBus::new(8));
852 let mut ctx = Context::new(bus, "r");
853 ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
854
855 let input = resolve_input("root", &ctx);
856 let (data, _) = input.as_tensor().unwrap();
857 assert_eq!(data, &[7.0]);
858 }
859
860 #[test]
861 fn graph_info_from_linear() {
862 let info = GraphInfo::for_linear(&["a", "b", "c"]);
863 assert!(info.predecessors("a").is_empty());
864 assert_eq!(info.predecessors("b"), &["a"]);
865 assert_eq!(info.predecessors("c"), &["b"]);
866 }
867}