Skip to main content

lash_core/
tool_dispatch.rs

1use std::future::Future;
2use std::sync::Arc;
3use std::time::Instant;
4
5use futures_util::stream::{FuturesUnordered, StreamExt};
6use tokio::sync::mpsc;
7
8use crate::plugin::{
9    PluginDirective, PluginSession, ToolCallHookContext, ToolHookHost, ToolResultHookContext,
10    emit_plugin_surface_events,
11};
12use crate::tool_executor::execute_tool_call;
13use crate::tool_schema::validate_tool_input;
14use crate::{
15    ProgressSender, SessionEvent, ToolCallRecord, ToolContext, ToolExecutionMode, ToolFailure,
16    ToolFailureClass, ToolManifest, ToolProvider, ToolResult, ToolSurface, TurnInjectionBridge,
17};
18
19#[derive(Clone)]
20pub struct ToolDispatchContext {
21    pub plugins: Arc<PluginSession>,
22    pub tools: Arc<dyn ToolProvider>,
23    pub surface: Arc<ToolSurface>,
24    pub host: Arc<dyn ToolHookHost>,
25    pub session_id: String,
26    pub event_tx: mpsc::Sender<SessionEvent>,
27    pub turn_injection_bridge: TurnInjectionBridge,
28    pub attachment_store: Arc<dyn crate::AttachmentStore>,
29    pub turn_context: crate::TurnContext,
30}
31
32#[derive(Clone)]
33pub(crate) struct ToolDispatchOutcome {
34    pub record: ToolCallRecord,
35}
36
37#[derive(Clone)]
38pub struct ParallelToolCallSpec {
39    pub index: usize,
40    pub tool_name: String,
41    pub args: serde_json::Value,
42}
43
44#[derive(Clone)]
45pub struct ParallelToolCallOutcome {
46    pub index: usize,
47    pub record: ToolCallRecord,
48}
49
50pub(crate) async fn dispatch_tool_call(
51    context: &ToolDispatchContext,
52    tool_name: String,
53    args: serde_json::Value,
54    progress: Option<&ProgressSender>,
55) -> ToolDispatchOutcome {
56    let tool_context = ToolContext::new(
57        context.session_id.clone(),
58        Arc::clone(&context.host),
59        context.turn_context.clone(),
60        Arc::clone(&context.attachment_store),
61        None,
62    );
63    dispatch_tool_call_with_execution_context(context, tool_name, args, progress, tool_context)
64        .await
65}
66
67pub(crate) async fn dispatch_tool_call_with_execution_context(
68    context: &ToolDispatchContext,
69    tool_name: String,
70    args: serde_json::Value,
71    progress: Option<&ProgressSender>,
72    tool_context: ToolContext,
73) -> ToolDispatchOutcome {
74    let Some(manifest) = resolve_callable_manifest(context, &tool_name) else {
75        return outcome(
76            tool_name,
77            args,
78            runtime_failure(
79                ToolFailureClass::Unavailable,
80                "tool_unavailable",
81                "Tool is unavailable in this session",
82            ),
83            0,
84        );
85    };
86    let mut args = args;
87
88    let directives = match context
89        .plugins
90        .before_tool_call(ToolCallHookContext::new(
91            context.session_id.clone(),
92            tool_name.clone(),
93            args.clone(),
94            context.turn_context.clone(),
95            Arc::clone(&context.host),
96        ))
97        .await
98    {
99        Ok(directives) => directives,
100        Err(err) => {
101            return outcome(
102                tool_name,
103                args,
104                runtime_failure(
105                    ToolFailureClass::Internal,
106                    "before_tool_call_failed",
107                    err.to_string(),
108                ),
109                0,
110            );
111        }
112    };
113
114    let mut short_circuit: Option<ToolResult> = None;
115    for emitted in directives {
116        let plugin_id = emitted.plugin_id;
117        let directive = emitted.value;
118        match directive {
119            PluginDirective::CreateSession { request } => {
120                if let Err(err) = context.host.create_session(*request).await {
121                    short_circuit = Some(ToolResult::err_fmt(err.to_string()));
122                    break;
123                }
124            }
125            PluginDirective::HandoffSession { .. } => {
126                short_circuit = Some(ToolResult::err_fmt(
127                    "before_tool_call does not support session handoff",
128                ));
129                break;
130            }
131            PluginDirective::ReplaceToolArgs { args: replacement } => {
132                args = replacement;
133            }
134            PluginDirective::ShortCircuitTool { output } => {
135                short_circuit = Some(ToolResult::from_output(output));
136            }
137            PluginDirective::AbortTurn { message, .. } => {
138                short_circuit = Some(ToolResult::err_fmt(message));
139            }
140            PluginDirective::EmitEvents { events } => {
141                emit_plugin_surface_events(&context.event_tx, &plugin_id, events).await;
142            }
143            PluginDirective::EmitTrace {
144                name,
145                payload,
146                context: trace_context,
147            } => {
148                if let Err(err) = context
149                    .host
150                    .emit_trace_event(
151                        *trace_context,
152                        lash_trace::TraceEvent::Custom {
153                            name: format!("plugin.{plugin_id}.{name}"),
154                            payload,
155                        },
156                    )
157                    .await
158                {
159                    short_circuit = Some(ToolResult::err_fmt(err.to_string()));
160                    break;
161                }
162            }
163            PluginDirective::EnqueueMessages { .. } => {
164                short_circuit = Some(ToolResult::err_fmt(
165                    "before_tool_call does not support message injection",
166                ));
167            }
168        }
169    }
170    if let Some(result) = short_circuit {
171        return outcome(tool_name, args, result, 0);
172    }
173
174    let contract = context
175        .plugins
176        .mode_native_tools()
177        .iter()
178        .find_map(|provider| provider.resolve_contract(&tool_name))
179        .or_else(|| context.tools.resolve_contract(&tool_name));
180    let Some(contract) = contract else {
181        return outcome(
182            tool_name,
183            args,
184            runtime_failure(
185                ToolFailureClass::Unavailable,
186                "tool_contract_unavailable",
187                "Tool contract is unavailable in this session",
188            ),
189            0,
190        );
191    };
192    if let Err(err) = validate_tool_input(&contract, &args) {
193        return outcome(
194            tool_name,
195            args,
196            runtime_failure(ToolFailureClass::InvalidRequest, "invalid_tool_args", err),
197            0,
198        );
199    }
200
201    let tool_start = Instant::now();
202    let result = execute_tool_call(
203        context,
204        &manifest,
205        &tool_name,
206        &args,
207        progress,
208        tool_context,
209    )
210    .await;
211    let duration_ms = tool_start.elapsed().as_millis() as u64;
212
213    let result = match context
214        .plugins
215        .after_tool_call(ToolResultHookContext::new(
216            context.session_id.clone(),
217            tool_name.clone(),
218            args.clone(),
219            result.clone(),
220            duration_ms,
221            context.turn_context.clone(),
222            Arc::clone(&context.host),
223        ))
224        .await
225    {
226        Ok(directives) => {
227            let mut final_result = result;
228            for emitted in directives {
229                let plugin_id = emitted.plugin_id;
230                let directive = emitted.value;
231                match directive {
232                    PluginDirective::CreateSession { request } => {
233                        if let Err(err) = context.host.create_session(*request).await {
234                            final_result = ToolResult::failure(ToolFailure::runtime(
235                                ToolFailureClass::Internal,
236                                "plugin_session_create_failed",
237                                err.to_string(),
238                            ));
239                            break;
240                        }
241                    }
242                    PluginDirective::HandoffSession { .. } => {
243                        final_result =
244                            ToolResult::err_fmt("after_tool_call does not support session handoff");
245                        break;
246                    }
247                    PluginDirective::ShortCircuitTool { output } => {
248                        final_result = ToolResult::from_output(output);
249                    }
250                    PluginDirective::AbortTurn { message, .. } => {
251                        final_result = ToolResult::err_fmt(message);
252                    }
253                    PluginDirective::EmitEvents { events } => {
254                        emit_plugin_surface_events(&context.event_tx, &plugin_id, events).await;
255                    }
256                    PluginDirective::EmitTrace {
257                        name,
258                        payload,
259                        context: trace_context,
260                    } => {
261                        if let Err(err) = context
262                            .host
263                            .emit_trace_event(
264                                *trace_context,
265                                lash_trace::TraceEvent::Custom {
266                                    name: format!("plugin.{plugin_id}.{name}"),
267                                    payload,
268                                },
269                            )
270                            .await
271                        {
272                            final_result = ToolResult::err_fmt(err.to_string());
273                            break;
274                        }
275                    }
276                    PluginDirective::EnqueueMessages { messages } => {
277                        if let Err(err) = context.turn_injection_bridge.enqueue(messages) {
278                            final_result = ToolResult::err_fmt(err);
279                            break;
280                        }
281                    }
282                    PluginDirective::ReplaceToolArgs { .. } => {
283                        final_result = ToolResult::err_fmt(
284                            "after_tool_call only supports abort, short-circuit, session creation, events, and message injection",
285                        );
286                    }
287                }
288            }
289            final_result
290        }
291        Err(err) => runtime_failure(
292            ToolFailureClass::Internal,
293            "after_tool_call_failed",
294            err.to_string(),
295        ),
296    };
297
298    outcome(tool_name, args, result, duration_ms)
299}
300
301fn resolve_callable_manifest(
302    context: &ToolDispatchContext,
303    tool_name: &str,
304) -> Option<ToolManifest> {
305    if let Some(entry) = context
306        .surface
307        .tools
308        .iter()
309        .find(|tool| tool.manifest.name == tool_name)
310    {
311        return entry
312            .availability
313            .is_callable()
314            .then(|| entry.manifest.clone());
315    }
316
317    let mode = context.plugins.execution_mode();
318    let visible_and_callable = |manifest: ToolManifest| {
319        if context.plugins.tool_access().hides(&manifest.name) {
320            return None;
321        }
322        manifest
323            .effective_availability(&mode)
324            .is_callable()
325            .then_some(manifest)
326    };
327
328    for provider in context.plugins.mode_native_tools() {
329        if let Some(manifest) = provider
330            .resolve_manifest(tool_name)
331            .and_then(&visible_and_callable)
332        {
333            return Some(manifest);
334        }
335    }
336
337    context
338        .tools
339        .resolve_manifest(tool_name)
340        .and_then(visible_and_callable)
341}
342
343pub(crate) async fn dispatch_parallel_tool_call(
344    context: Arc<ToolDispatchContext>,
345    spec: ParallelToolCallSpec,
346    progress: Option<ProgressSender>,
347) -> ParallelToolCallOutcome {
348    let outcome = dispatch_tool_call(&context, spec.tool_name, spec.args, progress.as_ref()).await;
349    ParallelToolCallOutcome {
350        index: spec.index,
351        record: outcome.record,
352    }
353}
354
355/// Resolve the [`ToolExecutionMode`] declared on a tool's definition. Unknown
356/// tool names default to [`ToolExecutionMode::Parallel`] — the dispatcher
357/// will still surface an "unknown tool" error via the normal path.
358pub(crate) fn resolve_tool_execution_mode(
359    context: &ToolDispatchContext,
360    tool_name: &str,
361) -> ToolExecutionMode {
362    context
363        .surface
364        .tools
365        .iter()
366        .find(|def| def.manifest.name == tool_name)
367        .map(|def| def.manifest.execution_mode)
368        .unwrap_or_default()
369}
370
371/// Schedule a batch using Lash's tool execution policy.
372///
373/// Parallel-safe tools run concurrently first, then serial tools run
374/// one-at-a-time in original index order. Returned outputs are sorted by the
375/// same original index so callers keep their source/model ordering.
376pub(crate) async fn schedule_tool_batch<T, O, IndexOf, ModeOf, Run, Fut>(
377    items: Vec<T>,
378    index_of: IndexOf,
379    mode_of: ModeOf,
380    run: Run,
381) -> Vec<O>
382where
383    T: Send + 'static,
384    O: Send + 'static,
385    IndexOf: Fn(&T) -> usize,
386    ModeOf: Fn(&T) -> ToolExecutionMode,
387    Run: Fn(T) -> Fut,
388    Fut: Future<Output = O> + Send,
389{
390    let mut parallel_items = Vec::new();
391    let mut serial_items = Vec::new();
392    for item in items {
393        let index = index_of(&item);
394        match mode_of(&item) {
395            ToolExecutionMode::Parallel => parallel_items.push((index, item)),
396            ToolExecutionMode::Serial => serial_items.push((index, item)),
397        }
398    }
399
400    let mut outcomes = Vec::new();
401
402    let mut pending = FuturesUnordered::new();
403    for (index, item) in parallel_items {
404        let future = run(item);
405        pending.push(async move { (index, future.await) });
406    }
407    while let Some(outcome) = pending.next().await {
408        outcomes.push(outcome);
409    }
410
411    serial_items.sort_by_key(|(index, _)| *index);
412    for (index, item) in serial_items {
413        outcomes.push((index, run(item).await));
414    }
415
416    outcomes.sort_by_key(|(index, _)| *index);
417    outcomes.into_iter().map(|(_, outcome)| outcome).collect()
418}
419
420/// Dispatch a batch of tool calls produced by one model response.
421pub async fn dispatch_parallel_tool_calls(
422    context: Arc<ToolDispatchContext>,
423    specs: Vec<ParallelToolCallSpec>,
424    progress: Option<&ProgressSender>,
425) -> Vec<ParallelToolCallOutcome> {
426    let progress = progress.cloned();
427    schedule_tool_batch(
428        specs,
429        |spec| spec.index,
430        {
431            let context = Arc::clone(&context);
432            move |spec| resolve_tool_execution_mode(&context, &spec.tool_name)
433        },
434        move |spec| dispatch_parallel_tool_call(Arc::clone(&context), spec, progress.clone()),
435    )
436    .await
437}
438
439fn outcome(
440    tool_name: String,
441    args: serde_json::Value,
442    result: ToolResult,
443    duration_ms: u64,
444) -> ToolDispatchOutcome {
445    let record = ToolCallRecord {
446        call_id: None,
447        tool: tool_name,
448        args,
449        output: *result.output,
450        duration_ms,
451    };
452    ToolDispatchOutcome { record }
453}
454
455fn runtime_failure(
456    class: ToolFailureClass,
457    code: impl Into<String>,
458    message: impl Into<String>,
459) -> ToolResult {
460    ToolResult::failure(ToolFailure::runtime(class, code, message))
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use crate::plugin::{PluginHost, StaticPluginFactory};
467    use crate::{
468        ExecutionMode, ToolCall, ToolCallOutcome, ToolProvider, ToolRetryDisposition,
469        ToolRetryPolicy,
470    };
471    use serde_json::json;
472    use std::collections::BTreeMap;
473    use std::sync::atomic::{AtomicUsize, Ordering};
474    use tokio::sync::Barrier;
475    use tokio::time::{Duration, timeout};
476
477    type ExecutionWindow = (&'static str, Instant, Instant);
478    type SharedExecutionWindows = Arc<std::sync::Mutex<Vec<ExecutionWindow>>>;
479    type AttemptObservation = (u32, u32, Option<String>);
480    type SharedAttemptObservations = Arc<std::sync::Mutex<Vec<AttemptObservation>>>;
481
482    fn test_tool(name: &str, execution_mode: ToolExecutionMode) -> crate::ToolDefinition {
483        crate::ToolDefinition::raw(
484            name,
485            "",
486            crate::ToolDefinition::default_input_schema(),
487            json!({ "type": "string" }),
488        )
489        .with_execution_mode(execution_mode)
490    }
491
492    fn beta_tool() -> crate::ToolDefinition {
493        crate::ToolDefinition::raw(
494            "beta",
495            "",
496            json!({
497                "type": "object",
498                "properties": {
499                    "value": { "type": "string" }
500                },
501                "required": ["value"],
502                "additionalProperties": false
503            }),
504            json!({ "type": "string" }),
505        )
506        .with_execution_mode(ToolExecutionMode::Parallel)
507    }
508
509    fn named_beta_tool(name: &str) -> crate::ToolDefinition {
510        crate::ToolDefinition::raw(
511            name,
512            "",
513            json!({
514                "type": "object",
515                "properties": {
516                    "value": { "type": "string" }
517                },
518                "required": ["value"],
519                "additionalProperties": false
520            }),
521            json!({ "type": "string" }),
522        )
523        .with_execution_mode(ToolExecutionMode::Parallel)
524    }
525
526    fn manifests(definitions: Vec<crate::ToolDefinition>) -> Vec<crate::ToolManifest> {
527        definitions
528            .into_iter()
529            .map(|tool| tool.manifest())
530            .collect()
531    }
532
533    fn contract_from(
534        definitions: Vec<crate::ToolDefinition>,
535        name: &str,
536    ) -> Option<Arc<crate::ToolContract>> {
537        definitions
538            .into_iter()
539            .find(|tool| tool.name == name)
540            .map(|tool| Arc::new(tool.contract()))
541    }
542
543    #[derive(Clone)]
544    struct ScheduledProbe {
545        index: usize,
546        name: &'static str,
547        mode: ToolExecutionMode,
548        delay: Duration,
549    }
550
551    #[tokio::test]
552    async fn scheduler_runs_parallel_bucket_then_serial_and_preserves_order() {
553        let windows: SharedExecutionWindows = Arc::new(std::sync::Mutex::new(Vec::new()));
554        let probes = vec![
555            ScheduledProbe {
556                index: 0,
557                name: "parallel_slow",
558                mode: ToolExecutionMode::Parallel,
559                delay: Duration::from_millis(40),
560            },
561            ScheduledProbe {
562                index: 1,
563                name: "serial",
564                mode: ToolExecutionMode::Serial,
565                delay: Duration::from_millis(10),
566            },
567            ScheduledProbe {
568                index: 2,
569                name: "parallel_fast",
570                mode: ToolExecutionMode::Parallel,
571                delay: Duration::from_millis(5),
572            },
573        ];
574
575        let outputs = schedule_tool_batch(probes, |probe| probe.index, |probe| probe.mode, {
576            let windows = Arc::clone(&windows);
577            move |probe| {
578                let windows = Arc::clone(&windows);
579                async move {
580                    let start = Instant::now();
581                    tokio::time::sleep(probe.delay).await;
582                    let end = Instant::now();
583                    windows
584                        .lock()
585                        .expect("windows")
586                        .push((probe.name, start, end));
587                    probe.name
588                }
589            }
590        })
591        .await;
592
593        assert_eq!(outputs, ["parallel_slow", "serial", "parallel_fast"]);
594
595        let recorded = windows.lock().expect("windows").clone();
596        let parallel_slow = recorded
597            .iter()
598            .find(|(name, _, _)| *name == "parallel_slow")
599            .expect("parallel_slow");
600        let parallel_fast = recorded
601            .iter()
602            .find(|(name, _, _)| *name == "parallel_fast")
603            .expect("parallel_fast");
604        let serial = recorded
605            .iter()
606            .find(|(name, _, _)| *name == "serial")
607            .expect("serial");
608
609        assert!(
610            parallel_fast.1 < parallel_slow.2,
611            "parallel tools should overlap even when completion order differs"
612        );
613        assert!(
614            serial.1 >= parallel_slow.2 && serial.1 >= parallel_fast.2,
615            "serial tool should start after the parallel bucket completes"
616        );
617    }
618
619    struct MockTools;
620
621    #[async_trait::async_trait]
622    impl ToolProvider for MockTools {
623        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
624            manifests(vec![
625                test_tool("alpha", ToolExecutionMode::Parallel),
626                beta_tool(),
627            ])
628        }
629
630        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
631            contract_from(
632                vec![test_tool("alpha", ToolExecutionMode::Parallel), beta_tool()],
633                name,
634            )
635        }
636
637        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
638            match call.name {
639                "alpha" => ToolResult::ok(json!("alpha")),
640                "beta" => {
641                    if call.args.get("value").and_then(|value| value.as_str()) == Some("fail") {
642                        ToolResult::err_fmt("beta failed")
643                    } else {
644                        ToolResult::ok(json!(
645                            call.args.get("value").cloned().unwrap_or(json!(null))
646                        ))
647                    }
648                }
649                other => ToolResult::err_fmt(format!("Unknown tool: {other}")),
650            }
651        }
652    }
653
654    struct ParallelProbeTools {
655        barrier: Arc<Barrier>,
656        started: Arc<AtomicUsize>,
657    }
658
659    #[async_trait::async_trait]
660    impl ToolProvider for ParallelProbeTools {
661        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
662            manifests(vec![
663                test_tool("probe_a", ToolExecutionMode::Parallel),
664                test_tool("probe_b", ToolExecutionMode::Parallel),
665            ])
666        }
667
668        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
669            contract_from(
670                vec![
671                    test_tool("probe_a", ToolExecutionMode::Parallel),
672                    test_tool("probe_b", ToolExecutionMode::Parallel),
673                ],
674                name,
675            )
676        }
677
678        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
679            self.started.fetch_add(1, Ordering::SeqCst);
680            let waited = timeout(Duration::from_millis(100), self.barrier.wait()).await;
681            match waited {
682                Ok(_) => ToolResult::ok(json!(call.name)),
683                Err(_) => ToolResult::err_fmt(format!("{} did not overlap with peer", call.name)),
684            }
685        }
686    }
687
688    struct StrictMcpTools {
689        executed: Arc<AtomicUsize>,
690    }
691
692    #[async_trait::async_trait]
693    impl ToolProvider for StrictMcpTools {
694        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
695            manifests(vec![strict_mcp_tool_definition()])
696        }
697
698        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
699            (name == "mcp__appworld__venmo_show_transactions")
700                .then(|| Arc::new(strict_mcp_tool_definition().contract()))
701        }
702
703        async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
704            self.executed.fetch_add(1, Ordering::SeqCst);
705            ToolResult::ok(json!({ "executed": true }))
706        }
707    }
708
709    fn strict_mcp_tool_definition() -> crate::ToolDefinition {
710        crate::ToolDefinition::raw(
711            "mcp__appworld__venmo_show_transactions",
712            "Show Venmo transactions",
713            json!({
714                "type": "object",
715                "properties": {
716                    "min_created_at": { "type": "string" },
717                    "max_created_at": { "type": "string" },
718                    "limit": { "type": "integer", "maximum": 100 }
719                },
720                "required": ["limit"]
721            }),
722            json!({ "type": "object", "additionalProperties": true }),
723        )
724    }
725
726    fn strict_mcp_dispatch_context(executed: Arc<AtomicUsize>) -> ToolDispatchContext {
727        let (event_tx, _event_rx) = mpsc::channel(8);
728        let plugins = test_plugins(Arc::new(StrictMcpTools { executed }));
729        let tools = plugins.tools();
730        let surface = plugins.tool_surface("session", ExecutionMode::standard());
731        ToolDispatchContext {
732            plugins,
733            tools,
734            surface,
735            host: Arc::new(MockSessionManager::default()),
736            session_id: "session".to_string(),
737            event_tx,
738            turn_injection_bridge: crate::TurnInjectionBridge::new(),
739            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
740            turn_context: crate::TurnContext::default(),
741        }
742    }
743
744    fn test_plugins(provider: Arc<dyn ToolProvider>) -> Arc<PluginSession> {
745        PluginHost::new(vec![Arc::new(StaticPluginFactory::new(
746            "test_tools",
747            crate::PluginSpec::new().with_tool_provider(Arc::clone(&provider)),
748        ))])
749        .build_standard_session("root", None)
750        .expect("plugin session")
751    }
752
753    use crate::testing::MockSessionManager;
754
755    fn dispatch_context() -> ToolDispatchContext {
756        let (event_tx, _event_rx) = mpsc::channel(8);
757        let plugins = test_plugins(Arc::new(MockTools));
758        let tools = plugins.tools();
759        let surface = plugins.tool_surface("session", ExecutionMode::standard());
760        ToolDispatchContext {
761            plugins,
762            tools,
763            surface,
764            host: Arc::new(MockSessionManager::default()),
765            session_id: "session".to_string(),
766            event_tx,
767            turn_injection_bridge: crate::TurnInjectionBridge::new(),
768            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
769            turn_context: crate::TurnContext::default(),
770        }
771    }
772
773    struct CountingContractTools {
774        contracts_resolved: Arc<AtomicUsize>,
775        executed: Arc<AtomicUsize>,
776    }
777
778    struct ExactDispatchTools {
779        contracts_resolved: Arc<AtomicUsize>,
780        executed: Arc<AtomicUsize>,
781        contract_available: bool,
782    }
783
784    struct HiddenDispatchTools {
785        contracts_resolved: Arc<AtomicUsize>,
786        executed: Arc<AtomicUsize>,
787    }
788
789    struct RetryProbeTools {
790        definition: crate::ToolDefinition,
791        attempts: Arc<AtomicUsize>,
792        successes_after: usize,
793        cancel_on_first: bool,
794        observed_attempts: SharedAttemptObservations,
795    }
796
797    #[async_trait::async_trait]
798    impl ToolProvider for CountingContractTools {
799        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
800            manifests(vec![beta_tool()])
801        }
802
803        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
804            self.contracts_resolved.fetch_add(1, Ordering::SeqCst);
805            (name == "beta").then(|| Arc::new(beta_tool().contract()))
806        }
807
808        async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
809            self.executed.fetch_add(1, Ordering::SeqCst);
810            ToolResult::ok(json!("ok"))
811        }
812    }
813
814    #[async_trait::async_trait]
815    impl ToolProvider for ExactDispatchTools {
816        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
817            Vec::new()
818        }
819
820        fn resolve_manifest(&self, name: &str) -> Option<crate::ToolManifest> {
821            (name == "host_only").then(|| named_beta_tool("host_only").manifest())
822        }
823
824        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
825            self.contracts_resolved.fetch_add(1, Ordering::SeqCst);
826            (self.contract_available && name == "host_only")
827                .then(|| Arc::new(named_beta_tool("host_only").contract()))
828        }
829
830        async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
831            self.executed.fetch_add(1, Ordering::SeqCst);
832            ToolResult::ok(json!("host"))
833        }
834    }
835
836    #[async_trait::async_trait]
837    impl ToolProvider for HiddenDispatchTools {
838        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
839            manifests(vec![
840                named_beta_tool("hidden").with_availability(crate::ToolAvailabilityConfig::off()),
841            ])
842        }
843
844        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
845            self.contracts_resolved.fetch_add(1, Ordering::SeqCst);
846            (name == "hidden").then(|| Arc::new(named_beta_tool("hidden").contract()))
847        }
848
849        async fn execute(&self, _call: ToolCall<'_>) -> ToolResult {
850            self.executed.fetch_add(1, Ordering::SeqCst);
851            ToolResult::ok(json!("hidden"))
852        }
853    }
854
855    #[async_trait::async_trait]
856    impl ToolProvider for RetryProbeTools {
857        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
858            manifests(vec![self.definition.clone()])
859        }
860
861        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
862            (name == self.definition.name).then(|| Arc::new(self.definition.contract()))
863        }
864
865        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
866            self.observed_attempts.lock().expect("attempts").push((
867                call.context.attempt_number(),
868                call.context.max_attempts(),
869                call.context.idempotency_key().map(str::to_string),
870            ));
871            let attempt_index = self.attempts.fetch_add(1, Ordering::SeqCst) + 1;
872            if self.cancel_on_first {
873                return ToolResult::cancelled("cancelled");
874            }
875            if attempt_index >= self.successes_after {
876                return ToolResult::ok(json!({ "attempt": attempt_index }));
877            }
878            ToolResult::retryable_failure(
879                crate::ToolFailureClass::External,
880                "transient",
881                "transient failure",
882                Some(0),
883            )
884        }
885    }
886
887    fn lazy_contract_dispatch_context(
888        contracts_resolved: Arc<AtomicUsize>,
889        executed: Arc<AtomicUsize>,
890    ) -> ToolDispatchContext {
891        let (event_tx, _event_rx) = mpsc::channel(8);
892        let provider: Arc<dyn ToolProvider> = Arc::new(CountingContractTools {
893            contracts_resolved,
894            executed,
895        });
896        let tools = Arc::clone(&provider);
897        let surface = Arc::new(crate::ToolSurface::from_tools(
898            provider.tool_manifests(),
899            ExecutionMode::standard(),
900            BTreeMap::new(),
901        ));
902        ToolDispatchContext {
903            plugins: test_plugins(provider),
904            tools,
905            surface,
906            host: Arc::new(MockSessionManager::default()),
907            session_id: "session".to_string(),
908            event_tx,
909            turn_injection_bridge: crate::TurnInjectionBridge::new(),
910            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
911            turn_context: crate::TurnContext::default(),
912        }
913    }
914
915    fn exact_dispatch_context(provider: Arc<dyn ToolProvider>) -> ToolDispatchContext {
916        let (event_tx, _event_rx) = mpsc::channel(8);
917        let plugins = test_plugins(Arc::clone(&provider));
918        let tools = plugins.tools();
919        let surface = plugins.tool_surface("session", ExecutionMode::standard());
920        ToolDispatchContext {
921            plugins,
922            tools,
923            surface,
924            host: Arc::new(MockSessionManager::default()),
925            session_id: "session".to_string(),
926            event_tx,
927            turn_injection_bridge: crate::TurnInjectionBridge::new(),
928            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
929            turn_context: crate::TurnContext::default(),
930        }
931    }
932
933    fn retry_tool(name: &str, retry_policy: ToolRetryPolicy) -> crate::ToolDefinition {
934        named_beta_tool(name)
935            .with_execution_mode(ToolExecutionMode::Parallel)
936            .with_retry_policy(retry_policy)
937    }
938
939    fn retry_dispatch_context(
940        retry_policy: ToolRetryPolicy,
941        attempts: Arc<AtomicUsize>,
942        successes_after: usize,
943        cancel_on_first: bool,
944        observed_attempts: SharedAttemptObservations,
945    ) -> ToolDispatchContext {
946        exact_dispatch_context(Arc::new(RetryProbeTools {
947            definition: retry_tool("retry_probe", retry_policy),
948            attempts,
949            successes_after,
950            cancel_on_first,
951            observed_attempts,
952        }))
953    }
954
955    fn parallel_dispatch_context(
956        barrier: Arc<Barrier>,
957        started: Arc<AtomicUsize>,
958    ) -> ToolDispatchContext {
959        let (event_tx, _event_rx) = mpsc::channel(8);
960        let plugins = test_plugins(Arc::new(ParallelProbeTools { barrier, started }));
961        let tools = plugins.tools();
962        let surface = plugins.tool_surface("session", ExecutionMode::standard());
963        ToolDispatchContext {
964            plugins,
965            tools,
966            surface,
967            host: Arc::new(MockSessionManager::default()),
968            session_id: "session".to_string(),
969            event_tx,
970            turn_injection_bridge: crate::TurnInjectionBridge::new(),
971            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
972            turn_context: crate::TurnContext::default(),
973        }
974    }
975
976    #[tokio::test]
977    async fn dispatch_rejects_invalid_args_before_provider_execution() {
978        let outcome =
979            dispatch_tool_call(&dispatch_context(), "beta".to_string(), json!({}), None).await;
980
981        assert!(!outcome.record.output.is_success());
982        assert_eq!(
983            outcome.record.output.value_for_projection()["message"],
984            json!("value: required property missing")
985        );
986    }
987
988    #[tokio::test]
989    async fn dispatch_resolves_contract_only_for_called_tool_before_execution() {
990        let contracts_resolved = Arc::new(AtomicUsize::new(0));
991        let executed = Arc::new(AtomicUsize::new(0));
992        let outcome = dispatch_tool_call(
993            &lazy_contract_dispatch_context(Arc::clone(&contracts_resolved), Arc::clone(&executed)),
994            "beta".to_string(),
995            json!({ "value": "ok" }),
996            None,
997        )
998        .await;
999
1000        assert!(outcome.record.output.is_success());
1001        assert_eq!(contracts_resolved.load(Ordering::SeqCst), 1);
1002        assert_eq!(executed.load(Ordering::SeqCst), 1);
1003    }
1004
1005    #[tokio::test]
1006    async fn dispatch_exact_resolves_missing_surface_tool_and_executes_owner() {
1007        let contracts_resolved = Arc::new(AtomicUsize::new(0));
1008        let executed = Arc::new(AtomicUsize::new(0));
1009        let provider: Arc<dyn ToolProvider> = Arc::new(ExactDispatchTools {
1010            contracts_resolved: Arc::clone(&contracts_resolved),
1011            executed: Arc::clone(&executed),
1012            contract_available: true,
1013        });
1014        let outcome = dispatch_tool_call(
1015            &exact_dispatch_context(provider),
1016            "host_only".to_string(),
1017            json!({ "value": "ok" }),
1018            None,
1019        )
1020        .await;
1021
1022        assert!(outcome.record.output.is_success());
1023        assert_eq!(outcome.record.output.value_for_projection(), json!("host"));
1024        assert_eq!(contracts_resolved.load(Ordering::SeqCst), 1);
1025        assert_eq!(executed.load(Ordering::SeqCst), 1);
1026    }
1027
1028    #[tokio::test]
1029    async fn dispatch_contract_unavailable_skips_execution() {
1030        let contracts_resolved = Arc::new(AtomicUsize::new(0));
1031        let executed = Arc::new(AtomicUsize::new(0));
1032        let provider: Arc<dyn ToolProvider> = Arc::new(ExactDispatchTools {
1033            contracts_resolved: Arc::clone(&contracts_resolved),
1034            executed: Arc::clone(&executed),
1035            contract_available: false,
1036        });
1037        let outcome = dispatch_tool_call(
1038            &exact_dispatch_context(provider),
1039            "host_only".to_string(),
1040            json!({ "value": "ok" }),
1041            None,
1042        )
1043        .await;
1044
1045        assert!(!outcome.record.output.is_success());
1046        assert_eq!(
1047            outcome.record.output.value_for_projection()["message"],
1048            json!("Tool contract is unavailable in this session")
1049        );
1050        assert_eq!(contracts_resolved.load(Ordering::SeqCst), 1);
1051        assert_eq!(executed.load(Ordering::SeqCst), 0);
1052    }
1053
1054    #[tokio::test]
1055    async fn dispatch_rejects_hidden_tool_before_contract_resolution() {
1056        let contracts_resolved = Arc::new(AtomicUsize::new(0));
1057        let executed = Arc::new(AtomicUsize::new(0));
1058        let provider: Arc<dyn ToolProvider> = Arc::new(HiddenDispatchTools {
1059            contracts_resolved: Arc::clone(&contracts_resolved),
1060            executed: Arc::clone(&executed),
1061        });
1062        let outcome = dispatch_tool_call(
1063            &exact_dispatch_context(provider),
1064            "hidden".to_string(),
1065            json!({ "value": "ok" }),
1066            None,
1067        )
1068        .await;
1069
1070        assert!(!outcome.record.output.is_success());
1071        assert_eq!(
1072            outcome.record.output.value_for_projection()["message"],
1073            json!("Tool is unavailable in this session")
1074        );
1075        assert_eq!(contracts_resolved.load(Ordering::SeqCst), 0);
1076        assert_eq!(executed.load(Ordering::SeqCst), 0);
1077    }
1078
1079    #[tokio::test]
1080    async fn dispatch_rejects_unknown_mcp_args_before_provider_execution() {
1081        let executed = Arc::new(AtomicUsize::new(0));
1082        let outcome = dispatch_tool_call(
1083            &strict_mcp_dispatch_context(Arc::clone(&executed)),
1084            "mcp__appworld__venmo_show_transactions".to_string(),
1085            json!({
1086                "min_datetime": "2024-01-01T00:00:00Z",
1087                "limit": 20
1088            }),
1089            None,
1090        )
1091        .await;
1092
1093        assert!(!outcome.record.output.is_success());
1094        assert_eq!(
1095            outcome.record.output.value_for_projection()["message"],
1096            json!("min_datetime: unexpected property")
1097        );
1098        assert_eq!(executed.load(Ordering::SeqCst), 0);
1099    }
1100
1101    #[tokio::test]
1102    async fn default_retry_policy_never_retries_safe_failures() {
1103        let attempts = Arc::new(AtomicUsize::new(0));
1104        let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1105        let outcome = dispatch_tool_call(
1106            &retry_dispatch_context(
1107                ToolRetryPolicy::Never,
1108                Arc::clone(&attempts),
1109                usize::MAX,
1110                false,
1111                Arc::clone(&observed),
1112            ),
1113            "retry_probe".to_string(),
1114            json!({ "value": "ok" }),
1115            None,
1116        )
1117        .await;
1118
1119        assert!(!outcome.record.output.is_success());
1120        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1121        assert_eq!(observed.lock().expect("observed")[0].0, 1);
1122    }
1123
1124    #[tokio::test]
1125    async fn safe_retry_policy_retries_safe_failure_and_stops_on_success() {
1126        let attempts = Arc::new(AtomicUsize::new(0));
1127        let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1128        let outcome = dispatch_tool_call(
1129            &retry_dispatch_context(
1130                ToolRetryPolicy::safe(3, 0, 0),
1131                Arc::clone(&attempts),
1132                2,
1133                false,
1134                Arc::clone(&observed),
1135            ),
1136            "retry_probe".to_string(),
1137            json!({ "value": "ok" }),
1138            None,
1139        )
1140        .await;
1141
1142        assert!(outcome.record.output.is_success());
1143        assert_eq!(attempts.load(Ordering::SeqCst), 2);
1144        assert_eq!(
1145            observed
1146                .lock()
1147                .expect("observed")
1148                .iter()
1149                .map(|(attempt, max, _)| (*attempt, *max))
1150                .collect::<Vec<_>>(),
1151            vec![(1, 3), (2, 3)]
1152        );
1153    }
1154
1155    #[tokio::test]
1156    async fn safe_retry_policy_marks_exhausted_after_final_attempt() {
1157        let attempts = Arc::new(AtomicUsize::new(0));
1158        let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1159        let outcome = dispatch_tool_call(
1160            &retry_dispatch_context(
1161                ToolRetryPolicy::safe(2, 0, 0),
1162                Arc::clone(&attempts),
1163                usize::MAX,
1164                false,
1165                Arc::clone(&observed),
1166            ),
1167            "retry_probe".to_string(),
1168            json!({ "value": "ok" }),
1169            None,
1170        )
1171        .await;
1172
1173        assert!(!outcome.record.output.is_success());
1174        assert_eq!(attempts.load(Ordering::SeqCst), 2);
1175        let ToolCallOutcome::Failure(failure) = outcome.record.output.outcome else {
1176            panic!("expected failure");
1177        };
1178        assert_eq!(
1179            failure.retry,
1180            ToolRetryDisposition::Exhausted { attempts: 2 }
1181        );
1182    }
1183
1184    #[tokio::test]
1185    async fn cancellation_stops_retry_immediately() {
1186        let attempts = Arc::new(AtomicUsize::new(0));
1187        let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1188        let outcome = dispatch_tool_call(
1189            &retry_dispatch_context(
1190                ToolRetryPolicy::safe(3, 0, 0),
1191                Arc::clone(&attempts),
1192                usize::MAX,
1193                true,
1194                Arc::clone(&observed),
1195            ),
1196            "retry_probe".to_string(),
1197            json!({ "value": "ok" }),
1198            None,
1199        )
1200        .await;
1201
1202        assert!(!outcome.record.output.is_success());
1203        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1204        assert!(matches!(
1205            outcome.record.output.outcome,
1206            ToolCallOutcome::Cancelled(_)
1207        ));
1208    }
1209
1210    #[tokio::test]
1211    async fn retry_context_has_stable_idempotency_key_across_attempts() {
1212        let attempts = Arc::new(AtomicUsize::new(0));
1213        let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1214        let context = retry_dispatch_context(
1215            ToolRetryPolicy::safe(3, 0, 0),
1216            Arc::clone(&attempts),
1217            3,
1218            false,
1219            Arc::clone(&observed),
1220        );
1221        let tool_context = ToolContext::new(
1222            context.session_id.clone(),
1223            Arc::clone(&context.host),
1224            context.turn_context.clone(),
1225            Arc::clone(&context.attachment_store),
1226            Some("call-1".to_string()),
1227        );
1228        let outcome = dispatch_tool_call_with_execution_context(
1229            &context,
1230            "retry_probe".to_string(),
1231            json!({ "value": "ok" }),
1232            None,
1233            tool_context,
1234        )
1235        .await;
1236
1237        assert!(outcome.record.output.is_success());
1238        let observed = observed.lock().expect("observed");
1239        assert_eq!(observed.len(), 3);
1240        assert_eq!(
1241            observed
1242                .iter()
1243                .map(|(attempt, max, _)| (*attempt, *max))
1244                .collect::<Vec<_>>(),
1245            vec![(1, 3), (2, 3), (3, 3)]
1246        );
1247        let keys = observed
1248            .iter()
1249            .map(|(_, _, key)| key.clone())
1250            .collect::<Vec<_>>();
1251        assert!(keys.iter().all(|key| key == &keys[0]));
1252        assert_eq!(
1253            keys[0].as_deref(),
1254            Some("lash-tool:session:call-1:retry_probe")
1255        );
1256    }
1257
1258    #[tokio::test]
1259    async fn idempotent_retry_policy_requires_stable_key() {
1260        let attempts = Arc::new(AtomicUsize::new(0));
1261        let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
1262        let outcome = dispatch_tool_call(
1263            &retry_dispatch_context(
1264                ToolRetryPolicy::idempotent(3, 0, 0),
1265                Arc::clone(&attempts),
1266                usize::MAX,
1267                false,
1268                Arc::clone(&observed),
1269            ),
1270            "retry_probe".to_string(),
1271            json!({ "value": "ok" }),
1272            None,
1273        )
1274        .await;
1275
1276        assert!(!outcome.record.output.is_success());
1277        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1278        assert_eq!(observed.lock().expect("observed")[0].1, 1);
1279    }
1280
1281    #[tokio::test]
1282    async fn batch_executes_nested_calls_and_preserves_partial_failures() {
1283        let outcome = dispatch_tool_call(
1284            &dispatch_context(),
1285            "batch".to_string(),
1286            json!({
1287                "tool_calls": [
1288                    {"tool": "alpha", "parameters": {}},
1289                    {"tool": "beta", "parameters": {"value": "ok"}},
1290                    {"tool": "beta", "parameters": {"value": "fail"}}
1291                ]
1292            }),
1293            None,
1294        )
1295        .await;
1296
1297        assert!(outcome.record.output.is_success());
1298        assert_eq!(outcome.record.tool, "batch");
1299        let value = outcome.record.output.value_for_projection();
1300        let results = value
1301            .get("results")
1302            .and_then(|value| value.as_array())
1303            .expect("results");
1304        assert_eq!(results.len(), 3);
1305        assert_eq!(
1306            results
1307                .iter()
1308                .filter(|item| item.get("success").and_then(|value| value.as_bool()) == Some(true))
1309                .count(),
1310            2
1311        );
1312        assert_eq!(results[0].get("tool"), Some(&json!("alpha")));
1313        assert_eq!(
1314            results[2]
1315                .get("error")
1316                .and_then(|value| value.get("message")),
1317            Some(&json!("beta failed"))
1318        );
1319    }
1320
1321    #[tokio::test]
1322    async fn batch_rejects_nested_batch_as_partial_failure() {
1323        let outcome = dispatch_tool_call(
1324            &dispatch_context(),
1325            "batch".to_string(),
1326            json!({
1327                "tool_calls": [
1328                    {"tool": "batch", "parameters": {"tool_calls": []}}
1329                ]
1330            }),
1331            None,
1332        )
1333        .await;
1334
1335        assert!(outcome.record.output.is_success());
1336        let value = outcome.record.output.value_for_projection();
1337        let first = value
1338            .get("results")
1339            .and_then(|value| value.as_array())
1340            .and_then(|items| items.first())
1341            .expect("first result");
1342        assert_eq!(
1343            first.get("error"),
1344            Some(&json!("Tool 'batch' is not allowed inside batch"))
1345        );
1346    }
1347
1348    #[tokio::test]
1349    async fn batch_marks_overflow_calls_as_failures() {
1350        let tool_calls = (0..26)
1351            .map(|_| json!({"tool": "alpha", "parameters": {}}))
1352            .collect::<Vec<_>>();
1353
1354        let outcome = dispatch_tool_call(
1355            &dispatch_context(),
1356            "batch".to_string(),
1357            json!({ "tool_calls": tool_calls }),
1358            None,
1359        )
1360        .await;
1361
1362        assert!(!outcome.record.output.is_success());
1363        let value = outcome.record.output.value_for_projection();
1364        let error = value
1365            .get("message")
1366            .and_then(|value| value.as_str())
1367            .expect("string error result");
1368        assert!(
1369            error.contains("tool_calls") && error.contains("items <= 25"),
1370            "{error}",
1371        );
1372    }
1373
1374    #[tokio::test]
1375    async fn batch_calls_make_progress_concurrently() {
1376        let barrier = Arc::new(Barrier::new(2));
1377        let started = Arc::new(AtomicUsize::new(0));
1378        let outcome = dispatch_tool_call(
1379            &parallel_dispatch_context(Arc::clone(&barrier), Arc::clone(&started)),
1380            "batch".to_string(),
1381            json!({
1382                "tool_calls": [
1383                    {"tool": "probe_a", "parameters": {}},
1384                    {"tool": "probe_b", "parameters": {}}
1385                ]
1386            }),
1387            None,
1388        )
1389        .await;
1390
1391        assert!(outcome.record.output.is_success());
1392        assert_eq!(started.load(Ordering::SeqCst), 2);
1393        let value = outcome.record.output.value_for_projection();
1394        let results = value
1395            .get("results")
1396            .and_then(|value| value.as_array())
1397            .expect("results");
1398        assert_eq!(results.len(), 2);
1399        assert!(
1400            results
1401                .iter()
1402                .all(|item| item.get("success").and_then(|value| value.as_bool()) == Some(true))
1403        );
1404    }
1405
1406    /// A tool provider whose tools are marked [`ToolExecutionMode::Serial`]
1407    /// and log (start, end) instants around a sleep into a shared `Mutex`.
1408    struct SerialProbeTools {
1409        /// (tool_name, start_instant, end_instant)
1410        log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1411    }
1412
1413    #[async_trait::async_trait]
1414    impl ToolProvider for SerialProbeTools {
1415        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
1416            manifests(vec![
1417                test_tool("serial_a", ToolExecutionMode::Serial),
1418                test_tool("serial_b", ToolExecutionMode::Serial),
1419            ])
1420        }
1421
1422        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
1423            contract_from(
1424                vec![
1425                    test_tool("serial_a", ToolExecutionMode::Serial),
1426                    test_tool("serial_b", ToolExecutionMode::Serial),
1427                ],
1428                name,
1429            )
1430        }
1431
1432        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
1433            let start = Instant::now();
1434            // Sleep long enough that if the two tools *were* dispatched
1435            // concurrently, their windows would overlap by a detectable
1436            // margin.
1437            tokio::time::sleep(Duration::from_millis(40)).await;
1438            let end = Instant::now();
1439            self.log
1440                .lock()
1441                .expect("serial probe log")
1442                .push((call.name.to_string(), start, end));
1443            ToolResult::ok(json!(call.name))
1444        }
1445    }
1446
1447    fn serial_dispatch_context(
1448        log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1449    ) -> ToolDispatchContext {
1450        let (event_tx, _event_rx) = mpsc::channel(8);
1451        let plugins = test_plugins(Arc::new(SerialProbeTools { log }));
1452        let tools = plugins.tools();
1453        let surface = plugins.tool_surface("session", ExecutionMode::standard());
1454        ToolDispatchContext {
1455            plugins,
1456            tools,
1457            surface,
1458            host: Arc::new(MockSessionManager::default()),
1459            session_id: "session".to_string(),
1460            event_tx,
1461            turn_injection_bridge: crate::TurnInjectionBridge::new(),
1462            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
1463            turn_context: crate::TurnContext::default(),
1464        }
1465    }
1466
1467    /// Two Serial tools in the same batch must not interleave: the second
1468    /// call's start instant must be at or after the first call's end
1469    /// instant.
1470    #[tokio::test]
1471    async fn serial_tools_do_not_interleave() {
1472        let log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>> =
1473            Arc::new(std::sync::Mutex::new(Vec::new()));
1474        let context = Arc::new(serial_dispatch_context(Arc::clone(&log)));
1475
1476        let specs = vec![
1477            ParallelToolCallSpec {
1478                index: 0,
1479                tool_name: "serial_a".to_string(),
1480                args: json!({}),
1481            },
1482            ParallelToolCallSpec {
1483                index: 1,
1484                tool_name: "serial_b".to_string(),
1485                args: json!({}),
1486            },
1487        ];
1488
1489        let outcomes = dispatch_parallel_tool_calls(context, specs, None).await;
1490
1491        assert_eq!(outcomes.len(), 2);
1492        assert!(
1493            outcomes
1494                .iter()
1495                .all(|outcome| outcome.record.output.is_success())
1496        );
1497        // Outcomes are sorted by original index.
1498        assert_eq!(outcomes[0].index, 0);
1499        assert_eq!(outcomes[1].index, 1);
1500        assert_eq!(outcomes[0].record.tool, "serial_a");
1501        assert_eq!(outcomes[1].record.tool, "serial_b");
1502
1503        let entries = log.lock().expect("log").clone();
1504        assert_eq!(entries.len(), 2, "both serial tools must have executed");
1505        // Sort entries by start time so we compare the first-to-run vs
1506        // second-to-run regardless of which tool happened to go first.
1507        let mut sorted = entries;
1508        sorted.sort_by_key(|(_, start, _)| *start);
1509        let (first_name, _first_start, first_end) = &sorted[0];
1510        let (second_name, second_start, _second_end) = &sorted[1];
1511        assert_ne!(first_name, second_name, "both tools should have run");
1512        assert!(
1513            second_start >= first_end,
1514            "serial tool ranges must not overlap: first ended at {:?}, second started at {:?}",
1515            first_end,
1516            second_start,
1517        );
1518    }
1519
1520    struct SerialRetryProbeTools {
1521        log: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1522        attempts_a: Arc<AtomicUsize>,
1523        attempts_b: Arc<AtomicUsize>,
1524    }
1525
1526    #[async_trait::async_trait]
1527    impl ToolProvider for SerialRetryProbeTools {
1528        fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
1529            manifests(vec![
1530                test_tool("serial_retry_a", ToolExecutionMode::Serial)
1531                    .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1532                test_tool("serial_retry_b", ToolExecutionMode::Serial)
1533                    .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1534            ])
1535        }
1536
1537        fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
1538            contract_from(
1539                vec![
1540                    test_tool("serial_retry_a", ToolExecutionMode::Serial)
1541                        .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1542                    test_tool("serial_retry_b", ToolExecutionMode::Serial)
1543                        .with_retry_policy(ToolRetryPolicy::safe(2, 0, 0)),
1544                ],
1545                name,
1546            )
1547        }
1548
1549        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
1550            let start = Instant::now();
1551            tokio::time::sleep(Duration::from_millis(20)).await;
1552            let end = Instant::now();
1553            self.log
1554                .lock()
1555                .expect("serial retry log")
1556                .push((call.name.to_string(), start, end));
1557
1558            let attempt = match call.name {
1559                "serial_retry_a" => self.attempts_a.fetch_add(1, Ordering::SeqCst) + 1,
1560                "serial_retry_b" => self.attempts_b.fetch_add(1, Ordering::SeqCst) + 1,
1561                _ => 1,
1562            };
1563            if attempt == 1 {
1564                ToolResult::retryable_failure(
1565                    crate::ToolFailureClass::External,
1566                    "transient",
1567                    "transient failure",
1568                    Some(0),
1569                )
1570            } else {
1571                ToolResult::ok(json!(call.name))
1572            }
1573        }
1574    }
1575
1576    #[tokio::test]
1577    async fn serial_tool_retries_do_not_overlap_other_serial_calls() {
1578        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
1579        let attempts_a = Arc::new(AtomicUsize::new(0));
1580        let attempts_b = Arc::new(AtomicUsize::new(0));
1581        let provider = Arc::new(SerialRetryProbeTools {
1582            log: Arc::clone(&log),
1583            attempts_a: Arc::clone(&attempts_a),
1584            attempts_b: Arc::clone(&attempts_b),
1585        });
1586        let (event_tx, _event_rx) = mpsc::channel(8);
1587        let plugins = test_plugins(provider);
1588        let tools = plugins.tools();
1589        let surface = plugins.tool_surface("session", ExecutionMode::standard());
1590        let context = Arc::new(ToolDispatchContext {
1591            plugins,
1592            tools,
1593            surface,
1594            host: Arc::new(MockSessionManager::default()),
1595            session_id: "session".to_string(),
1596            event_tx,
1597            turn_injection_bridge: crate::TurnInjectionBridge::new(),
1598            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
1599            turn_context: crate::TurnContext::default(),
1600        });
1601
1602        let outcomes = dispatch_parallel_tool_calls(
1603            context,
1604            vec![
1605                ParallelToolCallSpec {
1606                    index: 0,
1607                    tool_name: "serial_retry_a".to_string(),
1608                    args: json!({}),
1609                },
1610                ParallelToolCallSpec {
1611                    index: 1,
1612                    tool_name: "serial_retry_b".to_string(),
1613                    args: json!({}),
1614                },
1615            ],
1616            None,
1617        )
1618        .await;
1619
1620        assert!(
1621            outcomes
1622                .iter()
1623                .all(|outcome| outcome.record.output.is_success())
1624        );
1625        assert_eq!(attempts_a.load(Ordering::SeqCst), 2);
1626        assert_eq!(attempts_b.load(Ordering::SeqCst), 2);
1627
1628        let mut entries = log.lock().expect("serial retry log").clone();
1629        entries.sort_by_key(|(_, start, _)| *start);
1630        assert_eq!(entries.len(), 4);
1631        for window in entries.windows(2) {
1632            assert!(
1633                window[1].1 >= window[0].2,
1634                "serial retry windows must not overlap: {:?} then {:?}",
1635                window[0],
1636                window[1],
1637            );
1638        }
1639    }
1640
1641    /// When a batch contains a mix of parallel and serial tools, the
1642    /// parallel-safe ones should still run concurrently with each other
1643    /// (verified via a Barrier), and the serial one should run separately
1644    /// without interleaving with any parallel peer's window.
1645    #[tokio::test]
1646    async fn mixed_batch_runs_parallel_tools_concurrently_and_serial_alone() {
1647        struct MixedTools {
1648            barrier: Arc<Barrier>,
1649            serial_window: Arc<std::sync::Mutex<Option<(Instant, Instant)>>>,
1650            parallel_windows: Arc<std::sync::Mutex<Vec<(String, Instant, Instant)>>>,
1651        }
1652
1653        #[async_trait::async_trait]
1654        impl ToolProvider for MixedTools {
1655            fn tool_manifests(&self) -> Vec<crate::ToolManifest> {
1656                manifests(vec![
1657                    test_tool("par_a", ToolExecutionMode::Parallel),
1658                    test_tool("par_b", ToolExecutionMode::Parallel),
1659                    test_tool("ser", ToolExecutionMode::Serial),
1660                ])
1661            }
1662
1663            fn resolve_contract(&self, name: &str) -> Option<Arc<crate::ToolContract>> {
1664                contract_from(
1665                    vec![
1666                        test_tool("par_a", ToolExecutionMode::Parallel),
1667                        test_tool("par_b", ToolExecutionMode::Parallel),
1668                        test_tool("ser", ToolExecutionMode::Serial),
1669                    ],
1670                    name,
1671                )
1672            }
1673
1674            async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
1675                let name = call.name;
1676                if name == "ser" {
1677                    let start = Instant::now();
1678                    tokio::time::sleep(Duration::from_millis(30)).await;
1679                    let end = Instant::now();
1680                    *self.serial_window.lock().expect("serial window") = Some((start, end));
1681                    ToolResult::ok(json!(name))
1682                } else {
1683                    let start = Instant::now();
1684                    // Block until both parallel tools have reached this
1685                    // barrier — proves they're running concurrently.
1686                    let waited = timeout(Duration::from_millis(200), self.barrier.wait()).await;
1687                    let end = Instant::now();
1688                    self.parallel_windows
1689                        .lock()
1690                        .expect("parallel windows")
1691                        .push((name.to_string(), start, end));
1692                    match waited {
1693                        Ok(_) => ToolResult::ok(json!(name)),
1694                        Err(_) => ToolResult::err_fmt(format!("{name} did not overlap with peer")),
1695                    }
1696                }
1697            }
1698        }
1699
1700        let barrier = Arc::new(Barrier::new(2));
1701        let serial_window = Arc::new(std::sync::Mutex::new(None));
1702        let parallel_windows = Arc::new(std::sync::Mutex::new(Vec::new()));
1703        let (event_tx, _event_rx) = mpsc::channel(8);
1704        let provider = Arc::new(MixedTools {
1705            barrier: Arc::clone(&barrier),
1706            serial_window: Arc::clone(&serial_window),
1707            parallel_windows: Arc::clone(&parallel_windows),
1708        });
1709        let plugins = test_plugins(provider);
1710        let tools = plugins.tools();
1711        let surface = plugins.tool_surface("session", ExecutionMode::standard());
1712        let context = Arc::new(ToolDispatchContext {
1713            plugins,
1714            tools,
1715            surface,
1716            host: Arc::new(MockSessionManager::default()),
1717            session_id: "session".to_string(),
1718            event_tx,
1719            turn_injection_bridge: crate::TurnInjectionBridge::new(),
1720            attachment_store: Arc::new(crate::InMemoryAttachmentStore::new()),
1721            turn_context: crate::TurnContext::default(),
1722        });
1723
1724        let specs = vec![
1725            ParallelToolCallSpec {
1726                index: 0,
1727                tool_name: "par_a".to_string(),
1728                args: json!({}),
1729            },
1730            ParallelToolCallSpec {
1731                index: 1,
1732                tool_name: "ser".to_string(),
1733                args: json!({}),
1734            },
1735            ParallelToolCallSpec {
1736                index: 2,
1737                tool_name: "par_b".to_string(),
1738                args: json!({}),
1739            },
1740        ];
1741
1742        let outcomes = dispatch_parallel_tool_calls(context, specs, None).await;
1743
1744        assert_eq!(outcomes.len(), 3);
1745        assert!(
1746            outcomes
1747                .iter()
1748                .all(|outcome| outcome.record.output.is_success()),
1749            "all tools should succeed: {:?}",
1750            outcomes
1751                .iter()
1752                .map(|outcome| (&outcome.record.tool, outcome.record.output.is_success()))
1753                .collect::<Vec<_>>()
1754        );
1755
1756        let pw = parallel_windows.lock().expect("parallel windows");
1757        assert_eq!(pw.len(), 2);
1758        let sw = serial_window
1759            .lock()
1760            .expect("serial window")
1761            .expect("serial window recorded");
1762
1763        // The serial tool's window must not overlap either parallel
1764        // tool's window (Option A: serial runs after parallel).
1765        for (name, p_start, p_end) in pw.iter() {
1766            assert!(
1767                sw.0 >= *p_end || sw.1 <= *p_start,
1768                "serial window {:?} overlaps parallel window {} {:?}..{:?}",
1769                sw,
1770                name,
1771                p_start,
1772                p_end,
1773            );
1774        }
1775    }
1776}