Skip to main content

lash_core/session/
tool_execution.rs

1use super::execution_context::RuntimeExecutionContext;
2use crate::tool_dispatch::{
3    ToolDispatchOutcome, ToolPreparationOutcome,
4    dispatch_prepared_tool_call_with_execution_context, prepare_tool_call_with_context,
5    schedule_tool_batch,
6};
7use crate::{
8    ModelToolReturn, SessionEvent, ToolCallOutput, ToolCallRecord, ToolCancellation, ToolFailure,
9    ToolFailureClass, TurnActivityId, TurnEvent,
10};
11
12#[derive(Clone, Debug)]
13pub struct ToolInvocation {
14    pub id: String,
15    pub name: String,
16    pub args: serde_json::Value,
17}
18
19#[derive(Clone, Debug)]
20pub struct ToolInvocationReply {
21    pub output: ToolCallOutput,
22    pub record: Option<ToolCallRecord>,
23}
24
25impl ToolInvocationReply {
26    pub fn success(value: serde_json::Value) -> Self {
27        Self {
28            output: ToolCallOutput::success(value),
29            record: None,
30        }
31    }
32
33    pub fn error(value: serde_json::Value) -> Self {
34        let message = value
35            .as_str()
36            .map(ToOwned::to_owned)
37            .unwrap_or_else(|| value.to_string());
38        let mut failure = ToolFailure::tool(ToolFailureClass::Execution, "tool_error", message);
39        failure.raw =
40            Some(serde_json::from_value(value).unwrap_or_else(|_| {
41                crate::ToolValue::String("unserializable tool error".to_string())
42            }));
43        Self {
44            output: ToolCallOutput::failure(failure),
45            record: None,
46        }
47    }
48
49    pub fn from_output(output: ToolCallOutput) -> Self {
50        Self {
51            output,
52            record: None,
53        }
54    }
55
56    pub fn cancelled(message: impl Into<String>) -> Self {
57        Self::from_output(ToolCallOutput::cancelled(ToolCancellation::runtime(
58            message,
59        )))
60    }
61
62    pub(crate) fn with_record(mut self, record: ToolCallRecord) -> Self {
63        self.record = Some(record);
64        self
65    }
66}
67
68#[derive(Clone, Debug)]
69pub(crate) struct CompletedProtocolToolCall {
70    pub index: usize,
71    pub completed: crate::sansio::CompletedToolCall,
72    pub record: ToolCallRecord,
73}
74
75#[derive(Clone)]
76pub(crate) struct PreparedToolRun {
77    pub prepared: crate::PreparedToolCall,
78    pub index: usize,
79    pub parent_invocation: Option<crate::RuntimeInvocation>,
80    pub activity_id: TurnActivityId,
81}
82
83impl RuntimeExecutionContext<'_> {
84    fn prepared_tool_run(
85        &self,
86        prepared: crate::PreparedToolCall,
87        index: usize,
88        parent_invocation: Option<crate::RuntimeInvocation>,
89    ) -> PreparedToolRun {
90        let activity_id = TurnActivityId::new(format!("tool:{}", prepared.call_id));
91        PreparedToolRun {
92            prepared,
93            index,
94            parent_invocation,
95            activity_id,
96        }
97    }
98
99    #[expect(
100        clippy::too_many_arguments,
101        reason = "tool execution carries explicit runtime call metadata"
102    )]
103    pub(crate) async fn execute_tool_call(
104        &self,
105        call_id: String,
106        name: String,
107        args: serde_json::Value,
108        index: usize,
109        replay: Option<crate::llm::types::ProviderReplayMeta>,
110        parent_invocation: Option<crate::RuntimeInvocation>,
111        lashlang_execution_call_site: Option<crate::ToolLashlangExecutionCallSite>,
112    ) -> CompletedProtocolToolCall {
113        let _ = self
114            .dispatch
115            .event_tx
116            .send(SessionEvent::ToolCallStart {
117                call_id: Some(call_id.clone()),
118                name: name.clone(),
119                args: args.clone(),
120            })
121            .await;
122        let tool_correlation_id = TurnActivityId::new(format!("tool:{call_id}"));
123        self.emit_turn_activity(
124            tool_correlation_id.clone(),
125            TurnEvent::ToolCallStarted {
126                call_id: Some(call_id.clone()),
127                name: name.clone(),
128                args: args.clone(),
129            },
130        )
131        .await;
132
133        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
134        let mut dispatch = (*self.dispatch).clone();
135        dispatch.parent_invocation = parent_invocation.clone();
136        let pending = crate::sansio::PendingToolCall {
137            call_id: call_id.clone(),
138            tool_name: name,
139            args,
140            replay: replay.clone(),
141        };
142        let mut outcome =
143            match prepare_tool_call_with_context(&dispatch, pending, Some(call_id.clone())).await {
144                ToolPreparationOutcome::Prepared(prepared) => {
145                    let dispatch_context = std::sync::Arc::new(dispatch.clone());
146                    let tool_context =
147                        crate::ToolContext::from_dispatch(std::sync::Arc::clone(&dispatch_context))
148                            .prepared_call(&prepared)
149                            .cancellation_token(self.cancellation_token.clone())
150                            .parent_invocation(parent_invocation.clone())
151                            .lashlang_execution_call_site(lashlang_execution_call_site.clone())
152                            .build();
153                    dispatch_prepared_tool_call_with_execution_context(
154                        dispatch_context.as_ref(),
155                        prepared,
156                        None,
157                        tool_context,
158                    )
159                    .await
160                }
161                ToolPreparationOutcome::Completed(outcome) => *outcome,
162            };
163        outcome.record.call_id = Some(call_id.clone());
164
165        self.complete_tool_call(index, call_id, replay, outcome, tool_correlation_id)
166            .await
167    }
168
169    pub(crate) async fn prepare_tool_call(
170        &self,
171        pending: crate::sansio::PendingToolCall,
172    ) -> ToolPreparationOutcome {
173        let call_id = Some(pending.call_id.clone());
174        prepare_tool_call_with_context(self.dispatch.as_ref(), pending, call_id).await
175    }
176
177    pub(crate) async fn execute_prepared_tool_call(
178        &self,
179        prepared: crate::PreparedToolCall,
180        index: usize,
181        parent_invocation: Option<crate::RuntimeInvocation>,
182    ) -> CompletedProtocolToolCall {
183        self.execute_prepared_tool_call_inner(prepared, index, parent_invocation)
184            .await
185    }
186
187    async fn execute_prepared_tool_call_inner(
188        &self,
189        prepared: crate::PreparedToolCall,
190        index: usize,
191        parent_invocation: Option<crate::RuntimeInvocation>,
192    ) -> CompletedProtocolToolCall {
193        let call_id = prepared.call_id.clone();
194        let name = prepared.tool_name.clone();
195        let args = prepared.args.clone();
196        let replay = prepared.replay.clone();
197        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
198        let run = self.prepared_tool_run(prepared, index, parent_invocation);
199        let prepared = run.prepared.clone();
200        let _ = self
201            .dispatch
202            .event_tx
203            .send(SessionEvent::ToolCallStart {
204                call_id: Some(call_id.clone()),
205                name: name.clone(),
206                args: args.clone(),
207            })
208            .await;
209        let tool_correlation_id = run.activity_id.clone();
210        self.emit_turn_activity(
211            tool_correlation_id.clone(),
212            TurnEvent::ToolCallStarted {
213                call_id: Some(call_id.clone()),
214                name: name.clone(),
215                args: args.clone(),
216            },
217        )
218        .await;
219
220        let tool_context = crate::ToolContext::from_dispatch(std::sync::Arc::clone(&self.dispatch))
221            .prepared_call(&prepared)
222            .cancellation_token(self.cancellation_token.clone())
223            .parent_invocation(run.parent_invocation.clone())
224            .build();
225        let mut outcome = dispatch_prepared_tool_call_with_execution_context(
226            self.dispatch.as_ref(),
227            prepared,
228            None,
229            tool_context,
230        )
231        .await;
232        outcome.record.call_id = Some(call_id.clone());
233        tokio::task::yield_now().await;
234
235        self.complete_tool_call(run.index, call_id, replay, outcome, tool_correlation_id)
236            .await
237    }
238
239    pub(super) async fn await_process_with_cancellation(
240        &self,
241        process_id: &str,
242        parent_invocation: Option<crate::RuntimeInvocation>,
243        cancellation: Option<tokio_util::sync::CancellationToken>,
244    ) -> Result<crate::ProcessAwaitOutput, crate::PluginError> {
245        if let Some(cancellation) = cancellation {
246            tokio::select! {
247                result = self.dispatch.processes.await_process(
248                    process_id,
249                    self.process_scope(parent_invocation.clone()),
250                ) => result,
251                _ = cancellation.cancelled() => {
252                    let _ = self.dispatch.processes.cancel(
253                        &self.dispatch.session_id,
254                        process_id,
255                        self.process_scope(parent_invocation.clone()),
256                    ).await;
257                    self.dispatch.processes.await_process(
258                        process_id,
259                        self.process_scope(parent_invocation),
260                    ).await
261                }
262            }
263        } else {
264            self.dispatch
265                .processes
266                .await_process(process_id, self.process_scope(parent_invocation))
267                .await
268        }
269    }
270
271    pub(crate) async fn complete_tool_call(
272        &self,
273        index: usize,
274        call_id: String,
275        replay: Option<crate::llm::types::ProviderReplayMeta>,
276        outcome: ToolDispatchOutcome,
277        tool_correlation_id: TurnActivityId,
278    ) -> CompletedProtocolToolCall {
279        let output = outcome.record.output.clone();
280        let projection_output = output.clone();
281        let projection_tool_name = outcome.record.tool.clone();
282        let projection_args = outcome.record.args.clone();
283        let projection_duration_ms = outcome.record.duration_ms;
284        let projection_call_id = call_id.clone();
285        tokio::task::yield_now().await;
286        let plugins = std::sync::Arc::clone(&self.dispatch.plugins);
287        let projection_context = crate::plugin::ToolResultProjectionContext {
288            session_id: self.dispatch.session_id.clone(),
289            tool_name: projection_tool_name,
290            args: projection_args,
291            output: projection_output,
292            duration_ms: projection_duration_ms,
293            call_id: projection_call_id,
294        };
295        let model_return = match plugins.project_tool_result(projection_context).await {
296            Ok(projected) => projected,
297            Err(err) => ModelToolReturn::text(
298                call_id.clone(),
299                outcome.record.tool.clone(),
300                err.to_string(),
301            ),
302        };
303
304        self.emit_turn_activity(
305            tool_correlation_id,
306            TurnEvent::ToolCallCompleted {
307                call_id: Some(call_id.clone()),
308                name: outcome.record.tool.clone(),
309                args: outcome.record.args.clone(),
310                output: output.clone(),
311                duration_ms: outcome.record.duration_ms,
312            },
313        )
314        .await;
315
316        let record = ToolCallRecord {
317            call_id: Some(call_id.clone()),
318            tool: outcome.record.tool.clone(),
319            args: outcome.record.args.clone(),
320            output: output.clone(),
321            duration_ms: outcome.record.duration_ms,
322        };
323        CompletedProtocolToolCall {
324            index,
325            completed: crate::sansio::CompletedToolCall {
326                call_id,
327                tool_name: outcome.record.tool,
328                args: outcome.record.args,
329                output,
330                model_return,
331                duration_ms: outcome.record.duration_ms,
332                replay,
333            },
334            record,
335        }
336    }
337
338    pub async fn call_tool(
339        &self,
340        call_id: String,
341        name: String,
342        args: serde_json::Value,
343        index: usize,
344    ) -> ToolInvocationReply {
345        let executed = self
346            .execute_tool_call(call_id, name, args, index, None, None, None)
347            .await;
348        let reply = ToolInvocationReply::from_output(executed.completed.output);
349        reply.with_record(executed.record)
350    }
351
352    pub async fn call_tool_with_lashlang_execution_call_site(
353        &self,
354        call_id: String,
355        name: String,
356        args: serde_json::Value,
357        index: usize,
358        call_site: crate::ToolLashlangExecutionCallSite,
359    ) -> ToolInvocationReply {
360        let executed = self
361            .execute_tool_call(call_id, name, args, index, None, None, Some(call_site))
362            .await;
363        let reply = ToolInvocationReply::from_output(executed.completed.output);
364        reply.with_record(executed.record)
365    }
366
367    pub async fn call_tool_batch(&self, calls: Vec<ToolInvocation>) -> Vec<ToolInvocationReply> {
368        let indexed_calls = calls.into_iter().enumerate().collect::<Vec<_>>();
369        schedule_tool_batch(
370            indexed_calls,
371            |(index, _)| *index,
372            |(_, call)| self.tool_scheduling(&call.name),
373            |(index, call)| {
374                let ctx = self.clone();
375                async move { ctx.call_tool(call.id, call.name, call.args, index).await }
376            },
377        )
378        .await
379    }
380
381    pub async fn start_tool_call(
382        &self,
383        call_id: String,
384        name: String,
385        args: serde_json::Value,
386    ) -> ToolInvocationReply {
387        self.start_tool_process(call_id, name, args).await
388    }
389
390    pub async fn await_tool_handle(
391        &self,
392        call_id: String,
393        handle: serde_json::Value,
394    ) -> ToolInvocationReply {
395        self.await_process_handle(call_id, handle).await
396    }
397
398    pub async fn cancel_tool_handle(
399        &self,
400        call_id: String,
401        handle: serde_json::Value,
402    ) -> ToolInvocationReply {
403        self.cancel_process_handle(call_id, handle).await
404    }
405
406    pub async fn signal_tool_handle(
407        &self,
408        call_id: String,
409        handle: serde_json::Value,
410        payload: serde_json::Value,
411    ) -> ToolInvocationReply {
412        self.signal_process_handle(call_id, handle, payload).await
413    }
414}