1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use awaken_contract::contract::message::ToolCall;
8use awaken_contract::contract::suspension::ToolCallOutcome;
9use awaken_contract::contract::tool::{Tool, ToolCallContext, ToolOutput, ToolResult};
10use awaken_contract::state::StateCommand;
11
12#[cfg(feature = "background")]
13use crate::extensions::background::{ToolLineageContext, scope_tool_lineage_context};
14
15pub struct ToolExecutionResult {
17 pub call: ToolCall,
18 pub result: ToolResult,
19 pub outcome: ToolCallOutcome,
20 pub command: StateCommand,
22}
23
24#[derive(Debug, thiserror::Error)]
26pub enum ToolExecutorError {
27 #[error("tool execution cancelled")]
28 Cancelled,
29 #[error("tool execution failed: {0}")]
30 Failed(String),
31}
32
33#[async_trait]
35pub trait ToolExecutor: Send + Sync {
36 async fn execute(
38 &self,
39 tools: &HashMap<String, Arc<dyn Tool>>,
40 calls: &[ToolCall],
41 base_ctx: &ToolCallContext,
42 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError>;
43
44 fn name(&self) -> &'static str;
46
47 fn requires_incremental_state(&self) -> bool {
50 false
51 }
52}
53
54#[derive(Debug, Clone, Copy, Default)]
58pub struct SequentialToolExecutor;
59
60#[async_trait]
61impl ToolExecutor for SequentialToolExecutor {
62 async fn execute(
63 &self,
64 tools: &HashMap<String, Arc<dyn Tool>>,
65 calls: &[ToolCall],
66 base_ctx: &ToolCallContext,
67 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
68 let mut results = Vec::with_capacity(calls.len());
69
70 for call in calls {
71 let mut ctx = base_ctx.clone();
72 ctx.call_id = call.id.clone();
73 ctx.tool_name = call.name.clone();
74 let output = execute_single_tool(tools, call, &ctx).await;
75 let outcome = ToolCallOutcome::from_tool_result(&output.result);
76
77 results.push(ToolExecutionResult {
78 call: call.clone(),
79 result: output.result,
80 outcome,
81 command: output.command,
82 });
83
84 if results
86 .last()
87 .is_some_and(|r| r.outcome == ToolCallOutcome::Suspended)
88 {
89 break;
90 }
91 }
92
93 Ok(results)
94 }
95
96 fn name(&self) -> &'static str {
97 "sequential"
98 }
99
100 fn requires_incremental_state(&self) -> bool {
101 true
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum DecisionReplayPolicy {
108 Immediate,
110 BatchAllSuspended,
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum ParallelMode {
117 BatchApproval,
119 Streaming,
121}
122
123#[derive(Debug, Clone, Copy)]
129pub struct ParallelToolExecutor {
130 mode: ParallelMode,
131}
132
133impl ParallelToolExecutor {
134 pub const fn batch_approval() -> Self {
135 Self {
136 mode: ParallelMode::BatchApproval,
137 }
138 }
139
140 pub const fn streaming() -> Self {
141 Self {
142 mode: ParallelMode::Streaming,
143 }
144 }
145
146 pub fn decision_replay_policy(&self) -> DecisionReplayPolicy {
148 match self.mode {
149 ParallelMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
150 ParallelMode::Streaming => DecisionReplayPolicy::Immediate,
151 }
152 }
153
154 pub fn requires_conflict_check(&self) -> bool {
156 true
157 }
158}
159
160impl Default for ParallelToolExecutor {
161 fn default() -> Self {
162 Self::streaming()
163 }
164}
165
166#[async_trait]
167impl ToolExecutor for ParallelToolExecutor {
168 async fn execute(
169 &self,
170 tools: &HashMap<String, Arc<dyn Tool>>,
171 calls: &[ToolCall],
172 base_ctx: &ToolCallContext,
173 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
174 use futures::future::join_all;
175
176 let futures: Vec<_> = calls
177 .iter()
178 .map(|call| {
179 let tools = tools.clone();
180 let call = call.clone();
181 let mut ctx = base_ctx.clone();
182 ctx.call_id = call.id.clone();
183 ctx.tool_name = call.name.clone();
184 async move {
185 let output = execute_single_tool(&tools, &call, &ctx).await;
186 let outcome = ToolCallOutcome::from_tool_result(&output.result);
187 ToolExecutionResult {
188 call,
189 result: output.result,
190 outcome,
191 command: output.command,
192 }
193 }
194 })
195 .collect();
196
197 Ok(join_all(futures).await)
198 }
199
200 fn name(&self) -> &'static str {
201 match self.mode {
202 ParallelMode::BatchApproval => "parallel_batch_approval",
203 ParallelMode::Streaming => "parallel_streaming",
204 }
205 }
206}
207
208pub(crate) async fn execute_single_tool(
210 tools: &HashMap<String, Arc<dyn Tool>>,
211 call: &ToolCall,
212 ctx: &ToolCallContext,
213) -> ToolOutput {
214 let Some(tool) = tools.get(&call.name) else {
215 return ToolResult::error(&call.name, format!("tool '{}' not found", call.name)).into();
216 };
217
218 if let Err(e) = tool.validate_args(&call.arguments) {
219 return ToolResult::error(&call.name, e.to_string()).into();
220 }
221
222 match execute_tool_with_runtime_context(tool, call, ctx).await {
223 Ok(output) => output,
224 Err(e) => ToolResult::error(&call.name, e.to_string()).into(),
225 }
226}
227
228#[cfg(feature = "background")]
229async fn execute_tool_with_runtime_context(
230 tool: &Arc<dyn Tool>,
231 call: &ToolCall,
232 ctx: &ToolCallContext,
233) -> Result<ToolOutput, awaken_contract::contract::tool::ToolError> {
234 scope_tool_lineage_context(
235 ToolLineageContext {
236 run_id: ctx.run_identity.run_id.clone(),
237 call_id: ctx.call_id.clone(),
238 agent_id: ctx.run_identity.agent_id.clone(),
239 },
240 tool.execute(call.arguments.clone(), ctx),
241 )
242 .await
243}
244
245#[cfg(not(feature = "background"))]
246async fn execute_tool_with_runtime_context(
247 tool: &Arc<dyn Tool>,
248 call: &ToolCall,
249 ctx: &ToolCallContext,
250) -> Result<ToolOutput, awaken_contract::contract::tool::ToolError> {
251 tool.execute(call.arguments.clone(), ctx).await
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use awaken_contract::contract::tool::{ToolDescriptor, ToolError, ToolOutput};
258 use serde_json::{Value, json};
259
260 #[cfg(feature = "background")]
261 use crate::extensions::background::{
262 BackgroundTaskManager, BackgroundTaskPlugin, TaskParentContext, TaskResult,
263 };
264 #[cfg(feature = "background")]
265 use crate::phase::ExecutionEnv;
266 #[cfg(feature = "background")]
267 use crate::plugins::Plugin;
268 #[cfg(feature = "background")]
269 use crate::state::StateStore;
270 #[cfg(feature = "background")]
271 use awaken_contract::contract::identity::{RunIdentity, RunOrigin};
272
273 struct EchoTool;
274
275 #[async_trait]
276 impl Tool for EchoTool {
277 fn descriptor(&self) -> ToolDescriptor {
278 ToolDescriptor::new("echo", "echo", "Echoes input")
279 }
280
281 async fn execute(
282 &self,
283 args: Value,
284 _ctx: &ToolCallContext,
285 ) -> Result<ToolOutput, ToolError> {
286 let msg = args
287 .get("message")
288 .and_then(|v| v.as_str())
289 .unwrap_or("no message")
290 .to_string();
291 Ok(ToolResult::success_with_message("echo", args, msg).into())
292 }
293 }
294
295 struct FailingTool;
296
297 #[async_trait]
298 impl Tool for FailingTool {
299 fn descriptor(&self) -> ToolDescriptor {
300 ToolDescriptor::new("failing", "failing", "Always fails")
301 }
302
303 async fn execute(
304 &self,
305 _args: Value,
306 _ctx: &ToolCallContext,
307 ) -> Result<ToolOutput, ToolError> {
308 Err(ToolError::ExecutionFailed("intentional failure".into()))
309 }
310 }
311
312 struct SuspendingTool;
313
314 #[async_trait]
315 impl Tool for SuspendingTool {
316 fn descriptor(&self) -> ToolDescriptor {
317 ToolDescriptor::new("suspending", "suspending", "Returns pending")
318 }
319
320 async fn execute(
321 &self,
322 _args: Value,
323 _ctx: &ToolCallContext,
324 ) -> Result<ToolOutput, ToolError> {
325 Ok(ToolResult::suspended("suspending", "needs approval").into())
326 }
327 }
328
329 fn tool_map(tools: Vec<Arc<dyn Tool>>) -> HashMap<String, Arc<dyn Tool>> {
330 tools.into_iter().map(|t| (t.descriptor().id, t)).collect()
331 }
332
333 #[tokio::test]
336 async fn sequential_single_tool_success() {
337 let tools = tool_map(vec![Arc::new(EchoTool)]);
338 let calls = vec![ToolCall::new("c1", "echo", json!({"message": "hi"}))];
339 let executor = SequentialToolExecutor;
340
341 let results = executor
342 .execute(&tools, &calls, &ToolCallContext::test_default())
343 .await
344 .unwrap();
345 assert_eq!(results.len(), 1);
346 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
347 assert!(results[0].result.is_success());
348 }
349
350 #[tokio::test]
351 async fn sequential_partial_failure() {
352 let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
353 let calls = vec![
354 ToolCall::new("c1", "echo", json!({"message": "ok"})),
355 ToolCall::new("c2", "failing", json!({})),
356 ];
357 let executor = SequentialToolExecutor;
358
359 let results = executor
360 .execute(&tools, &calls, &ToolCallContext::test_default())
361 .await
362 .unwrap();
363 assert_eq!(results.len(), 2);
364 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
365 assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
366 }
367
368 #[tokio::test]
369 async fn sequential_stops_after_first_suspension() {
370 let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
371 let calls = vec![
372 ToolCall::new("c1", "suspending", json!({})),
373 ToolCall::new("c2", "echo", json!({"message": "should not run"})),
374 ];
375 let executor = SequentialToolExecutor;
376
377 let results = executor
378 .execute(&tools, &calls, &ToolCallContext::test_default())
379 .await
380 .unwrap();
381 assert_eq!(results.len(), 1, "should stop after suspended tool");
382 assert_eq!(results[0].outcome, ToolCallOutcome::Suspended);
383 }
384
385 #[tokio::test]
386 async fn sequential_unknown_tool_returns_error() {
387 let tools = tool_map(vec![]);
388 let calls = vec![ToolCall::new("c1", "nonexistent", json!({}))];
389 let executor = SequentialToolExecutor;
390
391 let results = executor
392 .execute(&tools, &calls, &ToolCallContext::test_default())
393 .await
394 .unwrap();
395 assert_eq!(results.len(), 1);
396 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
397 assert!(results[0].result.is_error());
398 }
399
400 #[tokio::test]
401 async fn sequential_empty_calls() {
402 let tools = tool_map(vec![Arc::new(EchoTool)]);
403 let executor = SequentialToolExecutor;
404
405 let results = executor
406 .execute(&tools, &[], &ToolCallContext::test_default())
407 .await
408 .unwrap();
409 assert!(results.is_empty());
410 }
411
412 #[tokio::test]
415 async fn parallel_all_succeed() {
416 let tools = tool_map(vec![Arc::new(EchoTool)]);
417 let calls = vec![
418 ToolCall::new("c1", "echo", json!({"message": "first"})),
419 ToolCall::new("c2", "echo", json!({"message": "second"})),
420 ];
421 let executor = ParallelToolExecutor::streaming();
422
423 let results = executor
424 .execute(&tools, &calls, &ToolCallContext::test_default())
425 .await
426 .unwrap();
427 assert_eq!(results.len(), 2);
428 assert!(
429 results
430 .iter()
431 .all(|r| r.outcome == ToolCallOutcome::Succeeded)
432 );
433 }
434
435 #[tokio::test]
436 async fn parallel_partial_failure() {
437 let tools = tool_map(vec![Arc::new(EchoTool), Arc::new(FailingTool)]);
438 let calls = vec![
439 ToolCall::new("c1", "echo", json!({"message": "ok"})),
440 ToolCall::new("c2", "failing", json!({})),
441 ];
442 let executor = ParallelToolExecutor::streaming();
443
444 let results = executor
445 .execute(&tools, &calls, &ToolCallContext::test_default())
446 .await
447 .unwrap();
448 assert_eq!(results.len(), 2);
449 let successes = results
450 .iter()
451 .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
452 .count();
453 let failures = results
454 .iter()
455 .filter(|r| r.outcome == ToolCallOutcome::Failed)
456 .count();
457 assert_eq!(successes, 1);
458 assert_eq!(failures, 1);
459 }
460
461 #[tokio::test]
462 async fn parallel_does_not_stop_on_suspension() {
463 let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
464 let calls = vec![
465 ToolCall::new("c1", "suspending", json!({})),
466 ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
467 ];
468 let executor = ParallelToolExecutor::streaming();
469
470 let results = executor
471 .execute(&tools, &calls, &ToolCallContext::test_default())
472 .await
473 .unwrap();
474 assert_eq!(results.len(), 2);
476 let suspended = results
477 .iter()
478 .filter(|r| r.outcome == ToolCallOutcome::Suspended)
479 .count();
480 let succeeded = results
481 .iter()
482 .filter(|r| r.outcome == ToolCallOutcome::Succeeded)
483 .count();
484 assert_eq!(suspended, 1);
485 assert_eq!(succeeded, 1);
486 }
487
488 #[tokio::test]
489 async fn parallel_empty_calls() {
490 let tools = tool_map(vec![Arc::new(EchoTool)]);
491 let executor = ParallelToolExecutor::streaming();
492
493 let results = executor
494 .execute(&tools, &[], &ToolCallContext::test_default())
495 .await
496 .unwrap();
497 assert!(results.is_empty());
498 }
499
500 #[test]
501 fn executor_names() {
502 assert_eq!(SequentialToolExecutor.name(), "sequential");
503 assert_eq!(
504 ParallelToolExecutor::streaming().name(),
505 "parallel_streaming"
506 );
507 assert_eq!(
508 ParallelToolExecutor::batch_approval().name(),
509 "parallel_batch_approval"
510 );
511 }
512
513 #[test]
514 fn parallel_default_is_streaming() {
515 let executor = ParallelToolExecutor::default();
516 assert_eq!(executor.name(), "parallel_streaming");
517 assert_eq!(
518 executor.decision_replay_policy(),
519 DecisionReplayPolicy::Immediate
520 );
521 }
522
523 #[test]
524 fn parallel_batch_approval_policy() {
525 let executor = ParallelToolExecutor::batch_approval();
526 assert_eq!(
527 executor.decision_replay_policy(),
528 DecisionReplayPolicy::BatchAllSuspended
529 );
530 assert!(executor.requires_conflict_check());
531 }
532
533 #[test]
534 fn parallel_streaming_policy() {
535 let executor = ParallelToolExecutor::streaming();
536 assert_eq!(
537 executor.decision_replay_policy(),
538 DecisionReplayPolicy::Immediate
539 );
540 assert!(executor.requires_conflict_check());
541 }
542
543 #[tokio::test]
544 async fn batch_approval_executes_all_concurrently() {
545 let tools = tool_map(vec![Arc::new(EchoTool)]);
546 let calls = vec![
547 ToolCall::new("c1", "echo", json!({"message": "a"})),
548 ToolCall::new("c2", "echo", json!({"message": "b"})),
549 ];
550 let executor = ParallelToolExecutor::batch_approval();
551
552 let results = executor
553 .execute(&tools, &calls, &ToolCallContext::test_default())
554 .await
555 .unwrap();
556 assert_eq!(results.len(), 2);
557 assert!(
558 results
559 .iter()
560 .all(|r| r.outcome == ToolCallOutcome::Succeeded)
561 );
562 }
563
564 #[tokio::test]
565 async fn batch_approval_does_not_stop_on_suspension() {
566 let tools = tool_map(vec![Arc::new(SuspendingTool), Arc::new(EchoTool)]);
567 let calls = vec![
568 ToolCall::new("c1", "suspending", json!({})),
569 ToolCall::new("c2", "echo", json!({"message": "runs anyway"})),
570 ];
571 let executor = ParallelToolExecutor::batch_approval();
572
573 let results = executor
574 .execute(&tools, &calls, &ToolCallContext::test_default())
575 .await
576 .unwrap();
577 assert_eq!(results.len(), 2);
578 }
579
580 struct CountingTool {
586 count: Arc<std::sync::atomic::AtomicUsize>,
587 }
588
589 impl CountingTool {
590 fn new() -> Self {
591 Self {
592 count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
593 }
594 }
595
596 fn call_count(&self) -> usize {
597 self.count.load(std::sync::atomic::Ordering::SeqCst)
598 }
599 }
600
601 #[async_trait]
602 impl Tool for CountingTool {
603 fn descriptor(&self) -> ToolDescriptor {
604 ToolDescriptor::new("counting", "counting", "Counts calls")
605 }
606
607 async fn execute(
608 &self,
609 _args: Value,
610 _ctx: &ToolCallContext,
611 ) -> Result<ToolOutput, ToolError> {
612 let n = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
613 Ok(ToolResult::success("counting", json!({"call_number": n + 1})).into())
614 }
615 }
616
617 #[tokio::test]
618 async fn sequential_multiple_calls_ordered() {
619 let counting = Arc::new(CountingTool::new());
620 let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
621 let calls = vec![
622 ToolCall::new("c1", "counting", json!({})),
623 ToolCall::new("c2", "counting", json!({})),
624 ToolCall::new("c3", "counting", json!({})),
625 ];
626 let executor = SequentialToolExecutor;
627
628 let results = executor
629 .execute(&tools, &calls, &ToolCallContext::test_default())
630 .await
631 .unwrap();
632 assert_eq!(results.len(), 3);
633 assert_eq!(counting.call_count(), 3);
634 for (i, result) in results.iter().enumerate() {
636 assert_eq!(result.call.id, format!("c{}", i + 1));
637 assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
638 }
639 }
640
641 #[tokio::test]
642 async fn sequential_failure_does_not_stop_execution() {
643 let tools = tool_map(vec![Arc::new(FailingTool), Arc::new(EchoTool)]);
645 let calls = vec![
646 ToolCall::new("c1", "failing", json!({})),
647 ToolCall::new("c2", "echo", json!({"message": "still runs"})),
648 ];
649 let executor = SequentialToolExecutor;
650
651 let results = executor
652 .execute(&tools, &calls, &ToolCallContext::test_default())
653 .await
654 .unwrap();
655 assert_eq!(results.len(), 2);
656 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
657 assert_eq!(results[1].outcome, ToolCallOutcome::Succeeded);
658 }
659
660 #[tokio::test]
661 async fn sequential_suspension_in_middle_stops_remaining() {
662 let tools = tool_map(vec![
663 Arc::new(EchoTool),
664 Arc::new(SuspendingTool),
665 Arc::new(EchoTool),
666 ]);
667 let calls = vec![
668 ToolCall::new("c1", "echo", json!({"message": "first"})),
669 ToolCall::new("c2", "suspending", json!({})),
670 ToolCall::new("c3", "echo", json!({"message": "should not run"})),
671 ];
672 let executor = SequentialToolExecutor;
673
674 let results = executor
675 .execute(&tools, &calls, &ToolCallContext::test_default())
676 .await
677 .unwrap();
678 assert_eq!(results.len(), 2, "should stop after suspension");
679 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
680 assert_eq!(results[1].outcome, ToolCallOutcome::Suspended);
681 }
682
683 #[tokio::test]
684 async fn parallel_all_fail() {
685 let tools = tool_map(vec![Arc::new(FailingTool)]);
686 let calls = vec![
687 ToolCall::new("c1", "failing", json!({})),
688 ToolCall::new("c2", "failing", json!({})),
689 ];
690 let executor = ParallelToolExecutor::streaming();
691
692 let results = executor
693 .execute(&tools, &calls, &ToolCallContext::test_default())
694 .await
695 .unwrap();
696 assert_eq!(results.len(), 2);
697 assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
698 }
699
700 #[tokio::test]
701 async fn parallel_unknown_tool_returns_error() {
702 let tools = tool_map(vec![]);
703 let calls = vec![
704 ToolCall::new("c1", "nonexistent_a", json!({})),
705 ToolCall::new("c2", "nonexistent_b", json!({})),
706 ];
707 let executor = ParallelToolExecutor::streaming();
708
709 let results = executor
710 .execute(&tools, &calls, &ToolCallContext::test_default())
711 .await
712 .unwrap();
713 assert_eq!(results.len(), 2);
714 assert!(results.iter().all(|r| r.outcome == ToolCallOutcome::Failed));
715 for r in &results {
716 assert!(r.result.is_error());
717 }
718 }
719
720 #[tokio::test]
721 async fn parallel_counting_tool_all_called() {
722 let counting = Arc::new(CountingTool::new());
723 let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
724 let calls = vec![
725 ToolCall::new("c1", "counting", json!({})),
726 ToolCall::new("c2", "counting", json!({})),
727 ToolCall::new("c3", "counting", json!({})),
728 ];
729 let executor = ParallelToolExecutor::streaming();
730
731 let results = executor
732 .execute(&tools, &calls, &ToolCallContext::test_default())
733 .await
734 .unwrap();
735 assert_eq!(results.len(), 3);
736 assert_eq!(counting.call_count(), 3);
737 }
738
739 struct StrictArgsTool;
741
742 #[async_trait]
743 impl Tool for StrictArgsTool {
744 fn descriptor(&self) -> ToolDescriptor {
745 ToolDescriptor::new("strict", "strict", "Validates args")
746 }
747
748 fn validate_args(&self, args: &Value) -> Result<(), ToolError> {
749 if args.get("required_field").is_none() {
750 return Err(ToolError::InvalidArguments("missing required_field".into()));
751 }
752 Ok(())
753 }
754
755 async fn execute(
756 &self,
757 args: Value,
758 _ctx: &ToolCallContext,
759 ) -> Result<ToolOutput, ToolError> {
760 Ok(ToolResult::success("strict", args).into())
761 }
762 }
763
764 #[tokio::test]
765 async fn sequential_validates_args_before_execute() {
766 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
767 let calls = vec![ToolCall::new("c1", "strict", json!({}))]; let executor = SequentialToolExecutor;
769
770 let results = executor
771 .execute(&tools, &calls, &ToolCallContext::test_default())
772 .await
773 .unwrap();
774 assert_eq!(results.len(), 1);
775 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
776 assert!(results[0].result.is_error());
777 }
778
779 #[tokio::test]
780 async fn sequential_valid_args_succeeds() {
781 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
782 let calls = vec![ToolCall::new(
783 "c1",
784 "strict",
785 json!({"required_field": "present"}),
786 )];
787 let executor = SequentialToolExecutor;
788
789 let results = executor
790 .execute(&tools, &calls, &ToolCallContext::test_default())
791 .await
792 .unwrap();
793 assert_eq!(results.len(), 1);
794 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
795 }
796
797 #[tokio::test]
798 async fn parallel_validates_args_before_execute() {
799 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
800 let calls = vec![ToolCall::new("c1", "strict", json!({}))];
801 let executor = ParallelToolExecutor::streaming();
802
803 let results = executor
804 .execute(&tools, &calls, &ToolCallContext::test_default())
805 .await
806 .unwrap();
807 assert_eq!(results.len(), 1);
808 assert_eq!(results[0].outcome, ToolCallOutcome::Failed);
809 }
810
811 #[test]
812 fn tool_execution_result_fields() {
813 let result = ToolExecutionResult {
814 call: ToolCall::new("c1", "echo", json!({})),
815 result: ToolResult::success("echo", json!({"ok": true})),
816 outcome: ToolCallOutcome::Succeeded,
817 command: StateCommand::default(),
818 };
819 assert_eq!(result.call.id, "c1");
820 assert_eq!(result.outcome, ToolCallOutcome::Succeeded);
821 }
822
823 #[test]
824 fn tool_executor_error_display() {
825 let err = ToolExecutorError::Cancelled;
826 assert!(err.to_string().contains("cancelled"));
827 let err2 = ToolExecutorError::Failed("some reason".into());
828 assert!(err2.to_string().contains("some reason"));
829 }
830
831 struct ContextCaptureTool {
837 captured_call_id: Arc<std::sync::Mutex<String>>,
838 captured_tool_name: Arc<std::sync::Mutex<String>>,
839 }
840
841 impl ContextCaptureTool {
842 fn new() -> Self {
843 Self {
844 captured_call_id: Arc::new(std::sync::Mutex::new(String::new())),
845 captured_tool_name: Arc::new(std::sync::Mutex::new(String::new())),
846 }
847 }
848 }
849
850 #[async_trait]
851 impl Tool for ContextCaptureTool {
852 fn descriptor(&self) -> ToolDescriptor {
853 ToolDescriptor::new("capture", "capture", "Captures context")
854 }
855
856 async fn execute(
857 &self,
858 _args: Value,
859 ctx: &ToolCallContext,
860 ) -> Result<ToolOutput, ToolError> {
861 *self.captured_call_id.lock().unwrap() = ctx.call_id.clone();
862 *self.captured_tool_name.lock().unwrap() = ctx.tool_name.clone();
863 Ok(ToolResult::success("capture", json!({"captured": true})).into())
864 }
865 }
866
867 #[tokio::test]
868 async fn execute_single_tool_preserves_call_context() {
869 let capture = Arc::new(ContextCaptureTool::new());
870 let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
871 let call = ToolCall::new("call-42", "capture", json!({}));
872 let ctx = ToolCallContext::test_default();
873
874 let output = execute_single_tool(&tools, &call, &ctx).await;
875 assert!(output.result.is_success());
876 }
879
880 #[tokio::test]
881 async fn execute_single_tool_missing_returns_error_with_name() {
882 let tools: HashMap<String, Arc<dyn Tool>> = HashMap::new();
883 let call = ToolCall::new("c1", "ghost_tool", json!({}));
884 let ctx = ToolCallContext::test_default();
885
886 let output = execute_single_tool(&tools, &call, &ctx).await;
887 assert!(output.result.is_error());
888 assert!(
889 output
890 .result
891 .message
892 .as_deref()
893 .unwrap_or("")
894 .contains("ghost_tool")
895 );
896 }
897
898 #[tokio::test]
899 async fn execute_single_tool_validates_args() {
900 let tools = tool_map(vec![Arc::new(StrictArgsTool)]);
901 let call = ToolCall::new("c1", "strict", json!({"wrong": "field"}));
902 let ctx = ToolCallContext::test_default();
903
904 let output = execute_single_tool(&tools, &call, &ctx).await;
905 assert!(output.result.is_error());
906 }
907
908 #[tokio::test]
909 async fn sequential_context_call_id_set_per_tool() {
910 let capture = Arc::new(ContextCaptureTool::new());
911 let tools = tool_map(vec![capture.clone() as Arc<dyn Tool>]);
912 let calls = vec![ToolCall::new("unique-id-99", "capture", json!({}))];
913 let executor = SequentialToolExecutor;
914
915 let results = executor
916 .execute(&tools, &calls, &ToolCallContext::test_default())
917 .await
918 .unwrap();
919 assert_eq!(results.len(), 1);
920 assert_eq!(results[0].call.id, "unique-id-99");
921 assert_eq!(*capture.captured_call_id.lock().unwrap(), "unique-id-99");
923 }
924
925 #[tokio::test]
926 async fn sequential_mixed_success_failure_suspension_order() {
927 let tools = tool_map(vec![
928 Arc::new(EchoTool),
929 Arc::new(FailingTool),
930 Arc::new(SuspendingTool),
931 ]);
932 let calls = vec![
934 ToolCall::new("c1", "echo", json!({"message": "hi"})),
935 ToolCall::new("c2", "failing", json!({})),
936 ToolCall::new("c3", "suspending", json!({})),
937 ToolCall::new("c4", "echo", json!({"message": "should not run"})),
938 ];
939 let executor = SequentialToolExecutor;
940
941 let results = executor
942 .execute(&tools, &calls, &ToolCallContext::test_default())
943 .await
944 .unwrap();
945 assert_eq!(results.len(), 3, "stops after suspension at c3");
946 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
947 assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
948 assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
949 }
950
951 #[tokio::test]
952 async fn parallel_preserves_result_order() {
953 let counting = Arc::new(CountingTool::new());
954 let tools = tool_map(vec![counting.clone() as Arc<dyn Tool>]);
955 let calls: Vec<_> = (0..5)
956 .map(|i| ToolCall::new(format!("c{i}"), "counting", json!({})))
957 .collect();
958 let executor = ParallelToolExecutor::streaming();
959
960 let results = executor
961 .execute(&tools, &calls, &ToolCallContext::test_default())
962 .await
963 .unwrap();
964 assert_eq!(results.len(), 5);
965 for (i, r) in results.iter().enumerate() {
967 assert_eq!(r.call.id, format!("c{i}"));
968 }
969 }
970
971 #[tokio::test]
972 async fn parallel_mixed_success_failure_suspension() {
973 let tools = tool_map(vec![
974 Arc::new(EchoTool),
975 Arc::new(FailingTool),
976 Arc::new(SuspendingTool),
977 ]);
978 let calls = vec![
979 ToolCall::new("c1", "echo", json!({"message": "hi"})),
980 ToolCall::new("c2", "failing", json!({})),
981 ToolCall::new("c3", "suspending", json!({})),
982 ];
983 let executor = ParallelToolExecutor::batch_approval();
984
985 let results = executor
986 .execute(&tools, &calls, &ToolCallContext::test_default())
987 .await
988 .unwrap();
989 assert_eq!(results.len(), 3, "parallel runs all regardless");
990 assert_eq!(results[0].outcome, ToolCallOutcome::Succeeded);
991 assert_eq!(results[1].outcome, ToolCallOutcome::Failed);
992 assert_eq!(results[2].outcome, ToolCallOutcome::Suspended);
993 }
994
995 #[test]
996 fn sequential_requires_incremental_state() {
997 let executor = SequentialToolExecutor;
998 assert!(executor.requires_incremental_state());
999 }
1000
1001 #[test]
1002 fn parallel_does_not_require_incremental_state() {
1003 let executor = ParallelToolExecutor::streaming();
1004 assert!(!executor.requires_incremental_state());
1005 let batch = ParallelToolExecutor::batch_approval();
1006 assert!(!batch.requires_incremental_state());
1007 }
1008
1009 #[tokio::test]
1010 async fn execute_single_tool_success_returns_correct_tool_name() {
1011 let tools = tool_map(vec![Arc::new(EchoTool)]);
1012 let call = ToolCall::new("c1", "echo", json!({"message": "test"}));
1013 let ctx = ToolCallContext::test_default();
1014
1015 let output = execute_single_tool(&tools, &call, &ctx).await;
1016 assert!(output.result.is_success());
1017 assert_eq!(output.result.tool_name, "echo");
1018 }
1019
1020 #[cfg(feature = "background")]
1021 struct SpawnBackgroundTaskTool {
1022 manager: Arc<BackgroundTaskManager>,
1023 spawned_task_id: Arc<std::sync::Mutex<Option<String>>>,
1024 }
1025
1026 #[cfg(feature = "background")]
1027 #[async_trait]
1028 impl Tool for SpawnBackgroundTaskTool {
1029 fn descriptor(&self) -> ToolDescriptor {
1030 ToolDescriptor::new(
1031 "spawn_background",
1032 "spawn_background",
1033 "Spawns a background task",
1034 )
1035 }
1036
1037 async fn execute(
1038 &self,
1039 _args: Value,
1040 _ctx: &ToolCallContext,
1041 ) -> Result<ToolOutput, ToolError> {
1042 let task_id = self
1043 .manager
1044 .spawn(
1045 "thread-1",
1046 "background",
1047 None,
1048 "spawned from tool",
1049 TaskParentContext::default(),
1050 |_| async { TaskResult::Success(json!({"ok": true})) },
1051 )
1052 .await
1053 .map_err(|error| ToolError::ExecutionFailed(error.to_string()))?;
1054 *self.spawned_task_id.lock().unwrap() = Some(task_id.clone());
1055 Ok(ToolResult::success("spawn_background", json!({"task_id": task_id})).into())
1056 }
1057 }
1058
1059 #[cfg(feature = "background")]
1060 #[tokio::test]
1061 async fn execute_single_tool_scopes_tool_lineage_for_background_spawns() {
1062 let store = StateStore::new();
1063 let manager = Arc::new(BackgroundTaskManager::new());
1064 manager.set_store(store.clone());
1065 let plugin: Arc<dyn Plugin> = Arc::new(BackgroundTaskPlugin::new(manager.clone()));
1066 let env = ExecutionEnv::from_plugins(&[plugin], &Default::default()).unwrap();
1067 store.register_keys(&env.key_registrations).unwrap();
1068
1069 let spawned_task_id = Arc::new(std::sync::Mutex::new(None::<String>));
1070 let tools = tool_map(vec![Arc::new(SpawnBackgroundTaskTool {
1071 manager: manager.clone(),
1072 spawned_task_id: spawned_task_id.clone(),
1073 }) as Arc<dyn Tool>]);
1074 let executor = SequentialToolExecutor;
1075 let calls = vec![ToolCall::new("call-77", "spawn_background", json!({}))];
1076
1077 let mut ctx = ToolCallContext::test_default();
1078 ctx.run_identity = RunIdentity::new(
1079 "thread-1".into(),
1080 None,
1081 "run-77".into(),
1082 None,
1083 "agent-77".into(),
1084 RunOrigin::User,
1085 );
1086 ctx.cancellation_token = Some(crate::cancellation::CancellationToken::new());
1087
1088 let results = executor.execute(&tools, &calls, &ctx).await.unwrap();
1089 assert_eq!(results.len(), 1);
1090 assert!(results[0].result.is_success());
1091
1092 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1093
1094 let task_id = spawned_task_id
1095 .lock()
1096 .unwrap()
1097 .clone()
1098 .expect("spawned task id should be recorded");
1099 let summary = manager
1100 .get(&task_id)
1101 .await
1102 .expect("spawned background task should be queryable");
1103 assert_eq!(summary.parent_context.run_id.as_deref(), Some("run-77"));
1104 assert_eq!(summary.parent_context.call_id.as_deref(), Some("call-77"));
1105 assert_eq!(summary.parent_context.agent_id.as_deref(), Some("agent-77"));
1106 }
1107}