Skip to main content

awaken_runtime/execution/
executor.rs

1//! Tool execution strategies: Sequential and Parallel.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use awaken_contract::contract::message::ToolCall;
8use awaken_contract::contract::suspension::ToolCallOutcome;
9use awaken_contract::contract::tool::{Tool, ToolCallContext, ToolOutput, ToolResult};
10use awaken_contract::state::StateCommand;
11
12/// Result of executing a single tool call.
13pub struct ToolExecutionResult {
14    pub call: ToolCall,
15    pub result: ToolResult,
16    pub outcome: ToolCallOutcome,
17    /// Side-effects produced by the tool (state mutations, scheduled actions, effects).
18    pub command: StateCommand,
19}
20
21/// Error from tool execution strategy.
22#[derive(Debug, thiserror::Error)]
23pub enum ToolExecutorError {
24    #[error("tool execution cancelled")]
25    Cancelled,
26    #[error("tool execution failed: {0}")]
27    Failed(String),
28}
29
30/// Strategy abstraction for tool execution.
31#[async_trait]
32pub trait ToolExecutor: Send + Sync {
33    /// Execute tool calls and return results.
34    async fn execute(
35        &self,
36        tools: &HashMap<String, Arc<dyn Tool>>,
37        calls: &[ToolCall],
38        base_ctx: &ToolCallContext,
39    ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError>;
40
41    /// Strategy name for logging.
42    fn name(&self) -> &'static str;
43
44    /// Whether the executor needs state refreshed between individual tool calls.
45    /// Sequential executors return true; parallel executors return false.
46    fn requires_incremental_state(&self) -> bool {
47        false
48    }
49}
50
51/// Execute tool calls one by one in call order.
52/// Context freshness between calls is controlled by the caller.
53/// Stops at first suspension.
54#[derive(Debug, Clone, Copy, Default)]
55pub struct SequentialToolExecutor;
56
57#[async_trait]
58impl ToolExecutor for SequentialToolExecutor {
59    async fn execute(
60        &self,
61        tools: &HashMap<String, Arc<dyn Tool>>,
62        calls: &[ToolCall],
63        base_ctx: &ToolCallContext,
64    ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
65        let mut results = Vec::with_capacity(calls.len());
66
67        for call in calls {
68            let mut ctx = base_ctx.clone();
69            ctx.call_id = call.id.clone();
70            ctx.tool_name = call.name.clone();
71            let output = execute_single_tool(tools, call, &ctx).await;
72            let outcome = ToolCallOutcome::from_tool_result(&output.result);
73
74            results.push(ToolExecutionResult {
75                call: call.clone(),
76                result: output.result,
77                outcome,
78                command: output.command,
79            });
80
81            // Stop at first suspension
82            if results
83                .last()
84                .is_some_and(|r| r.outcome == ToolCallOutcome::Suspended)
85            {
86                break;
87            }
88        }
89
90        Ok(results)
91    }
92
93    fn name(&self) -> &'static str {
94        "sequential"
95    }
96
97    fn requires_incremental_state(&self) -> bool {
98        true
99    }
100}
101
102/// Policy controlling when resume decisions are replayed into tool execution.
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum DecisionReplayPolicy {
105    /// Replay each resolved suspended call as soon as its decision arrives.
106    Immediate,
107    /// Replay only when all currently suspended calls have decisions.
108    BatchAllSuspended,
109}
110
111/// Parallel execution mode.
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum ParallelMode {
114    /// All concurrently, batch approval gate — wait for all decisions.
115    BatchApproval,
116    /// All concurrently, streaming results — replay decisions immediately.
117    Streaming,
118}
119
120/// Execute all tool calls concurrently. All tools see the same frozen snapshot.
121///
122/// Two modes differ in how suspension decisions are replayed:
123/// - `BatchApproval`: wait for all suspended calls to have decisions before replay
124/// - `Streaming`: replay each decision immediately as it arrives
125#[derive(Debug, Clone, Copy)]
126pub struct ParallelToolExecutor {
127    mode: ParallelMode,
128}
129
130impl ParallelToolExecutor {
131    pub const fn batch_approval() -> Self {
132        Self {
133            mode: ParallelMode::BatchApproval,
134        }
135    }
136
137    pub const fn streaming() -> Self {
138        Self {
139            mode: ParallelMode::Streaming,
140        }
141    }
142
143    /// How the runtime should replay resolved suspend decisions.
144    pub fn decision_replay_policy(&self) -> DecisionReplayPolicy {
145        match self.mode {
146            ParallelMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
147            ParallelMode::Streaming => DecisionReplayPolicy::Immediate,
148        }
149    }
150
151    /// Whether the runtime should enforce parallel patch conflict checks.
152    pub fn requires_conflict_check(&self) -> bool {
153        true
154    }
155}
156
157impl Default for ParallelToolExecutor {
158    fn default() -> Self {
159        Self::streaming()
160    }
161}
162
163#[async_trait]
164impl ToolExecutor for ParallelToolExecutor {
165    async fn execute(
166        &self,
167        tools: &HashMap<String, Arc<dyn Tool>>,
168        calls: &[ToolCall],
169        base_ctx: &ToolCallContext,
170    ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
171        use futures::future::join_all;
172
173        let futures: Vec<_> = calls
174            .iter()
175            .map(|call| {
176                let tools = tools.clone();
177                let call = call.clone();
178                let mut ctx = base_ctx.clone();
179                ctx.call_id = call.id.clone();
180                ctx.tool_name = call.name.clone();
181                async move {
182                    let output = execute_single_tool(&tools, &call, &ctx).await;
183                    let outcome = ToolCallOutcome::from_tool_result(&output.result);
184                    ToolExecutionResult {
185                        call,
186                        result: output.result,
187                        outcome,
188                        command: output.command,
189                    }
190                }
191            })
192            .collect();
193
194        Ok(join_all(futures).await)
195    }
196
197    fn name(&self) -> &'static str {
198        match self.mode {
199            ParallelMode::BatchApproval => "parallel_batch_approval",
200            ParallelMode::Streaming => "parallel_streaming",
201        }
202    }
203}
204
205/// Execute a single tool call, never panicking.
206pub(crate) async fn execute_single_tool(
207    tools: &HashMap<String, Arc<dyn Tool>>,
208    call: &ToolCall,
209    ctx: &ToolCallContext,
210) -> ToolOutput {
211    let Some(tool) = tools.get(&call.name) else {
212        return ToolResult::error(&call.name, format!("tool '{}' not found", call.name)).into();
213    };
214
215    if let Err(e) = tool.validate_args(&call.arguments) {
216        return ToolResult::error(&call.name, e.to_string()).into();
217    }
218
219    match tool.execute(call.arguments.clone(), ctx).await {
220        Ok(output) => output,
221        Err(e) => ToolResult::error(&call.name, e.to_string()).into(),
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use awaken_contract::contract::tool::{ToolDescriptor, ToolError, ToolOutput};
229    use serde_json::{Value, json};
230
231    struct EchoTool;
232
233    #[async_trait]
234    impl Tool for EchoTool {
235        fn descriptor(&self) -> ToolDescriptor {
236            ToolDescriptor::new("echo", "echo", "Echoes input")
237        }
238
239        async fn execute(
240            &self,
241            args: Value,
242            _ctx: &ToolCallContext,
243        ) -> Result<ToolOutput, ToolError> {
244            let msg = args
245                .get("message")
246                .and_then(|v| v.as_str())
247                .unwrap_or("no message")
248                .to_string();
249            Ok(ToolResult::success_with_message("echo", args, msg).into())
250        }
251    }
252
253    struct FailingTool;
254
255    #[async_trait]
256    impl Tool for FailingTool {
257        fn descriptor(&self) -> ToolDescriptor {
258            ToolDescriptor::new("failing", "failing", "Always fails")
259        }
260
261        async fn execute(
262            &self,
263            _args: Value,
264            _ctx: &ToolCallContext,
265        ) -> Result<ToolOutput, ToolError> {
266            Err(ToolError::ExecutionFailed("intentional failure".into()))
267        }
268    }
269
270    struct SuspendingTool;
271
272    #[async_trait]
273    impl Tool for SuspendingTool {
274        fn descriptor(&self) -> ToolDescriptor {
275            ToolDescriptor::new("suspending", "suspending", "Returns pending")
276        }
277
278        async fn execute(
279            &self,
280            _args: Value,
281            _ctx: &ToolCallContext,
282        ) -> Result<ToolOutput, ToolError> {
283            Ok(ToolResult::suspended("suspending", "needs approval").into())
284        }
285    }
286
287    fn tool_map(tools: Vec<Arc<dyn Tool>>) -> HashMap<String, Arc<dyn Tool>> {
288        tools.into_iter().map(|t| (t.descriptor().id, t)).collect()
289    }
290
291    // -- Sequential tests --
292
293    #[tokio::test]
294    async fn sequential_single_tool_success() {
295        let tools = tool_map(vec![Arc::new(EchoTool)]);
296        let calls = vec![ToolCall::new("c1", "echo", json!({"message": "hi"}))];
297        let executor = SequentialToolExecutor;
298
299        let results = executor
300            .execute(&tools, &calls, &ToolCallContext::test_default())
301            .await
302            .unwrap();
303        assert_eq!(results.len(), 1);
304        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
305        assert!(results[0].result.is_success());
306    }
307
308    #[tokio::test]
309    async fn sequential_partial_failure() {
310        let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
311        let calls = vec![
312            ToolCall::new("c1", "echo", json!({"message": "ok"})),
313            ToolCall::new("c2", "failing", json!({})),
314        ];
315        let executor = SequentialToolExecutor;
316
317        let results = executor
318            .execute(&tools, &calls, &ToolCallContext::test_default())
319            .await
320            .unwrap();
321        assert_eq!(results.len(), 2);
322        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
323        assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
324    }
325
326    #[tokio::test]
327    async fn sequential_stops_after_first_suspension() {
328        let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
329        let calls = vec![
330            ToolCall::new("c1", "suspending", json!({})),
331            ToolCall::new("c2", "echo", json!({"message": "should not run"})),
332        ];
333        let executor = SequentialToolExecutor;
334
335        let results = executor
336            .execute(&tools, &calls, &ToolCallContext::test_default())
337            .await
338            .unwrap();
339        assert_eq!(results.len(), 1, "should stop after suspended tool");
340        assert_eq!(results[0].outcome, ToolCallOutcome::Suspended);
341    }
342
343    #[tokio::test]
344    async fn sequential_unknown_tool_returns_error() {
345        let tools = tool_map(vec![]);
346        let calls = vec![ToolCall::new("c1", "nonexistent", json!({}))];
347        let executor = SequentialToolExecutor;
348
349        let results = executor
350            .execute(&tools, &calls, &ToolCallContext::test_default())
351            .await
352            .unwrap();
353        assert_eq!(results.len(), 1);
354        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
355        assert!(results[0].result.is_error());
356    }
357
358    #[tokio::test]
359    async fn sequential_empty_calls() {
360        let tools = tool_map(vec![Arc::new(EchoTool)]);
361        let executor = SequentialToolExecutor;
362
363        let results = executor
364            .execute(&tools, &[], &ToolCallContext::test_default())
365            .await
366            .unwrap();
367        assert!(results.is_empty());
368    }
369
370    // -- Parallel tests --
371
372    #[tokio::test]
373    async fn parallel_all_succeed() {
374        let tools = tool_map(vec![Arc::new(EchoTool)]);
375        let calls = vec![
376            ToolCall::new("c1", "echo", json!({"message": "first"})),
377            ToolCall::new("c2", "echo", json!({"message": "second"})),
378        ];
379        let executor = ParallelToolExecutor::streaming();
380
381        let results = executor
382            .execute(&tools, &calls, &ToolCallContext::test_default())
383            .await
384            .unwrap();
385        assert_eq!(results.len(), 2);
386        assert!(
387            results
388                .iter()
389                .all(|r| r.outcome == ToolCallOutcome::Succeeded)
390        );
391    }
392
393    #[tokio::test]
394    async fn parallel_partial_failure() {
395        let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
396        let calls = vec![
397            ToolCall::new("c1", "echo", json!({"message": "ok"})),
398            ToolCall::new("c2", "failing", json!({})),
399        ];
400        let executor = ParallelToolExecutor::streaming();
401
402        let results = executor
403            .execute(&tools, &calls, &ToolCallContext::test_default())
404            .await
405            .unwrap();
406        assert_eq!(results.len(), 2);
407        let successes = results
408            .iter()
409            .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
410            .count();
411        let failures = results
412            .iter()
413            .filter(|r| r.outcome == ToolCallOutcome::Failed)
414            .count();
415        assert_eq!(successes, 1);
416        assert_eq!(failures, 1);
417    }
418
419    #[tokio::test]
420    async fn parallel_does_not_stop_on_suspension() {
421        let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
422        let calls = vec![
423            ToolCall::new("c1", "suspending", json!({})),
424            ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
425        ];
426        let executor = ParallelToolExecutor::streaming();
427
428        let results = executor
429            .execute(&tools, &calls, &ToolCallContext::test_default())
430            .await
431            .unwrap();
432        // Parallel executes ALL tools regardless of suspension
433        assert_eq!(results.len(), 2);
434        let suspended = results
435            .iter()
436            .filter(|r| r.outcome == ToolCallOutcome::Suspended)
437            .count();
438        let succeeded = results
439            .iter()
440            .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
441            .count();
442        assert_eq!(suspended, 1);
443        assert_eq!(succeeded, 1);
444    }
445
446    #[tokio::test]
447    async fn parallel_empty_calls() {
448        let tools = tool_map(vec![Arc::new(EchoTool)]);
449        let executor = ParallelToolExecutor::streaming();
450
451        let results = executor
452            .execute(&tools, &[], &ToolCallContext::test_default())
453            .await
454            .unwrap();
455        assert!(results.is_empty());
456    }
457
458    #[test]
459    fn executor_names() {
460        assert_eq!(SequentialToolExecutor.name(), "sequential");
461        assert_eq!(
462            ParallelToolExecutor::streaming().name(),
463            "parallel_streaming"
464        );
465        assert_eq!(
466            ParallelToolExecutor::batch_approval().name(),
467            "parallel_batch_approval"
468        );
469    }
470
471    #[test]
472    fn parallel_default_is_streaming() {
473        let executor = ParallelToolExecutor::default();
474        assert_eq!(executor.name(), "parallel_streaming");
475        assert_eq!(
476            executor.decision_replay_policy(),
477            DecisionReplayPolicy::Immediate
478        );
479    }
480
481    #[test]
482    fn parallel_batch_approval_policy() {
483        let executor = ParallelToolExecutor::batch_approval();
484        assert_eq!(
485            executor.decision_replay_policy(),
486            DecisionReplayPolicy::BatchAllSuspended
487        );
488        assert!(executor.requires_conflict_check());
489    }
490
491    #[test]
492    fn parallel_streaming_policy() {
493        let executor = ParallelToolExecutor::streaming();
494        assert_eq!(
495            executor.decision_replay_policy(),
496            DecisionReplayPolicy::Immediate
497        );
498        assert!(executor.requires_conflict_check());
499    }
500
501    #[tokio::test]
502    async fn batch_approval_executes_all_concurrently() {
503        let tools = tool_map(vec![Arc::new(EchoTool)]);
504        let calls = vec![
505            ToolCall::new("c1", "echo", json!({"message": "a"})),
506            ToolCall::new("c2", "echo", json!({"message": "b"})),
507        ];
508        let executor = ParallelToolExecutor::batch_approval();
509
510        let results = executor
511            .execute(&tools, &calls, &ToolCallContext::test_default())
512            .await
513            .unwrap();
514        assert_eq!(results.len(), 2);
515        assert!(
516            results
517                .iter()
518                .all(|r| r.outcome == ToolCallOutcome::Succeeded)
519        );
520    }
521
522    #[tokio::test]
523    async fn batch_approval_does_not_stop_on_suspension() {
524        let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
525        let calls = vec![
526            ToolCall::new("c1", "suspending", json!({})),
527            ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
528        ];
529        let executor = ParallelToolExecutor::batch_approval();
530
531        let results = executor
532            .execute(&tools, &calls, &ToolCallContext::test_default())
533            .await
534            .unwrap();
535        assert_eq!(results.len(), 2);
536    }
537
538    // -----------------------------------------------------------------------
539    // Migrated from uncarve: additional tool executor tests
540    // -----------------------------------------------------------------------
541
542    /// A tool that counts how many times it's been called.
543    struct CountingTool {
544        count: Arc<std::sync::atomic::AtomicUsize>,
545    }
546
547    impl CountingTool {
548        fn new() -> Self {
549            Self {
550                count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
551            }
552        }
553
554        fn call_count(&self) -> usize {
555            self.count.load(std::sync::atomic::Ordering::SeqCst)
556        }
557    }
558
559    #[async_trait]
560    impl Tool for CountingTool {
561        fn descriptor(&self) -> ToolDescriptor {
562            ToolDescriptor::new("counting", "counting", "Counts calls")
563        }
564
565        async fn execute(
566            &self,
567            _args: Value,
568            _ctx: &ToolCallContext,
569        ) -> Result<ToolOutput, ToolError> {
570            let n = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
571            Ok(ToolResult::success("counting", json!({"call_number": n + 1})).into())
572        }
573    }
574
575    #[tokio::test]
576    async fn sequential_multiple_calls_ordered() {
577        let counting = Arc::new(CountingTool::new());
578        let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
579        let calls = vec![
580            ToolCall::new("c1", "counting", json!({})),
581            ToolCall::new("c2", "counting", json!({})),
582            ToolCall::new("c3", "counting", json!({})),
583        ];
584        let executor = SequentialToolExecutor;
585
586        let results = executor
587            .execute(&tools, &calls, &ToolCallContext::test_default())
588            .await
589            .unwrap();
590        assert_eq!(results.len(), 3);
591        assert_eq!(counting.call_count(), 3);
592        // Verify order is preserved
593        for (i, result) in results.iter().enumerate() {
594            assert_eq!(result.call.id, format!("c{}", i + 1));
595            assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
596        }
597    }
598
599    #[tokio::test]
600    async fn sequential_failure_does_not_stop_execution() {
601        // Unlike suspension, failures do NOT stop sequential execution
602        let tools = tool_map(vec![Arc::new(FailingTool), Arc::new(EchoTool)]);
603        let calls = vec![
604            ToolCall::new("c1", "failing", json!({})),
605            ToolCall::new("c2", "echo", json!({"message": "still runs"})),
606        ];
607        let executor = SequentialToolExecutor;
608
609        let results = executor
610            .execute(&tools, &calls, &ToolCallContext::test_default())
611            .await
612            .unwrap();
613        assert_eq!(results.len(), 2);
614        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
615        assert_eq!(results[1].outcome, ToolCallOutcome::Succeeded);
616    }
617
618    #[tokio::test]
619    async fn sequential_suspension_in_middle_stops_remaining() {
620        let tools = tool_map(vec![
621            Arc::new(EchoTool),
622            Arc::new(SuspendingTool),
623            Arc::new(EchoTool),
624        ]);
625        let calls = vec![
626            ToolCall::new("c1", "echo", json!({"message": "first"})),
627            ToolCall::new("c2", "suspending", json!({})),
628            ToolCall::new("c3", "echo", json!({"message": "should not run"})),
629        ];
630        let executor = SequentialToolExecutor;
631
632        let results = executor
633            .execute(&tools, &calls, &ToolCallContext::test_default())
634            .await
635            .unwrap();
636        assert_eq!(results.len(), 2, "should stop after suspension");
637        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
638        assert_eq!(results[1].outcome, ToolCallOutcome::Suspended);
639    }
640
641    #[tokio::test]
642    async fn parallel_all_fail() {
643        let tools = tool_map(vec![Arc::new(FailingTool)]);
644        let calls = vec![
645            ToolCall::new("c1", "failing", json!({})),
646            ToolCall::new("c2", "failing", json!({})),
647        ];
648        let executor = ParallelToolExecutor::streaming();
649
650        let results = executor
651            .execute(&tools, &calls, &ToolCallContext::test_default())
652            .await
653            .unwrap();
654        assert_eq!(results.len(), 2);
655        assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
656    }
657
658    #[tokio::test]
659    async fn parallel_unknown_tool_returns_error() {
660        let tools = tool_map(vec![]);
661        let calls = vec![
662            ToolCall::new("c1", "nonexistent_a", json!({})),
663            ToolCall::new("c2", "nonexistent_b", json!({})),
664        ];
665        let executor = ParallelToolExecutor::streaming();
666
667        let results = executor
668            .execute(&tools, &calls, &ToolCallContext::test_default())
669            .await
670            .unwrap();
671        assert_eq!(results.len(), 2);
672        assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
673        for r in &results {
674            assert!(r.result.is_error());
675        }
676    }
677
678    #[tokio::test]
679    async fn parallel_counting_tool_all_called() {
680        let counting = Arc::new(CountingTool::new());
681        let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
682        let calls = vec![
683            ToolCall::new("c1", "counting", json!({})),
684            ToolCall::new("c2", "counting", json!({})),
685            ToolCall::new("c3", "counting", json!({})),
686        ];
687        let executor = ParallelToolExecutor::streaming();
688
689        let results = executor
690            .execute(&tools, &calls, &ToolCallContext::test_default())
691            .await
692            .unwrap();
693        assert_eq!(results.len(), 3);
694        assert_eq!(counting.call_count(), 3);
695    }
696
697    /// Validate that tool args are validated before execution.
698    struct StrictArgsTool;
699
700    #[async_trait]
701    impl Tool for StrictArgsTool {
702        fn descriptor(&self) -> ToolDescriptor {
703            ToolDescriptor::new("strict", "strict", "Validates args")
704        }
705
706        fn validate_args(&self, args: &Value) -> Result<(), ToolError> {
707            if args.get("required_field").is_none() {
708                return Err(ToolError::InvalidArguments("missing required_field".into()));
709            }
710            Ok(())
711        }
712
713        async fn execute(
714            &self,
715            args: Value,
716            _ctx: &ToolCallContext,
717        ) -> Result<ToolOutput, ToolError> {
718            Ok(ToolResult::success("strict", args).into())
719        }
720    }
721
722    #[tokio::test]
723    async fn sequential_validates_args_before_execute() {
724        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
725        let calls = vec![ToolCall::new("c1", "strict", json!({}))]; // missing required_field
726        let executor = SequentialToolExecutor;
727
728        let results = executor
729            .execute(&tools, &calls, &ToolCallContext::test_default())
730            .await
731            .unwrap();
732        assert_eq!(results.len(), 1);
733        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
734        assert!(results[0].result.is_error());
735    }
736
737    #[tokio::test]
738    async fn sequential_valid_args_succeeds() {
739        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
740        let calls = vec![ToolCall::new(
741            "c1",
742            "strict",
743            json!({"required_field": "present"}),
744        )];
745        let executor = SequentialToolExecutor;
746
747        let results = executor
748            .execute(&tools, &calls, &ToolCallContext::test_default())
749            .await
750            .unwrap();
751        assert_eq!(results.len(), 1);
752        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
753    }
754
755    #[tokio::test]
756    async fn parallel_validates_args_before_execute() {
757        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
758        let calls = vec![ToolCall::new("c1", "strict", json!({}))];
759        let executor = ParallelToolExecutor::streaming();
760
761        let results = executor
762            .execute(&tools, &calls, &ToolCallContext::test_default())
763            .await
764            .unwrap();
765        assert_eq!(results.len(), 1);
766        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
767    }
768
769    #[test]
770    fn tool_execution_result_fields() {
771        let result = ToolExecutionResult {
772            call: ToolCall::new("c1", "echo", json!({})),
773            result: ToolResult::success("echo", json!({"ok": true})),
774            outcome: ToolCallOutcome::Succeeded,
775            command: StateCommand::default(),
776        };
777        assert_eq!(result.call.id, "c1");
778        assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
779    }
780
781    #[test]
782    fn tool_executor_error_display() {
783        let err = ToolExecutorError::Cancelled;
784        assert!(err.to_string().contains("cancelled"));
785        let err2 = ToolExecutorError::Failed("some reason".into());
786        assert!(err2.to_string().contains("some reason"));
787    }
788
789    // -----------------------------------------------------------------------
790    // Additional coverage: context preservation, mixed scenarios, edge cases
791    // -----------------------------------------------------------------------
792
793    /// Tool that captures the context it receives for later inspection.
794    struct ContextCaptureTool {
795        captured_call_id: Arc<std::sync::Mutex<String>>,
796        captured_tool_name: Arc<std::sync::Mutex<String>>,
797    }
798
799    impl ContextCaptureTool {
800        fn new() -> Self {
801            Self {
802                captured_call_id: Arc::new(std::sync::Mutex::new(String::new())),
803                captured_tool_name: Arc::new(std::sync::Mutex::new(String::new())),
804            }
805        }
806    }
807
808    #[async_trait]
809    impl Tool for ContextCaptureTool {
810        fn descriptor(&self) -> ToolDescriptor {
811            ToolDescriptor::new("capture", "capture", "Captures context")
812        }
813
814        async fn execute(
815            &self,
816            _args: Value,
817            ctx: &ToolCallContext,
818        ) -> Result<ToolOutput, ToolError> {
819            *self.captured_call_id.lock().unwrap() = ctx.call_id.clone();
820            *self.captured_tool_name.lock().unwrap() = ctx.tool_name.clone();
821            Ok(ToolResult::success("capture", json!({"captured": true})).into())
822        }
823    }
824
825    #[tokio::test]
826    async fn execute_single_tool_preserves_call_context() {
827        let capture = Arc::new(ContextCaptureTool::new());
828        let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
829        let call = ToolCall::new("call-42", "capture", json!({}));
830        let ctx = ToolCallContext::test_default();
831
832        let output = execute_single_tool(&tools, &call, &ctx).await;
833        assert!(output.result.is_success());
834        // execute_single_tool sets call_id and tool_name from the call's context
835        // (the caller is responsible for setting ctx fields, which the executor does)
836    }
837
838    #[tokio::test]
839    async fn execute_single_tool_missing_returns_error_with_name() {
840        let tools: HashMap<String, Arc<dyn Tool>> = HashMap::new();
841        let call = ToolCall::new("c1", "ghost_tool", json!({}));
842        let ctx = ToolCallContext::test_default();
843
844        let output = execute_single_tool(&tools, &call, &ctx).await;
845        assert!(output.result.is_error());
846        assert!(
847            output
848                .result
849                .message
850                .as_deref()
851                .unwrap_or("")
852                .contains("ghost_tool")
853        );
854    }
855
856    #[tokio::test]
857    async fn execute_single_tool_validates_args() {
858        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
859        let call = ToolCall::new("c1", "strict", json!({"wrong": "field"}));
860        let ctx = ToolCallContext::test_default();
861
862        let output = execute_single_tool(&tools, &call, &ctx).await;
863        assert!(output.result.is_error());
864    }
865
866    #[tokio::test]
867    async fn sequential_context_call_id_set_per_tool() {
868        let capture = Arc::new(ContextCaptureTool::new());
869        let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
870        let calls = vec![ToolCall::new("unique-id-99", "capture", json!({}))];
871        let executor = SequentialToolExecutor;
872
873        let results = executor
874            .execute(&tools, &calls, &ToolCallContext::test_default())
875            .await
876            .unwrap();
877        assert_eq!(results.len(), 1);
878        assert_eq!(results[0].call.id, "unique-id-99");
879        // The executor sets ctx.call_id = call.id before passing to execute_single_tool
880        assert_eq!(*capture.captured_call_id.lock().unwrap(), "unique-id-99");
881    }
882
883    #[tokio::test]
884    async fn sequential_mixed_success_failure_suspension_order() {
885        let tools = tool_map(vec![
886            Arc::new(EchoTool),
887            Arc::new(FailingTool),
888            Arc::new(SuspendingTool),
889        ]);
890        // success, fail, suspend — suspension stops execution
891        let calls = vec![
892            ToolCall::new("c1", "echo", json!({"message": "hi"})),
893            ToolCall::new("c2", "failing", json!({})),
894            ToolCall::new("c3", "suspending", json!({})),
895            ToolCall::new("c4", "echo", json!({"message": "should not run"})),
896        ];
897        let executor = SequentialToolExecutor;
898
899        let results = executor
900            .execute(&tools, &calls, &ToolCallContext::test_default())
901            .await
902            .unwrap();
903        assert_eq!(results.len(), 3, "stops after suspension at c3");
904        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
905        assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
906        assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
907    }
908
909    #[tokio::test]
910    async fn parallel_preserves_result_order() {
911        let counting = Arc::new(CountingTool::new());
912        let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
913        let calls: Vec<_> = (0..5)
914            .map(|i| ToolCall::new(format!("c{i}"), "counting", json!({})))
915            .collect();
916        let executor = ParallelToolExecutor::streaming();
917
918        let results = executor
919            .execute(&tools, &calls, &ToolCallContext::test_default())
920            .await
921            .unwrap();
922        assert_eq!(results.len(), 5);
923        // join_all preserves order of futures
924        for (i, r) in results.iter().enumerate() {
925            assert_eq!(r.call.id, format!("c{i}"));
926        }
927    }
928
929    #[tokio::test]
930    async fn parallel_mixed_success_failure_suspension() {
931        let tools = tool_map(vec![
932            Arc::new(EchoTool),
933            Arc::new(FailingTool),
934            Arc::new(SuspendingTool),
935        ]);
936        let calls = vec![
937            ToolCall::new("c1", "echo", json!({"message": "hi"})),
938            ToolCall::new("c2", "failing", json!({})),
939            ToolCall::new("c3", "suspending", json!({})),
940        ];
941        let executor = ParallelToolExecutor::batch_approval();
942
943        let results = executor
944            .execute(&tools, &calls, &ToolCallContext::test_default())
945            .await
946            .unwrap();
947        assert_eq!(results.len(), 3, "parallel runs all regardless");
948        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
949        assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
950        assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
951    }
952
953    #[test]
954    fn sequential_requires_incremental_state() {
955        let executor = SequentialToolExecutor;
956        assert!(executor.requires_incremental_state());
957    }
958
959    #[test]
960    fn parallel_does_not_require_incremental_state() {
961        let executor = ParallelToolExecutor::streaming();
962        assert!(!executor.requires_incremental_state());
963        let batch = ParallelToolExecutor::batch_approval();
964        assert!(!batch.requires_incremental_state());
965    }
966
967    #[tokio::test]
968    async fn execute_single_tool_success_returns_correct_tool_name() {
969        let tools = tool_map(vec![Arc::new(EchoTool)]);
970        let call = ToolCall::new("c1", "echo", json!({"message": "test"}));
971        let ctx = ToolCallContext::test_default();
972
973        let output = execute_single_tool(&tools, &call, &ctx).await;
974        assert!(output.result.is_success());
975        assert_eq!(output.result.tool_name, "echo");
976    }
977}