Skip to main content

lash_core/session/
tool_execution.rs

1use super::execution_context::RuntimeExecutionContext;
2use crate::tool_dispatch::{
3    ToolCallLaunch, ToolDispatchOutcome, ToolPreparationOutcome,
4    dispatch_prepared_tool_call_launch_with_execution_context,
5    finalize_tool_result_with_execution_context, prepare_tool_call_with_context,
6    schedule_tool_batch,
7};
8use crate::{
9    ModelToolReturn, SessionEvent, ToolCallOutput, ToolCallRecord, ToolCancellation, ToolFailure,
10    ToolFailureClass, ToolResult, TurnActivityId, TurnEvent,
11};
12
13#[derive(Clone)]
14pub struct ToolInvocation {
15    pub id: String,
16    pub name: String,
17    pub args: serde_json::Value,
18    pub child_execution_trace_hook: Option<crate::ToolChildExecutionTraceHook>,
19}
20
21impl ToolInvocation {
22    pub fn new(id: impl Into<String>, name: impl Into<String>, args: serde_json::Value) -> Self {
23        Self {
24            id: id.into(),
25            name: name.into(),
26            args,
27            child_execution_trace_hook: None,
28        }
29    }
30
31    pub fn with_child_execution_trace_hook(
32        mut self,
33        hook: crate::ToolChildExecutionTraceHook,
34    ) -> Self {
35        self.child_execution_trace_hook = Some(hook);
36        self
37    }
38}
39
40impl std::fmt::Debug for ToolInvocation {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("ToolInvocation")
43            .field("id", &self.id)
44            .field("name", &self.name)
45            .field("args", &self.args)
46            .field(
47                "child_execution_trace_hook",
48                &self.child_execution_trace_hook.as_ref().map(|_| "<hook>"),
49            )
50            .finish()
51    }
52}
53
54#[derive(Clone, Debug)]
55pub struct ToolInvocationReply {
56    pub output: ToolCallOutput,
57    pub record: Option<ToolCallRecord>,
58}
59
60impl ToolInvocationReply {
61    pub fn success(value: serde_json::Value) -> Self {
62        Self {
63            output: ToolCallOutput::success(value),
64            record: None,
65        }
66    }
67
68    pub fn error(value: serde_json::Value) -> Self {
69        let message = value
70            .as_str()
71            .map(ToOwned::to_owned)
72            .unwrap_or_else(|| value.to_string());
73        let mut failure = ToolFailure::tool(ToolFailureClass::Execution, "tool_error", message);
74        failure.raw =
75            Some(serde_json::from_value(value).unwrap_or_else(|_| {
76                crate::ToolValue::String("unserializable tool error".to_string())
77            }));
78        Self {
79            output: ToolCallOutput::failure(failure),
80            record: None,
81        }
82    }
83
84    pub fn from_output(output: ToolCallOutput) -> Self {
85        Self {
86            output,
87            record: None,
88        }
89    }
90
91    pub fn cancelled(message: impl Into<String>) -> Self {
92        Self::from_output(ToolCallOutput::cancelled(ToolCancellation::runtime(
93            message,
94        )))
95    }
96
97    pub(crate) fn with_record(mut self, record: ToolCallRecord) -> Self {
98        self.record = Some(record);
99        self
100    }
101}
102
103#[derive(Clone, Debug)]
104pub(crate) struct CompletedProtocolToolCall {
105    pub completed: crate::sansio::CompletedToolCall,
106    pub record: ToolCallRecord,
107}
108
109pub(crate) enum ProtocolToolCallLaunch {
110    Done(CompletedProtocolToolCall),
111    Pending(crate::tool_dispatch::PendingToolDispatchOutcome),
112}
113
114#[derive(Clone)]
115pub(crate) struct PreparedToolRun {
116    pub prepared: crate::PreparedToolCall,
117    pub index: usize,
118    pub parent_invocation: Option<crate::RuntimeInvocation>,
119    pub activity_id: TurnActivityId,
120}
121
122impl RuntimeExecutionContext<'_> {
123    fn prepared_tool_run(
124        &self,
125        prepared: crate::PreparedToolCall,
126        index: usize,
127        parent_invocation: Option<crate::RuntimeInvocation>,
128    ) -> PreparedToolRun {
129        let activity_id = TurnActivityId::new(format!("tool:{}", prepared.call_id));
130        PreparedToolRun {
131            prepared,
132            index,
133            parent_invocation,
134            activity_id,
135        }
136    }
137
138    #[expect(
139        clippy::too_many_arguments,
140        reason = "tool execution carries explicit runtime call metadata"
141    )]
142    pub(crate) async fn execute_tool_call(
143        &self,
144        call_id: String,
145        name: String,
146        args: serde_json::Value,
147        index: usize,
148        replay: Option<crate::llm::types::ProviderReplayMeta>,
149        parent_invocation: Option<crate::RuntimeInvocation>,
150        child_execution_trace_hook: Option<crate::ToolChildExecutionTraceHook>,
151    ) -> CompletedProtocolToolCall {
152        let _ = self
153            .dispatch
154            .event_tx
155            .send(SessionEvent::ToolCallStart {
156                call_id: Some(call_id.clone()),
157                name: name.clone(),
158                args: args.clone(),
159            })
160            .await;
161        let tool_correlation_id = TurnActivityId::new(format!("tool:{call_id}"));
162        self.emit_turn_activity(
163            tool_correlation_id.clone(),
164            TurnEvent::ToolCallStarted {
165                call_id: Some(call_id.clone()),
166                name: name.clone(),
167                args: args.clone(),
168            },
169        )
170        .await;
171
172        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
173        let mut dispatch = (*self.dispatch).clone();
174        dispatch.parent_invocation = parent_invocation.clone();
175        let pending = crate::sansio::PendingToolCall {
176            call_id: call_id.clone(),
177            tool_name: name,
178            args,
179            replay: replay.clone(),
180        };
181        let launch =
182            match prepare_tool_call_with_context(&dispatch, pending, Some(call_id.clone())).await {
183                ToolPreparationOutcome::Prepared(prepared) => {
184                    let dispatch_context = std::sync::Arc::new(dispatch.clone());
185                    let mut tool_context =
186                        crate::ToolContext::from_dispatch(std::sync::Arc::clone(&dispatch_context))
187                            .prepared_call(&prepared)
188                            .cancellation_token(self.cancellation_token.clone())
189                            .runtime_process_id(self.runtime_process_id.clone())
190                            .parent_invocation(parent_invocation.clone())
191                            .child_execution_trace_hook(child_execution_trace_hook.clone());
192                    if let Some(process_events) = self.process_event_context.as_ref() {
193                        tool_context = tool_context.process_events(
194                            process_events.process_id.clone(),
195                            std::sync::Arc::clone(&process_events.registry),
196                            process_events.store.clone(),
197                            process_events.session_store_factory.clone(),
198                            process_events.queued_work_poke.clone(),
199                        );
200                    }
201                    let tool_context = tool_context.build();
202                    dispatch_prepared_tool_call_launch_with_execution_context(
203                        dispatch_context.as_ref(),
204                        prepared,
205                        None,
206                        tool_context,
207                    )
208                    .await
209                }
210                ToolPreparationOutcome::Completed(outcome) => ToolCallLaunch::Done(*outcome),
211            };
212        let mut outcome = match launch {
213            ToolCallLaunch::Done(outcome) => outcome,
214            ToolCallLaunch::Pending(pending) => {
215                self.await_pending_tool_dispatch_outcome(
216                    &call_id,
217                    parent_invocation.clone(),
218                    pending,
219                    self.cancellation_token.clone(),
220                )
221                .await
222            }
223        };
224        outcome.record.call_id = Some(call_id.clone());
225
226        self.complete_tool_call(index, call_id, replay, outcome, tool_correlation_id)
227            .await
228    }
229
230    pub(crate) async fn prepare_tool_call(
231        &self,
232        pending: crate::sansio::PendingToolCall,
233    ) -> ToolPreparationOutcome {
234        let call_id = Some(pending.call_id.clone());
235        prepare_tool_call_with_context(self.dispatch.as_ref(), pending, call_id).await
236    }
237
238    pub(crate) async fn execute_prepared_tool_call_launch(
239        &self,
240        prepared: crate::PreparedToolCall,
241        index: usize,
242        parent_invocation: Option<crate::RuntimeInvocation>,
243    ) -> crate::runtime::ToolCallLaunch {
244        match Box::pin(self.execute_prepared_tool_call_launch_inner(
245            prepared,
246            index,
247            parent_invocation,
248        ))
249        .await
250        {
251            ProtocolToolCallLaunch::Done(completed) => crate::runtime::ToolCallLaunch::Done {
252                result: completed.completed,
253            },
254            ProtocolToolCallLaunch::Pending(pending) => crate::runtime::ToolCallLaunch::Pending {
255                key: pending.key,
256                pending: pending.pending,
257                duration_ms: pending.duration_ms,
258            },
259        }
260    }
261
262    async fn execute_prepared_tool_call_launch_inner(
263        &self,
264        prepared: crate::PreparedToolCall,
265        index: usize,
266        parent_invocation: Option<crate::RuntimeInvocation>,
267    ) -> ProtocolToolCallLaunch {
268        let call_id = prepared.call_id.clone();
269        let name = prepared.tool_name.clone();
270        let args = prepared.args.clone();
271        let replay = prepared.replay.clone();
272        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
273        let run = self.prepared_tool_run(prepared, index, parent_invocation);
274        let prepared = run.prepared.clone();
275        let _ = self
276            .dispatch
277            .event_tx
278            .send(SessionEvent::ToolCallStart {
279                call_id: Some(call_id.clone()),
280                name: name.clone(),
281                args: args.clone(),
282            })
283            .await;
284        let tool_correlation_id = run.activity_id.clone();
285        self.emit_turn_activity(
286            tool_correlation_id.clone(),
287            TurnEvent::ToolCallStarted {
288                call_id: Some(call_id.clone()),
289                name: name.clone(),
290                args: args.clone(),
291            },
292        )
293        .await;
294
295        let mut tool_context =
296            crate::ToolContext::from_dispatch(std::sync::Arc::clone(&self.dispatch))
297                .prepared_call(&prepared)
298                .cancellation_token(self.cancellation_token.clone())
299                .runtime_process_id(self.runtime_process_id.clone())
300                .parent_invocation(run.parent_invocation.clone());
301        if let Some(process_events) = self.process_event_context.as_ref() {
302            tool_context = tool_context.process_events(
303                process_events.process_id.clone(),
304                std::sync::Arc::clone(&process_events.registry),
305                process_events.store.clone(),
306                process_events.session_store_factory.clone(),
307                process_events.queued_work_poke.clone(),
308            );
309        }
310        let tool_context = tool_context.build();
311        let outcome = Box::pin(dispatch_prepared_tool_call_launch_with_execution_context(
312            self.dispatch.as_ref(),
313            prepared,
314            None,
315            tool_context,
316        ))
317        .await;
318        match outcome {
319            ToolCallLaunch::Done(mut outcome) => {
320                outcome.record.call_id = Some(call_id.clone());
321                tokio::task::yield_now().await;
322                let completed = self
323                    .complete_tool_call(run.index, call_id, replay, outcome, tool_correlation_id)
324                    .await;
325                ProtocolToolCallLaunch::Done(completed)
326            }
327            ToolCallLaunch::Pending(pending) => ProtocolToolCallLaunch::Pending(pending),
328        }
329    }
330
331    pub(super) async fn await_process_with_cancellation(
332        &self,
333        process_id: &str,
334        parent_invocation: Option<crate::RuntimeInvocation>,
335        cancellation: Option<tokio_util::sync::CancellationToken>,
336    ) -> Result<crate::ProcessAwaitOutput, crate::PluginError> {
337        let _phase = self.named_phase("process.await_handle");
338        if let Some(cancellation) = cancellation {
339            tokio::select! {
340                result = self.dispatch.processes.await_process(
341                    process_id,
342                    self.process_scope(parent_invocation.clone()),
343                ) => result,
344                _ = cancellation.cancelled() => {
345                    let _ = self.dispatch.processes.cancel(
346                        &self.dispatch.session_id,
347                        process_id,
348                        self.process_scope(parent_invocation.clone()),
349                    ).await;
350                    self.dispatch.processes.await_process(
351                        process_id,
352                        self.process_scope(parent_invocation),
353                    ).await
354                }
355            }
356        } else {
357            self.dispatch
358                .processes
359                .await_process(process_id, self.process_scope(parent_invocation))
360                .await
361        }
362    }
363
364    pub(crate) async fn complete_tool_call(
365        &self,
366        _index: usize,
367        call_id: String,
368        replay: Option<crate::llm::types::ProviderReplayMeta>,
369        outcome: ToolDispatchOutcome,
370        tool_correlation_id: TurnActivityId,
371    ) -> CompletedProtocolToolCall {
372        let output = outcome.record.output.clone();
373        let projection_output = output.clone();
374        let projection_tool_name = outcome.record.tool.clone();
375        let projection_args = outcome.record.args.clone();
376        let projection_duration_ms = outcome.record.duration_ms;
377        let projection_call_id = call_id.clone();
378        tokio::task::yield_now().await;
379        let plugins = std::sync::Arc::clone(&self.dispatch.plugins);
380        let projection_context = crate::plugin::ToolResultProjectionContext {
381            session_id: self.dispatch.session_id.clone(),
382            tool_name: projection_tool_name,
383            args: projection_args,
384            output: projection_output,
385            duration_ms: projection_duration_ms,
386            call_id: projection_call_id,
387        };
388        let model_return = match plugins.project_tool_result(projection_context).await {
389            Ok(projected) => projected,
390            Err(err) => ModelToolReturn::text(
391                call_id.clone(),
392                outcome.record.tool.clone(),
393                err.to_string(),
394            ),
395        };
396
397        self.emit_turn_activity(
398            tool_correlation_id,
399            TurnEvent::ToolCallCompleted {
400                call_id: Some(call_id.clone()),
401                name: outcome.record.tool.clone(),
402                args: outcome.record.args.clone(),
403                output: output.clone(),
404                duration_ms: outcome.record.duration_ms,
405            },
406        )
407        .await;
408
409        let record = ToolCallRecord {
410            call_id: Some(call_id.clone()),
411            tool: outcome.record.tool.clone(),
412            args: outcome.record.args.clone(),
413            output: output.clone(),
414            duration_ms: outcome.record.duration_ms,
415        };
416        CompletedProtocolToolCall {
417            completed: crate::sansio::CompletedToolCall {
418                call_id,
419                tool_name: outcome.record.tool,
420                args: outcome.record.args,
421                output,
422                model_return,
423                duration_ms: outcome.record.duration_ms,
424                replay,
425            },
426            record,
427        }
428    }
429
430    pub(crate) async fn pending_completion_dispatch_outcome(
431        &self,
432        tool_name: String,
433        args: serde_json::Value,
434        resolution: crate::Resolution,
435        duration_ms: u64,
436    ) -> ToolDispatchOutcome {
437        let output = crate::tool_result::tool_output_from_completion_resolution(resolution);
438        let result = finalize_tool_result_with_execution_context(
439            self.dispatch.as_ref(),
440            &tool_name,
441            &args,
442            ToolResult::from_output(output),
443            duration_ms,
444        )
445        .await;
446        let output = result.into_done_output().unwrap_or_else(|_| {
447            ToolCallOutput::failure(ToolFailure::runtime(
448                ToolFailureClass::Internal,
449                "pending_tool_not_finalized",
450                "pending tool result reached a completed-output projection path",
451            ))
452        });
453        ToolDispatchOutcome {
454            record: ToolCallRecord {
455                call_id: None,
456                tool: tool_name,
457                args,
458                output,
459                duration_ms,
460            },
461        }
462    }
463
464    async fn await_pending_tool_dispatch_outcome(
465        &self,
466        call_id: &str,
467        parent_invocation: Option<crate::RuntimeInvocation>,
468        pending: crate::tool_dispatch::PendingToolDispatchOutcome,
469        cancellation: Option<tokio_util::sync::CancellationToken>,
470    ) -> ToolDispatchOutcome {
471        let fallback;
472        let parent = if let Some(parent) = parent_invocation.as_ref() {
473            parent
474        } else {
475            fallback = crate::RuntimeInvocation::effect(
476                crate::RuntimeScope::new(&self.dispatch.session_id),
477                format!("tool:{call_id}:await"),
478                crate::RuntimeEffectKind::AwaitEvent,
479                format!("tool:{call_id}:await"),
480            );
481            &fallback
482        };
483        let parent_effect_id = parent.effect_id().unwrap_or("tool");
484        let invocation = crate::runtime::causal::child_effect_invocation(
485            parent,
486            format!("{parent_effect_id}:{call_id}:await"),
487            crate::RuntimeEffectKind::AwaitEvent,
488            format!("{call_id}:await"),
489        );
490        let cancellation = cancellation.unwrap_or_default();
491        let deadline = pending
492            .pending
493            .deadline
494            .map(|duration| std::time::Instant::now() + duration);
495        let outcome = self
496            .dispatch
497            .effect_controller
498            .controller()
499            .execute_effect(
500                crate::RuntimeEffectEnvelope::new(
501                    invocation,
502                    crate::RuntimeEffectCommand::AwaitEvent { key: pending.key },
503                ),
504                crate::RuntimeEffectLocalExecutor::await_event(cancellation, deadline),
505            )
506            .await;
507        let resolution = match outcome.and_then(crate::RuntimeEffectOutcome::into_await_event) {
508            Ok(resolution) => resolution,
509            Err(err) => {
510                return ToolDispatchOutcome {
511                    record: ToolCallRecord {
512                        call_id: None,
513                        tool: pending.tool_name,
514                        args: pending.args,
515                        output: ToolCallOutput::failure(ToolFailure::runtime(
516                            ToolFailureClass::Internal,
517                            "pending_tool_completion_failed",
518                            err.to_string(),
519                        )),
520                        duration_ms: pending.duration_ms,
521                    },
522                };
523            }
524        };
525        self.pending_completion_dispatch_outcome(
526            pending.tool_name,
527            pending.args,
528            resolution,
529            pending.duration_ms,
530        )
531        .await
532    }
533
534    pub async fn call_tool(
535        &self,
536        call_id: String,
537        name: String,
538        args: serde_json::Value,
539        index: usize,
540    ) -> ToolInvocationReply {
541        let executed = self
542            .execute_tool_call(call_id, name, args, index, None, None, None)
543            .await;
544        let reply = ToolInvocationReply::from_output(executed.completed.output);
545        reply.with_record(executed.record)
546    }
547
548    pub async fn call_tool_with_child_execution_trace_hook(
549        &self,
550        call_id: String,
551        name: String,
552        args: serde_json::Value,
553        index: usize,
554        trace_hook: crate::ToolChildExecutionTraceHook,
555    ) -> ToolInvocationReply {
556        let executed = self
557            .execute_tool_call(call_id, name, args, index, None, None, Some(trace_hook))
558            .await;
559        let reply = ToolInvocationReply::from_output(executed.completed.output);
560        reply.with_record(executed.record)
561    }
562
563    pub async fn call_tool_batch(&self, calls: Vec<ToolInvocation>) -> Vec<ToolInvocationReply> {
564        let indexed_calls = calls.into_iter().enumerate().collect::<Vec<_>>();
565        schedule_tool_batch(
566            indexed_calls,
567            |(index, _)| *index,
568            |(_, call)| self.tool_scheduling(&call.name),
569            |(index, call)| {
570                let ctx = self.clone();
571                async move {
572                    let executed = ctx
573                        .execute_tool_call(
574                            call.id,
575                            call.name,
576                            call.args,
577                            index,
578                            None,
579                            None,
580                            call.child_execution_trace_hook,
581                        )
582                        .await;
583                    ToolInvocationReply::from_output(executed.completed.output)
584                        .with_record(executed.record)
585                }
586            },
587        )
588        .await
589    }
590
591    pub async fn start_tool_call(
592        &self,
593        call_id: String,
594        name: String,
595        args: serde_json::Value,
596    ) -> ToolInvocationReply {
597        self.start_tool_process(call_id, name, args).await
598    }
599
600    pub async fn await_tool_handle(
601        &self,
602        call_id: String,
603        handle: serde_json::Value,
604    ) -> ToolInvocationReply {
605        self.await_process_handle(call_id, handle).await
606    }
607
608    pub async fn cancel_tool_handle(
609        &self,
610        call_id: String,
611        handle: serde_json::Value,
612    ) -> ToolInvocationReply {
613        self.cancel_process_handle(call_id, handle).await
614    }
615
616    pub async fn signal_tool_handle(
617        &self,
618        call_id: String,
619        handle: serde_json::Value,
620        signal_name: String,
621        payload: serde_json::Value,
622    ) -> ToolInvocationReply {
623        self.signal_process_handle(call_id, handle, signal_name, payload)
624            .await
625    }
626}