1use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use atomr_agents_core::{AgentError, CallCtx, Result, Value};
12
13use crate::{Callable, CallableHandle, FnCallable};
14
15#[derive(Clone, Copy, Debug)]
20pub struct RetryPolicy {
21 pub max_attempts: u32,
22 pub initial_backoff: Duration,
23 pub backoff_multiplier: f32,
24 pub max_backoff: Duration,
25}
26
27impl Default for RetryPolicy {
28 fn default() -> Self {
29 Self {
30 max_attempts: 3,
31 initial_backoff: Duration::from_millis(50),
32 backoff_multiplier: 2.0,
33 max_backoff: Duration::from_secs(5),
34 }
35 }
36}
37
38pub struct WithRetry {
39 inner: CallableHandle,
40 policy: RetryPolicy,
41 label: String,
42}
43
44impl WithRetry {
45 pub fn new(inner: CallableHandle, policy: RetryPolicy) -> Self {
46 let label = format!("retry({})", inner.label());
47 Self { inner, policy, label }
48 }
49}
50
51pub fn with_retry(inner: CallableHandle, policy: RetryPolicy) -> CallableHandle {
52 Arc::new(WithRetry::new(inner, policy))
53}
54
55#[async_trait]
56impl Callable for WithRetry {
57 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
58 let mut delay = self.policy.initial_backoff;
59 let mut last_err: Option<AgentError> = None;
60 for attempt in 0..self.policy.max_attempts {
61 match self.inner.call(input.clone(), ctx.clone()).await {
62 Ok(v) => return Ok(v),
63 Err(e) => {
64 last_err = Some(e);
65 if attempt + 1 == self.policy.max_attempts {
66 break;
67 }
68 tokio::time::sleep(delay).await;
69 let next_ms = (delay.as_millis() as f32 * self.policy.backoff_multiplier) as u64;
70 delay = Duration::from_millis(next_ms).min(self.policy.max_backoff);
71 }
72 }
73 }
74 Err(last_err.unwrap_or_else(|| AgentError::Internal("retry exhausted with no error".into())))
75 }
76
77 fn label(&self) -> &str {
78 &self.label
79 }
80}
81
82pub struct WithFallbacks {
87 primary: CallableHandle,
88 alternates: Vec<CallableHandle>,
89 label: String,
90}
91
92impl WithFallbacks {
93 pub fn new(primary: CallableHandle, alternates: Vec<CallableHandle>) -> Self {
94 let label = format!("fallback({})", primary.label());
95 Self {
96 primary,
97 alternates,
98 label,
99 }
100 }
101}
102
103pub fn with_fallbacks(primary: CallableHandle, alternates: Vec<CallableHandle>) -> CallableHandle {
104 Arc::new(WithFallbacks::new(primary, alternates))
105}
106
107#[async_trait]
108impl Callable for WithFallbacks {
109 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
110 if let Ok(v) = self.primary.call(input.clone(), ctx.clone()).await {
111 return Ok(v);
112 }
113 let mut last_err = None;
114 for alt in &self.alternates {
115 match alt.call(input.clone(), ctx.clone()).await {
116 Ok(v) => return Ok(v),
117 Err(e) => last_err = Some(e),
118 }
119 }
120 Err(last_err.unwrap_or_else(|| AgentError::Internal("fallbacks exhausted".into())))
121 }
122
123 fn label(&self) -> &str {
124 &self.label
125 }
126}
127
128#[derive(Clone, Default, Debug)]
133pub struct RunConfig {
134 pub run_name: Option<String>,
135 pub tags: Vec<String>,
136 pub metadata: serde_json::Map<String, Value>,
138}
139
140pub struct WithConfig {
141 inner: CallableHandle,
142 config: RunConfig,
143 label: String,
144}
145
146impl WithConfig {
147 pub fn new(inner: CallableHandle, config: RunConfig) -> Self {
148 let label = config
149 .run_name
150 .clone()
151 .unwrap_or_else(|| format!("config({})", inner.label()));
152 Self { inner, config, label }
153 }
154}
155
156pub fn with_config(inner: CallableHandle, config: RunConfig) -> CallableHandle {
157 Arc::new(WithConfig::new(inner, config))
158}
159
160#[async_trait]
161impl Callable for WithConfig {
162 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
163 let mut ctx = ctx;
164 if let Some(name) = &self.config.run_name {
165 ctx.trace.push(format!("run:{name}"));
166 }
167 for t in &self.config.tags {
168 ctx.trace.push(format!("tag:{t}"));
169 }
170 self.inner.call(input, ctx).await
171 }
172
173 fn label(&self) -> &str {
174 &self.label
175 }
176}
177
178pub struct WithTimeout {
183 inner: CallableHandle,
184 duration: Duration,
185 label: String,
186}
187
188impl WithTimeout {
189 pub fn new(inner: CallableHandle, duration: Duration) -> Self {
190 let label = format!("timeout({})", inner.label());
191 Self {
192 inner,
193 duration,
194 label,
195 }
196 }
197}
198
199pub fn with_timeout(inner: CallableHandle, duration: Duration) -> CallableHandle {
200 Arc::new(WithTimeout::new(inner, duration))
201}
202
203#[async_trait]
204impl Callable for WithTimeout {
205 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
206 match tokio::time::timeout(self.duration, self.inner.call(input, ctx)).await {
207 Ok(r) => r,
208 Err(_) => Err(AgentError::Internal(format!(
209 "timed out after {:?}",
210 self.duration
211 ))),
212 }
213 }
214
215 fn label(&self) -> &str {
216 &self.label
217 }
218}
219
220pub struct Branch {
225 predicate: Arc<dyn Fn(&Value) -> bool + Send + Sync + 'static>,
226 if_true: CallableHandle,
227 if_false: CallableHandle,
228 label: String,
229}
230
231impl Branch {
232 pub fn new<F>(predicate: F, if_true: CallableHandle, if_false: CallableHandle) -> Self
233 where
234 F: Fn(&Value) -> bool + Send + Sync + 'static,
235 {
236 let label = format!("branch({} | {})", if_true.label(), if_false.label());
237 Self {
238 predicate: Arc::new(predicate),
239 if_true,
240 if_false,
241 label,
242 }
243 }
244}
245
246#[async_trait]
247impl Callable for Branch {
248 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
249 if (self.predicate)(&input) {
250 self.if_true.call(input, ctx).await
251 } else {
252 self.if_false.call(input, ctx).await
253 }
254 }
255
256 fn label(&self) -> &str {
257 &self.label
258 }
259}
260
261pub type Lambda<F> = FnCallable<F>;
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
272 use std::sync::atomic::{AtomicU32, Ordering};
273
274 fn ctx() -> CallCtx {
275 CallCtx {
276 agent_id: None,
277 tokens: TokenBudget::new(1000),
278 time: TimeBudget::new(Duration::from_secs(10)),
279 money: MoneyBudget::from_usd(1.0),
280 iterations: IterationBudget::new(10),
281 trace: vec![],
282 }
283 }
284
285 #[tokio::test]
286 async fn retry_succeeds_after_two_failures() {
287 let count = Arc::new(AtomicU32::new(0));
288 let count_clone = count.clone();
289 let flaky = Arc::new(FnCallable::labeled("flaky", move |v: Value, _ctx| {
290 let count = count_clone.clone();
291 async move {
292 let n = count.fetch_add(1, Ordering::SeqCst);
293 if n < 2 {
294 Err(AgentError::Internal(format!("attempt {n} failed")))
295 } else {
296 Ok(v)
297 }
298 }
299 }));
300 let retried = with_retry(
301 flaky,
302 RetryPolicy {
303 max_attempts: 5,
304 initial_backoff: Duration::from_millis(1),
305 backoff_multiplier: 1.0,
306 max_backoff: Duration::from_millis(1),
307 },
308 );
309 let out = retried.call(Value::from("ok"), ctx()).await.unwrap();
310 assert_eq!(out, Value::from("ok"));
311 assert_eq!(count.load(Ordering::SeqCst), 3);
312 }
313
314 #[tokio::test]
315 async fn retry_exhausts_then_errors() {
316 let always_fail = Arc::new(FnCallable::labeled("nope", |_v: Value, _ctx| async {
317 Err::<Value, _>(AgentError::Internal("boom".into()))
318 }));
319 let retried = with_retry(
320 always_fail,
321 RetryPolicy {
322 max_attempts: 2,
323 initial_backoff: Duration::from_millis(1),
324 backoff_multiplier: 1.0,
325 max_backoff: Duration::from_millis(1),
326 },
327 );
328 let r = retried.call(Value::Null, ctx()).await;
329 assert!(r.is_err());
330 }
331
332 #[tokio::test]
333 async fn fallback_uses_alternate_after_primary_failure() {
334 let primary = Arc::new(FnCallable::labeled("p", |_v: Value, _ctx| async {
335 Err::<Value, _>(AgentError::Inference("primary down".into()))
336 }));
337 let alt = Arc::new(FnCallable::labeled("alt", |v: Value, _ctx| async move { Ok(v) }));
338 let composed = with_fallbacks(primary, vec![alt]);
339 let out = composed.call(Value::from(42), ctx()).await.unwrap();
340 assert_eq!(out, Value::from(42));
341 }
342
343 #[tokio::test]
344 async fn timeout_fires() {
345 let slow = Arc::new(FnCallable::labeled("slow", |_v: Value, _ctx| async {
346 tokio::time::sleep(Duration::from_millis(50)).await;
347 Ok(Value::Null)
348 }));
349 let bounded = with_timeout(slow, Duration::from_millis(5));
350 let r = bounded.call(Value::Null, ctx()).await;
351 assert!(r.is_err());
352 }
353
354 #[tokio::test]
355 async fn config_pushes_run_name_and_tags_into_trace() {
356 let inner = Arc::new(FnCallable::labeled(
357 "inner",
358 |_v: Value, ctx: CallCtx| async move { Ok(Value::from(serde_json::json!({"trace": ctx.trace}))) },
359 ));
360 let configured = with_config(
361 inner,
362 RunConfig {
363 run_name: Some("my-run".into()),
364 tags: vec!["alpha".into(), "beta".into()],
365 metadata: Default::default(),
366 },
367 );
368 let out = configured.call(Value::Null, ctx()).await.unwrap();
369 let trace = out["trace"].as_array().unwrap();
370 let s: Vec<String> = trace.iter().map(|v| v.as_str().unwrap().to_string()).collect();
371 assert!(s.contains(&"run:my-run".to_string()));
372 assert!(s.contains(&"tag:alpha".to_string()));
373 assert!(s.contains(&"tag:beta".to_string()));
374 }
375
376 #[tokio::test]
377 async fn branch_routes_on_predicate() {
378 let big = Arc::new(FnCallable::labeled("big", |_v: Value, _ctx| async {
379 Ok(Value::from("big"))
380 }));
381 let small = Arc::new(FnCallable::labeled("small", |_v: Value, _ctx| async {
382 Ok(Value::from("small"))
383 }));
384 let b = Branch::new(|v: &Value| v.as_i64().unwrap_or(0) > 10, big, small);
385 assert_eq!(b.call(Value::from(42), ctx()).await.unwrap(), Value::from("big"));
386 assert_eq!(b.call(Value::from(1), ctx()).await.unwrap(), Value::from("small"));
387 }
388}