Skip to main content

atomr_agents_callable/
decorators.rs

1//! Decorators over `Callable`.
2//!
3//! Each is itself a `Callable`, so they can be inserted anywhere a
4//! `CallableHandle` is expected — including as middleware around
5//! single steps in a `Pipeline`.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use atomr_agents_core::{AgentError, CallCtx, Result, Value};
12
13use crate::{Callable, CallableHandle, FnCallable};
14
15// --------------------------------------------------------------------
16// WithRetry
17// --------------------------------------------------------------------
18
19#[derive(Clone, Copy, Debug)]
20pub struct RetryPolicy {
21    pub max_attempts: u32,
22    pub initial_backoff: Duration,
23    pub backoff_multiplier: f32,
24    pub max_backoff: Duration,
25}
26
27impl Default for RetryPolicy {
28    fn default() -> Self {
29        Self {
30            max_attempts: 3,
31            initial_backoff: Duration::from_millis(50),
32            backoff_multiplier: 2.0,
33            max_backoff: Duration::from_secs(5),
34        }
35    }
36}
37
38pub struct WithRetry {
39    inner: CallableHandle,
40    policy: RetryPolicy,
41    label: String,
42}
43
44impl WithRetry {
45    pub fn new(inner: CallableHandle, policy: RetryPolicy) -> Self {
46        let label = format!("retry({})", inner.label());
47        Self { inner, policy, label }
48    }
49}
50
51pub fn with_retry(inner: CallableHandle, policy: RetryPolicy) -> CallableHandle {
52    Arc::new(WithRetry::new(inner, policy))
53}
54
55#[async_trait]
56impl Callable for WithRetry {
57    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
58        let mut delay = self.policy.initial_backoff;
59        let mut last_err: Option<AgentError> = None;
60        for attempt in 0..self.policy.max_attempts {
61            match self.inner.call(input.clone(), ctx.clone()).await {
62                Ok(v) => return Ok(v),
63                Err(e) => {
64                    last_err = Some(e);
65                    if attempt + 1 == self.policy.max_attempts {
66                        break;
67                    }
68                    tokio::time::sleep(delay).await;
69                    let next_ms = (delay.as_millis() as f32 * self.policy.backoff_multiplier) as u64;
70                    delay = Duration::from_millis(next_ms).min(self.policy.max_backoff);
71                }
72            }
73        }
74        Err(last_err.unwrap_or_else(|| AgentError::Internal("retry exhausted with no error".into())))
75    }
76
77    fn label(&self) -> &str {
78        &self.label
79    }
80}
81
82// --------------------------------------------------------------------
83// WithFallbacks
84// --------------------------------------------------------------------
85
86pub struct WithFallbacks {
87    primary: CallableHandle,
88    alternates: Vec<CallableHandle>,
89    label: String,
90}
91
92impl WithFallbacks {
93    pub fn new(primary: CallableHandle, alternates: Vec<CallableHandle>) -> Self {
94        let label = format!("fallback({})", primary.label());
95        Self {
96            primary,
97            alternates,
98            label,
99        }
100    }
101}
102
103pub fn with_fallbacks(primary: CallableHandle, alternates: Vec<CallableHandle>) -> CallableHandle {
104    Arc::new(WithFallbacks::new(primary, alternates))
105}
106
107#[async_trait]
108impl Callable for WithFallbacks {
109    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
110        if let Ok(v) = self.primary.call(input.clone(), ctx.clone()).await {
111            return Ok(v);
112        }
113        let mut last_err = None;
114        for alt in &self.alternates {
115            match alt.call(input.clone(), ctx.clone()).await {
116                Ok(v) => return Ok(v),
117                Err(e) => last_err = Some(e),
118            }
119        }
120        Err(last_err.unwrap_or_else(|| AgentError::Internal("fallbacks exhausted".into())))
121    }
122
123    fn label(&self) -> &str {
124        &self.label
125    }
126}
127
128// --------------------------------------------------------------------
129// WithConfig
130// --------------------------------------------------------------------
131
132#[derive(Clone, Default, Debug)]
133pub struct RunConfig {
134    pub run_name: Option<String>,
135    pub tags: Vec<String>,
136    /// Free-form metadata, JSON-encoded.
137    pub metadata: serde_json::Map<String, Value>,
138}
139
140pub struct WithConfig {
141    inner: CallableHandle,
142    config: RunConfig,
143    label: String,
144}
145
146impl WithConfig {
147    pub fn new(inner: CallableHandle, config: RunConfig) -> Self {
148        let label = config
149            .run_name
150            .clone()
151            .unwrap_or_else(|| format!("config({})", inner.label()));
152        Self { inner, config, label }
153    }
154}
155
156pub fn with_config(inner: CallableHandle, config: RunConfig) -> CallableHandle {
157    Arc::new(WithConfig::new(inner, config))
158}
159
160#[async_trait]
161impl Callable for WithConfig {
162    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
163        let mut ctx = ctx;
164        if let Some(name) = &self.config.run_name {
165            ctx.trace.push(format!("run:{name}"));
166        }
167        for t in &self.config.tags {
168            ctx.trace.push(format!("tag:{t}"));
169        }
170        self.inner.call(input, ctx).await
171    }
172
173    fn label(&self) -> &str {
174        &self.label
175    }
176}
177
178// --------------------------------------------------------------------
179// WithTimeout
180// --------------------------------------------------------------------
181
182pub struct WithTimeout {
183    inner: CallableHandle,
184    duration: Duration,
185    label: String,
186}
187
188impl WithTimeout {
189    pub fn new(inner: CallableHandle, duration: Duration) -> Self {
190        let label = format!("timeout({})", inner.label());
191        Self {
192            inner,
193            duration,
194            label,
195        }
196    }
197}
198
199pub fn with_timeout(inner: CallableHandle, duration: Duration) -> CallableHandle {
200    Arc::new(WithTimeout::new(inner, duration))
201}
202
203#[async_trait]
204impl Callable for WithTimeout {
205    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
206        match tokio::time::timeout(self.duration, self.inner.call(input, ctx)).await {
207            Ok(r) => r,
208            Err(_) => Err(AgentError::Internal(format!(
209                "timed out after {:?}",
210                self.duration
211            ))),
212        }
213    }
214
215    fn label(&self) -> &str {
216        &self.label
217    }
218}
219
220// --------------------------------------------------------------------
221// Branch — RunnableBranch analogue
222// --------------------------------------------------------------------
223
224pub struct Branch {
225    predicate: Arc<dyn Fn(&Value) -> bool + Send + Sync + 'static>,
226    if_true: CallableHandle,
227    if_false: CallableHandle,
228    label: String,
229}
230
231impl Branch {
232    pub fn new<F>(predicate: F, if_true: CallableHandle, if_false: CallableHandle) -> Self
233    where
234        F: Fn(&Value) -> bool + Send + Sync + 'static,
235    {
236        let label = format!("branch({} | {})", if_true.label(), if_false.label());
237        Self {
238            predicate: Arc::new(predicate),
239            if_true,
240            if_false,
241            label,
242        }
243    }
244}
245
246#[async_trait]
247impl Callable for Branch {
248    async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
249        if (self.predicate)(&input) {
250            self.if_true.call(input, ctx).await
251        } else {
252            self.if_false.call(input, ctx).await
253        }
254    }
255
256    fn label(&self) -> &str {
257        &self.label
258    }
259}
260
261// --------------------------------------------------------------------
262// Lambda — RunnableLambda alias
263// --------------------------------------------------------------------
264
265/// Type-alias for users who prefer the LangChain name.
266pub type Lambda<F> = FnCallable<F>;
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
272    use std::sync::atomic::{AtomicU32, Ordering};
273
274    fn ctx() -> CallCtx {
275        CallCtx {
276            agent_id: None,
277            tokens: TokenBudget::new(1000),
278            time: TimeBudget::new(Duration::from_secs(10)),
279            money: MoneyBudget::from_usd(1.0),
280            iterations: IterationBudget::new(10),
281            trace: vec![],
282        }
283    }
284
285    #[tokio::test]
286    async fn retry_succeeds_after_two_failures() {
287        let count = Arc::new(AtomicU32::new(0));
288        let count_clone = count.clone();
289        let flaky = Arc::new(FnCallable::labeled("flaky", move |v: Value, _ctx| {
290            let count = count_clone.clone();
291            async move {
292                let n = count.fetch_add(1, Ordering::SeqCst);
293                if n < 2 {
294                    Err(AgentError::Internal(format!("attempt {n} failed")))
295                } else {
296                    Ok(v)
297                }
298            }
299        }));
300        let retried = with_retry(
301            flaky,
302            RetryPolicy {
303                max_attempts: 5,
304                initial_backoff: Duration::from_millis(1),
305                backoff_multiplier: 1.0,
306                max_backoff: Duration::from_millis(1),
307            },
308        );
309        let out = retried.call(Value::from("ok"), ctx()).await.unwrap();
310        assert_eq!(out, Value::from("ok"));
311        assert_eq!(count.load(Ordering::SeqCst), 3);
312    }
313
314    #[tokio::test]
315    async fn retry_exhausts_then_errors() {
316        let always_fail = Arc::new(FnCallable::labeled("nope", |_v: Value, _ctx| async {
317            Err::<Value, _>(AgentError::Internal("boom".into()))
318        }));
319        let retried = with_retry(
320            always_fail,
321            RetryPolicy {
322                max_attempts: 2,
323                initial_backoff: Duration::from_millis(1),
324                backoff_multiplier: 1.0,
325                max_backoff: Duration::from_millis(1),
326            },
327        );
328        let r = retried.call(Value::Null, ctx()).await;
329        assert!(r.is_err());
330    }
331
332    #[tokio::test]
333    async fn fallback_uses_alternate_after_primary_failure() {
334        let primary = Arc::new(FnCallable::labeled("p", |_v: Value, _ctx| async {
335            Err::<Value, _>(AgentError::Inference("primary down".into()))
336        }));
337        let alt = Arc::new(FnCallable::labeled("alt", |v: Value, _ctx| async move { Ok(v) }));
338        let composed = with_fallbacks(primary, vec![alt]);
339        let out = composed.call(Value::from(42), ctx()).await.unwrap();
340        assert_eq!(out, Value::from(42));
341    }
342
343    #[tokio::test]
344    async fn timeout_fires() {
345        let slow = Arc::new(FnCallable::labeled("slow", |_v: Value, _ctx| async {
346            tokio::time::sleep(Duration::from_millis(50)).await;
347            Ok(Value::Null)
348        }));
349        let bounded = with_timeout(slow, Duration::from_millis(5));
350        let r = bounded.call(Value::Null, ctx()).await;
351        assert!(r.is_err());
352    }
353
354    #[tokio::test]
355    async fn config_pushes_run_name_and_tags_into_trace() {
356        let inner = Arc::new(FnCallable::labeled(
357            "inner",
358            |_v: Value, ctx: CallCtx| async move { Ok(Value::from(serde_json::json!({"trace": ctx.trace}))) },
359        ));
360        let configured = with_config(
361            inner,
362            RunConfig {
363                run_name: Some("my-run".into()),
364                tags: vec!["alpha".into(), "beta".into()],
365                metadata: Default::default(),
366            },
367        );
368        let out = configured.call(Value::Null, ctx()).await.unwrap();
369        let trace = out["trace"].as_array().unwrap();
370        let s: Vec<String> = trace.iter().map(|v| v.as_str().unwrap().to_string()).collect();
371        assert!(s.contains(&"run:my-run".to_string()));
372        assert!(s.contains(&"tag:alpha".to_string()));
373        assert!(s.contains(&"tag:beta".to_string()));
374    }
375
376    #[tokio::test]
377    async fn branch_routes_on_predicate() {
378        let big = Arc::new(FnCallable::labeled("big", |_v: Value, _ctx| async {
379            Ok(Value::from("big"))
380        }));
381        let small = Arc::new(FnCallable::labeled("small", |_v: Value, _ctx| async {
382            Ok(Value::from("small"))
383        }));
384        let b = Branch::new(|v: &Value| v.as_i64().unwrap_or(0) > 10, big, small);
385        assert_eq!(b.call(Value::from(42), ctx()).await.unwrap(), Value::from("big"));
386        assert_eq!(b.call(Value::from(1), ctx()).await.unwrap(), Value::from("small"));
387    }
388}