Skip to main content

lash_core/session/
tool_execution.rs

1use super::execution_context::RuntimeExecutionContext;
2use crate::tool_dispatch::{
3    ToolCallLaunch, ToolDispatchOutcome, ToolPreparationOutcome,
4    dispatch_prepared_tool_call_launch_with_execution_context,
5    finalize_tool_result_with_execution_context, prepare_tool_call_with_context,
6    schedule_tool_batch,
7};
8use crate::{
9    ModelToolReturn, SessionEvent, ToolCallOutput, ToolCallRecord, ToolCancellation, ToolFailure,
10    ToolFailureClass, ToolResult, TurnActivityId, TurnEvent,
11};
12
13#[derive(Clone, Debug)]
14pub struct ToolInvocation {
15    pub id: String,
16    pub name: String,
17    pub args: serde_json::Value,
18}
19
20#[derive(Clone, Debug)]
21pub struct ToolInvocationReply {
22    pub output: ToolCallOutput,
23    pub record: Option<ToolCallRecord>,
24}
25
26impl ToolInvocationReply {
27    pub fn success(value: serde_json::Value) -> Self {
28        Self {
29            output: ToolCallOutput::success(value),
30            record: None,
31        }
32    }
33
34    pub fn error(value: serde_json::Value) -> Self {
35        let message = value
36            .as_str()
37            .map(ToOwned::to_owned)
38            .unwrap_or_else(|| value.to_string());
39        let mut failure = ToolFailure::tool(ToolFailureClass::Execution, "tool_error", message);
40        failure.raw =
41            Some(serde_json::from_value(value).unwrap_or_else(|_| {
42                crate::ToolValue::String("unserializable tool error".to_string())
43            }));
44        Self {
45            output: ToolCallOutput::failure(failure),
46            record: None,
47        }
48    }
49
50    pub fn from_output(output: ToolCallOutput) -> Self {
51        Self {
52            output,
53            record: None,
54        }
55    }
56
57    pub fn cancelled(message: impl Into<String>) -> Self {
58        Self::from_output(ToolCallOutput::cancelled(ToolCancellation::runtime(
59            message,
60        )))
61    }
62
63    pub(crate) fn with_record(mut self, record: ToolCallRecord) -> Self {
64        self.record = Some(record);
65        self
66    }
67}
68
69#[derive(Clone, Debug)]
70pub(crate) struct CompletedProtocolToolCall {
71    pub completed: crate::sansio::CompletedToolCall,
72    pub record: ToolCallRecord,
73}
74
75pub(crate) enum ProtocolToolCallLaunch {
76    Done(CompletedProtocolToolCall),
77    Pending(crate::tool_dispatch::PendingToolDispatchOutcome),
78}
79
80#[derive(Clone)]
81pub(crate) struct PreparedToolRun {
82    pub prepared: crate::PreparedToolCall,
83    pub index: usize,
84    pub parent_invocation: Option<crate::RuntimeInvocation>,
85    pub activity_id: TurnActivityId,
86}
87
88impl RuntimeExecutionContext<'_> {
89    fn prepared_tool_run(
90        &self,
91        prepared: crate::PreparedToolCall,
92        index: usize,
93        parent_invocation: Option<crate::RuntimeInvocation>,
94    ) -> PreparedToolRun {
95        let activity_id = TurnActivityId::new(format!("tool:{}", prepared.call_id));
96        PreparedToolRun {
97            prepared,
98            index,
99            parent_invocation,
100            activity_id,
101        }
102    }
103
104    #[expect(
105        clippy::too_many_arguments,
106        reason = "tool execution carries explicit runtime call metadata"
107    )]
108    pub(crate) async fn execute_tool_call(
109        &self,
110        call_id: String,
111        name: String,
112        args: serde_json::Value,
113        index: usize,
114        replay: Option<crate::llm::types::ProviderReplayMeta>,
115        parent_invocation: Option<crate::RuntimeInvocation>,
116        lashlang_execution_call_site: Option<crate::ToolLashlangExecutionCallSite>,
117    ) -> CompletedProtocolToolCall {
118        let _ = self
119            .dispatch
120            .event_tx
121            .send(SessionEvent::ToolCallStart {
122                call_id: Some(call_id.clone()),
123                name: name.clone(),
124                args: args.clone(),
125            })
126            .await;
127        let tool_correlation_id = TurnActivityId::new(format!("tool:{call_id}"));
128        self.emit_turn_activity(
129            tool_correlation_id.clone(),
130            TurnEvent::ToolCallStarted {
131                call_id: Some(call_id.clone()),
132                name: name.clone(),
133                args: args.clone(),
134            },
135        )
136        .await;
137
138        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
139        let mut dispatch = (*self.dispatch).clone();
140        dispatch.parent_invocation = parent_invocation.clone();
141        let pending = crate::sansio::PendingToolCall {
142            call_id: call_id.clone(),
143            tool_name: name,
144            args,
145            replay: replay.clone(),
146        };
147        let launch =
148            match prepare_tool_call_with_context(&dispatch, pending, Some(call_id.clone())).await {
149                ToolPreparationOutcome::Prepared(prepared) => {
150                    let dispatch_context = std::sync::Arc::new(dispatch.clone());
151                    let mut tool_context =
152                        crate::ToolContext::from_dispatch(std::sync::Arc::clone(&dispatch_context))
153                            .prepared_call(&prepared)
154                            .cancellation_token(self.cancellation_token.clone())
155                            .runtime_process_id(self.runtime_process_id.clone())
156                            .parent_invocation(parent_invocation.clone())
157                            .lashlang_execution_call_site(lashlang_execution_call_site.clone());
158                    if let Some(process_events) = self.process_event_context.as_ref() {
159                        tool_context = tool_context.process_events(
160                            process_events.process_id.clone(),
161                            std::sync::Arc::clone(&process_events.registry),
162                            process_events.store.clone(),
163                            process_events.session_store_factory.clone(),
164                            process_events.queued_work_poke.clone(),
165                        );
166                    }
167                    let tool_context = tool_context.build();
168                    dispatch_prepared_tool_call_launch_with_execution_context(
169                        dispatch_context.as_ref(),
170                        prepared,
171                        None,
172                        tool_context,
173                    )
174                    .await
175                }
176                ToolPreparationOutcome::Completed(outcome) => ToolCallLaunch::Done(*outcome),
177            };
178        let mut outcome = match launch {
179            ToolCallLaunch::Done(outcome) => outcome,
180            ToolCallLaunch::Pending(pending) => {
181                self.await_pending_tool_dispatch_outcome(
182                    &call_id,
183                    parent_invocation.clone(),
184                    pending,
185                    self.cancellation_token.clone(),
186                )
187                .await
188            }
189        };
190        outcome.record.call_id = Some(call_id.clone());
191
192        self.complete_tool_call(index, call_id, replay, outcome, tool_correlation_id)
193            .await
194    }
195
196    pub(crate) async fn prepare_tool_call(
197        &self,
198        pending: crate::sansio::PendingToolCall,
199    ) -> ToolPreparationOutcome {
200        let call_id = Some(pending.call_id.clone());
201        prepare_tool_call_with_context(self.dispatch.as_ref(), pending, call_id).await
202    }
203
204    pub(crate) async fn execute_prepared_tool_call_launch(
205        &self,
206        prepared: crate::PreparedToolCall,
207        index: usize,
208        parent_invocation: Option<crate::RuntimeInvocation>,
209    ) -> crate::runtime::ToolCallLaunch {
210        match Box::pin(self.execute_prepared_tool_call_launch_inner(
211            prepared,
212            index,
213            parent_invocation,
214        ))
215        .await
216        {
217            ProtocolToolCallLaunch::Done(completed) => crate::runtime::ToolCallLaunch::Done {
218                result: completed.completed,
219            },
220            ProtocolToolCallLaunch::Pending(pending) => crate::runtime::ToolCallLaunch::Pending {
221                key: pending.key,
222                pending: pending.pending,
223                duration_ms: pending.duration_ms,
224            },
225        }
226    }
227
228    async fn execute_prepared_tool_call_launch_inner(
229        &self,
230        prepared: crate::PreparedToolCall,
231        index: usize,
232        parent_invocation: Option<crate::RuntimeInvocation>,
233    ) -> ProtocolToolCallLaunch {
234        let call_id = prepared.call_id.clone();
235        let name = prepared.tool_name.clone();
236        let args = prepared.args.clone();
237        let replay = prepared.replay.clone();
238        let parent_invocation = parent_invocation.or_else(|| self.parent_invocation.clone());
239        let run = self.prepared_tool_run(prepared, index, parent_invocation);
240        let prepared = run.prepared.clone();
241        let _ = self
242            .dispatch
243            .event_tx
244            .send(SessionEvent::ToolCallStart {
245                call_id: Some(call_id.clone()),
246                name: name.clone(),
247                args: args.clone(),
248            })
249            .await;
250        let tool_correlation_id = run.activity_id.clone();
251        self.emit_turn_activity(
252            tool_correlation_id.clone(),
253            TurnEvent::ToolCallStarted {
254                call_id: Some(call_id.clone()),
255                name: name.clone(),
256                args: args.clone(),
257            },
258        )
259        .await;
260
261        let mut tool_context =
262            crate::ToolContext::from_dispatch(std::sync::Arc::clone(&self.dispatch))
263                .prepared_call(&prepared)
264                .cancellation_token(self.cancellation_token.clone())
265                .runtime_process_id(self.runtime_process_id.clone())
266                .parent_invocation(run.parent_invocation.clone());
267        if let Some(process_events) = self.process_event_context.as_ref() {
268            tool_context = tool_context.process_events(
269                process_events.process_id.clone(),
270                std::sync::Arc::clone(&process_events.registry),
271                process_events.store.clone(),
272                process_events.session_store_factory.clone(),
273                process_events.queued_work_poke.clone(),
274            );
275        }
276        let tool_context = tool_context.build();
277        let outcome = Box::pin(dispatch_prepared_tool_call_launch_with_execution_context(
278            self.dispatch.as_ref(),
279            prepared,
280            None,
281            tool_context,
282        ))
283        .await;
284        match outcome {
285            ToolCallLaunch::Done(mut outcome) => {
286                outcome.record.call_id = Some(call_id.clone());
287                tokio::task::yield_now().await;
288                let completed = self
289                    .complete_tool_call(run.index, call_id, replay, outcome, tool_correlation_id)
290                    .await;
291                ProtocolToolCallLaunch::Done(completed)
292            }
293            ToolCallLaunch::Pending(pending) => ProtocolToolCallLaunch::Pending(pending),
294        }
295    }
296
297    pub(super) async fn await_process_with_cancellation(
298        &self,
299        process_id: &str,
300        parent_invocation: Option<crate::RuntimeInvocation>,
301        cancellation: Option<tokio_util::sync::CancellationToken>,
302    ) -> Result<crate::ProcessAwaitOutput, crate::PluginError> {
303        let _phase = self.named_phase("rlm_process.await_handle");
304        if let Some(cancellation) = cancellation {
305            tokio::select! {
306                result = self.dispatch.processes.await_process(
307                    process_id,
308                    self.process_scope(parent_invocation.clone()),
309                ) => result,
310                _ = cancellation.cancelled() => {
311                    let _ = self.dispatch.processes.cancel(
312                        &self.dispatch.session_id,
313                        process_id,
314                        self.process_scope(parent_invocation.clone()),
315                    ).await;
316                    self.dispatch.processes.await_process(
317                        process_id,
318                        self.process_scope(parent_invocation),
319                    ).await
320                }
321            }
322        } else {
323            self.dispatch
324                .processes
325                .await_process(process_id, self.process_scope(parent_invocation))
326                .await
327        }
328    }
329
330    pub(crate) async fn complete_tool_call(
331        &self,
332        _index: usize,
333        call_id: String,
334        replay: Option<crate::llm::types::ProviderReplayMeta>,
335        outcome: ToolDispatchOutcome,
336        tool_correlation_id: TurnActivityId,
337    ) -> CompletedProtocolToolCall {
338        let output = outcome.record.output.clone();
339        let projection_output = output.clone();
340        let projection_tool_name = outcome.record.tool.clone();
341        let projection_args = outcome.record.args.clone();
342        let projection_duration_ms = outcome.record.duration_ms;
343        let projection_call_id = call_id.clone();
344        tokio::task::yield_now().await;
345        let plugins = std::sync::Arc::clone(&self.dispatch.plugins);
346        let projection_context = crate::plugin::ToolResultProjectionContext {
347            session_id: self.dispatch.session_id.clone(),
348            tool_name: projection_tool_name,
349            args: projection_args,
350            output: projection_output,
351            duration_ms: projection_duration_ms,
352            call_id: projection_call_id,
353        };
354        let model_return = match plugins.project_tool_result(projection_context).await {
355            Ok(projected) => projected,
356            Err(err) => ModelToolReturn::text(
357                call_id.clone(),
358                outcome.record.tool.clone(),
359                err.to_string(),
360            ),
361        };
362
363        self.emit_turn_activity(
364            tool_correlation_id,
365            TurnEvent::ToolCallCompleted {
366                call_id: Some(call_id.clone()),
367                name: outcome.record.tool.clone(),
368                args: outcome.record.args.clone(),
369                output: output.clone(),
370                duration_ms: outcome.record.duration_ms,
371            },
372        )
373        .await;
374
375        let record = ToolCallRecord {
376            call_id: Some(call_id.clone()),
377            tool: outcome.record.tool.clone(),
378            args: outcome.record.args.clone(),
379            output: output.clone(),
380            duration_ms: outcome.record.duration_ms,
381        };
382        CompletedProtocolToolCall {
383            completed: crate::sansio::CompletedToolCall {
384                call_id,
385                tool_name: outcome.record.tool,
386                args: outcome.record.args,
387                output,
388                model_return,
389                duration_ms: outcome.record.duration_ms,
390                replay,
391            },
392            record,
393        }
394    }
395
396    pub(crate) async fn pending_completion_dispatch_outcome(
397        &self,
398        tool_name: String,
399        args: serde_json::Value,
400        resolution: crate::Resolution,
401        duration_ms: u64,
402    ) -> ToolDispatchOutcome {
403        let output = crate::tool_result::tool_output_from_completion_resolution(resolution);
404        let result = finalize_tool_result_with_execution_context(
405            self.dispatch.as_ref(),
406            &tool_name,
407            &args,
408            ToolResult::from_output(output),
409            duration_ms,
410        )
411        .await;
412        let output = result.into_done_output().unwrap_or_else(|_| {
413            ToolCallOutput::failure(ToolFailure::runtime(
414                ToolFailureClass::Internal,
415                "pending_tool_not_finalized",
416                "pending tool result reached a completed-output projection path",
417            ))
418        });
419        ToolDispatchOutcome {
420            record: ToolCallRecord {
421                call_id: None,
422                tool: tool_name,
423                args,
424                output,
425                duration_ms,
426            },
427        }
428    }
429
430    async fn await_pending_tool_dispatch_outcome(
431        &self,
432        call_id: &str,
433        parent_invocation: Option<crate::RuntimeInvocation>,
434        pending: crate::tool_dispatch::PendingToolDispatchOutcome,
435        cancellation: Option<tokio_util::sync::CancellationToken>,
436    ) -> ToolDispatchOutcome {
437        let fallback;
438        let parent = if let Some(parent) = parent_invocation.as_ref() {
439            parent
440        } else {
441            fallback = crate::RuntimeInvocation::effect(
442                crate::RuntimeScope::new(&self.dispatch.session_id),
443                format!("tool:{call_id}:await"),
444                crate::RuntimeEffectKind::AwaitEvent,
445                format!("tool:{call_id}:await"),
446            );
447            &fallback
448        };
449        let parent_effect_id = parent.effect_id().unwrap_or("tool");
450        let invocation = crate::runtime::causal::child_effect_invocation(
451            parent,
452            format!("{parent_effect_id}:{call_id}:await"),
453            crate::RuntimeEffectKind::AwaitEvent,
454            format!("{call_id}:await"),
455        );
456        let cancellation = cancellation.unwrap_or_default();
457        let deadline = pending
458            .pending
459            .deadline
460            .map(|duration| std::time::Instant::now() + duration);
461        let outcome = self
462            .dispatch
463            .effect_controller
464            .controller()
465            .execute_effect(
466                crate::RuntimeEffectEnvelope::new(
467                    invocation,
468                    crate::RuntimeEffectCommand::AwaitEvent { key: pending.key },
469                ),
470                crate::RuntimeEffectLocalExecutor::await_event(cancellation, deadline),
471            )
472            .await;
473        let resolution = match outcome.and_then(crate::RuntimeEffectOutcome::into_await_event) {
474            Ok(resolution) => resolution,
475            Err(err) => {
476                return ToolDispatchOutcome {
477                    record: ToolCallRecord {
478                        call_id: None,
479                        tool: pending.tool_name,
480                        args: pending.args,
481                        output: ToolCallOutput::failure(ToolFailure::runtime(
482                            ToolFailureClass::Internal,
483                            "pending_tool_completion_failed",
484                            err.to_string(),
485                        )),
486                        duration_ms: pending.duration_ms,
487                    },
488                };
489            }
490        };
491        self.pending_completion_dispatch_outcome(
492            pending.tool_name,
493            pending.args,
494            resolution,
495            pending.duration_ms,
496        )
497        .await
498    }
499
500    pub async fn call_tool(
501        &self,
502        call_id: String,
503        name: String,
504        args: serde_json::Value,
505        index: usize,
506    ) -> ToolInvocationReply {
507        let executed = self
508            .execute_tool_call(call_id, name, args, index, None, None, None)
509            .await;
510        let reply = ToolInvocationReply::from_output(executed.completed.output);
511        reply.with_record(executed.record)
512    }
513
514    pub async fn call_tool_with_lashlang_execution_call_site(
515        &self,
516        call_id: String,
517        name: String,
518        args: serde_json::Value,
519        index: usize,
520        call_site: crate::ToolLashlangExecutionCallSite,
521    ) -> ToolInvocationReply {
522        let executed = self
523            .execute_tool_call(call_id, name, args, index, None, None, Some(call_site))
524            .await;
525        let reply = ToolInvocationReply::from_output(executed.completed.output);
526        reply.with_record(executed.record)
527    }
528
529    pub async fn call_tool_batch(&self, calls: Vec<ToolInvocation>) -> Vec<ToolInvocationReply> {
530        let indexed_calls = calls.into_iter().enumerate().collect::<Vec<_>>();
531        schedule_tool_batch(
532            indexed_calls,
533            |(index, _)| *index,
534            |(_, call)| self.tool_scheduling(&call.name),
535            |(index, call)| {
536                let ctx = self.clone();
537                async move { ctx.call_tool(call.id, call.name, call.args, index).await }
538            },
539        )
540        .await
541    }
542
543    pub async fn start_tool_call(
544        &self,
545        call_id: String,
546        name: String,
547        args: serde_json::Value,
548    ) -> ToolInvocationReply {
549        self.start_tool_process(call_id, name, args).await
550    }
551
552    pub async fn await_tool_handle(
553        &self,
554        call_id: String,
555        handle: serde_json::Value,
556    ) -> ToolInvocationReply {
557        self.await_process_handle(call_id, handle).await
558    }
559
560    pub async fn cancel_tool_handle(
561        &self,
562        call_id: String,
563        handle: serde_json::Value,
564    ) -> ToolInvocationReply {
565        self.cancel_process_handle(call_id, handle).await
566    }
567
568    pub async fn signal_tool_handle(
569        &self,
570        call_id: String,
571        handle: serde_json::Value,
572        signal_name: String,
573        payload: serde_json::Value,
574    ) -> ToolInvocationReply {
575        self.signal_process_handle(call_id, handle, signal_name, payload)
576            .await
577    }
578}