1use std::collections::BTreeMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use atomr_agents_core::{AgentError, CallCtx, Result, Value};
13
14use crate::{Callable, CallableHandle};
15
16#[derive(Clone, Copy, Debug)]
17enum StageKind {
18 Sequential,
19 Assign,
21}
22
23#[derive(Clone)]
24struct Stage {
25 kind: StageKind,
26 key: Option<String>,
28 callable: CallableHandle,
29}
30
31pub struct Pipeline {
33 stages: Vec<Stage>,
34 label: String,
35}
36
37impl Pipeline {
38 pub fn from(c: CallableHandle) -> Self {
39 let label = c.label().to_string();
40 Self {
41 stages: vec![Stage {
42 kind: StageKind::Sequential,
43 key: None,
44 callable: c,
45 }],
46 label,
47 }
48 }
49
50 pub fn then(mut self, c: CallableHandle) -> Self {
52 self.label = format!("{} | {}", self.label, c.label());
53 self.stages.push(Stage {
54 kind: StageKind::Sequential,
55 key: None,
56 callable: c,
57 });
58 self
59 }
60
61 pub fn passthrough(self) -> Self {
63 let identity: CallableHandle = Arc::new(crate::FnCallable::labeled(
64 "passthrough",
65 |v: Value, _ctx| async move { Ok(v) },
66 ));
67 if self.stages.is_empty() {
68 return Pipeline::from(identity);
69 }
70 self.then(identity)
71 }
72
73 pub fn assign(mut self, key: impl Into<String>, c: CallableHandle) -> Self {
77 let key = key.into();
78 self.label = format!("{}.assign({})", self.label, key);
79 self.stages.push(Stage {
80 kind: StageKind::Assign,
81 key: Some(key),
82 callable: c,
83 });
84 self
85 }
86
87 pub fn fan_out_with(mut self, branches: Vec<(String, CallableHandle)>) -> Self {
91 let names: Vec<&str> = branches.iter().map(|(k, _)| k.as_str()).collect();
92 let label = format!("{} | fan_out({})", self.label, names.join(","));
93 let stage_callable = FanOutCallable::new(branches);
94 let handle: CallableHandle = Arc::new(stage_callable);
95 self.label = label;
96 self.stages.push(Stage {
97 kind: StageKind::Sequential,
98 key: None,
99 callable: handle,
100 });
101 self
102 }
103
104 pub fn build(self) -> CallableHandle {
105 Arc::new(BuiltPipeline {
106 stages: self.stages,
107 label: self.label,
108 })
109 }
110}
111
112struct BuiltPipeline {
113 stages: Vec<Stage>,
114 label: String,
115}
116
117#[async_trait]
118impl Callable for BuiltPipeline {
119 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
120 let mut current = input;
121 for stage in &self.stages {
122 match stage.kind {
123 StageKind::Sequential => {
124 current = stage.callable.call(current, ctx.clone()).await?;
125 }
126 StageKind::Assign => {
127 let key = stage
128 .key
129 .as_ref()
130 .ok_or_else(|| AgentError::Internal("assign without key".into()))?;
131 let derived = stage.callable.call(current.clone(), ctx.clone()).await?;
132 let mut obj = match current {
133 Value::Object(m) => m,
134 other => {
135 let mut m = serde_json::Map::new();
136 m.insert("input".into(), other);
137 m
138 }
139 };
140 obj.insert(key.clone(), derived);
141 current = Value::Object(obj);
142 }
143 }
144 }
145 Ok(current)
146 }
147
148 fn label(&self) -> &str {
149 &self.label
150 }
151}
152
153pub fn fan_out(branches: Vec<(String, CallableHandle)>) -> CallableHandle {
156 Arc::new(FanOutCallable::new(branches))
157}
158
159struct FanOutCallable {
160 branches: BTreeMap<String, CallableHandle>,
161 label: String,
162}
163
164impl FanOutCallable {
165 fn new(branches: Vec<(String, CallableHandle)>) -> Self {
166 let label = format!(
167 "fan_out({})",
168 branches
169 .iter()
170 .map(|(k, _)| k.as_str())
171 .collect::<Vec<_>>()
172 .join(",")
173 );
174 Self {
175 branches: branches.into_iter().collect(),
176 label,
177 }
178 }
179}
180
181#[async_trait]
182impl Callable for FanOutCallable {
183 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
184 let mut handles = Vec::with_capacity(self.branches.len());
185 for (k, c) in &self.branches {
186 let k = k.clone();
187 let c = c.clone();
188 let inp = input.clone();
189 let ctx = ctx.clone();
190 handles.push(tokio::spawn(async move {
191 let out = c.call(inp, ctx).await?;
192 Ok::<_, AgentError>((k, out))
193 }));
194 }
195 let mut out = serde_json::Map::new();
196 for h in handles {
197 let (k, v) = h.await.map_err(|e| AgentError::Internal(e.to_string()))??;
198 out.insert(k, v);
199 }
200 Ok(Value::Object(out))
201 }
202
203 fn label(&self) -> &str {
204 &self.label
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::FnCallable;
212 use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
213 use std::time::Duration;
214
215 fn ctx() -> CallCtx {
216 CallCtx {
217 agent_id: None,
218 tokens: TokenBudget::new(1000),
219 time: TimeBudget::new(Duration::from_secs(10)),
220 money: MoneyBudget::from_usd(1.0),
221 iterations: IterationBudget::new(10),
222 trace: vec![],
223 }
224 }
225
226 fn echo(label: &'static str) -> CallableHandle {
227 Arc::new(FnCallable::labeled(label, |v: Value, _ctx| async move { Ok(v) }))
228 }
229
230 fn append_str(label: &'static str, suffix: &'static str) -> CallableHandle {
231 Arc::new(FnCallable::labeled(label, move |v: Value, _ctx| async move {
232 let s = v.as_str().unwrap_or("").to_string() + suffix;
233 Ok(Value::String(s))
234 }))
235 }
236
237 #[tokio::test]
238 async fn pipeline_then_chains_sequentially() {
239 let p = Pipeline::from(append_str("a", "A"))
240 .then(append_str("b", "B"))
241 .then(append_str("c", "C"))
242 .build();
243 let out = p.call(Value::String(String::new()), ctx()).await.unwrap();
244 assert_eq!(out, Value::String("ABC".into()));
245 }
246
247 #[tokio::test]
248 async fn fan_out_runs_branches_in_parallel() {
249 let p = Pipeline::from(echo("seed"))
250 .fan_out_with(vec![
251 ("upper".into(), append_str("u", "U")),
252 ("lower".into(), append_str("l", "l")),
253 ])
254 .build();
255 let out = p.call(Value::String("x".into()), ctx()).await.unwrap();
256 assert_eq!(out["upper"], Value::String("xU".into()));
257 assert_eq!(out["lower"], Value::String("xl".into()));
258 }
259
260 #[tokio::test]
261 async fn assign_adds_key_keeping_input_fields() {
262 let derive = Arc::new(FnCallable::labeled("len", |v: Value, _ctx| async move {
263 let n = v.as_object().map(|m| m.len()).unwrap_or(0);
264 Ok(Value::from(n))
265 }));
266 let p = Pipeline::from(echo("seed")).assign("size", derive).build();
267 let out = p.call(serde_json::json!({"a": 1, "b": 2}), ctx()).await.unwrap();
268 assert_eq!(out["a"], Value::from(1));
269 assert_eq!(out["size"], Value::from(2));
270 }
271}