1use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use atomr_agents_core::{AgentError, CallCtx, Result, Value};
12
13use crate::{Callable, CallableHandle, FnCallable};
14
15#[derive(Clone, Copy, Debug)]
20pub struct RetryPolicy {
21 pub max_attempts: u32,
22 pub initial_backoff: Duration,
23 pub backoff_multiplier: f32,
24 pub max_backoff: Duration,
25}
26
27impl Default for RetryPolicy {
28 fn default() -> Self {
29 Self {
30 max_attempts: 3,
31 initial_backoff: Duration::from_millis(50),
32 backoff_multiplier: 2.0,
33 max_backoff: Duration::from_secs(5),
34 }
35 }
36}
37
38pub struct WithRetry {
39 inner: CallableHandle,
40 policy: RetryPolicy,
41 label: String,
42}
43
44impl WithRetry {
45 pub fn new(inner: CallableHandle, policy: RetryPolicy) -> Self {
46 let label = format!("retry({})", inner.label());
47 Self { inner, policy, label }
48 }
49}
50
51pub fn with_retry(inner: CallableHandle, policy: RetryPolicy) -> CallableHandle {
52 Arc::new(WithRetry::new(inner, policy))
53}
54
55#[async_trait]
56impl Callable for WithRetry {
57 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
58 let mut delay = self.policy.initial_backoff;
59 let mut last_err: Option<AgentError> = None;
60 for attempt in 0..self.policy.max_attempts {
61 match self.inner.call(input.clone(), ctx.clone()).await {
62 Ok(v) => return Ok(v),
63 Err(e) => {
64 last_err = Some(e);
65 if attempt + 1 == self.policy.max_attempts {
66 break;
67 }
68 tokio::time::sleep(delay).await;
69 let next_ms = (delay.as_millis() as f32 * self.policy.backoff_multiplier) as u64;
70 delay = Duration::from_millis(next_ms).min(self.policy.max_backoff);
71 }
72 }
73 }
74 Err(last_err.unwrap_or_else(|| AgentError::Internal("retry exhausted with no error".into())))
75 }
76
77 fn label(&self) -> &str {
78 &self.label
79 }
80}
81
82pub struct WithFallbacks {
87 primary: CallableHandle,
88 alternates: Vec<CallableHandle>,
89 label: String,
90}
91
92impl WithFallbacks {
93 pub fn new(primary: CallableHandle, alternates: Vec<CallableHandle>) -> Self {
94 let label = format!("fallback({})", primary.label());
95 Self {
96 primary,
97 alternates,
98 label,
99 }
100 }
101}
102
103pub fn with_fallbacks(primary: CallableHandle, alternates: Vec<CallableHandle>) -> CallableHandle {
104 Arc::new(WithFallbacks::new(primary, alternates))
105}
106
107#[async_trait]
108impl Callable for WithFallbacks {
109 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
110 if let Ok(v) = self.primary.call(input.clone(), ctx.clone()).await {
111 return Ok(v);
112 }
113 let mut last_err = None;
114 for alt in &self.alternates {
115 match alt.call(input.clone(), ctx.clone()).await {
116 Ok(v) => return Ok(v),
117 Err(e) => last_err = Some(e),
118 }
119 }
120 Err(last_err.unwrap_or_else(|| AgentError::Internal("fallbacks exhausted".into())))
121 }
122
123 fn label(&self) -> &str {
124 &self.label
125 }
126}
127
128#[derive(Clone, Default, Debug)]
133pub struct RunConfig {
134 pub run_name: Option<String>,
135 pub tags: Vec<String>,
136 pub metadata: serde_json::Map<String, Value>,
138}
139
140pub struct WithConfig {
141 inner: CallableHandle,
142 config: RunConfig,
143 label: String,
144}
145
146impl WithConfig {
147 pub fn new(inner: CallableHandle, config: RunConfig) -> Self {
148 let label = config
149 .run_name
150 .clone()
151 .unwrap_or_else(|| format!("config({})", inner.label()));
152 Self { inner, config, label }
153 }
154}
155
156pub fn with_config(inner: CallableHandle, config: RunConfig) -> CallableHandle {
157 Arc::new(WithConfig::new(inner, config))
158}
159
160#[async_trait]
161impl Callable for WithConfig {
162 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
163 let mut ctx = ctx;
164 if let Some(name) = &self.config.run_name {
165 ctx.trace.push(format!("run:{name}"));
166 }
167 for t in &self.config.tags {
168 ctx.trace.push(format!("tag:{t}"));
169 }
170 self.inner.call(input, ctx).await
171 }
172
173 fn label(&self) -> &str {
174 &self.label
175 }
176}
177
178pub struct WithTimeout {
183 inner: CallableHandle,
184 duration: Duration,
185 label: String,
186}
187
188impl WithTimeout {
189 pub fn new(inner: CallableHandle, duration: Duration) -> Self {
190 let label = format!("timeout({})", inner.label());
191 Self {
192 inner,
193 duration,
194 label,
195 }
196 }
197}
198
199pub fn with_timeout(inner: CallableHandle, duration: Duration) -> CallableHandle {
200 Arc::new(WithTimeout::new(inner, duration))
201}
202
203#[async_trait]
204impl Callable for WithTimeout {
205 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
206 match tokio::time::timeout(self.duration, self.inner.call(input, ctx)).await {
207 Ok(r) => r,
208 Err(_) => Err(AgentError::Internal(format!(
209 "timed out after {:?}",
210 self.duration
211 ))),
212 }
213 }
214
215 fn label(&self) -> &str {
216 &self.label
217 }
218}
219
220pub struct Branch {
225 predicate: Arc<dyn Fn(&Value) -> bool + Send + Sync + 'static>,
226 if_true: CallableHandle,
227 if_false: CallableHandle,
228 label: String,
229}
230
231impl Branch {
232 pub fn new<F>(predicate: F, if_true: CallableHandle, if_false: CallableHandle) -> Self
233 where
234 F: Fn(&Value) -> bool + Send + Sync + 'static,
235 {
236 let label = format!("branch({} | {})", if_true.label(), if_false.label());
237 Self {
238 predicate: Arc::new(predicate),
239 if_true,
240 if_false,
241 label,
242 }
243 }
244}
245
246#[async_trait]
247impl Callable for Branch {
248 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
249 if (self.predicate)(&input) {
250 self.if_true.call(input, ctx).await
251 } else {
252 self.if_false.call(input, ctx).await
253 }
254 }
255
256 fn label(&self) -> &str {
257 &self.label
258 }
259}
260
261pub type Lambda<F> = FnCallable<F>;
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
272 use std::sync::atomic::{AtomicU32, Ordering};
273
274 fn ctx() -> CallCtx {
275 CallCtx {
276 agent_id: None,
277 tokens: TokenBudget::new(1000),
278 time: TimeBudget::new(Duration::from_secs(10)),
279 money: MoneyBudget::from_usd(1.0),
280 iterations: IterationBudget::new(10),
281 trace: vec![],
282 extensions: Default::default(),
283 }
284 }
285
286 #[tokio::test]
287 async fn retry_succeeds_after_two_failures() {
288 let count = Arc::new(AtomicU32::new(0));
289 let count_clone = count.clone();
290 let flaky = Arc::new(FnCallable::labeled("flaky", move |v: Value, _ctx| {
291 let count = count_clone.clone();
292 async move {
293 let n = count.fetch_add(1, Ordering::SeqCst);
294 if n < 2 {
295 Err(AgentError::Internal(format!("attempt {n} failed")))
296 } else {
297 Ok(v)
298 }
299 }
300 }));
301 let retried = with_retry(
302 flaky,
303 RetryPolicy {
304 max_attempts: 5,
305 initial_backoff: Duration::from_millis(1),
306 backoff_multiplier: 1.0,
307 max_backoff: Duration::from_millis(1),
308 },
309 );
310 let out = retried.call(Value::from("ok"), ctx()).await.unwrap();
311 assert_eq!(out, Value::from("ok"));
312 assert_eq!(count.load(Ordering::SeqCst), 3);
313 }
314
315 #[tokio::test]
316 async fn retry_exhausts_then_errors() {
317 let always_fail = Arc::new(FnCallable::labeled("nope", |_v: Value, _ctx| async {
318 Err::<Value, _>(AgentError::Internal("boom".into()))
319 }));
320 let retried = with_retry(
321 always_fail,
322 RetryPolicy {
323 max_attempts: 2,
324 initial_backoff: Duration::from_millis(1),
325 backoff_multiplier: 1.0,
326 max_backoff: Duration::from_millis(1),
327 },
328 );
329 let r = retried.call(Value::Null, ctx()).await;
330 assert!(r.is_err());
331 }
332
333 #[tokio::test]
334 async fn fallback_uses_alternate_after_primary_failure() {
335 let primary = Arc::new(FnCallable::labeled("p", |_v: Value, _ctx| async {
336 Err::<Value, _>(AgentError::Inference("primary down".into()))
337 }));
338 let alt = Arc::new(FnCallable::labeled("alt", |v: Value, _ctx| async move { Ok(v) }));
339 let composed = with_fallbacks(primary, vec![alt]);
340 let out = composed.call(Value::from(42), ctx()).await.unwrap();
341 assert_eq!(out, Value::from(42));
342 }
343
344 #[tokio::test]
345 async fn timeout_fires() {
346 let slow = Arc::new(FnCallable::labeled("slow", |_v: Value, _ctx| async {
347 tokio::time::sleep(Duration::from_millis(50)).await;
348 Ok(Value::Null)
349 }));
350 let bounded = with_timeout(slow, Duration::from_millis(5));
351 let r = bounded.call(Value::Null, ctx()).await;
352 assert!(r.is_err());
353 }
354
355 #[tokio::test]
356 async fn config_pushes_run_name_and_tags_into_trace() {
357 let inner = Arc::new(FnCallable::labeled(
358 "inner",
359 |_v: Value, ctx: CallCtx| async move { Ok(Value::from(serde_json::json!({"trace": ctx.trace}))) },
360 ));
361 let configured = with_config(
362 inner,
363 RunConfig {
364 run_name: Some("my-run".into()),
365 tags: vec!["alpha".into(), "beta".into()],
366 metadata: Default::default(),
367 },
368 );
369 let out = configured.call(Value::Null, ctx()).await.unwrap();
370 let trace = out["trace"].as_array().unwrap();
371 let s: Vec<String> = trace.iter().map(|v| v.as_str().unwrap().to_string()).collect();
372 assert!(s.contains(&"run:my-run".to_string()));
373 assert!(s.contains(&"tag:alpha".to_string()));
374 assert!(s.contains(&"tag:beta".to_string()));
375 }
376
377 #[tokio::test]
378 async fn branch_routes_on_predicate() {
379 let big = Arc::new(FnCallable::labeled("big", |_v: Value, _ctx| async {
380 Ok(Value::from("big"))
381 }));
382 let small = Arc::new(FnCallable::labeled("small", |_v: Value, _ctx| async {
383 Ok(Value::from("small"))
384 }));
385 let b = Branch::new(|v: &Value| v.as_i64().unwrap_or(0) > 10, big, small);
386 assert_eq!(b.call(Value::from(42), ctx()).await.unwrap(), Value::from("big"));
387 assert_eq!(b.call(Value::from(1), ctx()).await.unwrap(), Value::from("small"));
388 }
389}