1use std::future::Future;
2use std::sync::Arc;
3use std::time::Instant;
4
5use futures_util::stream::{FuturesUnordered, StreamExt};
6use tokio::sync::mpsc;
7
8use crate::plugin::{
9 PluginDirective, PluginSession, ToolCallHookContext, ToolHookHost, ToolResultHookContext,
10 emit_plugin_surface_events,
11};
12use crate::tool_executor::execute_tool_call;
13use crate::tool_schema::validate_tool_input;
14use crate::{
15 ProgressSender, SessionEvent, ToolCallRecord, ToolContext, ToolExecutionMode, ToolFailure,
16 ToolFailureClass, ToolManifest, ToolProvider, ToolResult, ToolSurface, TurnInjectionBridge,
17};
18
19#[derive(Clone)]
20pub struct ToolDispatchContext {
21 pub plugins: Arc<PluginSession>,
22 pub tools: Arc<dyn ToolProvider>,
23 pub surface: Arc<ToolSurface>,
24 pub host: Arc<dyn ToolHookHost>,
25 pub session_id: String,
26 pub event_tx: mpsc::Sender<SessionEvent>,
27 pub turn_injection_bridge: TurnInjectionBridge,
28 pub attachment_store: Arc<dyn crate::AttachmentStore>,
29 pub turn_context: crate::TurnContext,
30}
31
32#[derive(Clone)]
33pub(crate) struct ToolDispatchOutcome {
34 pub record: ToolCallRecord,
35}
36
37#[derive(Clone)]
38pub struct ParallelToolCallSpec {
39 pub index: usize,
40 pub tool_name: String,
41 pub args: serde_json::Value,
42}
43
44#[derive(Clone)]
45pub struct ParallelToolCallOutcome {
46 pub index: usize,
47 pub record: ToolCallRecord,
48}
49
50pub(crate) async fn dispatch_tool_call(
51 context: &ToolDispatchContext,
52 tool_name: String,
53 args: serde_json::Value,
54 progress: Option<&ProgressSender>,
55) -> ToolDispatchOutcome {
56 let tool_context = ToolContext::new(
57 context.session_id.clone(),
58 Arc::clone(&context.host),
59 context.turn_context.clone(),
60 Arc::clone(&context.attachment_store),
61 None,
62 );
63 dispatch_tool_call_with_execution_context(context, tool_name, args, progress, tool_context)
64 .await
65}
66
67pub(crate) async fn dispatch_tool_call_with_execution_context(
68 context: &ToolDispatchContext,
69 tool_name: String,
70 args: serde_json::Value,
71 progress: Option<&ProgressSender>,
72 tool_context: ToolContext,
73) -> ToolDispatchOutcome {
74 let Some(manifest) = resolve_callable_manifest(context, &tool_name) else {
75 return outcome(
76 tool_name,
77 args,
78 runtime_failure(
79 ToolFailureClass::Unavailable,
80 "tool_unavailable",
81 "Tool is unavailable in this session",
82 ),
83 0,
84 );
85 };
86 let mut args = args;
87
88 let directives = match context
89 .plugins
90 .before_tool_call(ToolCallHookContext::new(
91 context.session_id.clone(),
92 tool_name.clone(),
93 args.clone(),
94 context.turn_context.clone(),
95 Arc::clone(&context.host),
96 ))
97 .await
98 {
99 Ok(directives) => directives,
100 Err(err) => {
101 return outcome(
102 tool_name,
103 args,
104 runtime_failure(
105 ToolFailureClass::Internal,
106 "before_tool_call_failed",
107 err.to_string(),
108 ),
109 0,
110 );
111 }
112 };
113
114 let mut short_circuit: Option<ToolResult> = None;
115 for emitted in directives {
116 let plugin_id = emitted.plugin_id;
117 let directive = emitted.value;
118 match directive {
119 PluginDirective::CreateSession { request } => {
120 if let Err(err) = context.host.create_session(*request).await {
121 short_circuit = Some(ToolResult::err_fmt(err.to_string()));
122 break;
123 }
124 }
125 PluginDirective::HandoffSession { .. } => {
126 short_circuit = Some(ToolResult::err_fmt(
127 "before_tool_call does not support session handoff",
128 ));
129 break;
130 }
131 PluginDirective::ReplaceToolArgs { args: replacement } => {
132 args = replacement;
133 }
134 PluginDirective::ShortCircuitTool { output } => {
135 short_circuit = Some(ToolResult::from_output(output));
136 }
137 PluginDirective::AbortTurn { message, .. } => {
138 short_circuit = Some(ToolResult::err_fmt(message));
139 }
140 PluginDirective::EmitEvents { events } => {
141 emit_plugin_surface_events(&context.event_tx, &plugin_id, events).await;
142 }
143 PluginDirective::EmitTrace {
144 name,
145 payload,
146 context: trace_context,
147 } => {
148 if let Err(err) = context
149 .host
150 .emit_trace_event(
151 *trace_context,
152 lash_trace::TraceEvent::Custom {
153 name: format!("plugin.{plugin_id}.{name}"),
154 payload,
155 },
156 )
157 .await
158 {
159 short_circuit = Some(ToolResult::err_fmt(err.to_string()));
160 break;
161 }
162 }
163 PluginDirective::EnqueueMessages { .. } => {
164 short_circuit = Some(ToolResult::err_fmt(
165 "before_tool_call does not support message injection",
166 ));
167 }
168 }
169 }
170 if let Some(result) = short_circuit {
171 return outcome(tool_name, args, result, 0);
172 }
173
174 let contract = context
175 .plugins
176 .mode_native_tools()
177 .iter()
178 .find_map(|provider| provider.resolve_contract(&tool_name))
179 .or_else(|| context.tools.resolve_contract(&tool_name));
180 let Some(contract) = contract else {
181 return outcome(
182 tool_name,
183 args,
184 runtime_failure(
185 ToolFailureClass::Unavailable,
186 "tool_contract_unavailable",
187 "Tool contract is unavailable in this session",
188 ),
189 0,
190 );
191 };
192 if let Err(err) = validate_tool_input(&contract, &args) {
193 return outcome(
194 tool_name,
195 args,
196 runtime_failure(ToolFailureClass::InvalidRequest, "invalid_tool_args", err),
197 0,
198 );
199 }
200
201 let tool_start = Instant::now();
202 let result = execute_tool_call(
203 context,
204 &manifest,
205 &tool_name,
206 &args,
207 progress,
208 tool_context,
209 )
210 .await;
211 let duration_ms = tool_start.elapsed().as_millis() as u64;
212
213 let result = match context
214 .plugins
215 .after_tool_call(ToolResultHookContext::new(
216 context.session_id.clone(),
217 tool_name.clone(),
218 args.clone(),
219 result.clone(),
220 duration_ms,
221 context.turn_context.clone(),
222 Arc::clone(&context.host),
223 ))
224 .await
225 {
226 Ok(directives) => {
227 let mut final_result = result;
228 for emitted in directives {
229 let plugin_id = emitted.plugin_id;
230 let directive = emitted.value;
231 match directive {
232 PluginDirective::CreateSession { request } => {
233 if let Err(err) = context.host.create_session(*request).await {
234 final_result = ToolResult::failure(ToolFailure::runtime(
235 ToolFailureClass::Internal,
236 "plugin_session_create_failed",
237 err.to_string(),
238 ));
239 break;
240 }
241 }
242 PluginDirective::HandoffSession { .. } => {
243 final_result =
244 ToolResult::err_fmt("after_tool_call does not support session handoff");
245 break;
246 }
247 PluginDirective::ShortCircuitTool { output } => {
248 final_result = ToolResult::from_output(output);
249 }
250 PluginDirective::AbortTurn { message, .. } => {
251 final_result = ToolResult::err_fmt(message);
252 }
253 PluginDirective::EmitEvents { events } => {
254 emit_plugin_surface_events(&context.event_tx, &plugin_id, events).await;
255 }
256 PluginDirective::EmitTrace {
257 name,
258 payload,
259 context: trace_context,
260 } => {
261 if let Err(err) = context
262 .host
263 .emit_trace_event(
264 *trace_context,
265 lash_trace::TraceEvent::Custom {
266 name: format!("plugin.{plugin_id}.{name}"),
267 payload,
268 },
269 )
270 .await
271 {
272 final_result = ToolResult::err_fmt(err.to_string());
273 break;
274 }
275 }
276 PluginDirective::EnqueueMessages { messages } => {
277 if let Err(err) = context.turn_injection_bridge.enqueue(messages) {
278 final_result = ToolResult::err_fmt(err);
279 break;
280 }
281 }
282 PluginDirective::ReplaceToolArgs { .. } => {
283 final_result = ToolResult::err_fmt(
284 "after_tool_call only supports abort, short-circuit, session creation, events, and message injection",
285 );
286 }
287 }
288 }
289 final_result
290 }
291 Err(err) => runtime_failure(
292 ToolFailureClass::Internal,
293 "after_tool_call_failed",
294 err.to_string(),
295 ),
296 };
297
298 outcome(tool_name, args, result, duration_ms)
299}
300
301fn resolve_callable_manifest(
302 context: &ToolDispatchContext,
303 tool_name: &str,
304) -> Option<ToolManifest> {
305 if let Some(entry) = context
306 .surface
307 .tools
308 .iter()
309 .find(|tool| tool.manifest.name == tool_name)
310 {
311 return entry
312 .availability
313 .is_callable()
314 .then(|| entry.manifest.clone());
315 }
316
317 let mode = context.plugins.execution_mode();
318 let visible_and_callable = |manifest: ToolManifest| {
319 if context.plugins.tool_access().hides(&manifest.name) {
320 return None;
321 }
322 manifest
323 .effective_availability(&mode)
324 .is_callable()
325 .then_some(manifest)
326 };
327
328 for provider in context.plugins.mode_native_tools() {
329 if let Some(manifest) = provider
330 .resolve_manifest(tool_name)
331 .and_then(&visible_and_callable)
332 {
333 return Some(manifest);
334 }
335 }
336
337 context
338 .tools
339 .resolve_manifest(tool_name)
340 .and_then(visible_and_callable)
341}
342
343pub(crate) async fn dispatch_parallel_tool_call(
344 context: Arc<ToolDispatchContext>,
345 spec: ParallelToolCallSpec,
346 progress: Option<ProgressSender>,
347) -> ParallelToolCallOutcome {
348 let outcome = dispatch_tool_call(&context, spec.tool_name, spec.args, progress.as_ref()).await;
349 ParallelToolCallOutcome {
350 index: spec.index,
351 record: outcome.record,
352 }
353}
354
355pub(crate) fn resolve_tool_execution_mode(
359 context: &ToolDispatchContext,
360 tool_name: &str,
361) -> ToolExecutionMode {
362 context
363 .surface
364 .tools
365 .iter()
366 .find(|def| def.manifest.name == tool_name)
367 .map(|def| def.manifest.execution_mode)
368 .unwrap_or_default()
369}
370
371pub(crate) async fn schedule_tool_batch<T, O, IndexOf, ModeOf, Run, Fut>(
377 items: Vec<T>,
378 index_of: IndexOf,
379 mode_of: ModeOf,
380 run: Run,
381) -> Vec<O>
382where
383 T: Send + 'static,
384 O: Send + 'static,
385 IndexOf: Fn(&T) -> usize,
386 ModeOf: Fn(&T) -> ToolExecutionMode,
387 Run: Fn(T) -> Fut,
388 Fut: Future<Output = O> + Send,
389{
390 let mut parallel_items = Vec::new();
391 let mut serial_items = Vec::new();
392 for item in items {
393 let index = index_of(&item);
394 match mode_of(&item) {
395 ToolExecutionMode::Parallel => parallel_items.push((index, item)),
396 ToolExecutionMode::Serial => serial_items.push((index, item)),
397 }
398 }
399
400 let mut outcomes = Vec::new();
401
402 let mut pending = FuturesUnordered::new();
403 for (index, item) in parallel_items {
404 let future = run(item);
405 pending.push(async move { (index, future.await) });
406 }
407 while let Some(outcome) = pending.next().await {
408 outcomes.push(outcome);
409 }
410
411 serial_items.sort_by_key(|(index, _)| *index);
412 for (index, item) in serial_items {
413 outcomes.push((index, run(item).await));
414 }
415
416 outcomes.sort_by_key(|(index, _)| *index);
417 outcomes.into_iter().map(|(_, outcome)| outcome).collect()
418}
419
420pub async fn dispatch_parallel_tool_calls(
422 context: Arc<ToolDispatchContext>,
423 specs: Vec<ParallelToolCallSpec>,
424 progress: Option<&ProgressSender>,
425) -> Vec<ParallelToolCallOutcome> {
426 let progress = progress.cloned();
427 schedule_tool_batch(
428 specs,
429 |spec| spec.index,
430 {
431 let context = Arc::clone(&context);
432 move |spec| resolve_tool_execution_mode(&context, &spec.tool_name)
433 },
434 move |spec| dispatch_parallel_tool_call(Arc::clone(&context), spec, progress.clone()),
435 )
436 .await
437}
438
439fn outcome(
440 tool_name: String,
441 args: serde_json::Value,
442 result: ToolResult,
443 duration_ms: u64,
444) -> ToolDispatchOutcome {
445 let record = ToolCallRecord {
446 call_id: None,
447 tool: tool_name,
448 args,
449 output: *result.output,
450 duration_ms,
451 };
452 ToolDispatchOutcome { record }
453}
454
455fn runtime_failure(
456 class: ToolFailureClass,
457 code: impl Into<String>,
458 message: impl Into<String>,
459) -> ToolResult {
460 ToolResult::failure(ToolFailure::runtime(class, code, message))
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use crate::plugin::{PluginHost, StaticPluginFactory};
467 use crate::{
468 ExecutionMode, ToolCall, ToolCallOutcome, ToolProvider, ToolRetryDisposition,
469 ToolRetryPolicy,
470 };
471 use serde_json::json;
472 use std::collections::BTreeMap;
473 use std::sync::atomic::{AtomicUsize, Ordering};
474 use tokio::sync::Barrier;
475 use tokio::time::{Duration, timeout};
476
477 type ExecutionWindow = (&'static str, Instant, Instant);
478 type SharedExecutionWindows = Arc<std::sync::Mutex<Vec<ExecutionWindow>>>;
479 type AttemptObservation = (u32, u32, Option<String>);
480 type SharedAttemptObservations = Arc<std::sync::Mutex<Vec<AttemptObservation>>>;
481
482 fn test_tool(name: &str, execution_mode: ToolExecutionMode) -> crate::ToolDefinition {
483 crate::ToolDefinition::raw(
484 name,
485 "",
486 crate::ToolDefinition::default_input_schema(),
487 json!({ "type": "string" }),
488 )
489 .with_execution_mode(execution_mode)
490 }
491
492 fn beta_tool() -> crate::ToolDefinition {
493 crate::ToolDefinition::raw(
494 "beta",
495 "",
496 json!({
497 "type": "object",
498 "properties": {
499 "value": { "type": "string" }
500 },
501 "required": ["value"],
502 "additionalProperties": false
503 }),
504 json!({ "type": "string" }),
505 )
506 .with_execution_mode(ToolExecutionMode::Parallel)
507 }
508
509 fn named_beta_tool(name: &str) -> crate::ToolDefinition {
510 crate::ToolDefinition::raw(
511 name,
512 "",
513 json!({
514 "type": "object",
515 "properties": {
516 "value": { "type": "string" }
517 },
518 "required": ["value"],
519 "additionalProperties": false
520 }),
521 json!({ "type": "string" }),
522 )
523 .with_execution_mode(ToolExecutionMode::Parallel)
524 }
525
526 fn manifests(definitions: Vec<crate::ToolDefinition>) -> Vec<crate::ToolManifest> {
527 definitions
528 .into_iter()
529 .map(|tool| tool.manifest())
530 .collect()
531 }
532
533 fn contract_from(
534 definitions: Vec<crate::ToolDefinition>,
535 name: &str,
536 ) -> Option<Arc<crate::ToolContract>> {
537 definitions
538 .into_iter()
539 .find(|tool| tool.name == name)
540 .map(|tool| Arc::new(tool.contract()))
541 }
542
543 #[derive(Clone)]
544 struct ScheduledProbe {
545 index: usize,
546 name: &'static str,
547 mode: ToolExecutionMode,
548 delay: Duration,
549 }
550
551 #[tokio::test]
552 async fn scheduler_runs_parallel_bucket_then_serial_and_preserves_order() {
553 let windows: SharedExecutionWindows = Arc::new(std::sync::Mutex::new(Vec::new()));
554 let probes = vec![
555 ScheduledProbe {
556 index: 0,
557 name: "parallel_slow",
558 mode: ToolExecutionMode::Parallel,
559 delay: Duration::from_millis(40),
560 },
561 ScheduledProbe {
562 index: 1,
563 name: "serial",
564 mode: ToolExecutionMode::Serial,
565 delay: Duration::from_millis(10),
566 },
567 ScheduledProbe {
568 index: 2,
569 name: "parallel_fast",
570 mode: ToolExecutionMode::Parallel,
571 delay: Duration::from_millis(5),
572 },
573 ];
574
575 let outputs = schedule_tool_batch(probes, |probe| probe.index, |probe| probe.mode, {
576 let windows = Arc::clone(&windows);
577 move |probe| {
578 let windows = Arc::clone(&windows);
579 async move {
580 let start = Instant::now();
581 tokio::time::sleep(probe.delay).await;
582 let end = Instant::now();
583 windows
584 .lock()
585 .expect("windows")
586 .push((probe.name, start, end));
587 probe.name
588 }
589 }
590 })
591 .await;
592
593 assert_eq!(outputs, ["parallel_slow", "serial", "parallel_fast"]);
594
595 let recorded = windows.lock().expect("windows").clone();
596 let parallel_slow = recorded
597 .iter()
598 .find(|(name, _, _)| *name == "parallel_slow")
599 .expect("parallel_slow");
600 let parallel_fast = recorded
601 .iter()
602 .find(|(name, _, _)| *name == "parallel_fast")
603 .expect("parallel_fast");
604 let serial = recorded
605 .iter()
606 .find(|(name, _, _)| *name == "serial")
607 .expect("serial");
608
609 assert!(
610 parallel_fast.1 < parallel_slow.2,
611 "parallel tools should overlap even when completion order differs"
612 );
613 assert!(
614 serial.1 >= parallel_slow.2 && serial.1 >= parallel_fast.2,
615 "serial tool should start after the parallel bucket completes"
616 );
617 }
618
619 struct MockTools;
620
621 #[async_trait::async_trait]
622 impl ToolProvider for MockTools {
623 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
624 manifests(vec![
625 test_tool("alpha", ToolExecutionMode::Parallel),
626 beta_tool(),
627 ])
628 }
629
630 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
631 contract_from(
632 vec![test_tool("alpha", ToolExecutionMode::Parallel), beta_tool()],
633 name,
634 )
635 }
636
637 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
638 match call.name {
639 "alpha" => ToolResult::ok(json!("alpha")),
640 "beta" => {
641 if call.args.get("value").and_then(|value| value.as_str()) == Some("fail") {
642 ToolResult::err_fmt("beta failed")
643 } else {
644 ToolResult::ok(json!(
645 call.args.get("value").cloned().unwrap_or(json!(null))
646 ))
647 }
648 }
649 other => ToolResult::err_fmt(format!("Unknown tool: {other}")),
650 }
651 }
652 }
653
654 struct ParallelProbeTools {
655 barrier: Arc<Barrier>,
656 started: Arc<AtomicUsize>,
657 }
658
659 #[async_trait::async_trait]
660 impl ToolProvider for ParallelProbeTools {
661 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
662 manifests(vec![
663 test_tool("probe_a", ToolExecutionMode::Parallel),
664 test_tool("probe_b", ToolExecutionMode::Parallel),
665 ])
666 }
667
668 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
669 contract_from(
670 vec![
671 test_tool("probe_a", ToolExecutionMode::Parallel),
672 test_tool("probe_b", ToolExecutionMode::Parallel),
673 ],
674 name,
675 )
676 }
677
678 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
679 self.started.fetch_add(1, Ordering::SeqCst);
680 let waited = timeout(Duration::from_millis(100), self.barrier.wait()).await;
681 match waited {
682 Ok(_) => ToolResult::ok(json!(call.name)),
683 Err(_) => ToolResult::err_fmt(format!("{} did not overlap with peer", call.name)),
684 }
685 }
686 }
687
688 struct StrictMcpTools {
689 executed: Arc<AtomicUsize>,
690 }
691
692 #[async_trait::async_trait]
693 impl ToolProvider for StrictMcpTools {
694 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
695 manifests(vec![strict_mcp_tool_definition()])
696 }
697
698 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
699 (name == "mcp__appworld__venmo_show_transactions")
700 .then(|| Arc::new(strict_mcp_tool_definition().contract()))
701 }
702
703 async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
704 self.executed.fetch_add(1, Ordering::SeqCst);
705 ToolResult::ok(json!({ "executed": true }))
706 }
707 }
708
709 fn strict_mcp_tool_definition() -> crate::ToolDefinition {
710 crate::ToolDefinition::raw(
711 "mcp__appworld__venmo_show_transactions",
712 "Show Venmo transactions",
713 json!({
714 "type": "object",
715 "properties": {
716 "min_created_at": { "type": "string" },
717 "max_created_at": { "type": "string" },
718 "limit": { "type": "integer", "maximum": 100 }
719 },
720 "required": ["limit"]
721 }),
722 json!({ "type": "object", "additionalProperties": true }),
723 )
724 }
725
726 fn strict_mcp_dispatch_context(executed: Arc<AtomicUsize>) -> ToolDispatchContext {
727 let (event_tx, _event_rx) = mpsc::channel(8);
728 let plugins = test_plugins(Arc::new(StrictMcpTools { executed }));
729 let tools = plugins.tools();
730 let surface = plugins.tool_surface("session", ExecutionMode::standard());
731 ToolDispatchContext {
732 plugins,
733 tools,
734 surface,
735 host: Arc::new(MockSessionManager::default()),
736 session_id: "session".to_string(),
737 event_tx,
738 turn_injection_bridge: crate::TurnInjectionBridge::new(),
739 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
740 turn_context: crate::TurnContext::default(),
741 }
742 }
743
744 fn test_plugins(provider: Arc<dyn ToolProvider>) -> Arc<PluginSession> {
745 PluginHost::new(vec![Arc::new(StaticPluginFactory::new(
746 "test_tools",
747 crate::PluginSpec::new().with_tool_provider(Arc::clone(&provider)),
748 ))])
749 .build_standard_session("root", None)
750 .expect("plugin session")
751 }
752
753 use crate::testing::MockSessionManager;
754
755 fn dispatch_context() -> ToolDispatchContext {
756 let (event_tx, _event_rx) = mpsc::channel(8);
757 let plugins = test_plugins(Arc::new(MockTools));
758 let tools = plugins.tools();
759 let surface = plugins.tool_surface("session", ExecutionMode::standard());
760 ToolDispatchContext {
761 plugins,
762 tools,
763 surface,
764 host: Arc::new(MockSessionManager::default()),
765 session_id: "session".to_string(),
766 event_tx,
767 turn_injection_bridge: crate::TurnInjectionBridge::new(),
768 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
769 turn_context: crate::TurnContext::default(),
770 }
771 }
772
773 struct CountingContractTools {
774 contracts_resolved: Arc<AtomicUsize>,
775 executed: Arc<AtomicUsize>,
776 }
777
778 struct ExactDispatchTools {
779 contracts_resolved: Arc<AtomicUsize>,
780 executed: Arc<AtomicUsize>,
781 contract_available: bool,
782 }
783
784 struct HiddenDispatchTools {
785 contracts_resolved: Arc<AtomicUsize>,
786 executed: Arc<AtomicUsize>,
787 }
788
789 struct RetryProbeTools {
790 definition: crate::ToolDefinition,
791 attempts: Arc<AtomicUsize>,
792 successes_after: usize,
793 cancel_on_first: bool,
794 observed_attempts: SharedAttemptObservations,
795 }
796
797 #[async_trait::async_trait]
798 impl ToolProvider for CountingContractTools {
799 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
800 manifests(vec![beta_tool()])
801 }
802
803 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
804 self.contracts_resolved.fetch_add(1, Ordering::SeqCst);
805 (name == "beta").then(|| Arc::new(beta_tool().contract()))
806 }
807
808 async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
809 self.executed.fetch_add(1, Ordering::SeqCst);
810 ToolResult::ok(json!("ok"))
811 }
812 }
813
814 #[async_trait::async_trait]
815 impl ToolProvider for ExactDispatchTools {
816 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
817 Vec::new()
818 }
819
820 fn resolve_manifest(&self, name: &str) -> Option<crate::ToolManifest> {
821 (name == "host_only").then(|| named_beta_tool("host_only").manifest())
822 }
823
824 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
825 self.contracts_resolved.fetch_add(1, Ordering::SeqCst);
826 (self.contract_available && name == "host_only")
827 .then(|| Arc::new(named_beta_tool("host_only").contract()))
828 }
829
830 async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
831 self.executed.fetch_add(1, Ordering::SeqCst);
832 ToolResult::ok(json!("host"))
833 }
834 }
835
836 #[async_trait::async_trait]
837 impl ToolProvider for HiddenDispatchTools {
838 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
839 manifests(vec![
840 named_beta_tool("hidden").with_availability(crate::ToolAvailabilityConfig::off()),
841 ])
842 }
843
844 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
845 self.contracts_resolved.fetch_add(1, Ordering::SeqCst);
846 (name == "hidden").then(|| Arc::new(named_beta_tool("hidden").contract()))
847 }
848
849 async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
850 self.executed.fetch_add(1, Ordering::SeqCst);
851 ToolResult::ok(json!("hidden"))
852 }
853 }
854
855 #[async_trait::async_trait]
856 impl ToolProvider for RetryProbeTools {
857 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
858 manifests(vec![self.definition.clone()])
859 }
860
861 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
862 (name == self.definition.name).then(|| Arc::new(self.definition.contract()))
863 }
864
865 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
866 self.observed_attempts.lock().expect("attempts").push((
867 call.context.attempt_number(),
868 call.context.max_attempts(),
869 call.context.idempotency_key().map(str::to_string),
870 ));
871 let attempt_index = self.attempts.fetch_add(1, Ordering::SeqCst) + 1;
872 if self.cancel_on_first {
873 return ToolResult::cancelled("cancelled");
874 }
875 if attempt_index >= self.successes_after {
876 return ToolResult::ok(json!({ "attempt": attempt_index }));
877 }
878 ToolResult::retryable_failure(
879 crate::ToolFailureClass::External,
880 "transient",
881 "transient failure",
882 Some(0),
883 )
884 }
885 }
886
887 fn lazy_contract_dispatch_context(
888 contracts_resolved: Arc<AtomicUsize>,
889 executed: Arc<AtomicUsize>,
890 ) -> ToolDispatchContext {
891 let (event_tx, _event_rx) = mpsc::channel(8);
892 let provider: Arc<dyn ToolProvider> = Arc::new(CountingContractTools {
893 contracts_resolved,
894 executed,
895 });
896 let tools = Arc::clone(&provider);
897 let surface = Arc::new(crate::ToolSurface::from_tools(
898 provider.tool_manifests(),
899 ExecutionMode::standard(),
900 BTreeMap::new(),
901 ));
902 ToolDispatchContext {
903 plugins: test_plugins(provider),
904 tools,
905 surface,
906 host: Arc::new(MockSessionManager::default()),
907 session_id: "session".to_string(),
908 event_tx,
909 turn_injection_bridge: crate::TurnInjectionBridge::new(),
910 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
911 turn_context: crate::TurnContext::default(),
912 }
913 }
914
915 fn exact_dispatch_context(provider: Arc<dyn ToolProvider>) -> ToolDispatchContext {
916 let (event_tx, _event_rx) = mpsc::channel(8);
917 let plugins = test_plugins(Arc::clone(&provider));
918 let tools = plugins.tools();
919 let surface = plugins.tool_surface("session", ExecutionMode::standard());
920 ToolDispatchContext {
921 plugins,
922 tools,
923 surface,
924 host: Arc::new(MockSessionManager::default()),
925 session_id: "session".to_string(),
926 event_tx,
927 turn_injection_bridge: crate::TurnInjectionBridge::new(),
928 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
929 turn_context: crate::TurnContext::default(),
930 }
931 }
932
933 fn retry_tool(name: &str, retry_policy: ToolRetryPolicy) -> crate::ToolDefinition {
934 named_beta_tool(name)
935 .with_execution_mode(ToolExecutionMode::Parallel)
936 .with_retry_policy(retry_policy)
937 }
938
939 fn retry_dispatch_context(
940 retry_policy: ToolRetryPolicy,
941 attempts: Arc<AtomicUsize>,
942 successes_after: usize,
943 cancel_on_first: bool,
944 observed_attempts: SharedAttemptObservations,
945 ) -> ToolDispatchContext {
946 exact_dispatch_context(Arc::new(RetryProbeTools {
947 definition: retry_tool("retry_probe", retry_policy),
948 attempts,
949 successes_after,
950 cancel_on_first,
951 observed_attempts,
952 }))
953 }
954
955 fn parallel_dispatch_context(
956 barrier: Arc<Barrier>,
957 started: Arc<AtomicUsize>,
958 ) -> ToolDispatchContext {
959 let (event_tx, _event_rx) = mpsc::channel(8);
960 let plugins = test_plugins(Arc::new(ParallelProbeTools { barrier, started }));
961 let tools = plugins.tools();
962 let surface = plugins.tool_surface("session", ExecutionMode::standard());
963 ToolDispatchContext {
964 plugins,
965 tools,
966 surface,
967 host: Arc::new(MockSessionManager::default()),
968 session_id: "session".to_string(),
969 event_tx,
970 turn_injection_bridge: crate::TurnInjectionBridge::new(),
971 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
972 turn_context: crate::TurnContext::default(),
973 }
974 }
975
976 #[tokio::test]
977 async fn dispatch_rejects_invalid_args_before_provider_execution() {
978 let outcome =
979 dispatch_tool_call(&dispatch_context(), "beta".to_string(), json!({}), None).await;
980
981 assert!(!outcome.record.output.is_success());
982 assert_eq!(
983 outcome.record.output.value_for_projection()["message"],
984 json!("value: required property missing")
985 );
986 }
987
988 #[tokio::test]
989 async fn dispatch_resolves_contract_only_for_called_tool_before_execution() {
990 let contracts_resolved = Arc::new(AtomicUsize::new(0));
991 let executed = Arc::new(AtomicUsize::new(0));
992 let outcome = dispatch_tool_call(
993 &lazy_contract_dispatch_context(Arc::clone(&contracts_resolved), Arc::clone(&executed)),
994 "beta".to_string(),
995 json!({ "value": "ok" }),
996 None,
997 )
998 .await;
999
1000 assert!(outcome.record.output.is_success());
1001 assert_eq!(contracts_resolved.load(Ordering::SeqCst), 1);
1002 assert_eq!(executed.load(Ordering::SeqCst), 1);
1003 }
1004
1005 #[tokio::test]
1006 async fn dispatch_exact_resolves_missing_surface_tool_and_executes_owner() {
1007 let contracts_resolved = Arc::new(AtomicUsize::new(0));
1008 let executed = Arc::new(AtomicUsize::new(0));
1009 let provider: Arc<dyn ToolProvider> = Arc::new(ExactDispatchTools {
1010 contracts_resolved: Arc::clone(&contracts_resolved),
1011 executed: Arc::clone(&executed),
1012 contract_available: true,
1013 });
1014 let outcome = dispatch_tool_call(
1015 &exact_dispatch_context(provider),
1016 "host_only".to_string(),
1017 json!({ "value": "ok" }),
1018 None,
1019 )
1020 .await;
1021
1022 assert!(outcome.record.output.is_success());
1023 assert_eq!(outcome.record.output.value_for_projection(), json!("host"));
1024 assert_eq!(contracts_resolved.load(Ordering::SeqCst), 1);
1025 assert_eq!(executed.load(Ordering::SeqCst), 1);
1026 }
1027
1028 #[tokio::test]
1029 async fn dispatch_contract_unavailable_skips_execution() {
1030 let contracts_resolved = Arc::new(AtomicUsize::new(0));
1031 let executed = Arc::new(AtomicUsize::new(0));
1032 let provider: Arc<dyn ToolProvider> = Arc::new(ExactDispatchTools {
1033 contracts_resolved: Arc::clone(&contracts_resolved),
1034 executed: Arc::clone(&executed),
1035 contract_available: false,
1036 });
1037 let outcome = dispatch_tool_call(
1038 &exact_dispatch_context(provider),
1039 "host_only".to_string(),
1040 json!({ "value": "ok" }),
1041 None,
1042 )
1043 .await;
1044
1045 assert!(!outcome.record.output.is_success());
1046 assert_eq!(
1047 outcome.record.output.value_for_projection()["message"],
1048 json!("Tool contract is unavailable in this session")
1049 );
1050 assert_eq!(contracts_resolved.load(Ordering::SeqCst), 1);
1051 assert_eq!(executed.load(Ordering::SeqCst), 0);
1052 }
1053
1054 #[tokio::test]
1055 async fn dispatch_rejects_hidden_tool_before_contract_resolution() {
1056 let contracts_resolved = Arc::new(AtomicUsize::new(0));
1057 let executed = Arc::new(AtomicUsize::new(0));
1058 let provider: Arc<dyn ToolProvider> = Arc::new(HiddenDispatchTools {
1059 contracts_resolved: Arc::clone(&contracts_resolved),
1060 executed: Arc::clone(&executed),
1061 });
1062 let outcome = dispatch_tool_call(
1063 &exact_dispatch_context(provider),
1064 "hidden".to_string(),
1065 json!({ "value": "ok" }),
1066 None,
1067 )
1068 .await;
1069
1070 assert!(!outcome.record.output.is_success());
1071 assert_eq!(
1072 outcome.record.output.value_for_projection()["message"],
1073 json!("Tool is unavailable in this session")
1074 );
1075 assert_eq!(contracts_resolved.load(Ordering::SeqCst), 0);
1076 assert_eq!(executed.load(Ordering::SeqCst), 0);
1077 }
1078
1079 #[tokio::test]
1080 async fn dispatch_rejects_unknown_mcp_args_before_provider_execution() {
1081 let executed = Arc::new(AtomicUsize::new(0));
1082 let outcome = dispatch_tool_call(
1083 &strict_mcp_dispatch_context(Arc::clone(&executed)),
1084 "mcp__appworld__venmo_show_transactions".to_string(),
1085 json!({
1086 "min_datetime": "2024-01-01T00:00:00Z",
1087 "limit": 20
1088 }),
1089 None,
1090 )
1091 .await;
1092
1093 assert!(!outcome.record.output.is_success());
1094 assert_eq!(
1095 outcome.record.output.value_for_projection()["message"],
1096 json!("min_datetime: unexpected property")
1097 );
1098 assert_eq!(executed.load(Ordering::SeqCst), 0);
1099 }
1100
1101 #[tokio::test]
1102 async fn default_retry_policy_never_retries_safe_failures() {
1103 let attempts = Arc::new(AtomicUsize::new(0));
1104 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1105 let outcome = dispatch_tool_call(
1106 &retry_dispatch_context(
1107 ToolRetryPolicy::Never,
1108 Arc::clone(&attempts),
1109 usize::MAX,
1110 false,
1111 Arc::clone(&observed),
1112 ),
1113 "retry_probe".to_string(),
1114 json!({ "value": "ok" }),
1115 None,
1116 )
1117 .await;
1118
1119 assert!(!outcome.record.output.is_success());
1120 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1121 assert_eq!(observed.lock().expect("observed")[0].0, 1);
1122 }
1123
1124 #[tokio::test]
1125 async fn safe_retry_policy_retries_safe_failure_and_stops_on_success() {
1126 let attempts = Arc::new(AtomicUsize::new(0));
1127 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1128 let outcome = dispatch_tool_call(
1129 &retry_dispatch_context(
1130 ToolRetryPolicy::safe(3, 0, 0),
1131 Arc::clone(&attempts),
1132 2,
1133 false,
1134 Arc::clone(&observed),
1135 ),
1136 "retry_probe".to_string(),
1137 json!({ "value": "ok" }),
1138 None,
1139 )
1140 .await;
1141
1142 assert!(outcome.record.output.is_success());
1143 assert_eq!(attempts.load(Ordering::SeqCst), 2);
1144 assert_eq!(
1145 observed
1146 .lock()
1147 .expect("observed")
1148 .iter()
1149 .map(|(attempt, max, _)| (*attempt, *max))
1150 .collect::<Vec<_>>(),
1151 vec![(1, 3), (2, 3)]
1152 );
1153 }
1154
1155 #[tokio::test]
1156 async fn safe_retry_policy_marks_exhausted_after_final_attempt() {
1157 let attempts = Arc::new(AtomicUsize::new(0));
1158 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1159 let outcome = dispatch_tool_call(
1160 &retry_dispatch_context(
1161 ToolRetryPolicy::safe(2, 0, 0),
1162 Arc::clone(&attempts),
1163 usize::MAX,
1164 false,
1165 Arc::clone(&observed),
1166 ),
1167 "retry_probe".to_string(),
1168 json!({ "value": "ok" }),
1169 None,
1170 )
1171 .await;
1172
1173 assert!(!outcome.record.output.is_success());
1174 assert_eq!(attempts.load(Ordering::SeqCst), 2);
1175 let ToolCallOutcome::Failure(failure) = outcome.record.output.outcome else {
1176 panic!("expected failure");
1177 };
1178 assert_eq!(
1179 failure.retry,
1180 ToolRetryDisposition::Exhausted { attempts: 2 }
1181 );
1182 }
1183
1184 #[tokio::test]
1185 async fn cancellation_stops_retry_immediately() {
1186 let attempts = Arc::new(AtomicUsize::new(0));
1187 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1188 let outcome = dispatch_tool_call(
1189 &retry_dispatch_context(
1190 ToolRetryPolicy::safe(3, 0, 0),
1191 Arc::clone(&attempts),
1192 usize::MAX,
1193 true,
1194 Arc::clone(&observed),
1195 ),
1196 "retry_probe".to_string(),
1197 json!({ "value": "ok" }),
1198 None,
1199 )
1200 .await;
1201
1202 assert!(!outcome.record.output.is_success());
1203 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1204 assert!(matches!(
1205 outcome.record.output.outcome,
1206 ToolCallOutcome::Cancelled(_)
1207 ));
1208 }
1209
1210 #[tokio::test]
1211 async fn retry_context_has_stable_idempotency_key_across_attempts() {
1212 let attempts = Arc::new(AtomicUsize::new(0));
1213 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1214 let context = retry_dispatch_context(
1215 ToolRetryPolicy::safe(3, 0, 0),
1216 Arc::clone(&attempts),
1217 3,
1218 false,
1219 Arc::clone(&observed),
1220 );
1221 let tool_context = ToolContext::new(
1222 context.session_id.clone(),
1223 Arc::clone(&context.host),
1224 context.turn_context.clone(),
1225 Arc::clone(&context.attachment_store),
1226 Some("call-1".to_string()),
1227 );
1228 let outcome = dispatch_tool_call_with_execution_context(
1229 &context,
1230 "retry_probe".to_string(),
1231 json!({ "value": "ok" }),
1232 None,
1233 tool_context,
1234 )
1235 .await;
1236
1237 assert!(outcome.record.output.is_success());
1238 let observed = observed.lock().expect("observed");
1239 assert_eq!(observed.len(), 3);
1240 assert_eq!(
1241 observed
1242 .iter()
1243 .map(|(attempt, max, _)| (*attempt, *max))
1244 .collect::<Vec<_>>(),
1245 vec![(1, 3), (2, 3), (3, 3)]
1246 );
1247 let keys = observed
1248 .iter()
1249 .map(|(_, _, key)| key.clone())
1250 .collect::<Vec<_>>();
1251 assert!(keys.iter().all(|key| key == &keys[0]));
1252 assert_eq!(
1253 keys[0].as_deref(),
1254 Some("lash-tool:session:call-1:retry_probe")
1255 );
1256 }
1257
1258 #[tokio::test]
1259 async fn idempotent_retry_policy_requires_stable_key() {
1260 let attempts = Arc::new(AtomicUsize::new(0));
1261 let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1262 let outcome = dispatch_tool_call(
1263 &retry_dispatch_context(
1264 ToolRetryPolicy::idempotent(3, 0, 0),
1265 Arc::clone(&attempts),
1266 usize::MAX,
1267 false,
1268 Arc::clone(&observed),
1269 ),
1270 "retry_probe".to_string(),
1271 json!({ "value": "ok" }),
1272 None,
1273 )
1274 .await;
1275
1276 assert!(!outcome.record.output.is_success());
1277 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1278 assert_eq!(observed.lock().expect("observed")[0].1, 1);
1279 }
1280
1281 #[tokio::test]
1282 async fn batch_executes_nested_calls_and_preserves_partial_failures() {
1283 let outcome = dispatch_tool_call(
1284 &dispatch_context(),
1285 "batch".to_string(),
1286 json!({
1287 "tool_calls": [
1288 {"tool": "alpha", "parameters": {}},
1289 {"tool": "beta", "parameters": {"value": "ok"}},
1290 {"tool": "beta", "parameters": {"value": "fail"}}
1291 ]
1292 }),
1293 None,
1294 )
1295 .await;
1296
1297 assert!(outcome.record.output.is_success());
1298 assert_eq!(outcome.record.tool, "batch");
1299 let value = outcome.record.output.value_for_projection();
1300 let results = value
1301 .get("results")
1302 .and_then(|value| value.as_array())
1303 .expect("results");
1304 assert_eq!(results.len(), 3);
1305 assert_eq!(
1306 results
1307 .iter()
1308 .filter(|item| item.get("success").and_then(|value| value.as_bool()) == Some(true))
1309 .count(),
1310 2
1311 );
1312 assert_eq!(results[0].get("tool"), Some(&json!("alpha")));
1313 assert_eq!(
1314 results[2]
1315 .get("error")
1316 .and_then(|value| value.get("message")),
1317 Some(&json!("beta failed"))
1318 );
1319 }
1320
1321 #[tokio::test]
1322 async fn batch_rejects_nested_batch_as_partial_failure() {
1323 let outcome = dispatch_tool_call(
1324 &dispatch_context(),
1325 "batch".to_string(),
1326 json!({
1327 "tool_calls": [
1328 {"tool": "batch", "parameters": {"tool_calls": []}}
1329 ]
1330 }),
1331 None,
1332 )
1333 .await;
1334
1335 assert!(outcome.record.output.is_success());
1336 let value = outcome.record.output.value_for_projection();
1337 let first = value
1338 .get("results")
1339 .and_then(|value| value.as_array())
1340 .and_then(|items| items.first())
1341 .expect("first result");
1342 assert_eq!(
1343 first.get("error"),
1344 Some(&json!("Tool 'batch' is not allowed inside batch"))
1345 );
1346 }
1347
1348 #[tokio::test]
1349 async fn batch_marks_overflow_calls_as_failures() {
1350 let tool_calls = (0..26)
1351 .map(|_| json!({"tool": "alpha", "parameters": {}}))
1352 .collect::<Vec<_>>();
1353
1354 let outcome = dispatch_tool_call(
1355 &dispatch_context(),
1356 "batch".to_string(),
1357 json!({ "tool_calls": tool_calls }),
1358 None,
1359 )
1360 .await;
1361
1362 assert!(!outcome.record.output.is_success());
1363 let value = outcome.record.output.value_for_projection();
1364 let error = value
1365 .get("message")
1366 .and_then(|value| value.as_str())
1367 .expect("string error result");
1368 assert!(
1369 error.contains("tool_calls") && error.contains("items <= 25"),
1370 "{error}",
1371 );
1372 }
1373
1374 #[tokio::test]
1375 async fn batch_calls_make_progress_concurrently() {
1376 let barrier = Arc::new(Barrier::new(2));
1377 let started = Arc::new(AtomicUsize::new(0));
1378 let outcome = dispatch_tool_call(
1379 ¶llel_dispatch_context(Arc::clone(&barrier), Arc::clone(&started)),
1380 "batch".to_string(),
1381 json!({
1382 "tool_calls": [
1383 {"tool": "probe_a", "parameters": {}},
1384 {"tool": "probe_b", "parameters": {}}
1385 ]
1386 }),
1387 None,
1388 )
1389 .await;
1390
1391 assert!(outcome.record.output.is_success());
1392 assert_eq!(started.load(Ordering::SeqCst), 2);
1393 let value = outcome.record.output.value_for_projection();
1394 let results = value
1395 .get("results")
1396 .and_then(|value| value.as_array())
1397 .expect("results");
1398 assert_eq!(results.len(), 2);
1399 assert!(
1400 results
1401 .iter()
1402 .all(|item| item.get("success").and_then(|value| value.as_bool()) == Some(true))
1403 );
1404 }
1405
1406 struct SerialProbeTools {
1409 log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1411 }
1412
1413 #[async_trait::async_trait]
1414 impl ToolProvider for SerialProbeTools {
1415 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
1416 manifests(vec![
1417 test_tool("serial_a", ToolExecutionMode::Serial),
1418 test_tool("serial_b", ToolExecutionMode::Serial),
1419 ])
1420 }
1421
1422 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
1423 contract_from(
1424 vec![
1425 test_tool("serial_a", ToolExecutionMode::Serial),
1426 test_tool("serial_b", ToolExecutionMode::Serial),
1427 ],
1428 name,
1429 )
1430 }
1431
1432 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
1433 let start = Instant::now();
1434 tokio::time::sleep(Duration::from_millis(40)).await;
1438 let end = Instant::now();
1439 self.log
1440 .lock()
1441 .expect("serial probe log")
1442 .push((call.name.to_string(), start, end));
1443 ToolResult::ok(json!(call.name))
1444 }
1445 }
1446
1447 fn serial_dispatch_context(
1448 log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1449 ) -> ToolDispatchContext {
1450 let (event_tx, _event_rx) = mpsc::channel(8);
1451 let plugins = test_plugins(Arc::new(SerialProbeTools { log }));
1452 let tools = plugins.tools();
1453 let surface = plugins.tool_surface("session", ExecutionMode::standard());
1454 ToolDispatchContext {
1455 plugins,
1456 tools,
1457 surface,
1458 host: Arc::new(MockSessionManager::default()),
1459 session_id: "session".to_string(),
1460 event_tx,
1461 turn_injection_bridge: crate::TurnInjectionBridge::new(),
1462 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
1463 turn_context: crate::TurnContext::default(),
1464 }
1465 }
1466
1467 #[tokio::test]
1471 async fn serial_tools_do_not_interleave() {
1472 let log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>> =
1473 Arc::new(std::sync::Mutex::new(Vec::new()));
1474 let context = Arc::new(serial_dispatch_context(Arc::clone(&log)));
1475
1476 let specs = vec![
1477 ParallelToolCallSpec {
1478 index: 0,
1479 tool_name: "serial_a".to_string(),
1480 args: json!({}),
1481 },
1482 ParallelToolCallSpec {
1483 index: 1,
1484 tool_name: "serial_b".to_string(),
1485 args: json!({}),
1486 },
1487 ];
1488
1489 let outcomes = dispatch_parallel_tool_calls(context, specs, None).await;
1490
1491 assert_eq!(outcomes.len(), 2);
1492 assert!(
1493 outcomes
1494 .iter()
1495 .all(|outcome| outcome.record.output.is_success())
1496 );
1497 assert_eq!(outcomes[0].index, 0);
1499 assert_eq!(outcomes[1].index, 1);
1500 assert_eq!(outcomes[0].record.tool, "serial_a");
1501 assert_eq!(outcomes[1].record.tool, "serial_b");
1502
1503 let entries = log.lock().expect("log").clone();
1504 assert_eq!(entries.len(), 2, "both serial tools must have executed");
1505 let mut sorted = entries;
1508 sorted.sort_by_key(|(_, start, _)| *start);
1509 let (first_name, _first_start, first_end) = &sorted[0];
1510 let (second_name, second_start, _second_end) = &sorted[1];
1511 assert_ne!(first_name, second_name, "both tools should have run");
1512 assert!(
1513 second_start >= first_end,
1514 "serial tool ranges must not overlap: first ended at {:?}, second started at {:?}",
1515 first_end,
1516 second_start,
1517 );
1518 }
1519
1520 struct SerialRetryProbeTools {
1521 log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1522 attempts_a: Arc<AtomicUsize>,
1523 attempts_b: Arc<AtomicUsize>,
1524 }
1525
1526 #[async_trait::async_trait]
1527 impl ToolProvider for SerialRetryProbeTools {
1528 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
1529 manifests(vec![
1530 test_tool("serial_retry_a", ToolExecutionMode::Serial)
1531 .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1532 test_tool("serial_retry_b", ToolExecutionMode::Serial)
1533 .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1534 ])
1535 }
1536
1537 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
1538 contract_from(
1539 vec![
1540 test_tool("serial_retry_a", ToolExecutionMode::Serial)
1541 .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1542 test_tool("serial_retry_b", ToolExecutionMode::Serial)
1543 .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1544 ],
1545 name,
1546 )
1547 }
1548
1549 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
1550 let start = Instant::now();
1551 tokio::time::sleep(Duration::from_millis(20)).await;
1552 let end = Instant::now();
1553 self.log
1554 .lock()
1555 .expect("serial retry log")
1556 .push((call.name.to_string(), start, end));
1557
1558 let attempt = match call.name {
1559 "serial_retry_a" => self.attempts_a.fetch_add(1, Ordering::SeqCst) + 1,
1560 "serial_retry_b" => self.attempts_b.fetch_add(1, Ordering::SeqCst) + 1,
1561 _ => 1,
1562 };
1563 if attempt == 1 {
1564 ToolResult::retryable_failure(
1565 crate::ToolFailureClass::External,
1566 "transient",
1567 "transient failure",
1568 Some(0),
1569 )
1570 } else {
1571 ToolResult::ok(json!(call.name))
1572 }
1573 }
1574 }
1575
1576 #[tokio::test]
1577 async fn serial_tool_retries_do_not_overlap_other_serial_calls() {
1578 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
1579 let attempts_a = Arc::new(AtomicUsize::new(0));
1580 let attempts_b = Arc::new(AtomicUsize::new(0));
1581 let provider = Arc::new(SerialRetryProbeTools {
1582 log: Arc::clone(&log),
1583 attempts_a: Arc::clone(&attempts_a),
1584 attempts_b: Arc::clone(&attempts_b),
1585 });
1586 let (event_tx, _event_rx) = mpsc::channel(8);
1587 let plugins = test_plugins(provider);
1588 let tools = plugins.tools();
1589 let surface = plugins.tool_surface("session", ExecutionMode::standard());
1590 let context = Arc::new(ToolDispatchContext {
1591 plugins,
1592 tools,
1593 surface,
1594 host: Arc::new(MockSessionManager::default()),
1595 session_id: "session".to_string(),
1596 event_tx,
1597 turn_injection_bridge: crate::TurnInjectionBridge::new(),
1598 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
1599 turn_context: crate::TurnContext::default(),
1600 });
1601
1602 let outcomes = dispatch_parallel_tool_calls(
1603 context,
1604 vec![
1605 ParallelToolCallSpec {
1606 index: 0,
1607 tool_name: "serial_retry_a".to_string(),
1608 args: json!({}),
1609 },
1610 ParallelToolCallSpec {
1611 index: 1,
1612 tool_name: "serial_retry_b".to_string(),
1613 args: json!({}),
1614 },
1615 ],
1616 None,
1617 )
1618 .await;
1619
1620 assert!(
1621 outcomes
1622 .iter()
1623 .all(|outcome| outcome.record.output.is_success())
1624 );
1625 assert_eq!(attempts_a.load(Ordering::SeqCst), 2);
1626 assert_eq!(attempts_b.load(Ordering::SeqCst), 2);
1627
1628 let mut entries = log.lock().expect("serial retry log").clone();
1629 entries.sort_by_key(|(_, start, _)| *start);
1630 assert_eq!(entries.len(), 4);
1631 for window in entries.windows(2) {
1632 assert!(
1633 window[1].1 >= window[0].2,
1634 "serial retry windows must not overlap: {:?} then {:?}",
1635 window[0],
1636 window[1],
1637 );
1638 }
1639 }
1640
1641 #[tokio::test]
1646 async fn mixed_batch_runs_parallel_tools_concurrently_and_serial_alone() {
1647 struct MixedTools {
1648 barrier: Arc<Barrier>,
1649 serial_window: Arc<std::sync::Mutex<Option<(Instant, Instant)>>>,
1650 parallel_windows: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1651 }
1652
1653 #[async_trait::async_trait]
1654 impl ToolProvider for MixedTools {
1655 fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
1656 manifests(vec![
1657 test_tool("par_a", ToolExecutionMode::Parallel),
1658 test_tool("par_b", ToolExecutionMode::Parallel),
1659 test_tool("ser", ToolExecutionMode::Serial),
1660 ])
1661 }
1662
1663 fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
1664 contract_from(
1665 vec![
1666 test_tool("par_a", ToolExecutionMode::Parallel),
1667 test_tool("par_b", ToolExecutionMode::Parallel),
1668 test_tool("ser", ToolExecutionMode::Serial),
1669 ],
1670 name,
1671 )
1672 }
1673
1674 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
1675 let name = call.name;
1676 if name == "ser" {
1677 let start = Instant::now();
1678 tokio::time::sleep(Duration::from_millis(30)).await;
1679 let end = Instant::now();
1680 *self.serial_window.lock().expect("serial window") = Some((start, end));
1681 ToolResult::ok(json!(name))
1682 } else {
1683 let start = Instant::now();
1684 let waited = timeout(Duration::from_millis(200), self.barrier.wait()).await;
1687 let end = Instant::now();
1688 self.parallel_windows
1689 .lock()
1690 .expect("parallel windows")
1691 .push((name.to_string(), start, end));
1692 match waited {
1693 Ok(_) => ToolResult::ok(json!(name)),
1694 Err(_) => ToolResult::err_fmt(format!("{name} did not overlap with peer")),
1695 }
1696 }
1697 }
1698 }
1699
1700 let barrier = Arc::new(Barrier::new(2));
1701 let serial_window = Arc::new(std::sync::Mutex::new(None));
1702 let parallel_windows = Arc::new(std::sync::Mutex::new(Vec::new()));
1703 let (event_tx, _event_rx) = mpsc::channel(8);
1704 let provider = Arc::new(MixedTools {
1705 barrier: Arc::clone(&barrier),
1706 serial_window: Arc::clone(&serial_window),
1707 parallel_windows: Arc::clone(¶llel_windows),
1708 });
1709 let plugins = test_plugins(provider);
1710 let tools = plugins.tools();
1711 let surface = plugins.tool_surface("session", ExecutionMode::standard());
1712 let context = Arc::new(ToolDispatchContext {
1713 plugins,
1714 tools,
1715 surface,
1716 host: Arc::new(MockSessionManager::default()),
1717 session_id: "session".to_string(),
1718 event_tx,
1719 turn_injection_bridge: crate::TurnInjectionBridge::new(),
1720 attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
1721 turn_context: crate::TurnContext::default(),
1722 });
1723
1724 let specs = vec![
1725 ParallelToolCallSpec {
1726 index: 0,
1727 tool_name: "par_a".to_string(),
1728 args: json!({}),
1729 },
1730 ParallelToolCallSpec {
1731 index: 1,
1732 tool_name: "ser".to_string(),
1733 args: json!({}),
1734 },
1735 ParallelToolCallSpec {
1736 index: 2,
1737 tool_name: "par_b".to_string(),
1738 args: json!({}),
1739 },
1740 ];
1741
1742 let outcomes = dispatch_parallel_tool_calls(context, specs, None).await;
1743
1744 assert_eq!(outcomes.len(), 3);
1745 assert!(
1746 outcomes
1747 .iter()
1748 .all(|outcome| outcome.record.output.is_success()),
1749 "all tools should succeed: {:?}",
1750 outcomes
1751 .iter()
1752 .map(|outcome| (&outcome.record.tool, outcome.record.output.is_success()))
1753 .collect::<Vec<_>>()
1754 );
1755
1756 let pw = parallel_windows.lock().expect("parallel windows");
1757 assert_eq!(pw.len(), 2);
1758 let sw = serial_window
1759 .lock()
1760 .expect("serial window")
1761 .expect("serial window recorded");
1762
1763 for (name, p_start, p_end) in pw.iter() {
1766 assert!(
1767 sw.0 >= *p_end || sw.1 <= *p_start,
1768 "serial window {:?} overlaps parallel window {} {:?}..{:?}",
1769 sw,
1770 name,
1771 p_start,
1772 p_end,
1773 );
1774 }
1775 }
1776}