Skip to main content

bamboo_agent/agent/core/composition/
expr.rs

1//! Tool Expression DSL - Serializable tool composition language
2//!
3//! This module provides a declarative DSL for composing tool calls that can be
4//! serialized to/from YAML and JSON.
5
6use crate::agent::core::tools::ToolError;
7use serde::{Deserialize, Serialize};
8
9use super::condition::Condition;
10use super::parallel::ParallelWait;
11
12/// Tool expression DSL for composing tool calls
13///
14/// This enum represents the AST (Abstract Syntax Tree) for the tool composition DSL.
15/// Each variant represents a different composition operation.
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(rename_all = "snake_case", tag = "type")]
18pub enum ToolExpr {
19    /// Execute a single tool call
20    Call {
21        /// Name of the tool to execute
22        tool: String,
23        /// Arguments to pass to the tool (JSON object)
24        args: serde_json::Value,
25    },
26    /// Execute a sequence of expressions
27    Sequence {
28        /// Steps to execute in order
29        steps: Vec<ToolExpr>,
30        /// Whether to stop on first error (default: true)
31        #[serde(default = "default_fail_fast")]
32        fail_fast: bool,
33    },
34    /// Execute branches in parallel
35    Parallel {
36        /// Branches to execute concurrently
37        branches: Vec<ToolExpr>,
38        /// Wait strategy: All, First, or Any
39        #[serde(default)]
40        wait: ParallelWait,
41    },
42    /// Conditional execution
43    Choice {
44        /// Condition to evaluate
45        condition: Condition,
46        /// Expression to execute if condition is true
47        then_branch: Box<ToolExpr>,
48        /// Expression to execute if condition is false
49        else_branch: Option<Box<ToolExpr>>,
50    },
51    /// Retry with backoff
52    Retry {
53        /// Expression to retry
54        expr: Box<ToolExpr>,
55        /// Maximum number of retry attempts (default: 3)
56        #[serde(default = "default_max_attempts")]
57        max_attempts: u32,
58        /// Delay between retries in milliseconds (default: 1000)
59        #[serde(default = "default_delay_ms")]
60        delay_ms: u64,
61    },
62    /// Variable binding
63    Let {
64        /// Variable name
65        var: String,
66        /// Expression to bind
67        expr: Box<ToolExpr>,
68        /// Body expression that uses the variable
69        body: Box<ToolExpr>,
70    },
71    /// Variable reference
72    Var(String),
73}
74
75fn default_fail_fast() -> bool {
76    true
77}
78
79fn default_max_attempts() -> u32 {
80    3
81}
82
83fn default_delay_ms() -> u64 {
84    1000
85}
86
87impl ToolExpr {
88    /// Create a simple tool call expression
89    pub fn call(tool: impl Into<String>, args: serde_json::Value) -> Self {
90        ToolExpr::Call {
91            tool: tool.into(),
92            args,
93        }
94    }
95
96    /// Create a sequence expression with fail_fast=true
97    pub fn sequence(steps: Vec<ToolExpr>) -> Self {
98        ToolExpr::Sequence {
99            steps,
100            fail_fast: true,
101        }
102    }
103
104    /// Create a sequence expression with custom fail_fast
105    pub fn sequence_with_fail_fast(steps: Vec<ToolExpr>, fail_fast: bool) -> Self {
106        ToolExpr::Sequence { steps, fail_fast }
107    }
108
109    /// Create a parallel expression
110    pub fn parallel(branches: Vec<ToolExpr>) -> Self {
111        ToolExpr::Parallel {
112            branches,
113            wait: ParallelWait::All,
114        }
115    }
116
117    /// Create a parallel expression with custom wait strategy
118    pub fn parallel_with_wait(branches: Vec<ToolExpr>, wait: ParallelWait) -> Self {
119        ToolExpr::Parallel { branches, wait }
120    }
121
122    /// Create a conditional expression
123    pub fn choice(condition: Condition, then_branch: ToolExpr) -> Self {
124        ToolExpr::Choice {
125            condition,
126            then_branch: Box::new(then_branch),
127            else_branch: None,
128        }
129    }
130
131    /// Create a conditional expression with else branch
132    pub fn choice_with_else(
133        condition: Condition,
134        then_branch: ToolExpr,
135        else_branch: ToolExpr,
136    ) -> Self {
137        ToolExpr::Choice {
138            condition,
139            then_branch: Box::new(then_branch),
140            else_branch: Some(Box::new(else_branch)),
141        }
142    }
143
144    /// Create a retry expression with defaults
145    pub fn retry(expr: ToolExpr) -> Self {
146        ToolExpr::Retry {
147            expr: Box::new(expr),
148            max_attempts: 3,
149            delay_ms: 1000,
150        }
151    }
152
153    /// Create a retry expression with custom parameters
154    pub fn retry_with_params(expr: ToolExpr, max_attempts: u32, delay_ms: u64) -> Self {
155        ToolExpr::Retry {
156            expr: Box::new(expr),
157            max_attempts,
158            delay_ms,
159        }
160    }
161
162    /// Create a let binding expression
163    pub fn let_binding(var: impl Into<String>, expr: ToolExpr, body: ToolExpr) -> Self {
164        ToolExpr::Let {
165            var: var.into(),
166            expr: Box::new(expr),
167            body: Box::new(body),
168        }
169    }
170
171    /// Create a variable reference
172    pub fn var(name: impl Into<String>) -> Self {
173        ToolExpr::Var(name.into())
174    }
175
176    /// Serialize to YAML string
177    pub fn to_yaml(&self) -> Result<String, serde_yaml::Error> {
178        serde_yaml::to_string(self)
179    }
180
181    /// Deserialize from YAML string
182    pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
183        serde_yaml::from_str(yaml)
184    }
185
186    /// Serialize to JSON string
187    pub fn to_json(&self) -> Result<String, serde_json::Error> {
188        serde_json::to_string(self)
189    }
190
191    /// Deserialize from JSON string
192    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
193        serde_json::from_str(json)
194    }
195}
196
197/// Composition error types
198#[derive(Debug, Clone)]
199pub enum CompositionError {
200    ToolError(ToolError),
201    VariableNotFound(String),
202    InvalidExpression(String),
203    MaxRetriesExceeded,
204}
205
206impl std::fmt::Display for CompositionError {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        match self {
209            CompositionError::ToolError(e) => write!(f, "Tool error: {}", e),
210            CompositionError::VariableNotFound(v) => write!(f, "Variable not found: {}", v),
211            CompositionError::InvalidExpression(e) => write!(f, "Invalid expression: {}", e),
212            CompositionError::MaxRetriesExceeded => write!(f, "Maximum retry attempts exceeded"),
213        }
214    }
215}
216
217impl std::error::Error for CompositionError {
218    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
219        match self {
220            CompositionError::ToolError(e) => Some(e),
221            _ => None,
222        }
223    }
224}
225
226impl From<ToolError> for CompositionError {
227    fn from(e: ToolError) -> Self {
228        CompositionError::ToolError(e)
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use serde_json::json;
236
237    #[test]
238    fn test_call_expr() {
239        let expr = ToolExpr::call("read_file", json!({"path": "/tmp/test"}));
240
241        match expr {
242            ToolExpr::Call { tool, args } => {
243                assert_eq!(tool, "read_file");
244                assert_eq!(args["path"], "/tmp/test");
245            }
246            _ => panic!("Expected Call variant"),
247        }
248    }
249
250    #[test]
251    fn test_sequence_expr() {
252        let steps = vec![
253            ToolExpr::call("step1", json!({})),
254            ToolExpr::call("step2", json!({})),
255        ];
256        let expr = ToolExpr::sequence(steps);
257
258        match expr {
259            ToolExpr::Sequence { steps, fail_fast } => {
260                assert_eq!(steps.len(), 2);
261                assert!(fail_fast);
262            }
263            _ => panic!("Expected Sequence variant"),
264        }
265    }
266
267    #[test]
268    fn test_parallel_expr() {
269        let branches = vec![
270            ToolExpr::call("branch1", json!({})),
271            ToolExpr::call("branch2", json!({})),
272        ];
273        let expr = ToolExpr::parallel(branches);
274
275        match expr {
276            ToolExpr::Parallel { branches, wait } => {
277                assert_eq!(branches.len(), 2);
278                assert_eq!(wait, ParallelWait::All);
279            }
280            _ => panic!("Expected Parallel variant"),
281        }
282    }
283
284    #[test]
285    fn test_choice_expr() {
286        let condition = Condition::Success;
287        let then_branch = ToolExpr::call("success_handler", json!({}));
288        let else_branch = ToolExpr::call("failure_handler", json!({}));
289
290        let expr = ToolExpr::choice_with_else(condition, then_branch, else_branch);
291
292        match expr {
293            ToolExpr::Choice {
294                condition: _,
295                then_branch,
296                else_branch,
297            } => {
298                assert!(else_branch.is_some());
299                match *then_branch {
300                    ToolExpr::Call { tool, .. } => assert_eq!(tool, "success_handler"),
301                    _ => panic!("Expected Call in then_branch"),
302                }
303            }
304            _ => panic!("Expected Choice variant"),
305        }
306    }
307
308    #[test]
309    fn test_retry_expr() {
310        let inner = ToolExpr::call("risky_op", json!({}));
311        let expr = ToolExpr::retry_with_params(inner, 5, 500);
312
313        match expr {
314            ToolExpr::Retry {
315                expr: _,
316                max_attempts,
317                delay_ms,
318            } => {
319                assert_eq!(max_attempts, 5);
320                assert_eq!(delay_ms, 500);
321            }
322            _ => panic!("Expected Retry variant"),
323        }
324    }
325
326    #[test]
327    fn test_let_expr() {
328        let expr = ToolExpr::let_binding(
329            "result",
330            ToolExpr::call("fetch", json!({"url": "http://example.com"})),
331            ToolExpr::call("process", json!({"data": "${result}"})),
332        );
333
334        match expr {
335            ToolExpr::Let { var, expr, body } => {
336                assert_eq!(var, "result");
337                assert!(matches!(*expr, ToolExpr::Call { .. }));
338                assert!(matches!(*body, ToolExpr::Call { .. }));
339            }
340            _ => panic!("Expected Let variant"),
341        }
342    }
343
344    #[test]
345    fn test_yaml_roundtrip() {
346        let expr = ToolExpr::sequence(vec![
347            ToolExpr::call("step1", json!({"arg": 1})),
348            ToolExpr::call("step2", json!({"arg": 2})),
349        ]);
350
351        let yaml = expr.to_yaml().unwrap();
352        let deserialized = ToolExpr::from_yaml(&yaml).unwrap();
353
354        assert_eq!(expr, deserialized);
355    }
356
357    #[test]
358    fn test_json_roundtrip() {
359        let expr = ToolExpr::choice_with_else(
360            Condition::Success,
361            ToolExpr::call("on_success", json!({})),
362            ToolExpr::call("on_failure", json!({})),
363        );
364
365        let json_str = expr.to_json().unwrap();
366        let deserialized = ToolExpr::from_json(&json_str).unwrap();
367
368        assert_eq!(expr, deserialized);
369    }
370
371    #[test]
372    fn test_yaml_deserialization() {
373        let yaml = r#"
374type: sequence
375steps:
376  - type: call
377    tool: read_file
378    args:
379      path: /tmp/test.txt
380  - type: call
381    tool: process
382    args:
383      data: "hello"
384fail_fast: true
385"#;
386
387        let expr: ToolExpr = serde_yaml::from_str(yaml).unwrap();
388        assert!(matches!(expr, ToolExpr::Sequence { .. }));
389    }
390}