use std::sync::Arc;
use async_trait::async_trait;
use atomr_agents_callable::Callable;
use atomr_agents_core::{AgentError, CallCtx, Result, RunId, Value, WorkflowId};
use atomr_agents_state::{Checkpointer, RunState, StateSchema};
use crate::dag::Dag;
use crate::state_runner::{StatefulRunner, StatefulStep};
pub struct Subgraph {
pub workflow_id: WorkflowId,
pub run_id: RunId,
pub dag: Dag<Arc<dyn StatefulStep>>,
pub schema: Arc<StateSchema>,
pub checkpointer: Arc<dyn Checkpointer>,
pub input_channels: Vec<String>,
pub output_channels: Vec<String>,
}
#[async_trait]
impl Callable for Subgraph {
async fn call(&self, input: Value, _ctx: CallCtx) -> Result<Value> {
let parent_obj = match input {
Value::Object(m) => m,
other => {
return Err(AgentError::Workflow(format!(
"subgraph: expected object input, got {other}"
)));
}
};
let mut child_state = RunState::new(self.schema.clone());
let mut writes = Vec::new();
for k in &self.input_channels {
if let Some(v) = parent_obj.get(k) {
writes.push((k.clone(), v.clone()));
}
}
child_state.merge_writes(writes)?;
self.checkpointer
.save(atomr_agents_state::Snapshot {
key: atomr_agents_state::CheckpointKey {
workflow_id: self.workflow_id.clone(),
run_id: self.run_id.clone(),
super_step: 0,
},
values: child_state.snapshot(),
label: "subgraph-seed".into(),
timestamp_ms: now_ms(),
})
.await?;
let runner = StatefulRunner {
workflow_id: self.workflow_id.clone(),
run_id: self.run_id.clone(),
dag: clone_dag(&self.dag),
schema: self.schema.clone(),
checkpointer: self.checkpointer.clone(),
};
let final_state = runner.run().await?;
let mut outputs = serde_json::Map::new();
for k in &self.output_channels {
outputs.insert(k.clone(), final_state.read(k).clone());
}
Ok(serde_json::json!({
"outputs": Value::Object(outputs),
"private_state": final_state.snapshot(),
}))
}
fn label(&self) -> &str {
self.workflow_id.as_str()
}
}
fn clone_dag(d: &Dag<Arc<dyn StatefulStep>>) -> Dag<Arc<dyn StatefulStep>> {
Dag {
steps: d.steps.clone(),
edges: d.edges.clone(),
entry: d.entry.clone(),
}
}
fn now_ms() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::Dag;
use crate::state_runner::FnStatefulStep;
use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
use atomr_agents_state::{AppendMessages, InMemoryCheckpointer, MergeMap, StateSchema};
use serde_json::json;
use std::time::Duration;
fn child_schema() -> Arc<StateSchema> {
Arc::new(
StateSchema::builder()
.add("messages", AppendMessages)
.add("notes", MergeMap)
.build(),
)
}
fn ctx() -> CallCtx {
CallCtx {
agent_id: None,
tokens: TokenBudget::new(1000),
time: TimeBudget::new(Duration::from_secs(5)),
money: MoneyBudget::from_usd(0.10),
iterations: IterationBudget::new(10),
trace: vec![],
}
}
fn child_step() -> Arc<dyn StatefulStep> {
Arc::new(FnStatefulStep(|s: &RunState| {
let n = s.read("messages").as_array().map(|v| v.len()).unwrap_or(0);
async move {
Ok(vec![
(
"messages".into(),
json!([{"id": format!("c-{n}"), "text": "child added"}]),
),
("notes".into(), json!({"child_saw": n})),
])
}
}))
}
#[tokio::test]
async fn subgraph_projects_in_then_out() {
let dag: Dag<Arc<dyn StatefulStep>> = Dag::builder("a").step("a", child_step()).build();
let sub = Subgraph {
workflow_id: WorkflowId::from("child-wf"),
run_id: RunId::from("child-run"),
dag,
schema: child_schema(),
checkpointer: Arc::new(InMemoryCheckpointer::new()),
input_channels: vec!["messages".into()],
output_channels: vec!["notes".into()],
};
let parent_input = json!({
"messages": [{"id": "p-1", "text": "from parent"}],
"config": {"unrelated": true},
});
let out = sub.call(parent_input, ctx()).await.unwrap();
assert!(out["outputs"]["notes"]["child_saw"].is_number());
assert!(out["private_state"]["messages"].is_array());
}
}