atomr_agents_workflow/
subgraph.rs1use std::sync::Arc;
14
15use async_trait::async_trait;
16use atomr_agents_callable::Callable;
17use atomr_agents_core::{AgentError, CallCtx, Result, RunId, Value, WorkflowId};
18use atomr_agents_state::{Checkpointer, RunState, StateSchema};
19
20use crate::dag::Dag;
21use crate::state_runner::{StatefulRunner, StatefulStep};
22
23pub struct Subgraph {
28 pub workflow_id: WorkflowId,
29 pub run_id: RunId,
30 pub dag: Dag<Arc<dyn StatefulStep>>,
31 pub schema: Arc<StateSchema>,
32 pub checkpointer: Arc<dyn Checkpointer>,
33 pub input_channels: Vec<String>,
34 pub output_channels: Vec<String>,
35}
36
37#[async_trait]
38impl Callable for Subgraph {
39 async fn call(&self, input: Value, _ctx: CallCtx) -> Result<Value> {
40 let parent_obj = match input {
43 Value::Object(m) => m,
44 other => {
45 return Err(AgentError::Workflow(format!(
46 "subgraph: expected object input, got {other}"
47 )));
48 }
49 };
50 let mut child_state = RunState::new(self.schema.clone());
51 let mut writes = Vec::new();
52 for k in &self.input_channels {
53 if let Some(v) = parent_obj.get(k) {
54 writes.push((k.clone(), v.clone()));
55 }
56 }
57 child_state.merge_writes(writes)?;
58
59 self.checkpointer
63 .save(atomr_agents_state::Snapshot {
64 key: atomr_agents_state::CheckpointKey {
65 workflow_id: self.workflow_id.clone(),
66 run_id: self.run_id.clone(),
67 super_step: 0,
68 },
69 values: child_state.snapshot(),
70 label: "subgraph-seed".into(),
71 timestamp_ms: now_ms(),
72 })
73 .await?;
74
75 let runner = StatefulRunner {
76 workflow_id: self.workflow_id.clone(),
77 run_id: self.run_id.clone(),
78 dag: clone_dag(&self.dag),
79 schema: self.schema.clone(),
80 checkpointer: self.checkpointer.clone(),
81 };
82 let final_state = runner.run().await?;
83 let mut outputs = serde_json::Map::new();
84 for k in &self.output_channels {
85 outputs.insert(k.clone(), final_state.read(k).clone());
86 }
87 Ok(serde_json::json!({
88 "outputs": Value::Object(outputs),
89 "private_state": final_state.snapshot(),
90 }))
91 }
92
93 fn label(&self) -> &str {
94 self.workflow_id.as_str()
95 }
96}
97
98fn clone_dag(d: &Dag<Arc<dyn StatefulStep>>) -> Dag<Arc<dyn StatefulStep>> {
99 Dag {
100 steps: d.steps.clone(),
101 edges: d.edges.clone(),
102 entry: d.entry.clone(),
103 }
104}
105
106fn now_ms() -> i64 {
107 use std::time::{SystemTime, UNIX_EPOCH};
108 SystemTime::now()
109 .duration_since(UNIX_EPOCH)
110 .map(|d| d.as_millis() as i64)
111 .unwrap_or(0)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::dag::Dag;
118 use crate::state_runner::FnStatefulStep;
119 use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
120 use atomr_agents_state::{AppendMessages, InMemoryCheckpointer, MergeMap, StateSchema};
121 use serde_json::json;
122 use std::time::Duration;
123
124 fn child_schema() -> Arc<StateSchema> {
125 Arc::new(
126 StateSchema::builder()
127 .add("messages", AppendMessages)
128 .add("notes", MergeMap)
129 .build(),
130 )
131 }
132
133 fn ctx() -> CallCtx {
134 CallCtx {
135 agent_id: None,
136 tokens: TokenBudget::new(1000),
137 time: TimeBudget::new(Duration::from_secs(5)),
138 money: MoneyBudget::from_usd(0.10),
139 iterations: IterationBudget::new(10),
140 trace: vec![],
141 }
142 }
143
144 fn child_step() -> Arc<dyn StatefulStep> {
145 Arc::new(FnStatefulStep(|s: &RunState| {
146 let n = s.read("messages").as_array().map(|v| v.len()).unwrap_or(0);
147 async move {
148 Ok(vec![
149 (
150 "messages".into(),
151 json!([{"id": format!("c-{n}"), "text": "child added"}]),
152 ),
153 ("notes".into(), json!({"child_saw": n})),
154 ])
155 }
156 }))
157 }
158
159 #[tokio::test]
160 async fn subgraph_projects_in_then_out() {
161 let dag: Dag<Arc<dyn StatefulStep>> = Dag::builder("a").step("a", child_step()).build();
162 let sub = Subgraph {
163 workflow_id: WorkflowId::from("child-wf"),
164 run_id: RunId::from("child-run"),
165 dag,
166 schema: child_schema(),
167 checkpointer: Arc::new(InMemoryCheckpointer::new()),
168 input_channels: vec!["messages".into()],
169 output_channels: vec!["notes".into()],
170 };
171 let parent_input = json!({
172 "messages": [{"id": "p-1", "text": "from parent"}],
173 "config": {"unrelated": true},
174 });
175 let out = sub.call(parent_input, ctx()).await.unwrap();
176 assert!(out["outputs"]["notes"]["child_saw"].is_number());
178 assert!(out["private_state"]["messages"].is_array());
180 }
181}