use std::collections::HashSet;
use std::sync::Arc;
use super::*;
use crate::graph::builder::{GraphBuilder, NodeContext};
use crate::graph::checkpoint::{Checkpointer, InMemoryCheckpointer};
use crate::graph::command::NodeResult;
use crate::graph::reducer::ClosureStateReducer;
use crate::harness::ids::{NodeId, RunId};
fn ctx_for(id: &str) -> NodeContext {
NodeContext {
node_id: NodeId::from(id),
run_id: RunId::new("run-test"),
thread_id: None,
step: 1,
resume: None,
fork: None,
send_arg: None,
root_run_id: None,
recursion_frames: Vec::new(),
child_runs: None,
}
}
fn child_add_ten() -> CompiledGraph<i32, i32> {
GraphBuilder::<i32, i32>::overwrite()
.add_node("add", |s: i32, _c: NodeContext| async move {
Ok(NodeResult::Update(s + 10))
})
.set_entry("add")
.set_finish("add")
.compile()
.unwrap()
}
#[tokio::test]
async fn shared_state_subgraph() {
let child = child_add_ten();
let parent = GraphBuilder::<i32, i32>::overwrite()
.add_node("pre", |s: i32, _c: NodeContext| async move {
Ok(NodeResult::Update(s + 1))
})
.add_node("child", shared_subgraph_node(child))
.set_entry("pre")
.add_edge("pre", "child")
.set_finish("child")
.compile()
.unwrap();
let run = parent.run(0).await.unwrap();
assert_eq!(run.state, 11);
}
#[derive(Clone, Debug, PartialEq)]
struct ParentState {
name: String,
score: i32,
}
#[tokio::test]
async fn adapter_subgraph_maps_state() {
let child = child_add_ten();
let parent = GraphBuilder::<ParentState, ParentState>::new()
.set_reducer(ClosureStateReducer::new(|_old, new: ParentState| Ok(new)))
.add_node(
"score",
adapter_subgraph_node(
child,
|p: &ParentState| p.score,
|p: &ParentState, child_score: i32| ParentState {
name: p.name.clone(),
score: child_score,
},
),
)
.set_entry("score")
.set_finish("score")
.compile()
.unwrap();
let run = parent
.run(ParentState {
name: "alice".to_string(),
score: 5,
})
.await
.unwrap();
assert_eq!(run.state.name, "alice");
assert_eq!(run.state.score, 15);
}
#[tokio::test]
async fn adapter_folds_child_output_with_parent_context() {
let child = child_add_ten();
let parent = GraphBuilder::<ParentState, ParentState>::new()
.set_reducer(ClosureStateReducer::new(|_old, new: ParentState| Ok(new)))
.add_node(
"score",
adapter_subgraph_node(
child,
|p: &ParentState| p.score,
|p: &ParentState, child_score: i32| ParentState {
name: format!("{}-scored", p.name),
score: p.score + child_score,
},
),
)
.set_entry("score")
.set_finish("score")
.compile()
.unwrap();
let run = parent
.run(ParentState {
name: "bob".to_string(),
score: 5,
})
.await
.unwrap();
assert_eq!(run.state.score, 20);
assert_eq!(run.state.name, "bob-scored");
}
#[test]
fn namespaced_clone_appends_embedding_node_id() {
let child = child_add_ten();
assert!(child.namespace().is_empty());
let scoped = namespaced(&child, &ctx_for("embed"));
assert_eq!(scoped.namespace(), &["embed".to_string()]);
}
#[test]
fn nested_namespaces_accumulate_and_stay_distinct() {
let child = child_add_ten();
let outer = namespaced(&child, &ctx_for("outer"));
let inner = namespaced(&outer, &ctx_for("inner"));
assert_eq!(
inner.namespace(),
&["outer".to_string(), "inner".to_string()]
);
let sibling = namespaced(&child, &ctx_for("other"));
assert_ne!(outer.namespace(), sibling.namespace());
}
#[tokio::test]
async fn namespaced_children_persist_under_isolated_namespaces() {
let ckpt = Arc::new(InMemoryCheckpointer::<i32>::new());
let child = child_add_ten().with_checkpointer(ckpt.clone());
let branch_a = namespaced(&child, &ctx_for("branch_a"));
let branch_b = namespaced(&child, &ctx_for("branch_b"));
branch_a.run_with_thread("t", 0).await.unwrap();
branch_b.run_with_thread("t", 1).await.unwrap();
let list = ckpt.list("t").await.unwrap();
assert_eq!(list.len(), 2);
let ids: HashSet<&str> = list.iter().map(|m| m.checkpoint_id.as_str()).collect();
assert_eq!(ids.len(), 2);
assert!(
list.iter()
.any(|m| m.namespace == vec!["branch_a".to_string()])
);
assert!(
list.iter()
.any(|m| m.namespace == vec!["branch_b".to_string()])
);
}
#[tokio::test]
async fn embedded_child_persists_under_parent_thread_and_child_namespace() {
let ckpt = Arc::new(InMemoryCheckpointer::<i32>::new());
let child = child_add_ten().with_checkpointer(ckpt.clone());
let parent = GraphBuilder::<i32, i32>::overwrite()
.add_node("child", shared_subgraph_node(child.clone()))
.set_entry("child")
.set_finish("child")
.compile()
.unwrap()
.with_checkpointer(ckpt.clone());
let run = parent.run_with_thread("t", 0).await.unwrap();
assert_eq!(run.state, 10);
let list = ckpt.list("t").await.unwrap();
assert_eq!(list.len(), 2);
assert!(list.iter().any(|m| m.namespace.is_empty()));
let child_meta = list
.iter()
.find(|m| m.namespace == vec!["child".to_string()])
.expect("child checkpoint is stored under the embedding namespace");
let child_scoped = namespaced(&child, &ctx_for("child"));
let child_state = child_scoped
.get_state("t", Some(&child_meta.checkpoint_id))
.await
.unwrap()
.expect("child checkpoint can be loaded from the parent thread");
assert_eq!(child_state.values, 10);
assert_eq!(child_state.next_nodes, Vec::<NodeId>::new());
}
#[tokio::test]
async fn subgraph_child_run_distinct_and_shares_root() {
let child = child_add_ten();
let parent = GraphBuilder::<i32, i32>::overwrite()
.add_node("child", shared_subgraph_node(child))
.set_entry("child")
.set_finish("child")
.compile()
.unwrap();
let run = parent.run(0).await.unwrap();
assert_eq!(run.state, 10);
assert_eq!(run.child_runs.len(), 1);
let child_run = &run.child_runs[0];
assert_eq!(child_run.node.as_str(), "child");
assert_ne!(child_run.run_id, run.run_id);
assert_eq!(child_run.root_run_id, run.run_id);
assert_eq!(run.root_run_id, run.run_id);
assert!(run.parent_run_id.is_none());
let tree = run.run_tree();
assert!(tree.is_root());
assert_eq!(tree.children.len(), 1);
assert_eq!(tree.children[0].run_id, child_run.run_id);
}
#[tokio::test]
async fn nested_subgraphs_produce_distinct_ids_sharing_one_root() {
let grandchild = child_add_ten();
let child = GraphBuilder::<i32, i32>::overwrite()
.add_node("grandchild", shared_subgraph_node(grandchild))
.set_entry("grandchild")
.set_finish("grandchild")
.compile()
.unwrap();
let parent = GraphBuilder::<i32, i32>::overwrite()
.add_node("child", shared_subgraph_node(child))
.set_entry("child")
.set_finish("child")
.compile()
.unwrap();
let run = parent.run(0).await.unwrap();
assert_eq!(run.state, 10);
assert_eq!(run.child_runs.len(), 1);
let child_run = &run.child_runs[0];
assert_eq!(child_run.node.as_str(), "child");
assert_eq!(child_run.root_run_id, run.run_id);
assert_ne!(child_run.run_id, run.run_id);
}
#[tokio::test]
async fn parent_frames_balanced_after_subgraph_returns() {
let parent = GraphBuilder::<i32, i32>::overwrite()
.add_node("a", shared_subgraph_node(child_add_ten()))
.add_node("b", shared_subgraph_node(child_add_ten()))
.set_entry("a")
.add_edge("a", "b")
.set_finish("b")
.compile()
.unwrap();
let run = parent.run(0).await.unwrap();
assert_eq!(run.state, 20);
assert_eq!(run.child_runs.len(), 2);
let nodes: Vec<&str> = run.child_runs.iter().map(|c| c.node.as_str()).collect();
assert_eq!(nodes, vec!["a", "b"]);
for c in &run.child_runs {
assert_eq!(c.root_run_id, run.run_id);
assert_ne!(c.run_id, run.run_id);
}
assert_ne!(run.child_runs[0].run_id, run.child_runs[1].run_id);
}
#[tokio::test]
async fn child_runs_recorded_in_checkpoint_metadata() {
let ckpt = Arc::new(InMemoryCheckpointer::<i32>::new());
let parent = GraphBuilder::<i32, i32>::overwrite()
.add_node("child", shared_subgraph_node(child_add_ten()))
.set_entry("child")
.set_finish("child")
.compile()
.unwrap()
.with_checkpointer(ckpt.clone());
let run = parent.run_with_thread("t", 0).await.unwrap();
assert_eq!(run.child_runs.len(), 1);
let list = ckpt.list("t").await.unwrap();
let mut found = false;
for meta in &list {
let checkpoint = ckpt
.get("t", Some(&meta.checkpoint_id))
.await
.unwrap()
.unwrap();
if checkpoint
.metadata
.get("child_runs")
.and_then(|v| v.as_array())
.is_some_and(|arr| {
arr.iter()
.any(|c| c.get("node").and_then(|n| n.as_str()) == Some("child"))
})
{
found = true;
break;
}
}
assert!(found, "child_runs not found in any checkpoint metadata");
}