atomr_agents_callable/
lib.rs1mod decorators;
4mod pipeline;
5
6pub use decorators::{
7 with_config, with_fallbacks, with_retry, with_timeout, Branch, Lambda, RetryPolicy, RunConfig,
8 WithConfig, WithFallbacks, WithRetry, WithTimeout,
9};
10pub use pipeline::{fan_out, Pipeline};
11
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use atomr_agents_core::{CallCtx, Result, Value};
16
17#[async_trait]
21pub trait Callable: Send + Sync + 'static {
22 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value>;
23
24 fn label(&self) -> &str {
27 std::any::type_name::<Self>()
28 }
29}
30
31pub type CallableHandle = Arc<dyn Callable>;
34
35pub struct FnCallable<F> {
41 inner: F,
42 label: &'static str,
43}
44
45impl<F> FnCallable<F> {
46 pub fn new(f: F) -> Self {
47 Self {
48 inner: f,
49 label: "fn",
50 }
51 }
52
53 pub fn labeled(label: &'static str, f: F) -> Self {
54 Self { inner: f, label }
55 }
56}
57
58#[async_trait]
59impl<F, Fut> Callable for FnCallable<F>
60where
61 F: Fn(Value, CallCtx) -> Fut + Send + Sync + 'static,
62 Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
63{
64 async fn call(&self, input: Value, ctx: CallCtx) -> Result<Value> {
65 (self.inner)(input, ctx).await
66 }
67
68 fn label(&self) -> &str {
69 self.label
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use atomr_agents_core::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
77 use std::time::Duration;
78
79 fn ctx() -> CallCtx {
80 CallCtx {
81 agent_id: None,
82 tokens: TokenBudget::new(1000),
83 time: TimeBudget::new(Duration::from_secs(10)),
84 money: MoneyBudget::from_usd(1.0),
85 iterations: IterationBudget::new(10),
86 trace: vec![],
87 }
88 }
89
90 #[tokio::test]
91 async fn fn_callable_round_trips() {
92 let c = FnCallable::new(|input: Value, _ctx| async move { Ok(input) });
93 let v = serde_json::json!({"hello": "world"});
94 let out = c.call(v.clone(), ctx()).await.unwrap();
95 assert_eq!(out, v);
96 }
97
98 #[tokio::test]
99 async fn handle_is_dyn_safe() {
100 let h: CallableHandle =
101 std::sync::Arc::new(FnCallable::labeled("echo", |input: Value, _ctx| async move {
102 Ok(input)
103 }));
104 let out = h.call(serde_json::json!(42), ctx()).await.unwrap();
105 assert_eq!(out, serde_json::json!(42));
106 assert_eq!(h.label(), "echo");
107 }
108}