Skip to main content

lash_core/session/
tool_execution.rs

1use super::execution_context::RuntimeExecutionContext;
2use crate::tool_dispatch::{
3    ToolCallLaunch, ToolDispatchOutcome, ToolPreparationOutcome,
4    dispatch_prepared_tool_call_launch_with_execution_context,
5    finalize_tool_result_with_execution_context, prepare_tool_call_with_context,
6    schedule_tool_batch,
7};
8use crate::{
9    ModelToolReturn, SessionEvent, ToolCallOutput, ToolCallRecord, ToolCancellation, ToolFailure,
10    ToolFailureClass, ToolResult, TurnActivityId, TurnEvent,
11};
12
13#[derive(Clone, Debug)]
14pub struct ToolInvocation {
15    pub id: String,
16    pub name: String,
17    pub args: serde_json::Value,
18}
19
20#[derive(Clone, Debug)]
21pub struct ToolInvocationReply {
22    pub output: ToolCallOutput,
23    pub record: Option<ToolCallRecord>,
24}
25
26impl ToolInvocationReply {
27    pub fn success(value: serde_json::Value) -> Self {
28        Self {
29            output: ToolCallOutput::success(value),
30            record: None,
31        }
32    }
33
34    pub fn error(value: serde_json::Value) -> Self {
35        let message = value
36            .as_str()
37            .map(ToOwned::to_owned)
38            .unwrap_or_else(|| value.to_string());
39        let mut failure = ToolFailure::tool(ToolFailureClass::Execution, "tool_error", message);
40        failure.raw =
41            Some(serde_json::from_value(value).unwrap_or_else(|_| {
42                crate::ToolValue::String("unserializable tool error".to_string())
43            }));
44        Self {
45            output: ToolCallOutput::failure(failure),
46            record: None,
47        }
48    }
49
50    pub fn from_output(output: ToolCallOutput) -> Self {
51        Self {
52            output,
53            record: None,
54        }
55    }
56
57    pub fn cancelled(message: impl Into<String>) -> Self {
58        Self::from_output(ToolCallOutput::cancelled(ToolCancellation::runtime(
59            message,
60        )))
61    }
62
63    pub(crate) fn with_record(mut self, record: ToolCallRecord) -> Self {
64        self.record = Some(record);
65        self
66    }
67}
68
69#[derive(Clone, Debug)]
70pub(crate) struct CompletedProtocolToolCall {
71    pub completed: crate::sansio::CompletedToolCall,
72    pub record: ToolCallRecord,
73}
74
75pub(crate) enum ProtocolToolCallLaunch {
76    Done(CompletedProtocolToolCall),
77    Pending(crate::tool_dispatch::PendingToolDispatchOutcome),
78}
79
80#[derive(Clone)]
81pub(crate) struct PreparedToolRun {
82    pub prepared: crate::PreparedToolCall,
83    pub index: usize,
84    pub parent_invocation: Option<crate::RuntimeInvocation>,
85    pub activity_id: TurnActivityId,
86}
87
88impl RuntimeExecutionContext<'_> {
89    fn prepared_tool_run(
90        &self,
91        prepared: crate::PreparedToolCall,
92        index: usize,
93        parent_invocation: Option<crate::RuntimeInvocation>,
94    ) -> PreparedToolRun {
95        let activity_id = TurnActivityId::new(format!("tool:{}", prepared.call_id));
96        PreparedToolRun {
97            prepared,
98            index,
99            parent_invocation,
100            activity_id,
101        }
102    }
103
104    #[expect(
105        clippy::too_many_arguments,
106        reason = "tool execution carries explicit runtime call metadata"
107    )]
108    pub(crate) async fn execute_tool_call(
109        &self,
110        call_id: String,
111        name: String,
112        args: serde_json::Value,
113        index: usize,
114        replay: Option<crate::llm::types::ProviderReplayMeta>,
115        parent_invocation: Option<crate::RuntimeInvocation>,
116        lashlang_execution_call_site: Option<crate::ToolLashlangExecutionCallSite>,
117    ) -> CompletedProtocolToolCall {
118        let _ = self
119            .dispatch
120            .event_tx
121            .send(SessionEvent::ToolCallStart {
122                call_id: Some(call_id.clone()),
123                name: name.clone(),
124                args: args.clone(),
125            })
126            .await;
127        let tool_correlation_id = TurnActivityId::new(format!("tool:{call_id}"));
128        self.emit_turn_activity(
129            tool_correlation_id.clone(),
130            TurnEvent::ToolCallStarted {
131                call_id: Some(call_id.clone()),
132                name: name.clone(),
133                args: args.clone(),
134            },
135        )
136        .await;
137
138        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
139        let mut dispatch = (*self.dispatch).clone();
140        dispatch.parent_invocation = parent_invocation.clone();
141        let pending = crate::sansio::PendingToolCall {
142            call_id: call_id.clone(),
143            tool_name: name,
144            args,
145            replay: replay.clone(),
146        };
147        let launch =
148            match prepare_tool_call_with_context(&dispatch, pending, Some(call_id.clone())).await {
149                ToolPreparationOutcome::Prepared(prepared) => {
150                    let dispatch_context = std::sync::Arc::new(dispatch.clone());
151                    let tool_context =
152                        crate::ToolContext::from_dispatch(std::sync::Arc::clone(&dispatch_context))
153                            .prepared_call(&prepared)
154                            .cancellation_token(self.cancellation_token.clone())
155                            .runtime_process_id(self.runtime_process_id.clone())
156                            .parent_invocation(parent_invocation.clone())
157                            .lashlang_execution_call_site(lashlang_execution_call_site.clone())
158                            .build();
159                    dispatch_prepared_tool_call_launch_with_execution_context(
160                        dispatch_context.as_ref(),
161                        prepared,
162                        None,
163                        tool_context,
164                    )
165                    .await
166                }
167                ToolPreparationOutcome::Completed(outcome) => ToolCallLaunch::Done(*outcome),
168            };
169        let mut outcome = match launch {
170            ToolCallLaunch::Done(outcome) => outcome,
171            ToolCallLaunch::Pending(pending) => {
172                self.await_pending_tool_dispatch_outcome(
173                    &call_id,
174                    parent_invocation.clone(),
175                    pending,
176                    self.cancellation_token.clone(),
177                )
178                .await
179            }
180        };
181        outcome.record.call_id = Some(call_id.clone());
182
183        self.complete_tool_call(index, call_id, replay, outcome, tool_correlation_id)
184            .await
185    }
186
187    pub(crate) async fn prepare_tool_call(
188        &self,
189        pending: crate::sansio::PendingToolCall,
190    ) -> ToolPreparationOutcome {
191        let call_id = Some(pending.call_id.clone());
192        prepare_tool_call_with_context(self.dispatch.as_ref(), pending, call_id).await
193    }
194
195    pub(crate) async fn execute_prepared_tool_call_launch(
196        &self,
197        prepared: crate::PreparedToolCall,
198        index: usize,
199        parent_invocation: Option<crate::RuntimeInvocation>,
200    ) -> crate::runtime::ToolCallLaunch {
201        match Box::pin(self.execute_prepared_tool_call_launch_inner(
202            prepared,
203            index,
204            parent_invocation,
205        ))
206        .await
207        {
208            ProtocolToolCallLaunch::Done(completed) => crate::runtime::ToolCallLaunch::Done {
209                result: completed.completed,
210            },
211            ProtocolToolCallLaunch::Pending(pending) => crate::runtime::ToolCallLaunch::Pending {
212                key: pending.key,
213                pending: pending.pending,
214                duration_ms: pending.duration_ms,
215            },
216        }
217    }
218
219    async fn execute_prepared_tool_call_launch_inner(
220        &self,
221        prepared: crate::PreparedToolCall,
222        index: usize,
223        parent_invocation: Option<crate::RuntimeInvocation>,
224    ) -> ProtocolToolCallLaunch {
225        let call_id = prepared.call_id.clone();
226        let name = prepared.tool_name.clone();
227        let args = prepared.args.clone();
228        let replay = prepared.replay.clone();
229        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
230        let run = self.prepared_tool_run(prepared, index, parent_invocation);
231        let prepared = run.prepared.clone();
232        let _ = self
233            .dispatch
234            .event_tx
235            .send(SessionEvent::ToolCallStart {
236                call_id: Some(call_id.clone()),
237                name: name.clone(),
238                args: args.clone(),
239            })
240            .await;
241        let tool_correlation_id = run.activity_id.clone();
242        self.emit_turn_activity(
243            tool_correlation_id.clone(),
244            TurnEvent::ToolCallStarted {
245                call_id: Some(call_id.clone()),
246                name: name.clone(),
247                args: args.clone(),
248            },
249        )
250        .await;
251
252        let tool_context = crate::ToolContext::from_dispatch(std::sync::Arc::clone(&self.dispatch))
253            .prepared_call(&prepared)
254            .cancellation_token(self.cancellation_token.clone())
255            .runtime_process_id(self.runtime_process_id.clone())
256            .parent_invocation(run.parent_invocation.clone())
257            .build();
258        let outcome = Box::pin(dispatch_prepared_tool_call_launch_with_execution_context(
259            self.dispatch.as_ref(),
260            prepared,
261            None,
262            tool_context,
263        ))
264        .await;
265        match outcome {
266            ToolCallLaunch::Done(mut outcome) => {
267                outcome.record.call_id = Some(call_id.clone());
268                tokio::task::yield_now().await;
269                let completed = self
270                    .complete_tool_call(run.index, call_id, replay, outcome, tool_correlation_id)
271                    .await;
272                ProtocolToolCallLaunch::Done(completed)
273            }
274            ToolCallLaunch::Pending(pending) => ProtocolToolCallLaunch::Pending(pending),
275        }
276    }
277
278    pub(super) async fn await_process_with_cancellation(
279        &self,
280        process_id: &str,
281        parent_invocation: Option<crate::RuntimeInvocation>,
282        cancellation: Option<tokio_util::sync::CancellationToken>,
283    ) -> Result<crate::ProcessAwaitOutput, crate::PluginError> {
284        let _phase = self.named_phase("rlm_process.await_handle");
285        if let Some(cancellation) = cancellation {
286            tokio::select! {
287                result = self.dispatch.processes.await_process(
288                    process_id,
289                    self.process_scope(parent_invocation.clone()),
290                ) => result,
291                _ = cancellation.cancelled() => {
292                    let _ = self.dispatch.processes.cancel(
293                        &self.dispatch.session_id,
294                        process_id,
295                        self.process_scope(parent_invocation.clone()),
296                    ).await;
297                    self.dispatch.processes.await_process(
298                        process_id,
299                        self.process_scope(parent_invocation),
300                    ).await
301                }
302            }
303        } else {
304            self.dispatch
305                .processes
306                .await_process(process_id, self.process_scope(parent_invocation))
307                .await
308        }
309    }
310
311    pub(crate) async fn complete_tool_call(
312        &self,
313        _index: usize,
314        call_id: String,
315        replay: Option<crate::llm::types::ProviderReplayMeta>,
316        outcome: ToolDispatchOutcome,
317        tool_correlation_id: TurnActivityId,
318    ) -> CompletedProtocolToolCall {
319        let output = outcome.record.output.clone();
320        let projection_output = output.clone();
321        let projection_tool_name = outcome.record.tool.clone();
322        let projection_args = outcome.record.args.clone();
323        let projection_duration_ms = outcome.record.duration_ms;
324        let projection_call_id = call_id.clone();
325        tokio::task::yield_now().await;
326        let plugins = std::sync::Arc::clone(&self.dispatch.plugins);
327        let projection_context = crate::plugin::ToolResultProjectionContext {
328            session_id: self.dispatch.session_id.clone(),
329            tool_name: projection_tool_name,
330            args: projection_args,
331            output: projection_output,
332            duration_ms: projection_duration_ms,
333            call_id: projection_call_id,
334        };
335        let model_return = match plugins.project_tool_result(projection_context).await {
336            Ok(projected) => projected,
337            Err(err) => ModelToolReturn::text(
338                call_id.clone(),
339                outcome.record.tool.clone(),
340                err.to_string(),
341            ),
342        };
343
344        self.emit_turn_activity(
345            tool_correlation_id,
346            TurnEvent::ToolCallCompleted {
347                call_id: Some(call_id.clone()),
348                name: outcome.record.tool.clone(),
349                args: outcome.record.args.clone(),
350                output: output.clone(),
351                duration_ms: outcome.record.duration_ms,
352            },
353        )
354        .await;
355
356        let record = ToolCallRecord {
357            call_id: Some(call_id.clone()),
358            tool: outcome.record.tool.clone(),
359            args: outcome.record.args.clone(),
360            output: output.clone(),
361            duration_ms: outcome.record.duration_ms,
362        };
363        CompletedProtocolToolCall {
364            completed: crate::sansio::CompletedToolCall {
365                call_id,
366                tool_name: outcome.record.tool,
367                args: outcome.record.args,
368                output,
369                model_return,
370                duration_ms: outcome.record.duration_ms,
371                replay,
372            },
373            record,
374        }
375    }
376
377    pub(crate) async fn pending_completion_dispatch_outcome(
378        &self,
379        tool_name: String,
380        args: serde_json::Value,
381        resolution: crate::Resolution,
382        duration_ms: u64,
383    ) -> ToolDispatchOutcome {
384        let output = crate::tool_result::tool_output_from_completion_resolution(resolution);
385        let result = finalize_tool_result_with_execution_context(
386            self.dispatch.as_ref(),
387            &tool_name,
388            &args,
389            ToolResult::from_output(output),
390            duration_ms,
391        )
392        .await;
393        let output = result.into_done_output().unwrap_or_else(|_| {
394            ToolCallOutput::failure(ToolFailure::runtime(
395                ToolFailureClass::Internal,
396                "pending_tool_not_finalized",
397                "pending tool result reached a completed-output projection path",
398            ))
399        });
400        ToolDispatchOutcome {
401            record: ToolCallRecord {
402                call_id: None,
403                tool: tool_name,
404                args,
405                output,
406                duration_ms,
407            },
408        }
409    }
410
411    async fn await_pending_tool_dispatch_outcome(
412        &self,
413        call_id: &str,
414        parent_invocation: Option<crate::RuntimeInvocation>,
415        pending: crate::tool_dispatch::PendingToolDispatchOutcome,
416        cancellation: Option<tokio_util::sync::CancellationToken>,
417    ) -> ToolDispatchOutcome {
418        let fallback;
419        let parent = if let Some(parent) = parent_invocation.as_ref() {
420            parent
421        } else {
422            fallback = crate::RuntimeInvocation::effect(
423                crate::RuntimeScope::new(&self.dispatch.session_id),
424                format!("tool:{call_id}:await"),
425                crate::RuntimeEffectKind::AwaitEvent,
426                format!("tool:{call_id}:await"),
427            );
428            &fallback
429        };
430        let parent_effect_id = parent.effect_id().unwrap_or("tool");
431        let invocation = crate::runtime::causal::child_effect_invocation(
432            parent,
433            format!("{parent_effect_id}:{call_id}:await"),
434            crate::RuntimeEffectKind::AwaitEvent,
435            format!("{call_id}:await"),
436        );
437        let cancellation = cancellation.unwrap_or_default();
438        let deadline = pending
439            .pending
440            .deadline
441            .map(|duration| std::time::Instant::now() + duration);
442        let outcome = self
443            .dispatch
444            .effect_controller
445            .controller()
446            .execute_effect(
447                crate::RuntimeEffectEnvelope::new(
448                    invocation,
449                    crate::RuntimeEffectCommand::AwaitEvent { key: pending.key },
450                ),
451                crate::RuntimeEffectLocalExecutor::await_event(cancellation, deadline),
452            )
453            .await;
454        let resolution = match outcome.and_then(crate::RuntimeEffectOutcome::into_await_event) {
455            Ok(resolution) => resolution,
456            Err(err) => {
457                return ToolDispatchOutcome {
458                    record: ToolCallRecord {
459                        call_id: None,
460                        tool: pending.tool_name,
461                        args: pending.args,
462                        output: ToolCallOutput::failure(ToolFailure::runtime(
463                            ToolFailureClass::Internal,
464                            "pending_tool_completion_failed",
465                            err.to_string(),
466                        )),
467                        duration_ms: pending.duration_ms,
468                    },
469                };
470            }
471        };
472        self.pending_completion_dispatch_outcome(
473            pending.tool_name,
474            pending.args,
475            resolution,
476            pending.duration_ms,
477        )
478        .await
479    }
480
481    pub async fn call_tool(
482        &self,
483        call_id: String,
484        name: String,
485        args: serde_json::Value,
486        index: usize,
487    ) -> ToolInvocationReply {
488        let executed = self
489            .execute_tool_call(call_id, name, args, index, None, None, None)
490            .await;
491        let reply = ToolInvocationReply::from_output(executed.completed.output);
492        reply.with_record(executed.record)
493    }
494
495    pub async fn call_tool_with_lashlang_execution_call_site(
496        &self,
497        call_id: String,
498        name: String,
499        args: serde_json::Value,
500        index: usize,
501        call_site: crate::ToolLashlangExecutionCallSite,
502    ) -> ToolInvocationReply {
503        let executed = self
504            .execute_tool_call(call_id, name, args, index, None, None, Some(call_site))
505            .await;
506        let reply = ToolInvocationReply::from_output(executed.completed.output);
507        reply.with_record(executed.record)
508    }
509
510    pub async fn call_tool_batch(&self, calls: Vec<ToolInvocation>) -> Vec<ToolInvocationReply> {
511        let indexed_calls = calls.into_iter().enumerate().collect::<Vec<_>>();
512        schedule_tool_batch(
513            indexed_calls,
514            |(index, _)| *index,
515            |(_, call)| self.tool_scheduling(&call.name),
516            |(index, call)| {
517                let ctx = self.clone();
518                async move { ctx.call_tool(call.id, call.name, call.args, index).await }
519            },
520        )
521        .await
522    }
523
524    pub async fn start_tool_call(
525        &self,
526        call_id: String,
527        name: String,
528        args: serde_json::Value,
529    ) -> ToolInvocationReply {
530        self.start_tool_process(call_id, name, args).await
531    }
532
533    pub async fn await_tool_handle(
534        &self,
535        call_id: String,
536        handle: serde_json::Value,
537    ) -> ToolInvocationReply {
538        self.await_process_handle(call_id, handle).await
539    }
540
541    pub async fn cancel_tool_handle(
542        &self,
543        call_id: String,
544        handle: serde_json::Value,
545    ) -> ToolInvocationReply {
546        self.cancel_process_handle(call_id, handle).await
547    }
548
549    pub async fn signal_tool_handle(
550        &self,
551        call_id: String,
552        handle: serde_json::Value,
553        signal_name: String,
554        payload: serde_json::Value,
555    ) -> ToolInvocationReply {
556        self.signal_process_handle(call_id, handle, signal_name, payload)
557            .await
558    }
559}