Skip to main content

atomr_agents_callable/
pipeline.rs

1//! `Pipeline` — composable builder over `Callable`s.
2//!
3//! LCEL's `prompt | model | parser` becomes
4//! `Pipeline::from(prompt).then(model).then(parser).build()`.
5//! Fan-out becomes `.fan_out({"a": ca, "b": cb})`. Every result is
6//! itself a `Callable`, so pipelines compose recursively.
7
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use atomr_agents_core::{AgentError, CallCtx, Result, Value};
13
14use crate::{Callable, CallableHandle};
15
16#[derive(Clone, Copy, Debug)]
17enum StageKind {
18    Sequential,
19    /// Adds a key to the input dict; the original input is preserved.
20    Assign,
21}
22
23#[derive(Clone)]
24struct Stage {
25    kind: StageKind,
26    /// Used by `Assign`/`Passthrough`.
27    key: Option<String>,
28    callable: CallableHandle,
29}
30
31/// Builder over a sequence of `Callable`s.
32pub struct Pipeline {
33    stages: Vec<Stage>,
34    label: String,
35}
36
37impl Pipeline {
38    pub fn from(c: CallableHandle) -> Self {
39        let label = c.label().to_string();
40        Self {
41            stages: vec![Stage {
42                kind: StageKind::Sequential,
43                key: None,
44                callable: c,
45            }],
46            label,
47        }
48    }
49
50    /// `prompt | model` — chain another stage.
51    pub fn then(mut self, c: CallableHandle) -> Self {
52        self.label = format!("{} | {}", self.label, c.label());
53        self.stages.push(Stage {
54            kind: StageKind::Sequential,
55            key: None,
56            callable: c,
57        });
58        self
59    }
60
61    /// Pass input through unchanged. Useful as a starting node.
62    pub fn passthrough(self) -> Self {
63        let identity: CallableHandle = Arc::new(crate::FnCallable::labeled(
64            "passthrough",
65            |v: Value, _ctx| async move { Ok(v) },
66        ));
67        if self.stages.is_empty() {
68            return Pipeline::from(identity);
69        }
70        self.then(identity)
71    }
72
73    /// `RunnablePassthrough.assign(key=fn)` — run `c` on the *current*
74    /// input and add the result under `key`, leaving original input
75    /// fields intact.
76    pub fn assign(mut self, key: impl Into<String>, c: CallableHandle) -> Self {
77        let key = key.into();
78        self.label = format!("{}.assign({})", self.label, key);
79        self.stages.push(Stage {
80            kind: StageKind::Assign,
81            key: Some(key),
82            callable: c,
83        });
84        self
85    }
86
87    /// `RunnableParallel({a: ca, b: cb})` as an inline stage. Each
88    /// branch runs concurrently on the current input; output is a
89    /// JSON object keyed by branch name.
90    pub fn fan_out_with(mut self, branches: Vec<(String, CallableHandle)>) -> Self {
91        let names: Vec<&str> = branches.iter().map(|(k, _)| k.as_str()).collect();
92        let label = format!("{} | fan_out({})", self.label, names.join(","));
93        let stage_callable = FanOutCallable::new(branches);
94        let handle: CallableHandle = Arc::new(stage_callable);
95        self.label = label;
96        self.stages.push(Stage {
97            kind: StageKind::Sequential,
98            key: None,
99            callable: handle,
100        });
101        self
102    }
103
104    pub fn build(self) -> CallableHandle {
105        Arc::new(BuiltPipeline {
106            stages: self.stages,
107            label: self.label,
108        })
109    }
110}
111
112struct BuiltPipeline {
113    stages: Vec<Stage>,
114    label: String,
115}
116
117#[async_trait]
118impl Callable for BuiltPipeline {
119    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
120        let mut current = input;
121        for stage in &self.stages {
122            match stage.kind {
123                StageKind::Sequential => {
124                    current = stage.callable.call(current, ctx.clone()).await?;
125                }
126                StageKind::Assign => {
127                    let key = stage
128                        .key
129                        .as_ref()
130                        .ok_or_else(|| AgentError::Internal("assign without key".into()))?;
131                    let derived = stage.callable.call(current.clone(), ctx.clone()).await?;
132                    let mut obj = match current {
133                        Value::Object(m) => m,
134                        other => {
135                            let mut m = serde_json::Map::new();
136                            m.insert("input".into(), other);
137                            m
138                        }
139                    };
140                    obj.insert(key.clone(), derived);
141                    current = Value::Object(obj);
142                }
143            }
144        }
145        Ok(current)
146    }
147
148    fn label(&self) -> &str {
149        &self.label
150    }
151}
152
153/// Standalone fan-out factory. `fan_out([("a", ca), ("b", cb)])` —
154/// equivalent to LCEL's `RunnableParallel` outside of a pipeline.
155pub fn fan_out(branches: Vec<(String, CallableHandle)>) -> CallableHandle {
156    Arc::new(FanOutCallable::new(branches))
157}
158
159struct FanOutCallable {
160    branches: BTreeMap<String, CallableHandle>,
161    label: String,
162}
163
164impl FanOutCallable {
165    fn new(branches: Vec<(String, CallableHandle)>) -> Self {
166        let label = format!(
167            "fan_out({})",
168            branches
169                .iter()
170                .map(|(k, _)| k.as_str())
171                .collect::<Vec<_>>()
172                .join(",")
173        );
174        Self {
175            branches: branches.into_iter().collect(),
176            label,
177        }
178    }
179}
180
181#[async_trait]
182impl Callable for FanOutCallable {
183    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
184        let mut handles = Vec::with_capacity(self.branches.len());
185        for (k, c) in &self.branches {
186            let k = k.clone();
187            let c = c.clone();
188            let inp = input.clone();
189            let ctx = ctx.clone();
190            handles.push(tokio::spawn(async move {
191                let out = c.call(inp, ctx).await?;
192                Ok::<_, AgentError>((k, out))
193            }));
194        }
195        let mut out = serde_json::Map::new();
196        for h in handles {
197            let (k, v) = h.await.map_err(|e| AgentError::Internal(e.to_string()))??;
198            out.insert(k, v);
199        }
200        Ok(Value::Object(out))
201    }
202
203    fn label(&self) -> &str {
204        &self.label
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::FnCallable;
212    use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
213    use std::time::Duration;
214
215    fn ctx() -> CallCtx {
216        CallCtx {
217            agent_id: None,
218            tokens: TokenBudget::new(1000),
219            time: TimeBudget::new(Duration::from_secs(10)),
220            money: MoneyBudget::from_usd(1.0),
221            iterations: IterationBudget::new(10),
222            trace: vec![],
223        }
224    }
225
226    fn echo(label: &'static str) -> CallableHandle {
227        Arc::new(FnCallable::labeled(label, |v: Value, _ctx| async move { Ok(v) }))
228    }
229
230    fn append_str(label: &'static str, suffix: &'static str) -> CallableHandle {
231        Arc::new(FnCallable::labeled(label, move |v: Value, _ctx| async move {
232            let s = v.as_str().unwrap_or("").to_string() + suffix;
233            Ok(Value::String(s))
234        }))
235    }
236
237    #[tokio::test]
238    async fn pipeline_then_chains_sequentially() {
239        let p = Pipeline::from(append_str("a", "A"))
240            .then(append_str("b", "B"))
241            .then(append_str("c", "C"))
242            .build();
243        let out = p.call(Value::String(String::new()), ctx()).await.unwrap();
244        assert_eq!(out, Value::String("ABC".into()));
245    }
246
247    #[tokio::test]
248    async fn fan_out_runs_branches_in_parallel() {
249        let p = Pipeline::from(echo("seed"))
250            .fan_out_with(vec![
251                ("upper".into(), append_str("u", "U")),
252                ("lower".into(), append_str("l", "l")),
253            ])
254            .build();
255        let out = p.call(Value::String("x".into()), ctx()).await.unwrap();
256        assert_eq!(out["upper"], Value::String("xU".into()));
257        assert_eq!(out["lower"], Value::String("xl".into()));
258    }
259
260    #[tokio::test]
261    async fn assign_adds_key_keeping_input_fields() {
262        let derive = Arc::new(FnCallable::labeled("len", |v: Value, _ctx| async move {
263            let n = v.as_object().map(|m| m.len()).unwrap_or(0);
264            Ok(Value::from(n))
265        }));
266        let p = Pipeline::from(echo("seed")).assign("size", derive).build();
267        let out = p.call(serde_json::json!({"a": 1, "b": 2}), ctx()).await.unwrap();
268        assert_eq!(out["a"], Value::from(1));
269        assert_eq!(out["size"], Value::from(2));
270    }
271}