1use crate::agent::core::tools::{normalize_tool_name, ToolError, ToolRegistry, ToolResult};
7use futures::future::join_all;
8use regex::Regex;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::time::sleep;
13
14use super::condition::Condition;
15use super::context::ExecutionContext;
16use super::expr::ToolExpr;
17use super::parallel::ParallelWait;
18
19pub struct CompositionExecutor {
24 registry: Arc<ToolRegistry>,
26}
27
28type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
29
30impl CompositionExecutor {
31 pub fn new(registry: Arc<ToolRegistry>) -> Self {
33 Self { registry }
34 }
35
36 pub async fn execute(
41 &self,
42 expr: &ToolExpr,
43 ctx: &mut ExecutionContext,
44 ) -> Result<ToolResult, ToolError> {
45 let result = self.execute_internal(expr, ctx).await;
46
47 ctx.log_step(Self::expr_name(expr).to_string(), result.clone());
48
49 if let Ok(value) = &result {
50 ctx.bind("_last".to_string(), value.clone());
51 }
52
53 result
54 }
55
56 fn execute_internal<'a>(
57 &'a self,
58 expr: &'a ToolExpr,
59 ctx: &'a mut ExecutionContext,
60 ) -> BoxFuture<'a, Result<ToolResult, ToolError>> {
61 Box::pin(async move {
62 match expr {
63 ToolExpr::Call { tool, args } => self.execute_call(tool, args).await,
64 ToolExpr::Sequence { steps, fail_fast } => {
65 self.execute_sequence(steps, *fail_fast, ctx).await
66 }
67 ToolExpr::Parallel { branches, wait } => {
68 self.execute_parallel(branches, wait, ctx).await
69 }
70 ToolExpr::Choice {
71 condition,
72 then_branch,
73 else_branch,
74 } => {
75 self.execute_choice(condition, then_branch, else_branch.as_deref(), ctx)
76 .await
77 }
78 ToolExpr::Retry {
79 expr,
80 max_attempts,
81 delay_ms,
82 } => {
83 self.execute_retry(expr, *max_attempts, *delay_ms, ctx)
84 .await
85 }
86 ToolExpr::Let { var, expr, body } => self.execute_let(var, expr, body, ctx).await,
87 ToolExpr::Var(name) => self.execute_var(name, ctx),
88 }
89 })
90 }
91
92 async fn execute_call(
93 &self,
94 tool: &str,
95 args: &serde_json::Value,
96 ) -> Result<ToolResult, ToolError> {
97 let normalized = normalize_tool_name(tool);
98 let tool_impl = self
99 .registry
100 .get(normalized)
101 .ok_or_else(|| ToolError::NotFound(format!("Tool '{}' not found", normalized)))?;
102
103 tool_impl.execute(args.clone()).await
104 }
105
106 async fn execute_sequence(
107 &self,
108 steps: &[ToolExpr],
109 fail_fast: bool,
110 ctx: &mut ExecutionContext,
111 ) -> Result<ToolResult, ToolError> {
112 let mut last_result = Self::default_result("empty sequence", true);
113
114 for step in steps {
115 match self.execute_internal(step, ctx).await {
116 Ok(result) => {
117 ctx.bind("_last".to_string(), result.clone());
118 let should_stop = fail_fast && !result.success;
119 last_result = result;
120
121 if should_stop {
122 return Ok(last_result);
123 }
124 }
125 Err(error) => {
126 if fail_fast {
127 return Err(error);
128 }
129
130 let failure = Self::default_result(error.to_string(), false);
131 ctx.bind("_last".to_string(), failure.clone());
132 last_result = failure;
133 }
134 }
135 }
136
137 Ok(last_result)
138 }
139
140 async fn execute_parallel(
141 &self,
142 branches: &[ToolExpr],
143 wait: &ParallelWait,
144 ctx: &ExecutionContext,
145 ) -> Result<ToolResult, ToolError> {
146 if branches.is_empty() {
147 return Ok(Self::default_result("empty parallel", true));
148 }
149
150 let futures = branches.iter().map(|branch| {
151 let mut branch_ctx = ctx.clone();
152 async move { self.execute_internal(branch, &mut branch_ctx).await }
153 });
154
155 let results = join_all(futures).await;
156
157 match wait {
158 ParallelWait::All => self.resolve_parallel_all(results),
159 ParallelWait::Any => self.resolve_parallel_any(results),
160 ParallelWait::N(target) => self.resolve_parallel_n(results, branches.len(), *target),
161 }
162 }
163
164 fn resolve_parallel_all(
165 &self,
166 results: Vec<Result<ToolResult, ToolError>>,
167 ) -> Result<ToolResult, ToolError> {
168 let mut last_success = None;
169
170 for result in results {
171 match result {
172 Ok(tool_result) => {
173 if !tool_result.success {
174 return Ok(tool_result);
175 }
176 last_success = Some(tool_result);
177 }
178 Err(error) => return Err(error),
179 }
180 }
181
182 Ok(last_success.unwrap_or_else(|| Self::default_result("all branches completed", true)))
183 }
184
185 fn resolve_parallel_any(
186 &self,
187 results: Vec<Result<ToolResult, ToolError>>,
188 ) -> Result<ToolResult, ToolError> {
189 let mut first_failure = None;
190 let mut last_error = None;
191
192 for result in results {
193 match result {
194 Ok(tool_result) if tool_result.success => return Ok(tool_result),
195 Ok(tool_result) => {
196 if first_failure.is_none() {
197 first_failure = Some(tool_result);
198 }
199 }
200 Err(error) => last_error = Some(error),
201 }
202 }
203
204 if let Some(failure) = first_failure {
205 return Ok(failure);
206 }
207
208 Err(last_error
209 .unwrap_or_else(|| ToolError::Execution("no parallel branch succeeded".to_string())))
210 }
211
212 fn resolve_parallel_n(
213 &self,
214 results: Vec<Result<ToolResult, ToolError>>,
215 branch_count: usize,
216 target: usize,
217 ) -> Result<ToolResult, ToolError> {
218 let mut success_count = 0;
219 let mut last_success = None;
220
221 for result in results {
222 match result {
223 Ok(tool_result) => {
224 if tool_result.success {
225 success_count += 1;
226 last_success = Some(tool_result);
227 }
228 }
229 Err(error) => return Err(error),
230 }
231 }
232
233 if success_count >= target {
234 return Ok(last_success
235 .unwrap_or_else(|| Self::default_result("required branches succeeded", true)));
236 }
237
238 Ok(Self::default_result(
239 format!("only {success_count} of {branch_count} branches succeeded; required {target}"),
240 false,
241 ))
242 }
243
244 async fn execute_choice(
245 &self,
246 condition: &Condition,
247 then_branch: &ToolExpr,
248 else_branch: Option<&ToolExpr>,
249 ctx: &mut ExecutionContext,
250 ) -> Result<ToolResult, ToolError> {
251 let last_result = ctx
252 .lookup("_last")
253 .cloned()
254 .unwrap_or_else(|| Self::default_result("{}", true));
255
256 if self.evaluate_condition(condition, &last_result) {
257 self.execute_internal(then_branch, ctx).await
258 } else if let Some(else_expr) = else_branch {
259 self.execute_internal(else_expr, ctx).await
260 } else {
261 Ok(Self::default_result("condition not met", true))
262 }
263 }
264
265 async fn execute_retry(
266 &self,
267 expr: &ToolExpr,
268 max_attempts: u32,
269 delay_ms: u64,
270 ctx: &mut ExecutionContext,
271 ) -> Result<ToolResult, ToolError> {
272 let attempts = max_attempts.max(1);
273 let mut last_error = None;
274
275 for attempt in 0..attempts {
276 match self.execute_internal(expr, ctx).await {
277 Ok(result) if result.success => return Ok(result),
278 Ok(result) => last_error = Some(ToolError::Execution(result.result)),
279 Err(error) => last_error = Some(error),
280 }
281
282 if attempt + 1 < attempts && delay_ms > 0 {
283 sleep(Duration::from_millis(delay_ms)).await;
284 }
285 }
286
287 Err(last_error
288 .unwrap_or_else(|| ToolError::Execution("retry attempts exhausted".to_string())))
289 }
290
291 async fn execute_let(
292 &self,
293 var: &str,
294 expr: &ToolExpr,
295 body: &ToolExpr,
296 ctx: &mut ExecutionContext,
297 ) -> Result<ToolResult, ToolError> {
298 let value = self.execute_internal(expr, ctx).await?;
299 ctx.bind(var.to_string(), value.clone());
300 ctx.bind("_last".to_string(), value);
301 self.execute_internal(body, ctx).await
302 }
303
304 fn execute_var(&self, name: &str, ctx: &ExecutionContext) -> Result<ToolResult, ToolError> {
305 ctx.lookup(name)
306 .cloned()
307 .ok_or_else(|| ToolError::Execution(format!("Variable not found: {}", name)))
308 }
309
310 fn evaluate_condition(&self, condition: &Condition, result: &ToolResult) -> bool {
311 match condition {
312 Condition::Success => result.success,
313 Condition::Contains { path, value } => {
314 Self::extract_value_at_path(&result.result, path)
315 .map(|current| current.contains(value))
316 .unwrap_or(false)
317 }
318 Condition::Matches { path, pattern } => {
319 Self::extract_value_at_path(&result.result, path)
320 .map(|current| {
321 Regex::new(pattern)
322 .map(|regex| regex.is_match(¤t))
323 .unwrap_or(false)
324 })
325 .unwrap_or(false)
326 }
327 Condition::And { conditions } => conditions
328 .iter()
329 .all(|inner| self.evaluate_condition(inner, result)),
330 Condition::Or { conditions } => conditions
331 .iter()
332 .any(|inner| self.evaluate_condition(inner, result)),
333 }
334 }
335
336 fn extract_value_at_path(payload: &str, path: &str) -> Option<String> {
337 let parsed: serde_json::Value = serde_json::from_str(payload).ok()?;
338
339 if path.is_empty() {
340 return Some(Self::value_as_string(&parsed));
341 }
342
343 let mut current = &parsed;
344
345 for segment in path.split('.') {
346 if let Ok(index) = segment.parse::<usize>() {
347 current = current.get(index)?;
348 } else {
349 current = current.get(segment)?;
350 }
351 }
352
353 Some(Self::value_as_string(current))
354 }
355
356 fn value_as_string(value: &serde_json::Value) -> String {
357 match value {
358 serde_json::Value::String(inner) => inner.clone(),
359 _ => value.to_string(),
360 }
361 }
362
363 fn expr_name(expr: &ToolExpr) -> &'static str {
364 match expr {
365 ToolExpr::Call { .. } => "call",
366 ToolExpr::Sequence { .. } => "sequence",
367 ToolExpr::Parallel { .. } => "parallel",
368 ToolExpr::Choice { .. } => "choice",
369 ToolExpr::Retry { .. } => "retry",
370 ToolExpr::Let { .. } => "let",
371 ToolExpr::Var(_) => "var",
372 }
373 }
374
375 fn default_result(result: impl Into<String>, success: bool) -> ToolResult {
376 ToolResult {
377 success,
378 result: result.into(),
379 display_preference: None,
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::agent::core::tools::Tool;
388 use async_trait::async_trait;
389 use serde_json::json;
390 use std::sync::atomic::{AtomicUsize, Ordering};
391
392 struct EchoArgsTool;
393
394 #[async_trait]
395 impl Tool for EchoArgsTool {
396 fn name(&self) -> &str {
397 "echo_args"
398 }
399
400 fn description(&self) -> &str {
401 "echoes input args"
402 }
403
404 fn parameters_schema(&self) -> serde_json::Value {
405 json!({ "type": "object" })
406 }
407
408 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
409 Ok(ToolResult {
410 success: true,
411 result: args.to_string(),
412 display_preference: None,
413 })
414 }
415 }
416
417 struct StaticTool {
418 name: &'static str,
419 success: bool,
420 result: &'static str,
421 }
422
423 #[async_trait]
424 impl Tool for StaticTool {
425 fn name(&self) -> &str {
426 self.name
427 }
428
429 fn description(&self) -> &str {
430 "static tool"
431 }
432
433 fn parameters_schema(&self) -> serde_json::Value {
434 json!({ "type": "object" })
435 }
436
437 async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
438 Ok(ToolResult {
439 success: self.success,
440 result: self.result.to_string(),
441 display_preference: None,
442 })
443 }
444 }
445
446 struct ErrorTool {
447 name: &'static str,
448 }
449
450 #[async_trait]
451 impl Tool for ErrorTool {
452 fn name(&self) -> &str {
453 self.name
454 }
455
456 fn description(&self) -> &str {
457 "always errors"
458 }
459
460 fn parameters_schema(&self) -> serde_json::Value {
461 json!({ "type": "object" })
462 }
463
464 async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
465 Err(ToolError::Execution(format!("{} failed", self.name)))
466 }
467 }
468
469 struct FlakyTool {
470 attempts: Arc<AtomicUsize>,
471 fail_until: usize,
472 }
473
474 #[async_trait]
475 impl Tool for FlakyTool {
476 fn name(&self) -> &str {
477 "flaky"
478 }
479
480 fn description(&self) -> &str {
481 "fails until a threshold"
482 }
483
484 fn parameters_schema(&self) -> serde_json::Value {
485 json!({ "type": "object" })
486 }
487
488 async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
489 let attempt = self.attempts.fetch_add(1, Ordering::SeqCst) + 1;
490 if attempt <= self.fail_until {
491 return Err(ToolError::Execution("transient failure".to_string()));
492 }
493
494 Ok(ToolResult {
495 success: true,
496 result: format!("attempt-{attempt}"),
497 display_preference: None,
498 })
499 }
500 }
501
502 fn setup_executor() -> (CompositionExecutor, Arc<AtomicUsize>) {
503 let registry = Arc::new(ToolRegistry::new());
504 let attempts = Arc::new(AtomicUsize::new(0));
505
506 registry.register(EchoArgsTool).unwrap();
507 registry
508 .register(StaticTool {
509 name: "ok",
510 success: true,
511 result: "ok-result",
512 })
513 .unwrap();
514 registry
515 .register(StaticTool {
516 name: "status_ready",
517 success: true,
518 result: r#"{"status":"ready","email":"agent@example.com"}"#,
519 })
520 .unwrap();
521 registry
522 .register(StaticTool {
523 name: "then_branch",
524 success: true,
525 result: "then",
526 })
527 .unwrap();
528 registry
529 .register(StaticTool {
530 name: "else_branch",
531 success: true,
532 result: "else",
533 })
534 .unwrap();
535 registry
536 .register(StaticTool {
537 name: "soft_fail",
538 success: false,
539 result: "not-good",
540 })
541 .unwrap();
542 registry.register(ErrorTool { name: "hard_fail" }).unwrap();
543 registry
544 .register(FlakyTool {
545 attempts: Arc::clone(&attempts),
546 fail_until: 2,
547 })
548 .unwrap();
549
550 (CompositionExecutor::new(registry), attempts)
551 }
552
553 #[tokio::test]
554 async fn executes_call_variant() {
555 let (executor, _) = setup_executor();
556 let mut ctx = ExecutionContext::new();
557
558 let expr = ToolExpr::call("echo_args", json!({ "value": 42 }));
559 let result = executor.execute(&expr, &mut ctx).await.unwrap();
560
561 assert!(result.success);
562 assert_eq!(result.result, r#"{"value":42}"#);
563 }
564
565 #[tokio::test]
566 async fn executes_sequence_with_continue_on_error() {
567 let (executor, _) = setup_executor();
568 let mut ctx = ExecutionContext::new();
569
570 let expr = ToolExpr::sequence_with_fail_fast(
571 vec![
572 ToolExpr::call("hard_fail", json!({})),
573 ToolExpr::call("ok", json!({})),
574 ],
575 false,
576 );
577
578 let result = executor.execute(&expr, &mut ctx).await.unwrap();
579
580 assert!(result.success);
581 assert_eq!(result.result, "ok-result");
582 }
583
584 #[tokio::test]
585 async fn executes_parallel_and_choice_variants() {
586 let (executor, _) = setup_executor();
587 let mut ctx = ExecutionContext::new();
588
589 let parallel = ToolExpr::parallel_with_wait(
590 vec![
591 ToolExpr::call("soft_fail", json!({})),
592 ToolExpr::call("ok", json!({})),
593 ],
594 ParallelWait::Any,
595 );
596
597 let parallel_result = executor.execute(¶llel, &mut ctx).await.unwrap();
598 assert!(parallel_result.success);
599 assert_eq!(parallel_result.result, "ok-result");
600
601 let choice = ToolExpr::sequence(vec![
602 ToolExpr::call("status_ready", json!({})),
603 ToolExpr::choice_with_else(
604 Condition::Contains {
605 path: "status".to_string(),
606 value: "ready".to_string(),
607 },
608 ToolExpr::call("then_branch", json!({})),
609 ToolExpr::call("else_branch", json!({})),
610 ),
611 ]);
612
613 let choice_result = executor.execute(&choice, &mut ctx).await.unwrap();
614 assert_eq!(choice_result.result, "then");
615 }
616
617 #[tokio::test]
618 async fn executes_retry_and_let_var_variants() {
619 let (executor, attempts) = setup_executor();
620 let mut ctx = ExecutionContext::new();
621
622 let retry_expr = ToolExpr::retry_with_params(ToolExpr::call("flaky", json!({})), 3, 0);
623 let retry_result = executor.execute(&retry_expr, &mut ctx).await.unwrap();
624 assert!(retry_result.success);
625 assert_eq!(attempts.load(Ordering::SeqCst), 3);
626
627 let let_expr = ToolExpr::let_binding(
628 "saved",
629 ToolExpr::call("ok", json!({})),
630 ToolExpr::var("saved"),
631 );
632 let let_result = executor.execute(&let_expr, &mut ctx).await.unwrap();
633 assert_eq!(let_result.result, "ok-result");
634
635 let missing = ToolExpr::var("missing");
636 let error = executor.execute(&missing, &mut ctx).await.unwrap_err();
637 assert!(matches!(error, ToolError::Execution(_)));
638 }
639
640 #[test]
641 fn evaluates_nested_conditions() {
642 let executor = CompositionExecutor::new(Arc::new(ToolRegistry::new()));
643 let result = ToolResult {
644 success: true,
645 result: r#"{"status":"ready","email":"agent@example.com"}"#.to_string(),
646 display_preference: None,
647 };
648
649 let condition = Condition::And {
650 conditions: vec![
651 Condition::Success,
652 Condition::Or {
653 conditions: vec![
654 Condition::Contains {
655 path: "status".to_string(),
656 value: "ready".to_string(),
657 },
658 Condition::Matches {
659 path: "email".to_string(),
660 pattern: ".+@example\\.com".to_string(),
661 },
662 ],
663 },
664 ],
665 };
666
667 assert!(executor.evaluate_condition(&condition, &result));
668 }
669}