Skip to main content

bamboo_domain/session/composition/
expr.rs

1//! Tool expression AST types — pure data, no agent dependencies.
2
3use serde::{Deserialize, Serialize};
4
5use super::condition::Condition;
6use super::parallel::ParallelWait;
7
8/// Tool expression DSL for composing tool calls.
9///
10/// This enum represents the AST for the tool composition DSL.
11/// Each variant represents a different composition operation.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13#[serde(rename_all = "snake_case", tag = "type")]
14pub enum ToolExpr {
15    /// Execute a single tool call
16    Call {
17        /// Name of the tool to execute
18        tool: String,
19        /// Arguments to pass to the tool (JSON object)
20        args: serde_json::Value,
21    },
22    /// Execute a sequence of expressions
23    Sequence {
24        /// Steps to execute in order
25        steps: Vec<ToolExpr>,
26        /// Whether to stop on first error (default: true)
27        #[serde(default = "default_fail_fast")]
28        fail_fast: bool,
29    },
30    /// Execute branches in parallel
31    Parallel {
32        /// Branches to execute concurrently
33        branches: Vec<ToolExpr>,
34        /// Wait strategy: All, First, or Any
35        #[serde(default)]
36        wait: ParallelWait,
37    },
38    /// Conditional execution
39    Choice {
40        /// Condition to evaluate
41        condition: Condition,
42        /// Expression to execute if condition is true
43        then_branch: Box<ToolExpr>,
44        /// Expression to execute if condition is false
45        else_branch: Option<Box<ToolExpr>>,
46    },
47    /// Retry with backoff
48    Retry {
49        /// Expression to retry
50        expr: Box<ToolExpr>,
51        /// Maximum number of retry attempts (default: 3)
52        #[serde(default = "default_max_attempts")]
53        max_attempts: u32,
54        /// Delay between retries in milliseconds (default: 1000)
55        #[serde(default = "default_delay_ms")]
56        delay_ms: u64,
57    },
58    /// Variable binding
59    Let {
60        /// Variable name
61        var: String,
62        /// Expression to bind
63        expr: Box<ToolExpr>,
64        /// Body expression that use the variable
65        body: Box<ToolExpr>,
66    },
67    /// Variable reference
68    Var(String),
69}
70
71fn default_fail_fast() -> bool {
72    true
73}
74
75fn default_max_attempts() -> u32 {
76    3
77}
78
79fn default_delay_ms() -> u64 {
80    1000
81}
82
83impl ToolExpr {
84    /// Create a simple tool call expression
85    pub fn call(tool: impl Into<String>, args: serde_json::Value) -> Self {
86        ToolExpr::Call {
87            tool: tool.into(),
88            args,
89        }
90    }
91
92    /// Create a sequence expression with fail_fast=true
93    pub fn sequence(steps: Vec<ToolExpr>) -> Self {
94        ToolExpr::Sequence {
95            steps,
96            fail_fast: true,
97        }
98    }
99
100    /// Create a sequence expression with custom fail_fast
101    pub fn sequence_with_fail_fast(steps: Vec<ToolExpr>, fail_fast: bool) -> Self {
102        ToolExpr::Sequence { steps, fail_fast }
103    }
104
105    /// Create a parallel expression
106    pub fn parallel(branches: Vec<ToolExpr>) -> Self {
107        ToolExpr::Parallel {
108            branches,
109            wait: ParallelWait::All,
110        }
111    }
112
113    /// Create a parallel expression with custom wait strategy
114    pub fn parallel_with_wait(branches: Vec<ToolExpr>, wait: ParallelWait) -> Self {
115        ToolExpr::Parallel { branches, wait }
116    }
117
118    /// Create a conditional expression
119    pub fn choice(condition: Condition, then_branch: ToolExpr) -> Self {
120        ToolExpr::Choice {
121            condition,
122            then_branch: Box::new(then_branch),
123            else_branch: None,
124        }
125    }
126
127    /// Create a conditional expression with else branch
128    pub fn choice_with_else(
129        condition: Condition,
130        then_branch: ToolExpr,
131        else_branch: ToolExpr,
132    ) -> Self {
133        ToolExpr::Choice {
134            condition,
135            then_branch: Box::new(then_branch),
136            else_branch: Some(Box::new(else_branch)),
137        }
138    }
139
140    /// Create a retry expression with defaults
141    pub fn retry(expr: ToolExpr) -> Self {
142        ToolExpr::Retry {
143            expr: Box::new(expr),
144            max_attempts: 3,
145            delay_ms: 1000,
146        }
147    }
148
149    /// Create a retry expression with custom parameters
150    pub fn retry_with_params(expr: ToolExpr, max_attempts: u32, delay_ms: u64) -> Self {
151        ToolExpr::Retry {
152            expr: Box::new(expr),
153            max_attempts,
154            delay_ms,
155        }
156    }
157
158    /// Create a let binding expression
159    pub fn let_binding(var: impl Into<String>, expr: ToolExpr, body: ToolExpr) -> Self {
160        ToolExpr::Let {
161            var: var.into(),
162            expr: Box::new(expr),
163            body: Box::new(body),
164        }
165    }
166
167    /// Create a variable reference
168    pub fn var(name: impl Into<String>) -> Self {
169        ToolExpr::Var(name.into())
170    }
171
172    /// Serialize to YAML string
173    pub fn to_yaml(&self) -> Result<String, serde_yaml::Error> {
174        serde_yaml::to_string(self)
175    }
176
177    /// Deserialize from YAML string
178    pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
179        serde_yaml::from_str(yaml)
180    }
181
182    /// Serialize to JSON string
183    pub fn to_json(&self) -> Result<String, serde_json::Error> {
184        serde_json::to_string(self)
185    }
186
187    /// Deserialize from JSON string
188    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
189        serde_json::from_str(json)
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use serde_json::json;
197
198    #[test]
199    fn test_call_expr() {
200        let expr = ToolExpr::call("read_file", json!({"path": "/tmp/test"}));
201        match expr {
202            ToolExpr::Call { tool, args } => {
203                assert_eq!(tool, "read_file");
204                assert_eq!(args["path"], "/tmp/test");
205            }
206            _ => panic!("Expected Call variant"),
207        }
208    }
209
210    #[test]
211    fn test_sequence_expr() {
212        let steps = vec![
213            ToolExpr::call("step1", json!({})),
214            ToolExpr::call("step2", json!({})),
215        ];
216        let expr = ToolExpr::sequence(steps);
217        match expr {
218            ToolExpr::Sequence { steps, fail_fast } => {
219                assert_eq!(steps.len(), 2);
220                assert!(fail_fast);
221            }
222            _ => panic!("Expected Sequence variant"),
223        }
224    }
225
226    #[test]
227    fn test_yaml_roundtrip() {
228        let expr = ToolExpr::sequence(vec![
229            ToolExpr::call("step1", json!({"arg": 1})),
230            ToolExpr::call("step2", json!({"arg": 2})),
231        ]);
232        let yaml = expr.to_yaml().unwrap();
233        let deserialized = ToolExpr::from_yaml(&yaml).unwrap();
234        assert_eq!(expr, deserialized);
235    }
236
237    #[test]
238    fn test_json_roundtrip() {
239        let expr = ToolExpr::choice_with_else(
240            Condition::Success,
241            ToolExpr::call("on_success", json!({})),
242            ToolExpr::call("on_failure", json!({})),
243        );
244        let json_str = expr.to_json().unwrap();
245        let deserialized = ToolExpr::from_json(&json_str).unwrap();
246        assert_eq!(expr, deserialized);
247    }
248}