Skip to main content

bamboo_agent/agent/core/composition/
executor.rs

1//! Workflow executor for tool composition DSL
2//!
3//! This module provides the execution engine for running tool composition
4//! workflows defined using the ToolExpr DSL.
5
6use crate::agent::core::tools::{normalize_tool_name, ToolError, ToolRegistry, ToolResult};
7use futures::future::join_all;
8use regex::Regex;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::time::sleep;
13
14use super::condition::Condition;
15use super::context::ExecutionContext;
16use super::expr::ToolExpr;
17use super::parallel::ParallelWait;
18
19/// Executor for running tool composition workflows
20///
21/// The executor takes a tool registry and executes composition expressions,
22/// managing context, variables, and control flow.
23pub struct CompositionExecutor {
24    /// Tool registry for looking up and executing tools
25    registry: Arc<ToolRegistry>,
26}
27
28type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
29
30impl CompositionExecutor {
31    /// Create a new composition executor with the given tool registry
32    pub fn new(registry: Arc<ToolRegistry>) -> Self {
33        Self { registry }
34    }
35
36    /// Execute a tool expression with the given context
37    ///
38    /// This method executes the expression and updates the context with
39    /// the result, binding it to the `_last` variable.
40    pub async fn execute(
41        &self,
42        expr: &ToolExpr,
43        ctx: &mut ExecutionContext,
44    ) -> Result<ToolResult, ToolError> {
45        let result = self.execute_internal(expr, ctx).await;
46
47        ctx.log_step(Self::expr_name(expr).to_string(), result.clone());
48
49        if let Ok(value) = &result {
50            ctx.bind("_last".to_string(), value.clone());
51        }
52
53        result
54    }
55
56    fn execute_internal<'a>(
57        &'a self,
58        expr: &'a ToolExpr,
59        ctx: &'a mut ExecutionContext,
60    ) -> BoxFuture<'a, Result<ToolResult, ToolError>> {
61        Box::pin(async move {
62            match expr {
63                ToolExpr::Call { tool, args } => self.execute_call(tool, args).await,
64                ToolExpr::Sequence { steps, fail_fast } => {
65                    self.execute_sequence(steps, *fail_fast, ctx).await
66                }
67                ToolExpr::Parallel { branches, wait } => {
68                    self.execute_parallel(branches, wait, ctx).await
69                }
70                ToolExpr::Choice {
71                    condition,
72                    then_branch,
73                    else_branch,
74                } => {
75                    self.execute_choice(condition, then_branch, else_branch.as_deref(), ctx)
76                        .await
77                }
78                ToolExpr::Retry {
79                    expr,
80                    max_attempts,
81                    delay_ms,
82                } => {
83                    self.execute_retry(expr, *max_attempts, *delay_ms, ctx)
84                        .await
85                }
86                ToolExpr::Let { var, expr, body } => self.execute_let(var, expr, body, ctx).await,
87                ToolExpr::Var(name) => self.execute_var(name, ctx),
88            }
89        })
90    }
91
92    async fn execute_call(
93        &self,
94        tool: &str,
95        args: &serde_json::Value,
96    ) -> Result<ToolResult, ToolError> {
97        let normalized = normalize_tool_name(tool);
98        let tool_impl = self
99            .registry
100            .get(normalized)
101            .ok_or_else(|| ToolError::NotFound(format!("Tool '{}' not found", normalized)))?;
102
103        tool_impl.execute(args.clone()).await
104    }
105
106    async fn execute_sequence(
107        &self,
108        steps: &[ToolExpr],
109        fail_fast: bool,
110        ctx: &mut ExecutionContext,
111    ) -> Result<ToolResult, ToolError> {
112        let mut last_result = Self::default_result("empty sequence", true);
113
114        for step in steps {
115            match self.execute_internal(step, ctx).await {
116                Ok(result) => {
117                    ctx.bind("_last".to_string(), result.clone());
118                    let should_stop = fail_fast && !result.success;
119                    last_result = result;
120
121                    if should_stop {
122                        return Ok(last_result);
123                    }
124                }
125                Err(error) => {
126                    if fail_fast {
127                        return Err(error);
128                    }
129
130                    let failure = Self::default_result(error.to_string(), false);
131                    ctx.bind("_last".to_string(), failure.clone());
132                    last_result = failure;
133                }
134            }
135        }
136
137        Ok(last_result)
138    }
139
140    async fn execute_parallel(
141        &self,
142        branches: &[ToolExpr],
143        wait: &ParallelWait,
144        ctx: &ExecutionContext,
145    ) -> Result<ToolResult, ToolError> {
146        if branches.is_empty() {
147            return Ok(Self::default_result("empty parallel", true));
148        }
149
150        let futures = branches.iter().map(|branch| {
151            let mut branch_ctx = ctx.clone();
152            async move { self.execute_internal(branch, &mut branch_ctx).await }
153        });
154
155        let results = join_all(futures).await;
156
157        match wait {
158            ParallelWait::All => self.resolve_parallel_all(results),
159            ParallelWait::Any => self.resolve_parallel_any(results),
160            ParallelWait::N(target) => self.resolve_parallel_n(results, branches.len(), *target),
161        }
162    }
163
164    fn resolve_parallel_all(
165        &self,
166        results: Vec<Result<ToolResult, ToolError>>,
167    ) -> Result<ToolResult, ToolError> {
168        let mut last_success = None;
169
170        for result in results {
171            match result {
172                Ok(tool_result) => {
173                    if !tool_result.success {
174                        return Ok(tool_result);
175                    }
176                    last_success = Some(tool_result);
177                }
178                Err(error) => return Err(error),
179            }
180        }
181
182        Ok(last_success.unwrap_or_else(|| Self::default_result("all branches completed", true)))
183    }
184
185    fn resolve_parallel_any(
186        &self,
187        results: Vec<Result<ToolResult, ToolError>>,
188    ) -> Result<ToolResult, ToolError> {
189        let mut first_failure = None;
190        let mut last_error = None;
191
192        for result in results {
193            match result {
194                Ok(tool_result) if tool_result.success => return Ok(tool_result),
195                Ok(tool_result) => {
196                    if first_failure.is_none() {
197                        first_failure = Some(tool_result);
198                    }
199                }
200                Err(error) => last_error = Some(error),
201            }
202        }
203
204        if let Some(failure) = first_failure {
205            return Ok(failure);
206        }
207
208        Err(last_error
209            .unwrap_or_else(|| ToolError::Execution("no parallel branch succeeded".to_string())))
210    }
211
212    fn resolve_parallel_n(
213        &self,
214        results: Vec<Result<ToolResult, ToolError>>,
215        branch_count: usize,
216        target: usize,
217    ) -> Result<ToolResult, ToolError> {
218        let mut success_count = 0;
219        let mut last_success = None;
220
221        for result in results {
222            match result {
223                Ok(tool_result) => {
224                    if tool_result.success {
225                        success_count += 1;
226                        last_success = Some(tool_result);
227                    }
228                }
229                Err(error) => return Err(error),
230            }
231        }
232
233        if success_count >= target {
234            return Ok(last_success
235                .unwrap_or_else(|| Self::default_result("required branches succeeded", true)));
236        }
237
238        Ok(Self::default_result(
239            format!("only {success_count} of {branch_count} branches succeeded; required {target}"),
240            false,
241        ))
242    }
243
244    async fn execute_choice(
245        &self,
246        condition: &Condition,
247        then_branch: &ToolExpr,
248        else_branch: Option<&ToolExpr>,
249        ctx: &mut ExecutionContext,
250    ) -> Result<ToolResult, ToolError> {
251        let last_result = ctx
252            .lookup("_last")
253            .cloned()
254            .unwrap_or_else(|| Self::default_result("{}", true));
255
256        if self.evaluate_condition(condition, &last_result) {
257            self.execute_internal(then_branch, ctx).await
258        } else if let Some(else_expr) = else_branch {
259            self.execute_internal(else_expr, ctx).await
260        } else {
261            Ok(Self::default_result("condition not met", true))
262        }
263    }
264
265    async fn execute_retry(
266        &self,
267        expr: &ToolExpr,
268        max_attempts: u32,
269        delay_ms: u64,
270        ctx: &mut ExecutionContext,
271    ) -> Result<ToolResult, ToolError> {
272        let attempts = max_attempts.max(1);
273        let mut last_error = None;
274
275        for attempt in 0..attempts {
276            match self.execute_internal(expr, ctx).await {
277                Ok(result) if result.success => return Ok(result),
278                Ok(result) => last_error = Some(ToolError::Execution(result.result)),
279                Err(error) => last_error = Some(error),
280            }
281
282            if attempt + 1 < attempts && delay_ms > 0 {
283                sleep(Duration::from_millis(delay_ms)).await;
284            }
285        }
286
287        Err(last_error
288            .unwrap_or_else(|| ToolError::Execution("retry attempts exhausted".to_string())))
289    }
290
291    async fn execute_let(
292        &self,
293        var: &str,
294        expr: &ToolExpr,
295        body: &ToolExpr,
296        ctx: &mut ExecutionContext,
297    ) -> Result<ToolResult, ToolError> {
298        let value = self.execute_internal(expr, ctx).await?;
299        ctx.bind(var.to_string(), value.clone());
300        ctx.bind("_last".to_string(), value);
301        self.execute_internal(body, ctx).await
302    }
303
304    fn execute_var(&self, name: &str, ctx: &ExecutionContext) -> Result<ToolResult, ToolError> {
305        ctx.lookup(name)
306            .cloned()
307            .ok_or_else(|| ToolError::Execution(format!("Variable not found: {}", name)))
308    }
309
310    fn evaluate_condition(&self, condition: &Condition, result: &ToolResult) -> bool {
311        match condition {
312            Condition::Success => result.success,
313            Condition::Contains { path, value } => {
314                Self::extract_value_at_path(&result.result, path)
315                    .map(|current| current.contains(value))
316                    .unwrap_or(false)
317            }
318            Condition::Matches { path, pattern } => {
319                Self::extract_value_at_path(&result.result, path)
320                    .map(|current| {
321                        Regex::new(pattern)
322                            .map(|regex| regex.is_match(&current))
323                            .unwrap_or(false)
324                    })
325                    .unwrap_or(false)
326            }
327            Condition::And { conditions } => conditions
328                .iter()
329                .all(|inner| self.evaluate_condition(inner, result)),
330            Condition::Or { conditions } => conditions
331                .iter()
332                .any(|inner| self.evaluate_condition(inner, result)),
333        }
334    }
335
336    fn extract_value_at_path(payload: &str, path: &str) -> Option<String> {
337        let parsed: serde_json::Value = serde_json::from_str(payload).ok()?;
338
339        if path.is_empty() {
340            return Some(Self::value_as_string(&parsed));
341        }
342
343        let mut current = &parsed;
344
345        for segment in path.split('.') {
346            if let Ok(index) = segment.parse::<usize>() {
347                current = current.get(index)?;
348            } else {
349                current = current.get(segment)?;
350            }
351        }
352
353        Some(Self::value_as_string(current))
354    }
355
356    fn value_as_string(value: &serde_json::Value) -> String {
357        match value {
358            serde_json::Value::String(inner) => inner.clone(),
359            _ => value.to_string(),
360        }
361    }
362
363    fn expr_name(expr: &ToolExpr) -> &'static str {
364        match expr {
365            ToolExpr::Call { .. } => "call",
366            ToolExpr::Sequence { .. } => "sequence",
367            ToolExpr::Parallel { .. } => "parallel",
368            ToolExpr::Choice { .. } => "choice",
369            ToolExpr::Retry { .. } => "retry",
370            ToolExpr::Let { .. } => "let",
371            ToolExpr::Var(_) => "var",
372        }
373    }
374
375    fn default_result(result: impl Into<String>, success: bool) -> ToolResult {
376        ToolResult {
377            success,
378            result: result.into(),
379            display_preference: None,
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use crate::agent::core::tools::Tool;
388    use async_trait::async_trait;
389    use serde_json::json;
390    use std::sync::atomic::{AtomicUsize, Ordering};
391
392    struct EchoArgsTool;
393
394    #[async_trait]
395    impl Tool for EchoArgsTool {
396        fn name(&self) -> &str {
397            "echo_args"
398        }
399
400        fn description(&self) -> &str {
401            "echoes input args"
402        }
403
404        fn parameters_schema(&self) -> serde_json::Value {
405            json!({ "type": "object" })
406        }
407
408        async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
409            Ok(ToolResult {
410                success: true,
411                result: args.to_string(),
412                display_preference: None,
413            })
414        }
415    }
416
417    struct StaticTool {
418        name: &'static str,
419        success: bool,
420        result: &'static str,
421    }
422
423    #[async_trait]
424    impl Tool for StaticTool {
425        fn name(&self) -> &str {
426            self.name
427        }
428
429        fn description(&self) -> &str {
430            "static tool"
431        }
432
433        fn parameters_schema(&self) -> serde_json::Value {
434            json!({ "type": "object" })
435        }
436
437        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
438            Ok(ToolResult {
439                success: self.success,
440                result: self.result.to_string(),
441                display_preference: None,
442            })
443        }
444    }
445
446    struct ErrorTool {
447        name: &'static str,
448    }
449
450    #[async_trait]
451    impl Tool for ErrorTool {
452        fn name(&self) -> &str {
453            self.name
454        }
455
456        fn description(&self) -> &str {
457            "always errors"
458        }
459
460        fn parameters_schema(&self) -> serde_json::Value {
461            json!({ "type": "object" })
462        }
463
464        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
465            Err(ToolError::Execution(format!("{} failed", self.name)))
466        }
467    }
468
469    struct FlakyTool {
470        attempts: Arc<AtomicUsize>,
471        fail_until: usize,
472    }
473
474    #[async_trait]
475    impl Tool for FlakyTool {
476        fn name(&self) -> &str {
477            "flaky"
478        }
479
480        fn description(&self) -> &str {
481            "fails until a threshold"
482        }
483
484        fn parameters_schema(&self) -> serde_json::Value {
485            json!({ "type": "object" })
486        }
487
488        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
489            let attempt = self.attempts.fetch_add(1, Ordering::SeqCst) + 1;
490            if attempt <= self.fail_until {
491                return Err(ToolError::Execution("transient failure".to_string()));
492            }
493
494            Ok(ToolResult {
495                success: true,
496                result: format!("attempt-{attempt}"),
497                display_preference: None,
498            })
499        }
500    }
501
502    fn setup_executor() -> (CompositionExecutor, Arc<AtomicUsize>) {
503        let registry = Arc::new(ToolRegistry::new());
504        let attempts = Arc::new(AtomicUsize::new(0));
505
506        registry.register(EchoArgsTool).unwrap();
507        registry
508            .register(StaticTool {
509                name: "ok",
510                success: true,
511                result: "ok-result",
512            })
513            .unwrap();
514        registry
515            .register(StaticTool {
516                name: "status_ready",
517                success: true,
518                result: r#"{"status":"ready","email":"agent@example.com"}"#,
519            })
520            .unwrap();
521        registry
522            .register(StaticTool {
523                name: "then_branch",
524                success: true,
525                result: "then",
526            })
527            .unwrap();
528        registry
529            .register(StaticTool {
530                name: "else_branch",
531                success: true,
532                result: "else",
533            })
534            .unwrap();
535        registry
536            .register(StaticTool {
537                name: "soft_fail",
538                success: false,
539                result: "not-good",
540            })
541            .unwrap();
542        registry.register(ErrorTool { name: "hard_fail" }).unwrap();
543        registry
544            .register(FlakyTool {
545                attempts: Arc::clone(&attempts),
546                fail_until: 2,
547            })
548            .unwrap();
549
550        (CompositionExecutor::new(registry), attempts)
551    }
552
553    #[tokio::test]
554    async fn executes_call_variant() {
555        let (executor, _) = setup_executor();
556        let mut ctx = ExecutionContext::new();
557
558        let expr = ToolExpr::call("echo_args", json!({ "value": 42 }));
559        let result = executor.execute(&expr, &mut ctx).await.unwrap();
560
561        assert!(result.success);
562        assert_eq!(result.result, r#"{"value":42}"#);
563    }
564
565    #[tokio::test]
566    async fn executes_sequence_with_continue_on_error() {
567        let (executor, _) = setup_executor();
568        let mut ctx = ExecutionContext::new();
569
570        let expr = ToolExpr::sequence_with_fail_fast(
571            vec![
572                ToolExpr::call("hard_fail", json!({})),
573                ToolExpr::call("ok", json!({})),
574            ],
575            false,
576        );
577
578        let result = executor.execute(&expr, &mut ctx).await.unwrap();
579
580        assert!(result.success);
581        assert_eq!(result.result, "ok-result");
582    }
583
584    #[tokio::test]
585    async fn executes_parallel_and_choice_variants() {
586        let (executor, _) = setup_executor();
587        let mut ctx = ExecutionContext::new();
588
589        let parallel = ToolExpr::parallel_with_wait(
590            vec![
591                ToolExpr::call("soft_fail", json!({})),
592                ToolExpr::call("ok", json!({})),
593            ],
594            ParallelWait::Any,
595        );
596
597        let parallel_result = executor.execute(&parallel, &mut ctx).await.unwrap();
598        assert!(parallel_result.success);
599        assert_eq!(parallel_result.result, "ok-result");
600
601        let choice = ToolExpr::sequence(vec![
602            ToolExpr::call("status_ready", json!({})),
603            ToolExpr::choice_with_else(
604                Condition::Contains {
605                    path: "status".to_string(),
606                    value: "ready".to_string(),
607                },
608                ToolExpr::call("then_branch", json!({})),
609                ToolExpr::call("else_branch", json!({})),
610            ),
611        ]);
612
613        let choice_result = executor.execute(&choice, &mut ctx).await.unwrap();
614        assert_eq!(choice_result.result, "then");
615    }
616
617    #[tokio::test]
618    async fn executes_retry_and_let_var_variants() {
619        let (executor, attempts) = setup_executor();
620        let mut ctx = ExecutionContext::new();
621
622        let retry_expr = ToolExpr::retry_with_params(ToolExpr::call("flaky", json!({})), 3, 0);
623        let retry_result = executor.execute(&retry_expr, &mut ctx).await.unwrap();
624        assert!(retry_result.success);
625        assert_eq!(attempts.load(Ordering::SeqCst), 3);
626
627        let let_expr = ToolExpr::let_binding(
628            "saved",
629            ToolExpr::call("ok", json!({})),
630            ToolExpr::var("saved"),
631        );
632        let let_result = executor.execute(&let_expr, &mut ctx).await.unwrap();
633        assert_eq!(let_result.result, "ok-result");
634
635        let missing = ToolExpr::var("missing");
636        let error = executor.execute(&missing, &mut ctx).await.unwrap_err();
637        assert!(matches!(error, ToolError::Execution(_)));
638    }
639
640    #[test]
641    fn evaluates_nested_conditions() {
642        let executor = CompositionExecutor::new(Arc::new(ToolRegistry::new()));
643        let result = ToolResult {
644            success: true,
645            result: r#"{"status":"ready","email":"agent@example.com"}"#.to_string(),
646            display_preference: None,
647        };
648
649        let condition = Condition::And {
650            conditions: vec![
651                Condition::Success,
652                Condition::Or {
653                    conditions: vec![
654                        Condition::Contains {
655                            path: "status".to_string(),
656                            value: "ready".to_string(),
657                        },
658                        Condition::Matches {
659                            path: "email".to_string(),
660                            pattern: ".+@example\\.com".to_string(),
661                        },
662                    ],
663                },
664            ],
665        };
666
667        assert!(executor.evaluate_condition(&condition, &result));
668    }
669}