1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use awaken_contract::contract::message::ToolCall;
8use awaken_contract::contract::suspension::ToolCallOutcome;
9use awaken_contract::contract::tool::{Tool, ToolCallContext, ToolOutput, ToolResult};
10use awaken_contract::state::StateCommand;
11
12pub struct ToolExecutionResult {
14 pub call: ToolCall,
15 pub result: ToolResult,
16 pub outcome: ToolCallOutcome,
17 pub command: StateCommand,
19}
20
21#[derive(Debug, thiserror::Error)]
23pub enum ToolExecutorError {
24 #[error("tool execution cancelled")]
25 Cancelled,
26 #[error("tool execution failed: {0}")]
27 Failed(String),
28}
29
30#[async_trait]
32pub trait ToolExecutor: Send + Sync {
33 async fn execute(
35 &self,
36 tools: &HashMap<String, Arc<dyn Tool>>,
37 calls: &[ToolCall],
38 base_ctx: &ToolCallContext,
39 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError>;
40
41 fn name(&self) -> &'static str;
43
44 fn requires_incremental_state(&self) -> bool {
47 false
48 }
49}
50
51#[derive(Debug, Clone, Copy, Default)]
55pub struct SequentialToolExecutor;
56
57#[async_trait]
58impl ToolExecutor for SequentialToolExecutor {
59 async fn execute(
60 &self,
61 tools: &HashMap<String, Arc<dyn Tool>>,
62 calls: &[ToolCall],
63 base_ctx: &ToolCallContext,
64 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
65 let mut results = Vec::with_capacity(calls.len());
66
67 for call in calls {
68 let mut ctx = base_ctx.clone();
69 ctx.call_id = call.id.clone();
70 ctx.tool_name = call.name.clone();
71 let output = execute_single_tool(tools, call, &ctx).await;
72 let outcome = ToolCallOutcome::from_tool_result(&output.result);
73
74 results.push(ToolExecutionResult {
75 call: call.clone(),
76 result: output.result,
77 outcome,
78 command: output.command,
79 });
80
81 if results
83 .last()
84 .is_some_and(|r| r.outcome == ToolCallOutcome::Suspended)
85 {
86 break;
87 }
88 }
89
90 Ok(results)
91 }
92
93 fn name(&self) -> &'static str {
94 "sequential"
95 }
96
97 fn requires_incremental_state(&self) -> bool {
98 true
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum DecisionReplayPolicy {
105 Immediate,
107 BatchAllSuspended,
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum ParallelMode {
114 BatchApproval,
116 Streaming,
118}
119
120#[derive(Debug, Clone, Copy)]
126pub struct ParallelToolExecutor {
127 mode: ParallelMode,
128}
129
130impl ParallelToolExecutor {
131 pub const fn batch_approval() -> Self {
132 Self {
133 mode: ParallelMode::BatchApproval,
134 }
135 }
136
137 pub const fn streaming() -> Self {
138 Self {
139 mode: ParallelMode::Streaming,
140 }
141 }
142
143 pub fn decision_replay_policy(&self) -> DecisionReplayPolicy {
145 match self.mode {
146 ParallelMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
147 ParallelMode::Streaming => DecisionReplayPolicy::Immediate,
148 }
149 }
150
151 pub fn requires_conflict_check(&self) -> bool {
153 true
154 }
155}
156
157impl Default for ParallelToolExecutor {
158 fn default() -> Self {
159 Self::streaming()
160 }
161}
162
163#[async_trait]
164impl ToolExecutor for ParallelToolExecutor {
165 async fn execute(
166 &self,
167 tools: &HashMap<String, Arc<dyn Tool>>,
168 calls: &[ToolCall],
169 base_ctx: &ToolCallContext,
170 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
171 use futures::future::join_all;
172
173 let futures: Vec<_> = calls
174 .iter()
175 .map(|call| {
176 let tools = tools.clone();
177 let call = call.clone();
178 let mut ctx = base_ctx.clone();
179 ctx.call_id = call.id.clone();
180 ctx.tool_name = call.name.clone();
181 async move {
182 let output = execute_single_tool(&tools, &call, &ctx).await;
183 let outcome = ToolCallOutcome::from_tool_result(&output.result);
184 ToolExecutionResult {
185 call,
186 result: output.result,
187 outcome,
188 command: output.command,
189 }
190 }
191 })
192 .collect();
193
194 Ok(join_all(futures).await)
195 }
196
197 fn name(&self) -> &'static str {
198 match self.mode {
199 ParallelMode::BatchApproval => "parallel_batch_approval",
200 ParallelMode::Streaming => "parallel_streaming",
201 }
202 }
203}
204
205pub(crate) async fn execute_single_tool(
207 tools: &HashMap<String, Arc<dyn Tool>>,
208 call: &ToolCall,
209 ctx: &ToolCallContext,
210) -> ToolOutput {
211 let Some(tool) = tools.get(&call.name) else {
212 return ToolResult::error(&call.name, format!("tool '{}' not found", call.name)).into();
213 };
214
215 if let Err(e) = tool.validate_args(&call.arguments) {
216 return ToolResult::error(&call.name, e.to_string()).into();
217 }
218
219 match tool.execute(call.arguments.clone(), ctx).await {
220 Ok(output) => output,
221 Err(e) => ToolResult::error(&call.name, e.to_string()).into(),
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use awaken_contract::contract::tool::{ToolDescriptor, ToolError, ToolOutput};
229 use serde_json::{Value, json};
230
231 struct EchoTool;
232
233 #[async_trait]
234 impl Tool for EchoTool {
235 fn descriptor(&self) -> ToolDescriptor {
236 ToolDescriptor::new("echo", "echo", "Echoes input")
237 }
238
239 async fn execute(
240 &self,
241 args: Value,
242 _ctx: &ToolCallContext,
243 ) -> Result<ToolOutput, ToolError> {
244 let msg = args
245 .get("message")
246 .and_then(|v| v.as_str())
247 .unwrap_or("no message")
248 .to_string();
249 Ok(ToolResult::success_with_message("echo", args, msg).into())
250 }
251 }
252
253 struct FailingTool;
254
255 #[async_trait]
256 impl Tool for FailingTool {
257 fn descriptor(&self) -> ToolDescriptor {
258 ToolDescriptor::new("failing", "failing", "Always fails")
259 }
260
261 async fn execute(
262 &self,
263 _args: Value,
264 _ctx: &ToolCallContext,
265 ) -> Result<ToolOutput, ToolError> {
266 Err(ToolError::ExecutionFailed("intentional failure".into()))
267 }
268 }
269
270 struct SuspendingTool;
271
272 #[async_trait]
273 impl Tool for SuspendingTool {
274 fn descriptor(&self) -> ToolDescriptor {
275 ToolDescriptor::new("suspending", "suspending", "Returns pending")
276 }
277
278 async fn execute(
279 &self,
280 _args: Value,
281 _ctx: &ToolCallContext,
282 ) -> Result<ToolOutput, ToolError> {
283 Ok(ToolResult::suspended("suspending", "needs approval").into())
284 }
285 }
286
287 fn tool_map(tools: Vec<Arc<dyn Tool>>) -> HashMap<String, Arc<dyn Tool>> {
288 tools.into_iter().map(|t| (t.descriptor().id, t)).collect()
289 }
290
291 #[tokio::test]
294 async fn sequential_single_tool_success() {
295 let tools = tool_map(vec![Arc::new(EchoTool)]);
296 let calls = vec![ToolCall::new("c1", "echo", json!({"message": "hi"}))];
297 let executor = SequentialToolExecutor;
298
299 let results = executor
300 .execute(&tools, &calls, &ToolCallContext::test_default())
301 .await
302 .unwrap();
303 assert_eq!(results.len(), 1);
304 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
305 assert!(results[0].result.is_success());
306 }
307
308 #[tokio::test]
309 async fn sequential_partial_failure() {
310 let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
311 let calls = vec![
312 ToolCall::new("c1", "echo", json!({"message": "ok"})),
313 ToolCall::new("c2", "failing", json!({})),
314 ];
315 let executor = SequentialToolExecutor;
316
317 let results = executor
318 .execute(&tools, &calls, &ToolCallContext::test_default())
319 .await
320 .unwrap();
321 assert_eq!(results.len(), 2);
322 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
323 assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
324 }
325
326 #[tokio::test]
327 async fn sequential_stops_after_first_suspension() {
328 let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
329 let calls = vec![
330 ToolCall::new("c1", "suspending", json!({})),
331 ToolCall::new("c2", "echo", json!({"message": "should not run"})),
332 ];
333 let executor = SequentialToolExecutor;
334
335 let results = executor
336 .execute(&tools, &calls, &ToolCallContext::test_default())
337 .await
338 .unwrap();
339 assert_eq!(results.len(), 1, "should stop after suspended tool");
340 assert_eq!(results[0].outcome, ToolCallOutcome::Suspended);
341 }
342
343 #[tokio::test]
344 async fn sequential_unknown_tool_returns_error() {
345 let tools = tool_map(vec![]);
346 let calls = vec![ToolCall::new("c1", "nonexistent", json!({}))];
347 let executor = SequentialToolExecutor;
348
349 let results = executor
350 .execute(&tools, &calls, &ToolCallContext::test_default())
351 .await
352 .unwrap();
353 assert_eq!(results.len(), 1);
354 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
355 assert!(results[0].result.is_error());
356 }
357
358 #[tokio::test]
359 async fn sequential_empty_calls() {
360 let tools = tool_map(vec![Arc::new(EchoTool)]);
361 let executor = SequentialToolExecutor;
362
363 let results = executor
364 .execute(&tools, &[], &ToolCallContext::test_default())
365 .await
366 .unwrap();
367 assert!(results.is_empty());
368 }
369
370 #[tokio::test]
373 async fn parallel_all_succeed() {
374 let tools = tool_map(vec![Arc::new(EchoTool)]);
375 let calls = vec![
376 ToolCall::new("c1", "echo", json!({"message": "first"})),
377 ToolCall::new("c2", "echo", json!({"message": "second"})),
378 ];
379 let executor = ParallelToolExecutor::streaming();
380
381 let results = executor
382 .execute(&tools, &calls, &ToolCallContext::test_default())
383 .await
384 .unwrap();
385 assert_eq!(results.len(), 2);
386 assert!(
387 results
388 .iter()
389 .all(|r| r.outcome == ToolCallOutcome::Succeeded)
390 );
391 }
392
393 #[tokio::test]
394 async fn parallel_partial_failure() {
395 let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
396 let calls = vec![
397 ToolCall::new("c1", "echo", json!({"message": "ok"})),
398 ToolCall::new("c2", "failing", json!({})),
399 ];
400 let executor = ParallelToolExecutor::streaming();
401
402 let results = executor
403 .execute(&tools, &calls, &ToolCallContext::test_default())
404 .await
405 .unwrap();
406 assert_eq!(results.len(), 2);
407 let successes = results
408 .iter()
409 .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
410 .count();
411 let failures = results
412 .iter()
413 .filter(|r| r.outcome == ToolCallOutcome::Failed)
414 .count();
415 assert_eq!(successes, 1);
416 assert_eq!(failures, 1);
417 }
418
419 #[tokio::test]
420 async fn parallel_does_not_stop_on_suspension() {
421 let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
422 let calls = vec![
423 ToolCall::new("c1", "suspending", json!({})),
424 ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
425 ];
426 let executor = ParallelToolExecutor::streaming();
427
428 let results = executor
429 .execute(&tools, &calls, &ToolCallContext::test_default())
430 .await
431 .unwrap();
432 assert_eq!(results.len(), 2);
434 let suspended = results
435 .iter()
436 .filter(|r| r.outcome == ToolCallOutcome::Suspended)
437 .count();
438 let succeeded = results
439 .iter()
440 .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
441 .count();
442 assert_eq!(suspended, 1);
443 assert_eq!(succeeded, 1);
444 }
445
446 #[tokio::test]
447 async fn parallel_empty_calls() {
448 let tools = tool_map(vec![Arc::new(EchoTool)]);
449 let executor = ParallelToolExecutor::streaming();
450
451 let results = executor
452 .execute(&tools, &[], &ToolCallContext::test_default())
453 .await
454 .unwrap();
455 assert!(results.is_empty());
456 }
457
458 #[test]
459 fn executor_names() {
460 assert_eq!(SequentialToolExecutor.name(), "sequential");
461 assert_eq!(
462 ParallelToolExecutor::streaming().name(),
463 "parallel_streaming"
464 );
465 assert_eq!(
466 ParallelToolExecutor::batch_approval().name(),
467 "parallel_batch_approval"
468 );
469 }
470
471 #[test]
472 fn parallel_default_is_streaming() {
473 let executor = ParallelToolExecutor::default();
474 assert_eq!(executor.name(), "parallel_streaming");
475 assert_eq!(
476 executor.decision_replay_policy(),
477 DecisionReplayPolicy::Immediate
478 );
479 }
480
481 #[test]
482 fn parallel_batch_approval_policy() {
483 let executor = ParallelToolExecutor::batch_approval();
484 assert_eq!(
485 executor.decision_replay_policy(),
486 DecisionReplayPolicy::BatchAllSuspended
487 );
488 assert!(executor.requires_conflict_check());
489 }
490
491 #[test]
492 fn parallel_streaming_policy() {
493 let executor = ParallelToolExecutor::streaming();
494 assert_eq!(
495 executor.decision_replay_policy(),
496 DecisionReplayPolicy::Immediate
497 );
498 assert!(executor.requires_conflict_check());
499 }
500
501 #[tokio::test]
502 async fn batch_approval_executes_all_concurrently() {
503 let tools = tool_map(vec![Arc::new(EchoTool)]);
504 let calls = vec![
505 ToolCall::new("c1", "echo", json!({"message": "a"})),
506 ToolCall::new("c2", "echo", json!({"message": "b"})),
507 ];
508 let executor = ParallelToolExecutor::batch_approval();
509
510 let results = executor
511 .execute(&tools, &calls, &ToolCallContext::test_default())
512 .await
513 .unwrap();
514 assert_eq!(results.len(), 2);
515 assert!(
516 results
517 .iter()
518 .all(|r| r.outcome == ToolCallOutcome::Succeeded)
519 );
520 }
521
522 #[tokio::test]
523 async fn batch_approval_does_not_stop_on_suspension() {
524 let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
525 let calls = vec![
526 ToolCall::new("c1", "suspending", json!({})),
527 ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
528 ];
529 let executor = ParallelToolExecutor::batch_approval();
530
531 let results = executor
532 .execute(&tools, &calls, &ToolCallContext::test_default())
533 .await
534 .unwrap();
535 assert_eq!(results.len(), 2);
536 }
537
538 struct CountingTool {
544 count: Arc<std::sync::atomic::AtomicUsize>,
545 }
546
547 impl CountingTool {
548 fn new() -> Self {
549 Self {
550 count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
551 }
552 }
553
554 fn call_count(&self) -> usize {
555 self.count.load(std::sync::atomic::Ordering::SeqCst)
556 }
557 }
558
559 #[async_trait]
560 impl Tool for CountingTool {
561 fn descriptor(&self) -> ToolDescriptor {
562 ToolDescriptor::new("counting", "counting", "Counts calls")
563 }
564
565 async fn execute(
566 &self,
567 _args: Value,
568 _ctx: &ToolCallContext,
569 ) -> Result<ToolOutput, ToolError> {
570 let n = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
571 Ok(ToolResult::success("counting", json!({"call_number": n + 1})).into())
572 }
573 }
574
575 #[tokio::test]
576 async fn sequential_multiple_calls_ordered() {
577 let counting = Arc::new(CountingTool::new());
578 let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
579 let calls = vec![
580 ToolCall::new("c1", "counting", json!({})),
581 ToolCall::new("c2", "counting", json!({})),
582 ToolCall::new("c3", "counting", json!({})),
583 ];
584 let executor = SequentialToolExecutor;
585
586 let results = executor
587 .execute(&tools, &calls, &ToolCallContext::test_default())
588 .await
589 .unwrap();
590 assert_eq!(results.len(), 3);
591 assert_eq!(counting.call_count(), 3);
592 for (i, result) in results.iter().enumerate() {
594 assert_eq!(result.call.id, format!("c{}", i + 1));
595 assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
596 }
597 }
598
599 #[tokio::test]
600 async fn sequential_failure_does_not_stop_execution() {
601 let tools = tool_map(vec![Arc::new(FailingTool), Arc::new(EchoTool)]);
603 let calls = vec![
604 ToolCall::new("c1", "failing", json!({})),
605 ToolCall::new("c2", "echo", json!({"message": "still runs"})),
606 ];
607 let executor = SequentialToolExecutor;
608
609 let results = executor
610 .execute(&tools, &calls, &ToolCallContext::test_default())
611 .await
612 .unwrap();
613 assert_eq!(results.len(), 2);
614 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
615 assert_eq!(results[1].outcome, ToolCallOutcome::Succeeded);
616 }
617
618 #[tokio::test]
619 async fn sequential_suspension_in_middle_stops_remaining() {
620 let tools = tool_map(vec![
621 Arc::new(EchoTool),
622 Arc::new(SuspendingTool),
623 Arc::new(EchoTool),
624 ]);
625 let calls = vec![
626 ToolCall::new("c1", "echo", json!({"message": "first"})),
627 ToolCall::new("c2", "suspending", json!({})),
628 ToolCall::new("c3", "echo", json!({"message": "should not run"})),
629 ];
630 let executor = SequentialToolExecutor;
631
632 let results = executor
633 .execute(&tools, &calls, &ToolCallContext::test_default())
634 .await
635 .unwrap();
636 assert_eq!(results.len(), 2, "should stop after suspension");
637 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
638 assert_eq!(results[1].outcome, ToolCallOutcome::Suspended);
639 }
640
641 #[tokio::test]
642 async fn parallel_all_fail() {
643 let tools = tool_map(vec![Arc::new(FailingTool)]);
644 let calls = vec![
645 ToolCall::new("c1", "failing", json!({})),
646 ToolCall::new("c2", "failing", json!({})),
647 ];
648 let executor = ParallelToolExecutor::streaming();
649
650 let results = executor
651 .execute(&tools, &calls, &ToolCallContext::test_default())
652 .await
653 .unwrap();
654 assert_eq!(results.len(), 2);
655 assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
656 }
657
658 #[tokio::test]
659 async fn parallel_unknown_tool_returns_error() {
660 let tools = tool_map(vec![]);
661 let calls = vec![
662 ToolCall::new("c1", "nonexistent_a", json!({})),
663 ToolCall::new("c2", "nonexistent_b", json!({})),
664 ];
665 let executor = ParallelToolExecutor::streaming();
666
667 let results = executor
668 .execute(&tools, &calls, &ToolCallContext::test_default())
669 .await
670 .unwrap();
671 assert_eq!(results.len(), 2);
672 assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
673 for r in &results {
674 assert!(r.result.is_error());
675 }
676 }
677
678 #[tokio::test]
679 async fn parallel_counting_tool_all_called() {
680 let counting = Arc::new(CountingTool::new());
681 let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
682 let calls = vec![
683 ToolCall::new("c1", "counting", json!({})),
684 ToolCall::new("c2", "counting", json!({})),
685 ToolCall::new("c3", "counting", json!({})),
686 ];
687 let executor = ParallelToolExecutor::streaming();
688
689 let results = executor
690 .execute(&tools, &calls, &ToolCallContext::test_default())
691 .await
692 .unwrap();
693 assert_eq!(results.len(), 3);
694 assert_eq!(counting.call_count(), 3);
695 }
696
697 struct StrictArgsTool;
699
700 #[async_trait]
701 impl Tool for StrictArgsTool {
702 fn descriptor(&self) -> ToolDescriptor {
703 ToolDescriptor::new("strict", "strict", "Validates args")
704 }
705
706 fn validate_args(&self, args: &Value) -> Result<(), ToolError> {
707 if args.get("required_field").is_none() {
708 return Err(ToolError::InvalidArguments("missing required_field".into()));
709 }
710 Ok(())
711 }
712
713 async fn execute(
714 &self,
715 args: Value,
716 _ctx: &ToolCallContext,
717 ) -> Result<ToolOutput, ToolError> {
718 Ok(ToolResult::success("strict", args).into())
719 }
720 }
721
722 #[tokio::test]
723 async fn sequential_validates_args_before_execute() {
724 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
725 let calls = vec![ToolCall::new("c1", "strict", json!({}))]; let executor = SequentialToolExecutor;
727
728 let results = executor
729 .execute(&tools, &calls, &ToolCallContext::test_default())
730 .await
731 .unwrap();
732 assert_eq!(results.len(), 1);
733 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
734 assert!(results[0].result.is_error());
735 }
736
737 #[tokio::test]
738 async fn sequential_valid_args_succeeds() {
739 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
740 let calls = vec![ToolCall::new(
741 "c1",
742 "strict",
743 json!({"required_field": "present"}),
744 )];
745 let executor = SequentialToolExecutor;
746
747 let results = executor
748 .execute(&tools, &calls, &ToolCallContext::test_default())
749 .await
750 .unwrap();
751 assert_eq!(results.len(), 1);
752 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
753 }
754
755 #[tokio::test]
756 async fn parallel_validates_args_before_execute() {
757 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
758 let calls = vec![ToolCall::new("c1", "strict", json!({}))];
759 let executor = ParallelToolExecutor::streaming();
760
761 let results = executor
762 .execute(&tools, &calls, &ToolCallContext::test_default())
763 .await
764 .unwrap();
765 assert_eq!(results.len(), 1);
766 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
767 }
768
769 #[test]
770 fn tool_execution_result_fields() {
771 let result = ToolExecutionResult {
772 call: ToolCall::new("c1", "echo", json!({})),
773 result: ToolResult::success("echo", json!({"ok": true})),
774 outcome: ToolCallOutcome::Succeeded,
775 command: StateCommand::default(),
776 };
777 assert_eq!(result.call.id, "c1");
778 assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
779 }
780
781 #[test]
782 fn tool_executor_error_display() {
783 let err = ToolExecutorError::Cancelled;
784 assert!(err.to_string().contains("cancelled"));
785 let err2 = ToolExecutorError::Failed("some reason".into());
786 assert!(err2.to_string().contains("some reason"));
787 }
788
789 struct ContextCaptureTool {
795 captured_call_id: Arc<std::sync::Mutex<String>>,
796 captured_tool_name: Arc<std::sync::Mutex<String>>,
797 }
798
799 impl ContextCaptureTool {
800 fn new() -> Self {
801 Self {
802 captured_call_id: Arc::new(std::sync::Mutex::new(String::new())),
803 captured_tool_name: Arc::new(std::sync::Mutex::new(String::new())),
804 }
805 }
806 }
807
808 #[async_trait]
809 impl Tool for ContextCaptureTool {
810 fn descriptor(&self) -> ToolDescriptor {
811 ToolDescriptor::new("capture", "capture", "Captures context")
812 }
813
814 async fn execute(
815 &self,
816 _args: Value,
817 ctx: &ToolCallContext,
818 ) -> Result<ToolOutput, ToolError> {
819 *self.captured_call_id.lock().unwrap() = ctx.call_id.clone();
820 *self.captured_tool_name.lock().unwrap() = ctx.tool_name.clone();
821 Ok(ToolResult::success("capture", json!({"captured": true})).into())
822 }
823 }
824
825 #[tokio::test]
826 async fn execute_single_tool_preserves_call_context() {
827 let capture = Arc::new(ContextCaptureTool::new());
828 let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
829 let call = ToolCall::new("call-42", "capture", json!({}));
830 let ctx = ToolCallContext::test_default();
831
832 let output = execute_single_tool(&tools, &call, &ctx).await;
833 assert!(output.result.is_success());
834 }
837
838 #[tokio::test]
839 async fn execute_single_tool_missing_returns_error_with_name() {
840 let tools: HashMap<String, Arc<dyn Tool>> = HashMap::new();
841 let call = ToolCall::new("c1", "ghost_tool", json!({}));
842 let ctx = ToolCallContext::test_default();
843
844 let output = execute_single_tool(&tools, &call, &ctx).await;
845 assert!(output.result.is_error());
846 assert!(
847 output
848 .result
849 .message
850 .as_deref()
851 .unwrap_or("")
852 .contains("ghost_tool")
853 );
854 }
855
856 #[tokio::test]
857 async fn execute_single_tool_validates_args() {
858 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
859 let call = ToolCall::new("c1", "strict", json!({"wrong": "field"}));
860 let ctx = ToolCallContext::test_default();
861
862 let output = execute_single_tool(&tools, &call, &ctx).await;
863 assert!(output.result.is_error());
864 }
865
866 #[tokio::test]
867 async fn sequential_context_call_id_set_per_tool() {
868 let capture = Arc::new(ContextCaptureTool::new());
869 let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
870 let calls = vec![ToolCall::new("unique-id-99", "capture", json!({}))];
871 let executor = SequentialToolExecutor;
872
873 let results = executor
874 .execute(&tools, &calls, &ToolCallContext::test_default())
875 .await
876 .unwrap();
877 assert_eq!(results.len(), 1);
878 assert_eq!(results[0].call.id, "unique-id-99");
879 assert_eq!(*capture.captured_call_id.lock().unwrap(), "unique-id-99");
881 }
882
883 #[tokio::test]
884 async fn sequential_mixed_success_failure_suspension_order() {
885 let tools = tool_map(vec![
886 Arc::new(EchoTool),
887 Arc::new(FailingTool),
888 Arc::new(SuspendingTool),
889 ]);
890 let calls = vec![
892 ToolCall::new("c1", "echo", json!({"message": "hi"})),
893 ToolCall::new("c2", "failing", json!({})),
894 ToolCall::new("c3", "suspending", json!({})),
895 ToolCall::new("c4", "echo", json!({"message": "should not run"})),
896 ];
897 let executor = SequentialToolExecutor;
898
899 let results = executor
900 .execute(&tools, &calls, &ToolCallContext::test_default())
901 .await
902 .unwrap();
903 assert_eq!(results.len(), 3, "stops after suspension at c3");
904 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
905 assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
906 assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
907 }
908
909 #[tokio::test]
910 async fn parallel_preserves_result_order() {
911 let counting = Arc::new(CountingTool::new());
912 let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
913 let calls: Vec<_> = (0..5)
914 .map(|i| ToolCall::new(format!("c{i}"), "counting", json!({})))
915 .collect();
916 let executor = ParallelToolExecutor::streaming();
917
918 let results = executor
919 .execute(&tools, &calls, &ToolCallContext::test_default())
920 .await
921 .unwrap();
922 assert_eq!(results.len(), 5);
923 for (i, r) in results.iter().enumerate() {
925 assert_eq!(r.call.id, format!("c{i}"));
926 }
927 }
928
929 #[tokio::test]
930 async fn parallel_mixed_success_failure_suspension() {
931 let tools = tool_map(vec![
932 Arc::new(EchoTool),
933 Arc::new(FailingTool),
934 Arc::new(SuspendingTool),
935 ]);
936 let calls = vec![
937 ToolCall::new("c1", "echo", json!({"message": "hi"})),
938 ToolCall::new("c2", "failing", json!({})),
939 ToolCall::new("c3", "suspending", json!({})),
940 ];
941 let executor = ParallelToolExecutor::batch_approval();
942
943 let results = executor
944 .execute(&tools, &calls, &ToolCallContext::test_default())
945 .await
946 .unwrap();
947 assert_eq!(results.len(), 3, "parallel runs all regardless");
948 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
949 assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
950 assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
951 }
952
953 #[test]
954 fn sequential_requires_incremental_state() {
955 let executor = SequentialToolExecutor;
956 assert!(executor.requires_incremental_state());
957 }
958
959 #[test]
960 fn parallel_does_not_require_incremental_state() {
961 let executor = ParallelToolExecutor::streaming();
962 assert!(!executor.requires_incremental_state());
963 let batch = ParallelToolExecutor::batch_approval();
964 assert!(!batch.requires_incremental_state());
965 }
966
967 #[tokio::test]
968 async fn execute_single_tool_success_returns_correct_tool_name() {
969 let tools = tool_map(vec![Arc::new(EchoTool)]);
970 let call = ToolCall::new("c1", "echo", json!({"message": "test"}));
971 let ctx = ToolCallContext::test_default();
972
973 let output = execute_single_tool(&tools, &call, &ctx).await;
974 assert!(output.result.is_success());
975 assert_eq!(output.result.tool_name, "echo");
976 }
977}