Skip to main content

lash_core/session/
tool_execution.rs

1use std::sync::Arc;
2
3use super::execution_context::ModeExecutionContext;
4use crate::tool_dispatch::{dispatch_tool_call_with_execution_context, schedule_tool_batch};
5use crate::{
6    ModelToolReturn, SessionEvent, ToolCallOutput, ToolCallRecord, ToolCancellation, ToolContext,
7    ToolFailure, ToolFailureClass, TurnActivityId, TurnEvent,
8};
9
10#[derive(Clone, Debug)]
11pub struct ModeToolBatchItem {
12    pub id: String,
13    pub name: String,
14    pub args: serde_json::Value,
15}
16
17#[derive(Clone, Debug)]
18pub struct ModeToolReply {
19    pub output: ToolCallOutput,
20    pub record: Option<ToolCallRecord>,
21}
22
23impl ModeToolReply {
24    pub fn success(value: serde_json::Value) -> Self {
25        Self {
26            output: ToolCallOutput::success(value),
27            record: None,
28        }
29    }
30
31    pub fn error(value: serde_json::Value) -> Self {
32        let message = value
33            .as_str()
34            .map(ToOwned::to_owned)
35            .unwrap_or_else(|| value.to_string());
36        let mut failure = ToolFailure::tool(ToolFailureClass::Execution, "tool_error", message);
37        failure.raw =
38            Some(serde_json::from_value(value).unwrap_or_else(|_| {
39                crate::ToolValue::String("unserializable tool error".to_string())
40            }));
41        Self {
42            output: ToolCallOutput::failure(failure),
43            record: None,
44        }
45    }
46
47    pub fn from_output(output: ToolCallOutput) -> Self {
48        Self {
49            output,
50            record: None,
51        }
52    }
53
54    pub fn cancelled(message: impl Into<String>) -> Self {
55        Self::from_output(ToolCallOutput::cancelled(ToolCancellation::runtime(
56            message,
57        )))
58    }
59
60    pub(crate) fn with_record(mut self, record: ToolCallRecord) -> Self {
61        self.record = Some(record);
62        self
63    }
64}
65
66#[derive(Clone, Debug)]
67pub(crate) struct CompletedModeToolCall {
68    pub index: usize,
69    pub completed: crate::sansio::CompletedToolCall,
70    pub record: ToolCallRecord,
71}
72
73impl ModeExecutionContext {
74    pub(crate) async fn execute_tool_call(
75        &self,
76        call_id: String,
77        name: String,
78        args: serde_json::Value,
79        index: usize,
80        replay: Option<crate::llm::types::ProviderReplayMeta>,
81    ) -> CompletedModeToolCall {
82        let _ = self
83            .dispatch
84            .event_tx
85            .send(SessionEvent::ToolCallStart {
86                call_id: Some(call_id.clone()),
87                name: name.clone(),
88                args: args.clone(),
89            })
90            .await;
91        let tool_correlation_id = TurnActivityId::new(format!("tool:{call_id}"));
92        self.emit_turn_activity(
93            tool_correlation_id.clone(),
94            TurnEvent::ToolCallStarted {
95                call_id: Some(call_id.clone()),
96                name: name.clone(),
97                args: args.clone(),
98            },
99        )
100        .await;
101
102        let (progress_tx, mut progress_rx) =
103            tokio::sync::mpsc::unbounded_channel::<crate::SandboxMessage>();
104        let event_tx = self.dispatch.event_tx.clone();
105        let progress_handle = tokio::spawn(async move {
106            while let Some(sandbox_msg) = progress_rx.recv().await {
107                if sandbox_msg.kind != "lashlang_code" {
108                    let _ = event_tx
109                        .send(SessionEvent::Message {
110                            text: sandbox_msg.text,
111                            kind: sandbox_msg.kind,
112                        })
113                        .await;
114                }
115            }
116        });
117
118        let mut tool_context = ToolContext::new(
119            self.dispatch.session_id.clone(),
120            Arc::clone(&self.dispatch.host),
121            self.dispatch.turn_context.clone(),
122            Arc::clone(&self.dispatch.attachment_store),
123            Some(call_id.clone()),
124        );
125        tool_context.cancellation_token = self.cancellation_token.clone();
126        let mut outcome = dispatch_tool_call_with_execution_context(
127            &self.dispatch,
128            name,
129            args,
130            Some(&progress_tx),
131            tool_context,
132        )
133        .await;
134        outcome.record.call_id = Some(call_id.clone());
135        drop(progress_tx);
136        let _ = progress_handle.await;
137
138        let output = outcome.record.output.clone();
139        let model_return = match self
140            .dispatch
141            .plugins
142            .project_tool_result(crate::plugin::ToolResultProjectionContext {
143                session_id: self.dispatch.session_id.clone(),
144                tool_name: outcome.record.tool.clone(),
145                args: outcome.record.args.clone(),
146                output: output.clone(),
147                duration_ms: outcome.record.duration_ms,
148                call_id: call_id.clone(),
149            })
150            .await
151        {
152            Ok(projected) => projected,
153            Err(err) => ModelToolReturn::text(
154                call_id.clone(),
155                outcome.record.tool.clone(),
156                err.to_string(),
157            ),
158        };
159
160        self.emit_turn_activity(
161            tool_correlation_id,
162            TurnEvent::ToolCallCompleted {
163                call_id: Some(call_id.clone()),
164                name: outcome.record.tool.clone(),
165                args: outcome.record.args.clone(),
166                output: output.clone(),
167                duration_ms: outcome.record.duration_ms,
168            },
169        )
170        .await;
171
172        let record = ToolCallRecord {
173            call_id: Some(call_id.clone()),
174            tool: outcome.record.tool.clone(),
175            args: outcome.record.args.clone(),
176            output: output.clone(),
177            duration_ms: outcome.record.duration_ms,
178        };
179        CompletedModeToolCall {
180            index,
181            completed: crate::sansio::CompletedToolCall {
182                call_id,
183                tool_name: outcome.record.tool,
184                args: outcome.record.args,
185                output,
186                model_return,
187                duration_ms: outcome.record.duration_ms,
188                replay,
189            },
190            record,
191        }
192    }
193
194    pub async fn call_tool(
195        &self,
196        call_id: String,
197        name: String,
198        args: serde_json::Value,
199        index: usize,
200    ) -> ModeToolReply {
201        if name == "list_async_handles" {
202            let live_monitor_tasks = self.live_monitor_tasks().await;
203            return self.list_async_handles(live_monitor_tasks);
204        }
205        if name == "monitor" {
206            return self.start_monitor_handle_call(call_id, args, index).await;
207        }
208        let executed = self
209            .execute_tool_call(call_id, name, args, index, None)
210            .await;
211        let reply = ModeToolReply::from_output(executed.completed.output);
212        reply.with_record(executed.record)
213    }
214
215    pub async fn call_tool_batch(&self, calls: Vec<ModeToolBatchItem>) -> Vec<ModeToolReply> {
216        let indexed_calls = calls.into_iter().enumerate().collect::<Vec<_>>();
217        schedule_tool_batch(
218            indexed_calls,
219            |(index, _)| *index,
220            |(_, call)| self.tool_execution_mode(&call.name),
221            |(index, call)| {
222                let ctx = self.clone();
223                async move { ctx.call_tool(call.id, call.name, call.args, index).await }
224            },
225        )
226        .await
227    }
228
229    pub async fn start_tool_call(
230        &self,
231        call_id: String,
232        name: String,
233        args: serde_json::Value,
234    ) -> ModeToolReply {
235        if name == "monitor" {
236            return self.start_monitor_handle_call(call_id, args, 0).await;
237        }
238        self.start_async_tool_call(call_id, name, args).await
239    }
240
241    pub async fn await_tool_handle(
242        &self,
243        _call_id: String,
244        handle: serde_json::Value,
245    ) -> ModeToolReply {
246        self.await_async_tool_handle(handle).await
247    }
248
249    pub async fn cancel_tool_handle(
250        &self,
251        _call_id: String,
252        handle: serde_json::Value,
253    ) -> ModeToolReply {
254        self.cancel_async_tool_handle(handle).await
255    }
256}