Skip to main content

awaken_runtime/execution/
executor.rs

1//! Tool execution strategies: Sequential and Parallel.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use awaken_contract::contract::message::ToolCall;
8use awaken_contract::contract::suspension::ToolCallOutcome;
9use awaken_contract::contract::tool::{Tool, ToolCallContext, ToolOutput, ToolResult};
10use awaken_contract::state::StateCommand;
11
12#[cfg(feature = "background")]
13use crate::extensions::background::{ToolLineageContext, scope_tool_lineage_context};
14
15/// Result of executing a single tool call.
16pub struct ToolExecutionResult {
17    pub call: ToolCall,
18    pub result: ToolResult,
19    pub outcome: ToolCallOutcome,
20    /// Side-effects produced by the tool (state mutations, scheduled actions, effects).
21    pub command: StateCommand,
22}
23
24/// Error from tool execution strategy.
25#[derive(Debug, thiserror::Error)]
26pub enum ToolExecutorError {
27    #[error("tool execution cancelled")]
28    Cancelled,
29    #[error("tool execution failed: {0}")]
30    Failed(String),
31}
32
33/// Strategy abstraction for tool execution.
34#[async_trait]
35pub trait ToolExecutor: Send + Sync {
36    /// Execute tool calls and return results.
37    async fn execute(
38        &self,
39        tools: &HashMap<String, Arc<dyn Tool>>,
40        calls: &[ToolCall],
41        base_ctx: &ToolCallContext,
42    ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError>;
43
44    /// Strategy name for logging.
45    fn name(&self) -> &'static str;
46
47    /// Whether the executor needs state refreshed between individual tool calls.
48    /// Sequential executors return true; parallel executors return false.
49    fn requires_incremental_state(&self) -> bool {
50        false
51    }
52}
53
54/// Execute tool calls one by one in call order.
55/// Context freshness between calls is controlled by the caller.
56/// Stops at first suspension.
57#[derive(Debug, Clone, Copy, Default)]
58pub struct SequentialToolExecutor;
59
60#[async_trait]
61impl ToolExecutor for SequentialToolExecutor {
62    async fn execute(
63        &self,
64        tools: &HashMap<String, Arc<dyn Tool>>,
65        calls: &[ToolCall],
66        base_ctx: &ToolCallContext,
67    ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
68        let mut results = Vec::with_capacity(calls.len());
69
70        for call in calls {
71            let mut ctx = base_ctx.clone();
72            ctx.call_id = call.id.clone();
73            ctx.tool_name = call.name.clone();
74            let output = execute_single_tool(tools, call, &ctx).await;
75            let outcome = ToolCallOutcome::from_tool_result(&output.result);
76
77            results.push(ToolExecutionResult {
78                call: call.clone(),
79                result: output.result,
80                outcome,
81                command: output.command,
82            });
83
84            // Stop at first suspension
85            if results
86                .last()
87                .is_some_and(|r| r.outcome == ToolCallOutcome::Suspended)
88            {
89                break;
90            }
91        }
92
93        Ok(results)
94    }
95
96    fn name(&self) -> &'static str {
97        "sequential"
98    }
99
100    fn requires_incremental_state(&self) -> bool {
101        true
102    }
103}
104
105/// Policy controlling when resume decisions are replayed into tool execution.
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum DecisionReplayPolicy {
108    /// Replay each resolved suspended call as soon as its decision arrives.
109    Immediate,
110    /// Replay only when all currently suspended calls have decisions.
111    BatchAllSuspended,
112}
113
114/// Parallel execution mode.
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum ParallelMode {
117    /// All concurrently, batch approval gate — wait for all decisions.
118    BatchApproval,
119    /// All concurrently, streaming results — replay decisions immediately.
120    Streaming,
121}
122
123/// Execute all tool calls concurrently. All tools see the same frozen snapshot.
124///
125/// Two modes differ in how suspension decisions are replayed:
126/// - `BatchApproval`: wait for all suspended calls to have decisions before replay
127/// - `Streaming`: replay each decision immediately as it arrives
128#[derive(Debug, Clone, Copy)]
129pub struct ParallelToolExecutor {
130    mode: ParallelMode,
131}
132
133impl ParallelToolExecutor {
134    pub const fn batch_approval() -> Self {
135        Self {
136            mode: ParallelMode::BatchApproval,
137        }
138    }
139
140    pub const fn streaming() -> Self {
141        Self {
142            mode: ParallelMode::Streaming,
143        }
144    }
145
146    /// How the runtime should replay resolved suspend decisions.
147    pub fn decision_replay_policy(&self) -> DecisionReplayPolicy {
148        match self.mode {
149            ParallelMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
150            ParallelMode::Streaming => DecisionReplayPolicy::Immediate,
151        }
152    }
153
154    /// Whether the runtime should enforce parallel patch conflict checks.
155    pub fn requires_conflict_check(&self) -> bool {
156        true
157    }
158}
159
160impl Default for ParallelToolExecutor {
161    fn default() -> Self {
162        Self::streaming()
163    }
164}
165
166#[async_trait]
167impl ToolExecutor for ParallelToolExecutor {
168    async fn execute(
169        &self,
170        tools: &HashMap<String, Arc<dyn Tool>>,
171        calls: &[ToolCall],
172        base_ctx: &ToolCallContext,
173    ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
174        use futures::future::join_all;
175
176        let futures: Vec<_> = calls
177            .iter()
178            .map(|call| {
179                let tools = tools.clone();
180                let call = call.clone();
181                let mut ctx = base_ctx.clone();
182                ctx.call_id = call.id.clone();
183                ctx.tool_name = call.name.clone();
184                async move {
185                    let output = execute_single_tool(&tools, &call, &ctx).await;
186                    let outcome = ToolCallOutcome::from_tool_result(&output.result);
187                    ToolExecutionResult {
188                        call,
189                        result: output.result,
190                        outcome,
191                        command: output.command,
192                    }
193                }
194            })
195            .collect();
196
197        Ok(join_all(futures).await)
198    }
199
200    fn name(&self) -> &'static str {
201        match self.mode {
202            ParallelMode::BatchApproval => "parallel_batch_approval",
203            ParallelMode::Streaming => "parallel_streaming",
204        }
205    }
206}
207
208/// Execute a single tool call, never panicking.
209pub(crate) async fn execute_single_tool(
210    tools: &HashMap<String, Arc<dyn Tool>>,
211    call: &ToolCall,
212    ctx: &ToolCallContext,
213) -> ToolOutput {
214    let Some(tool) = tools.get(&call.name) else {
215        return ToolResult::error(&call.name, format!("tool '{}' not found", call.name)).into();
216    };
217
218    if let Err(e) = tool.validate_args(&call.arguments) {
219        return ToolResult::error(&call.name, e.to_string()).into();
220    }
221
222    match execute_tool_with_runtime_context(tool, call, ctx).await {
223        Ok(output) => output,
224        Err(e) => ToolResult::error(&call.name, e.to_string()).into(),
225    }
226}
227
228#[cfg(feature = "background")]
229async fn execute_tool_with_runtime_context(
230    tool: &Arc<dyn Tool>,
231    call: &ToolCall,
232    ctx: &ToolCallContext,
233) -> Result<ToolOutput, awaken_contract::contract::tool::ToolError> {
234    scope_tool_lineage_context(
235        ToolLineageContext {
236            run_id: ctx.run_identity.run_id.clone(),
237            call_id: ctx.call_id.clone(),
238            agent_id: ctx.run_identity.agent_id.clone(),
239        },
240        tool.execute(call.arguments.clone(), ctx),
241    )
242    .await
243}
244
245#[cfg(not(feature = "background"))]
246async fn execute_tool_with_runtime_context(
247    tool: &Arc<dyn Tool>,
248    call: &ToolCall,
249    ctx: &ToolCallContext,
250) -> Result<ToolOutput, awaken_contract::contract::tool::ToolError> {
251    tool.execute(call.arguments.clone(), ctx).await
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use awaken_contract::contract::tool::{ToolDescriptor, ToolError, ToolOutput};
258    use serde_json::{Value, json};
259
260    #[cfg(feature = "background")]
261    use crate::extensions::background::{
262        BackgroundTaskManager, BackgroundTaskPlugin, TaskParentContext, TaskResult,
263    };
264    #[cfg(feature = "background")]
265    use crate::phase::ExecutionEnv;
266    #[cfg(feature = "background")]
267    use crate::plugins::Plugin;
268    #[cfg(feature = "background")]
269    use crate::state::StateStore;
270    #[cfg(feature = "background")]
271    use awaken_contract::contract::identity::{RunIdentity, RunOrigin};
272
273    struct EchoTool;
274
275    #[async_trait]
276    impl Tool for EchoTool {
277        fn descriptor(&self) -> ToolDescriptor {
278            ToolDescriptor::new("echo", "echo", "Echoes input")
279        }
280
281        async fn execute(
282            &self,
283            args: Value,
284            _ctx: &ToolCallContext,
285        ) -> Result<ToolOutput, ToolError> {
286            let msg = args
287                .get("message")
288                .and_then(|v| v.as_str())
289                .unwrap_or("no message")
290                .to_string();
291            Ok(ToolResult::success_with_message("echo", args, msg).into())
292        }
293    }
294
295    struct FailingTool;
296
297    #[async_trait]
298    impl Tool for FailingTool {
299        fn descriptor(&self) -> ToolDescriptor {
300            ToolDescriptor::new("failing", "failing", "Always fails")
301        }
302
303        async fn execute(
304            &self,
305            _args: Value,
306            _ctx: &ToolCallContext,
307        ) -> Result<ToolOutput, ToolError> {
308            Err(ToolError::ExecutionFailed("intentional failure".into()))
309        }
310    }
311
312    struct SuspendingTool;
313
314    #[async_trait]
315    impl Tool for SuspendingTool {
316        fn descriptor(&self) -> ToolDescriptor {
317            ToolDescriptor::new("suspending", "suspending", "Returns pending")
318        }
319
320        async fn execute(
321            &self,
322            _args: Value,
323            _ctx: &ToolCallContext,
324        ) -> Result<ToolOutput, ToolError> {
325            Ok(ToolResult::suspended("suspending", "needs approval").into())
326        }
327    }
328
329    fn tool_map(tools: Vec<Arc<dyn Tool>>) -> HashMap<String, Arc<dyn Tool>> {
330        tools.into_iter().map(|t| (t.descriptor().id, t)).collect()
331    }
332
333    // -- Sequential tests --
334
335    #[tokio::test]
336    async fn sequential_single_tool_success() {
337        let tools = tool_map(vec![Arc::new(EchoTool)]);
338        let calls = vec![ToolCall::new("c1", "echo", json!({"message": "hi"}))];
339        let executor = SequentialToolExecutor;
340
341        let results = executor
342            .execute(&tools, &calls, &ToolCallContext::test_default())
343            .await
344            .unwrap();
345        assert_eq!(results.len(), 1);
346        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
347        assert!(results[0].result.is_success());
348    }
349
350    #[tokio::test]
351    async fn sequential_partial_failure() {
352        let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
353        let calls = vec![
354            ToolCall::new("c1", "echo", json!({"message": "ok"})),
355            ToolCall::new("c2", "failing", json!({})),
356        ];
357        let executor = SequentialToolExecutor;
358
359        let results = executor
360            .execute(&tools, &calls, &ToolCallContext::test_default())
361            .await
362            .unwrap();
363        assert_eq!(results.len(), 2);
364        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
365        assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
366    }
367
368    #[tokio::test]
369    async fn sequential_stops_after_first_suspension() {
370        let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
371        let calls = vec![
372            ToolCall::new("c1", "suspending", json!({})),
373            ToolCall::new("c2", "echo", json!({"message": "should not run"})),
374        ];
375        let executor = SequentialToolExecutor;
376
377        let results = executor
378            .execute(&tools, &calls, &ToolCallContext::test_default())
379            .await
380            .unwrap();
381        assert_eq!(results.len(), 1, "should stop after suspended tool");
382        assert_eq!(results[0].outcome, ToolCallOutcome::Suspended);
383    }
384
385    #[tokio::test]
386    async fn sequential_unknown_tool_returns_error() {
387        let tools = tool_map(vec![]);
388        let calls = vec![ToolCall::new("c1", "nonexistent", json!({}))];
389        let executor = SequentialToolExecutor;
390
391        let results = executor
392            .execute(&tools, &calls, &ToolCallContext::test_default())
393            .await
394            .unwrap();
395        assert_eq!(results.len(), 1);
396        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
397        assert!(results[0].result.is_error());
398    }
399
400    #[tokio::test]
401    async fn sequential_empty_calls() {
402        let tools = tool_map(vec![Arc::new(EchoTool)]);
403        let executor = SequentialToolExecutor;
404
405        let results = executor
406            .execute(&tools, &[], &ToolCallContext::test_default())
407            .await
408            .unwrap();
409        assert!(results.is_empty());
410    }
411
412    // -- Parallel tests --
413
414    #[tokio::test]
415    async fn parallel_all_succeed() {
416        let tools = tool_map(vec![Arc::new(EchoTool)]);
417        let calls = vec![
418            ToolCall::new("c1", "echo", json!({"message": "first"})),
419            ToolCall::new("c2", "echo", json!({"message": "second"})),
420        ];
421        let executor = ParallelToolExecutor::streaming();
422
423        let results = executor
424            .execute(&tools, &calls, &ToolCallContext::test_default())
425            .await
426            .unwrap();
427        assert_eq!(results.len(), 2);
428        assert!(
429            results
430                .iter()
431                .all(|r| r.outcome == ToolCallOutcome::Succeeded)
432        );
433    }
434
435    #[tokio::test]
436    async fn parallel_partial_failure() {
437        let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
438        let calls = vec![
439            ToolCall::new("c1", "echo", json!({"message": "ok"})),
440            ToolCall::new("c2", "failing", json!({})),
441        ];
442        let executor = ParallelToolExecutor::streaming();
443
444        let results = executor
445            .execute(&tools, &calls, &ToolCallContext::test_default())
446            .await
447            .unwrap();
448        assert_eq!(results.len(), 2);
449        let successes = results
450            .iter()
451            .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
452            .count();
453        let failures = results
454            .iter()
455            .filter(|r| r.outcome == ToolCallOutcome::Failed)
456            .count();
457        assert_eq!(successes, 1);
458        assert_eq!(failures, 1);
459    }
460
461    #[tokio::test]
462    async fn parallel_does_not_stop_on_suspension() {
463        let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
464        let calls = vec![
465            ToolCall::new("c1", "suspending", json!({})),
466            ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
467        ];
468        let executor = ParallelToolExecutor::streaming();
469
470        let results = executor
471            .execute(&tools, &calls, &ToolCallContext::test_default())
472            .await
473            .unwrap();
474        // Parallel executes ALL tools regardless of suspension
475        assert_eq!(results.len(), 2);
476        let suspended = results
477            .iter()
478            .filter(|r| r.outcome == ToolCallOutcome::Suspended)
479            .count();
480        let succeeded = results
481            .iter()
482            .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
483            .count();
484        assert_eq!(suspended, 1);
485        assert_eq!(succeeded, 1);
486    }
487
488    #[tokio::test]
489    async fn parallel_empty_calls() {
490        let tools = tool_map(vec![Arc::new(EchoTool)]);
491        let executor = ParallelToolExecutor::streaming();
492
493        let results = executor
494            .execute(&tools, &[], &ToolCallContext::test_default())
495            .await
496            .unwrap();
497        assert!(results.is_empty());
498    }
499
500    #[test]
501    fn executor_names() {
502        assert_eq!(SequentialToolExecutor.name(), "sequential");
503        assert_eq!(
504            ParallelToolExecutor::streaming().name(),
505            "parallel_streaming"
506        );
507        assert_eq!(
508            ParallelToolExecutor::batch_approval().name(),
509            "parallel_batch_approval"
510        );
511    }
512
513    #[test]
514    fn parallel_default_is_streaming() {
515        let executor = ParallelToolExecutor::default();
516        assert_eq!(executor.name(), "parallel_streaming");
517        assert_eq!(
518            executor.decision_replay_policy(),
519            DecisionReplayPolicy::Immediate
520        );
521    }
522
523    #[test]
524    fn parallel_batch_approval_policy() {
525        let executor = ParallelToolExecutor::batch_approval();
526        assert_eq!(
527            executor.decision_replay_policy(),
528            DecisionReplayPolicy::BatchAllSuspended
529        );
530        assert!(executor.requires_conflict_check());
531    }
532
533    #[test]
534    fn parallel_streaming_policy() {
535        let executor = ParallelToolExecutor::streaming();
536        assert_eq!(
537            executor.decision_replay_policy(),
538            DecisionReplayPolicy::Immediate
539        );
540        assert!(executor.requires_conflict_check());
541    }
542
543    #[tokio::test]
544    async fn batch_approval_executes_all_concurrently() {
545        let tools = tool_map(vec![Arc::new(EchoTool)]);
546        let calls = vec![
547            ToolCall::new("c1", "echo", json!({"message": "a"})),
548            ToolCall::new("c2", "echo", json!({"message": "b"})),
549        ];
550        let executor = ParallelToolExecutor::batch_approval();
551
552        let results = executor
553            .execute(&tools, &calls, &ToolCallContext::test_default())
554            .await
555            .unwrap();
556        assert_eq!(results.len(), 2);
557        assert!(
558            results
559                .iter()
560                .all(|r| r.outcome == ToolCallOutcome::Succeeded)
561        );
562    }
563
564    #[tokio::test]
565    async fn batch_approval_does_not_stop_on_suspension() {
566        let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
567        let calls = vec![
568            ToolCall::new("c1", "suspending", json!({})),
569            ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
570        ];
571        let executor = ParallelToolExecutor::batch_approval();
572
573        let results = executor
574            .execute(&tools, &calls, &ToolCallContext::test_default())
575            .await
576            .unwrap();
577        assert_eq!(results.len(), 2);
578    }
579
580    // -----------------------------------------------------------------------
581    // Migrated from uncarve: additional tool executor tests
582    // -----------------------------------------------------------------------
583
584    /// A tool that counts how many times it's been called.
585    struct CountingTool {
586        count: Arc<std::sync::atomic::AtomicUsize>,
587    }
588
589    impl CountingTool {
590        fn new() -> Self {
591            Self {
592                count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
593            }
594        }
595
596        fn call_count(&self) -> usize {
597            self.count.load(std::sync::atomic::Ordering::SeqCst)
598        }
599    }
600
601    #[async_trait]
602    impl Tool for CountingTool {
603        fn descriptor(&self) -> ToolDescriptor {
604            ToolDescriptor::new("counting", "counting", "Counts calls")
605        }
606
607        async fn execute(
608            &self,
609            _args: Value,
610            _ctx: &ToolCallContext,
611        ) -> Result<ToolOutput, ToolError> {
612            let n = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
613            Ok(ToolResult::success("counting", json!({"call_number": n + 1})).into())
614        }
615    }
616
617    #[tokio::test]
618    async fn sequential_multiple_calls_ordered() {
619        let counting = Arc::new(CountingTool::new());
620        let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
621        let calls = vec![
622            ToolCall::new("c1", "counting", json!({})),
623            ToolCall::new("c2", "counting", json!({})),
624            ToolCall::new("c3", "counting", json!({})),
625        ];
626        let executor = SequentialToolExecutor;
627
628        let results = executor
629            .execute(&tools, &calls, &ToolCallContext::test_default())
630            .await
631            .unwrap();
632        assert_eq!(results.len(), 3);
633        assert_eq!(counting.call_count(), 3);
634        // Verify order is preserved
635        for (i, result) in results.iter().enumerate() {
636            assert_eq!(result.call.id, format!("c{}", i + 1));
637            assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
638        }
639    }
640
641    #[tokio::test]
642    async fn sequential_failure_does_not_stop_execution() {
643        // Unlike suspension, failures do NOT stop sequential execution
644        let tools = tool_map(vec![Arc::new(FailingTool), Arc::new(EchoTool)]);
645        let calls = vec![
646            ToolCall::new("c1", "failing", json!({})),
647            ToolCall::new("c2", "echo", json!({"message": "still runs"})),
648        ];
649        let executor = SequentialToolExecutor;
650
651        let results = executor
652            .execute(&tools, &calls, &ToolCallContext::test_default())
653            .await
654            .unwrap();
655        assert_eq!(results.len(), 2);
656        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
657        assert_eq!(results[1].outcome, ToolCallOutcome::Succeeded);
658    }
659
660    #[tokio::test]
661    async fn sequential_suspension_in_middle_stops_remaining() {
662        let tools = tool_map(vec![
663            Arc::new(EchoTool),
664            Arc::new(SuspendingTool),
665            Arc::new(EchoTool),
666        ]);
667        let calls = vec![
668            ToolCall::new("c1", "echo", json!({"message": "first"})),
669            ToolCall::new("c2", "suspending", json!({})),
670            ToolCall::new("c3", "echo", json!({"message": "should not run"})),
671        ];
672        let executor = SequentialToolExecutor;
673
674        let results = executor
675            .execute(&tools, &calls, &ToolCallContext::test_default())
676            .await
677            .unwrap();
678        assert_eq!(results.len(), 2, "should stop after suspension");
679        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
680        assert_eq!(results[1].outcome, ToolCallOutcome::Suspended);
681    }
682
683    #[tokio::test]
684    async fn parallel_all_fail() {
685        let tools = tool_map(vec![Arc::new(FailingTool)]);
686        let calls = vec![
687            ToolCall::new("c1", "failing", json!({})),
688            ToolCall::new("c2", "failing", json!({})),
689        ];
690        let executor = ParallelToolExecutor::streaming();
691
692        let results = executor
693            .execute(&tools, &calls, &ToolCallContext::test_default())
694            .await
695            .unwrap();
696        assert_eq!(results.len(), 2);
697        assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
698    }
699
700    #[tokio::test]
701    async fn parallel_unknown_tool_returns_error() {
702        let tools = tool_map(vec![]);
703        let calls = vec![
704            ToolCall::new("c1", "nonexistent_a", json!({})),
705            ToolCall::new("c2", "nonexistent_b", json!({})),
706        ];
707        let executor = ParallelToolExecutor::streaming();
708
709        let results = executor
710            .execute(&tools, &calls, &ToolCallContext::test_default())
711            .await
712            .unwrap();
713        assert_eq!(results.len(), 2);
714        assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
715        for r in &results {
716            assert!(r.result.is_error());
717        }
718    }
719
720    #[tokio::test]
721    async fn parallel_counting_tool_all_called() {
722        let counting = Arc::new(CountingTool::new());
723        let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
724        let calls = vec![
725            ToolCall::new("c1", "counting", json!({})),
726            ToolCall::new("c2", "counting", json!({})),
727            ToolCall::new("c3", "counting", json!({})),
728        ];
729        let executor = ParallelToolExecutor::streaming();
730
731        let results = executor
732            .execute(&tools, &calls, &ToolCallContext::test_default())
733            .await
734            .unwrap();
735        assert_eq!(results.len(), 3);
736        assert_eq!(counting.call_count(), 3);
737    }
738
739    /// Validate that tool args are validated before execution.
740    struct StrictArgsTool;
741
742    #[async_trait]
743    impl Tool for StrictArgsTool {
744        fn descriptor(&self) -> ToolDescriptor {
745            ToolDescriptor::new("strict", "strict", "Validates args")
746        }
747
748        fn validate_args(&self, args: &Value) -> Result<(), ToolError> {
749            if args.get("required_field").is_none() {
750                return Err(ToolError::InvalidArguments("missing required_field".into()));
751            }
752            Ok(())
753        }
754
755        async fn execute(
756            &self,
757            args: Value,
758            _ctx: &ToolCallContext,
759        ) -> Result<ToolOutput, ToolError> {
760            Ok(ToolResult::success("strict", args).into())
761        }
762    }
763
764    #[tokio::test]
765    async fn sequential_validates_args_before_execute() {
766        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
767        let calls = vec![ToolCall::new("c1", "strict", json!({}))]; // missing required_field
768        let executor = SequentialToolExecutor;
769
770        let results = executor
771            .execute(&tools, &calls, &ToolCallContext::test_default())
772            .await
773            .unwrap();
774        assert_eq!(results.len(), 1);
775        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
776        assert!(results[0].result.is_error());
777    }
778
779    #[tokio::test]
780    async fn sequential_valid_args_succeeds() {
781        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
782        let calls = vec![ToolCall::new(
783            "c1",
784            "strict",
785            json!({"required_field": "present"}),
786        )];
787        let executor = SequentialToolExecutor;
788
789        let results = executor
790            .execute(&tools, &calls, &ToolCallContext::test_default())
791            .await
792            .unwrap();
793        assert_eq!(results.len(), 1);
794        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
795    }
796
797    #[tokio::test]
798    async fn parallel_validates_args_before_execute() {
799        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
800        let calls = vec![ToolCall::new("c1", "strict", json!({}))];
801        let executor = ParallelToolExecutor::streaming();
802
803        let results = executor
804            .execute(&tools, &calls, &ToolCallContext::test_default())
805            .await
806            .unwrap();
807        assert_eq!(results.len(), 1);
808        assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
809    }
810
811    #[test]
812    fn tool_execution_result_fields() {
813        let result = ToolExecutionResult {
814            call: ToolCall::new("c1", "echo", json!({})),
815            result: ToolResult::success("echo", json!({"ok": true})),
816            outcome: ToolCallOutcome::Succeeded,
817            command: StateCommand::default(),
818        };
819        assert_eq!(result.call.id, "c1");
820        assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
821    }
822
823    #[test]
824    fn tool_executor_error_display() {
825        let err = ToolExecutorError::Cancelled;
826        assert!(err.to_string().contains("cancelled"));
827        let err2 = ToolExecutorError::Failed("some reason".into());
828        assert!(err2.to_string().contains("some reason"));
829    }
830
831    // -----------------------------------------------------------------------
832    // Additional coverage: context preservation, mixed scenarios, edge cases
833    // -----------------------------------------------------------------------
834
835    /// Tool that captures the context it receives for later inspection.
836    struct ContextCaptureTool {
837        captured_call_id: Arc<std::sync::Mutex<String>>,
838        captured_tool_name: Arc<std::sync::Mutex<String>>,
839    }
840
841    impl ContextCaptureTool {
842        fn new() -> Self {
843            Self {
844                captured_call_id: Arc::new(std::sync::Mutex::new(String::new())),
845                captured_tool_name: Arc::new(std::sync::Mutex::new(String::new())),
846            }
847        }
848    }
849
850    #[async_trait]
851    impl Tool for ContextCaptureTool {
852        fn descriptor(&self) -> ToolDescriptor {
853            ToolDescriptor::new("capture", "capture", "Captures context")
854        }
855
856        async fn execute(
857            &self,
858            _args: Value,
859            ctx: &ToolCallContext,
860        ) -> Result<ToolOutput, ToolError> {
861            *self.captured_call_id.lock().unwrap() = ctx.call_id.clone();
862            *self.captured_tool_name.lock().unwrap() = ctx.tool_name.clone();
863            Ok(ToolResult::success("capture", json!({"captured": true})).into())
864        }
865    }
866
867    #[tokio::test]
868    async fn execute_single_tool_preserves_call_context() {
869        let capture = Arc::new(ContextCaptureTool::new());
870        let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
871        let call = ToolCall::new("call-42", "capture", json!({}));
872        let ctx = ToolCallContext::test_default();
873
874        let output = execute_single_tool(&tools, &call, &ctx).await;
875        assert!(output.result.is_success());
876        // execute_single_tool sets call_id and tool_name from the call's context
877        // (the caller is responsible for setting ctx fields, which the executor does)
878    }
879
880    #[tokio::test]
881    async fn execute_single_tool_missing_returns_error_with_name() {
882        let tools: HashMap<String, Arc<dyn Tool>> = HashMap::new();
883        let call = ToolCall::new("c1", "ghost_tool", json!({}));
884        let ctx = ToolCallContext::test_default();
885
886        let output = execute_single_tool(&tools, &call, &ctx).await;
887        assert!(output.result.is_error());
888        assert!(
889            output
890                .result
891                .message
892                .as_deref()
893                .unwrap_or("")
894                .contains("ghost_tool")
895        );
896    }
897
898    #[tokio::test]
899    async fn execute_single_tool_validates_args() {
900        let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
901        let call = ToolCall::new("c1", "strict", json!({"wrong": "field"}));
902        let ctx = ToolCallContext::test_default();
903
904        let output = execute_single_tool(&tools, &call, &ctx).await;
905        assert!(output.result.is_error());
906    }
907
908    #[tokio::test]
909    async fn sequential_context_call_id_set_per_tool() {
910        let capture = Arc::new(ContextCaptureTool::new());
911        let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
912        let calls = vec![ToolCall::new("unique-id-99", "capture", json!({}))];
913        let executor = SequentialToolExecutor;
914
915        let results = executor
916            .execute(&tools, &calls, &ToolCallContext::test_default())
917            .await
918            .unwrap();
919        assert_eq!(results.len(), 1);
920        assert_eq!(results[0].call.id, "unique-id-99");
921        // The executor sets ctx.call_id = call.id before passing to execute_single_tool
922        assert_eq!(*capture.captured_call_id.lock().unwrap(), "unique-id-99");
923    }
924
925    #[tokio::test]
926    async fn sequential_mixed_success_failure_suspension_order() {
927        let tools = tool_map(vec![
928            Arc::new(EchoTool),
929            Arc::new(FailingTool),
930            Arc::new(SuspendingTool),
931        ]);
932        // success, fail, suspend — suspension stops execution
933        let calls = vec![
934            ToolCall::new("c1", "echo", json!({"message": "hi"})),
935            ToolCall::new("c2", "failing", json!({})),
936            ToolCall::new("c3", "suspending", json!({})),
937            ToolCall::new("c4", "echo", json!({"message": "should not run"})),
938        ];
939        let executor = SequentialToolExecutor;
940
941        let results = executor
942            .execute(&tools, &calls, &ToolCallContext::test_default())
943            .await
944            .unwrap();
945        assert_eq!(results.len(), 3, "stops after suspension at c3");
946        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
947        assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
948        assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
949    }
950
951    #[tokio::test]
952    async fn parallel_preserves_result_order() {
953        let counting = Arc::new(CountingTool::new());
954        let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
955        let calls: Vec<_> = (0..5)
956            .map(|i| ToolCall::new(format!("c{i}"), "counting", json!({})))
957            .collect();
958        let executor = ParallelToolExecutor::streaming();
959
960        let results = executor
961            .execute(&tools, &calls, &ToolCallContext::test_default())
962            .await
963            .unwrap();
964        assert_eq!(results.len(), 5);
965        // join_all preserves order of futures
966        for (i, r) in results.iter().enumerate() {
967            assert_eq!(r.call.id, format!("c{i}"));
968        }
969    }
970
971    #[tokio::test]
972    async fn parallel_mixed_success_failure_suspension() {
973        let tools = tool_map(vec![
974            Arc::new(EchoTool),
975            Arc::new(FailingTool),
976            Arc::new(SuspendingTool),
977        ]);
978        let calls = vec![
979            ToolCall::new("c1", "echo", json!({"message": "hi"})),
980            ToolCall::new("c2", "failing", json!({})),
981            ToolCall::new("c3", "suspending", json!({})),
982        ];
983        let executor = ParallelToolExecutor::batch_approval();
984
985        let results = executor
986            .execute(&tools, &calls, &ToolCallContext::test_default())
987            .await
988            .unwrap();
989        assert_eq!(results.len(), 3, "parallel runs all regardless");
990        assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
991        assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
992        assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
993    }
994
995    #[test]
996    fn sequential_requires_incremental_state() {
997        let executor = SequentialToolExecutor;
998        assert!(executor.requires_incremental_state());
999    }
1000
1001    #[test]
1002    fn parallel_does_not_require_incremental_state() {
1003        let executor = ParallelToolExecutor::streaming();
1004        assert!(!executor.requires_incremental_state());
1005        let batch = ParallelToolExecutor::batch_approval();
1006        assert!(!batch.requires_incremental_state());
1007    }
1008
1009    #[tokio::test]
1010    async fn execute_single_tool_success_returns_correct_tool_name() {
1011        let tools = tool_map(vec![Arc::new(EchoTool)]);
1012        let call = ToolCall::new("c1", "echo", json!({"message": "test"}));
1013        let ctx = ToolCallContext::test_default();
1014
1015        let output = execute_single_tool(&tools, &call, &ctx).await;
1016        assert!(output.result.is_success());
1017        assert_eq!(output.result.tool_name, "echo");
1018    }
1019
1020    #[cfg(feature = "background")]
1021    struct SpawnBackgroundTaskTool {
1022        manager: Arc<BackgroundTaskManager>,
1023        spawned_task_id: Arc<std::sync::Mutex<Option<String>>>,
1024    }
1025
1026    #[cfg(feature = "background")]
1027    #[async_trait]
1028    impl Tool for SpawnBackgroundTaskTool {
1029        fn descriptor(&self) -> ToolDescriptor {
1030            ToolDescriptor::new(
1031                "spawn_background",
1032                "spawn_background",
1033                "Spawns a background task",
1034            )
1035        }
1036
1037        async fn execute(
1038            &self,
1039            _args: Value,
1040            _ctx: &ToolCallContext,
1041        ) -> Result<ToolOutput, ToolError> {
1042            let task_id = self
1043                .manager
1044                .spawn(
1045                    "thread-1",
1046                    "background",
1047                    None,
1048                    "spawned from tool",
1049                    TaskParentContext::default(),
1050                    |_| async { TaskResult::Success(json!({"ok": true})) },
1051                )
1052                .await
1053                .map_err(|error| ToolError::ExecutionFailed(error.to_string()))?;
1054            *self.spawned_task_id.lock().unwrap() = Some(task_id.clone());
1055            Ok(ToolResult::success("spawn_background", json!({"task_id": task_id})).into())
1056        }
1057    }
1058
1059    #[cfg(feature = "background")]
1060    #[tokio::test]
1061    async fn execute_single_tool_scopes_tool_lineage_for_background_spawns() {
1062        let store = StateStore::new();
1063        let manager = Arc::new(BackgroundTaskManager::new());
1064        manager.set_store(store.clone());
1065        let plugin: Arc<dyn Plugin> = Arc::new(BackgroundTaskPlugin::new(manager.clone()));
1066        let env = ExecutionEnv::from_plugins(&[plugin], &Default::default()).unwrap();
1067        store.register_keys(&env.key_registrations).unwrap();
1068
1069        let spawned_task_id = Arc::new(std::sync::Mutex::new(None::<String>));
1070        let tools = tool_map(vec![Arc::new(SpawnBackgroundTaskTool {
1071            manager: manager.clone(),
1072            spawned_task_id: spawned_task_id.clone(),
1073        }) as Arc<dyn Tool>]);
1074        let executor = SequentialToolExecutor;
1075        let calls = vec![ToolCall::new("call-77", "spawn_background", json!({}))];
1076
1077        let mut ctx = ToolCallContext::test_default();
1078        ctx.run_identity = RunIdentity::new(
1079            "thread-1".into(),
1080            None,
1081            "run-77".into(),
1082            None,
1083            "agent-77".into(),
1084            RunOrigin::User,
1085        );
1086        ctx.cancellation_token = Some(crate::cancellation::CancellationToken::new());
1087
1088        let results = executor.execute(&tools, &calls, &ctx).await.unwrap();
1089        assert_eq!(results.len(), 1);
1090        assert!(results[0].result.is_success());
1091
1092        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1093
1094        let task_id = spawned_task_id
1095            .lock()
1096            .unwrap()
1097            .clone()
1098            .expect("spawned task id should be recorded");
1099        let summary = manager
1100            .get(&task_id)
1101            .await
1102            .expect("spawned background task should be queryable");
1103        assert_eq!(summary.parent_context.run_id.as_deref(), Some("run-77"));
1104        assert_eq!(summary.parent_context.call_id.as_deref(), Some("call-77"));
1105        assert_eq!(summary.parent_context.agent_id.as_deref(), Some("agent-77"));
1106    }
1107}