bamboo_agent/agent/core/composition/
expr.rs1use crate::agent::core::tools::ToolError;
7use serde::{Deserialize, Serialize};
8
9use super::condition::Condition;
10use super::parallel::ParallelWait;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(rename_all = "snake_case", tag = "type")]
18pub enum ToolExpr {
19 Call {
21 tool: String,
23 args: serde_json::Value,
25 },
26 Sequence {
28 steps: Vec<ToolExpr>,
30 #[serde(default = "default_fail_fast")]
32 fail_fast: bool,
33 },
34 Parallel {
36 branches: Vec<ToolExpr>,
38 #[serde(default)]
40 wait: ParallelWait,
41 },
42 Choice {
44 condition: Condition,
46 then_branch: Box<ToolExpr>,
48 else_branch: Option<Box<ToolExpr>>,
50 },
51 Retry {
53 expr: Box<ToolExpr>,
55 #[serde(default = "default_max_attempts")]
57 max_attempts: u32,
58 #[serde(default = "default_delay_ms")]
60 delay_ms: u64,
61 },
62 Let {
64 var: String,
66 expr: Box<ToolExpr>,
68 body: Box<ToolExpr>,
70 },
71 Var(String),
73}
74
75fn default_fail_fast() -> bool {
76 true
77}
78
79fn default_max_attempts() -> u32 {
80 3
81}
82
83fn default_delay_ms() -> u64 {
84 1000
85}
86
87impl ToolExpr {
88 pub fn call(tool: impl Into<String>, args: serde_json::Value) -> Self {
90 ToolExpr::Call {
91 tool: tool.into(),
92 args,
93 }
94 }
95
96 pub fn sequence(steps: Vec<ToolExpr>) -> Self {
98 ToolExpr::Sequence {
99 steps,
100 fail_fast: true,
101 }
102 }
103
104 pub fn sequence_with_fail_fast(steps: Vec<ToolExpr>, fail_fast: bool) -> Self {
106 ToolExpr::Sequence { steps, fail_fast }
107 }
108
109 pub fn parallel(branches: Vec<ToolExpr>) -> Self {
111 ToolExpr::Parallel {
112 branches,
113 wait: ParallelWait::All,
114 }
115 }
116
117 pub fn parallel_with_wait(branches: Vec<ToolExpr>, wait: ParallelWait) -> Self {
119 ToolExpr::Parallel { branches, wait }
120 }
121
122 pub fn choice(condition: Condition, then_branch: ToolExpr) -> Self {
124 ToolExpr::Choice {
125 condition,
126 then_branch: Box::new(then_branch),
127 else_branch: None,
128 }
129 }
130
131 pub fn choice_with_else(
133 condition: Condition,
134 then_branch: ToolExpr,
135 else_branch: ToolExpr,
136 ) -> Self {
137 ToolExpr::Choice {
138 condition,
139 then_branch: Box::new(then_branch),
140 else_branch: Some(Box::new(else_branch)),
141 }
142 }
143
144 pub fn retry(expr: ToolExpr) -> Self {
146 ToolExpr::Retry {
147 expr: Box::new(expr),
148 max_attempts: 3,
149 delay_ms: 1000,
150 }
151 }
152
153 pub fn retry_with_params(expr: ToolExpr, max_attempts: u32, delay_ms: u64) -> Self {
155 ToolExpr::Retry {
156 expr: Box::new(expr),
157 max_attempts,
158 delay_ms,
159 }
160 }
161
162 pub fn let_binding(var: impl Into<String>, expr: ToolExpr, body: ToolExpr) -> Self {
164 ToolExpr::Let {
165 var: var.into(),
166 expr: Box::new(expr),
167 body: Box::new(body),
168 }
169 }
170
171 pub fn var(name: impl Into<String>) -> Self {
173 ToolExpr::Var(name.into())
174 }
175
176 pub fn to_yaml(&self) -> Result<String, serde_yaml::Error> {
178 serde_yaml::to_string(self)
179 }
180
181 pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
183 serde_yaml::from_str(yaml)
184 }
185
186 pub fn to_json(&self) -> Result<String, serde_json::Error> {
188 serde_json::to_string(self)
189 }
190
191 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
193 serde_json::from_str(json)
194 }
195}
196
197#[derive(Debug, Clone)]
199pub enum CompositionError {
200 ToolError(ToolError),
201 VariableNotFound(String),
202 InvalidExpression(String),
203 MaxRetriesExceeded,
204}
205
206impl std::fmt::Display for CompositionError {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 match self {
209 CompositionError::ToolError(e) => write!(f, "Tool error: {}", e),
210 CompositionError::VariableNotFound(v) => write!(f, "Variable not found: {}", v),
211 CompositionError::InvalidExpression(e) => write!(f, "Invalid expression: {}", e),
212 CompositionError::MaxRetriesExceeded => write!(f, "Maximum retry attempts exceeded"),
213 }
214 }
215}
216
217impl std::error::Error for CompositionError {
218 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
219 match self {
220 CompositionError::ToolError(e) => Some(e),
221 _ => None,
222 }
223 }
224}
225
226impl From<ToolError> for CompositionError {
227 fn from(e: ToolError) -> Self {
228 CompositionError::ToolError(e)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use serde_json::json;
236
237 #[test]
238 fn test_call_expr() {
239 let expr = ToolExpr::call("read_file", json!({"path": "/tmp/test"}));
240
241 match expr {
242 ToolExpr::Call { tool, args } => {
243 assert_eq!(tool, "read_file");
244 assert_eq!(args["path"], "/tmp/test");
245 }
246 _ => panic!("Expected Call variant"),
247 }
248 }
249
250 #[test]
251 fn test_sequence_expr() {
252 let steps = vec![
253 ToolExpr::call("step1", json!({})),
254 ToolExpr::call("step2", json!({})),
255 ];
256 let expr = ToolExpr::sequence(steps);
257
258 match expr {
259 ToolExpr::Sequence { steps, fail_fast } => {
260 assert_eq!(steps.len(), 2);
261 assert!(fail_fast);
262 }
263 _ => panic!("Expected Sequence variant"),
264 }
265 }
266
267 #[test]
268 fn test_parallel_expr() {
269 let branches = vec![
270 ToolExpr::call("branch1", json!({})),
271 ToolExpr::call("branch2", json!({})),
272 ];
273 let expr = ToolExpr::parallel(branches);
274
275 match expr {
276 ToolExpr::Parallel { branches, wait } => {
277 assert_eq!(branches.len(), 2);
278 assert_eq!(wait, ParallelWait::All);
279 }
280 _ => panic!("Expected Parallel variant"),
281 }
282 }
283
284 #[test]
285 fn test_choice_expr() {
286 let condition = Condition::Success;
287 let then_branch = ToolExpr::call("success_handler", json!({}));
288 let else_branch = ToolExpr::call("failure_handler", json!({}));
289
290 let expr = ToolExpr::choice_with_else(condition, then_branch, else_branch);
291
292 match expr {
293 ToolExpr::Choice {
294 condition: _,
295 then_branch,
296 else_branch,
297 } => {
298 assert!(else_branch.is_some());
299 match *then_branch {
300 ToolExpr::Call { tool, .. } => assert_eq!(tool, "success_handler"),
301 _ => panic!("Expected Call in then_branch"),
302 }
303 }
304 _ => panic!("Expected Choice variant"),
305 }
306 }
307
308 #[test]
309 fn test_retry_expr() {
310 let inner = ToolExpr::call("risky_op", json!({}));
311 let expr = ToolExpr::retry_with_params(inner, 5, 500);
312
313 match expr {
314 ToolExpr::Retry {
315 expr: _,
316 max_attempts,
317 delay_ms,
318 } => {
319 assert_eq!(max_attempts, 5);
320 assert_eq!(delay_ms, 500);
321 }
322 _ => panic!("Expected Retry variant"),
323 }
324 }
325
326 #[test]
327 fn test_let_expr() {
328 let expr = ToolExpr::let_binding(
329 "result",
330 ToolExpr::call("fetch", json!({"url": "http://example.com"})),
331 ToolExpr::call("process", json!({"data": "${result}"})),
332 );
333
334 match expr {
335 ToolExpr::Let { var, expr, body } => {
336 assert_eq!(var, "result");
337 assert!(matches!(*expr, ToolExpr::Call { .. }));
338 assert!(matches!(*body, ToolExpr::Call { .. }));
339 }
340 _ => panic!("Expected Let variant"),
341 }
342 }
343
344 #[test]
345 fn test_yaml_roundtrip() {
346 let expr = ToolExpr::sequence(vec![
347 ToolExpr::call("step1", json!({"arg": 1})),
348 ToolExpr::call("step2", json!({"arg": 2})),
349 ]);
350
351 let yaml = expr.to_yaml().unwrap();
352 let deserialized = ToolExpr::from_yaml(&yaml).unwrap();
353
354 assert_eq!(expr, deserialized);
355 }
356
357 #[test]
358 fn test_json_roundtrip() {
359 let expr = ToolExpr::choice_with_else(
360 Condition::Success,
361 ToolExpr::call("on_success", json!({})),
362 ToolExpr::call("on_failure", json!({})),
363 );
364
365 let json_str = expr.to_json().unwrap();
366 let deserialized = ToolExpr::from_json(&json_str).unwrap();
367
368 assert_eq!(expr, deserialized);
369 }
370
371 #[test]
372 fn test_yaml_deserialization() {
373 let yaml = r#"
374type: sequence
375steps:
376 - type: call
377 tool: read_file
378 args:
379 path: /tmp/test.txt
380 - type: call
381 tool: process
382 args:
383 data: "hello"
384fail_fast: true
385"#;
386
387 let expr: ToolExpr = serde_yaml::from_str(yaml).unwrap();
388 assert!(matches!(expr, ToolExpr::Sequence { .. }));
389 }
390}