Skip to main content

atomr_agents_workflow/
subgraph.rs

1//! Subgraphs with shared channels.
2//!
3//! A `Subgraph` is a `StatefulRunner` packaged so a parent workflow
4//! can call it as a step. The parent declares two projection lists:
5//!
6//! - `input_channels`: keys read from the parent state and passed
7//!   into the child as initial values.
8//! - `output_channels`: keys read from the child's final state and
9//!   merged back into the parent through the parent's reducers.
10//!
11//! Channels not in either list are *private* to the child.
12
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use atomr_agents_callable::Callable;
17use atomr_agents_core::{AgentError, CallCtx, Result, RunId, Value, WorkflowId};
18use atomr_agents_state::{Checkpointer, RunState, StateSchema};
19
20use crate::dag::Dag;
21use crate::state_runner::{StatefulRunner, StatefulStep};
22
23/// Subgraph-as-callable. Returns a JSON object with two keys:
24/// `outputs` (the projected output channels) and `private_state`
25/// (the full child snapshot, included when callers want to inspect
26/// child-only channels).
27pub struct Subgraph {
28    pub workflow_id: WorkflowId,
29    pub run_id: RunId,
30    pub dag: Dag<Arc<dyn StatefulStep>>,
31    pub schema: Arc<StateSchema>,
32    pub checkpointer: Arc<dyn Checkpointer>,
33    pub input_channels: Vec<String>,
34    pub output_channels: Vec<String>,
35}
36
37#[async_trait]
38impl Callable for Subgraph {
39    async fn call(&self, input: Value, _ctx: CallCtx) -> Result<Value> {
40        // Build the child's RunState by projecting the parent input
41        // (a JSON object) through `input_channels`.
42        let parent_obj = match input {
43            Value::Object(m) => m,
44            other => {
45                return Err(AgentError::Workflow(format!(
46                    "subgraph: expected object input, got {other}"
47                )));
48            }
49        };
50        let mut child_state = RunState::new(self.schema.clone());
51        let mut writes = Vec::new();
52        for k in &self.input_channels {
53            if let Some(v) = parent_obj.get(k) {
54                writes.push((k.clone(), v.clone()));
55            }
56        }
57        child_state.merge_writes(writes)?;
58
59        // Persist the seeded state as super_step 0 so the runner
60        // resumes from there (i.e. it'll skip super_step 0 and start
61        // running the first DAG layer).
62        self.checkpointer
63            .save(atomr_agents_state::Snapshot {
64                key: atomr_agents_state::CheckpointKey {
65                    workflow_id: self.workflow_id.clone(),
66                    run_id: self.run_id.clone(),
67                    super_step: 0,
68                },
69                values: child_state.snapshot(),
70                label: "subgraph-seed".into(),
71                timestamp_ms: now_ms(),
72            })
73            .await?;
74
75        let runner = StatefulRunner {
76            workflow_id: self.workflow_id.clone(),
77            run_id: self.run_id.clone(),
78            dag: clone_dag(&self.dag),
79            schema: self.schema.clone(),
80            checkpointer: self.checkpointer.clone(),
81        };
82        let final_state = runner.run().await?;
83        let mut outputs = serde_json::Map::new();
84        for k in &self.output_channels {
85            outputs.insert(k.clone(), final_state.read(k).clone());
86        }
87        Ok(serde_json::json!({
88            "outputs": Value::Object(outputs),
89            "private_state": final_state.snapshot(),
90        }))
91    }
92
93    fn label(&self) -> &str {
94        self.workflow_id.as_str()
95    }
96}
97
98fn clone_dag(d: &Dag<Arc<dyn StatefulStep>>) -> Dag<Arc<dyn StatefulStep>> {
99    Dag {
100        steps: d.steps.clone(),
101        edges: d.edges.clone(),
102        entry: d.entry.clone(),
103    }
104}
105
106fn now_ms() -> i64 {
107    use std::time::{SystemTime, UNIX_EPOCH};
108    SystemTime::now()
109        .duration_since(UNIX_EPOCH)
110        .map(|d| d.as_millis() as i64)
111        .unwrap_or(0)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::dag::Dag;
118    use crate::state_runner::FnStatefulStep;
119    use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
120    use atomr_agents_state::{AppendMessages, InMemoryCheckpointer, MergeMap, StateSchema};
121    use serde_json::json;
122    use std::time::Duration;
123
124    fn child_schema() -> Arc<StateSchema> {
125        Arc::new(
126            StateSchema::builder()
127                .add("messages", AppendMessages)
128                .add("notes", MergeMap)
129                .build(),
130        )
131    }
132
133    fn ctx() -> CallCtx {
134        CallCtx {
135            agent_id: None,
136            tokens: TokenBudget::new(1000),
137            time: TimeBudget::new(Duration::from_secs(5)),
138            money: MoneyBudget::from_usd(0.10),
139            iterations: IterationBudget::new(10),
140            trace: vec![],
141        }
142    }
143
144    fn child_step() -> Arc<dyn StatefulStep> {
145        Arc::new(FnStatefulStep(|s: &RunState| {
146            let n = s.read("messages").as_array().map(|v| v.len()).unwrap_or(0);
147            async move {
148                Ok(vec![
149                    (
150                        "messages".into(),
151                        json!([{"id": format!("c-{n}"), "text": "child added"}]),
152                    ),
153                    ("notes".into(), json!({"child_saw": n})),
154                ])
155            }
156        }))
157    }
158
159    #[tokio::test]
160    async fn subgraph_projects_in_then_out() {
161        let dag: Dag<Arc<dyn StatefulStep>> = Dag::builder("a").step("a", child_step()).build();
162        let sub = Subgraph {
163            workflow_id: WorkflowId::from("child-wf"),
164            run_id: RunId::from("child-run"),
165            dag,
166            schema: child_schema(),
167            checkpointer: Arc::new(InMemoryCheckpointer::new()),
168            input_channels: vec!["messages".into()],
169            output_channels: vec!["notes".into()],
170        };
171        let parent_input = json!({
172            "messages": [{"id": "p-1", "text": "from parent"}],
173            "config": {"unrelated": true},
174        });
175        let out = sub.call(parent_input, ctx()).await.unwrap();
176        // Output projected.
177        assert!(out["outputs"]["notes"]["child_saw"].is_number());
178        // private_state contains messages too (full child snapshot).
179        assert!(out["private_state"]["messages"].is_array());
180    }
181}