1use std::sync::Arc;
2
3use super::execution_context::ModeExecutionContext;
4use crate::tool_dispatch::{dispatch_tool_call_with_execution_context, schedule_tool_batch};
5use crate::{
6 ModelToolReturn, SessionEvent, ToolCallOutput, ToolCallRecord, ToolCancellation, ToolContext,
7 ToolFailure, ToolFailureClass, TurnActivityId, TurnEvent,
8};
9
10#[derive(Clone, Debug)]
11pub struct ModeToolBatchItem {
12 pub id: String,
13 pub name: String,
14 pub args: serde_json::Value,
15}
16
17#[derive(Clone, Debug)]
18pub struct ModeToolReply {
19 pub output: ToolCallOutput,
20 pub record: Option<ToolCallRecord>,
21}
22
23impl ModeToolReply {
24 pub fn success(value: serde_json::Value) -> Self {
25 Self {
26 output: ToolCallOutput::success(value),
27 record: None,
28 }
29 }
30
31 pub fn error(value: serde_json::Value) -> Self {
32 let message = value
33 .as_str()
34 .map(ToOwned::to_owned)
35 .unwrap_or_else(|| value.to_string());
36 let mut failure = ToolFailure::tool(ToolFailureClass::Execution, "tool_error", message);
37 failure.raw =
38 Some(serde_json::from_value(value).unwrap_or_else(|_| {
39 crate::ToolValue::String("unserializable tool error".to_string())
40 }));
41 Self {
42 output: ToolCallOutput::failure(failure),
43 record: None,
44 }
45 }
46
47 pub fn from_output(output: ToolCallOutput) -> Self {
48 Self {
49 output,
50 record: None,
51 }
52 }
53
54 pub fn cancelled(message: impl Into<String>) -> Self {
55 Self::from_output(ToolCallOutput::cancelled(ToolCancellation::runtime(
56 message,
57 )))
58 }
59
60 pub(crate) fn with_record(mut self, record: ToolCallRecord) -> Self {
61 self.record = Some(record);
62 self
63 }
64}
65
66#[derive(Clone, Debug)]
67pub(crate) struct CompletedModeToolCall {
68 pub index: usize,
69 pub completed: crate::sansio::CompletedToolCall,
70 pub record: ToolCallRecord,
71}
72
73impl ModeExecutionContext {
74 pub(crate) async fn execute_tool_call(
75 &self,
76 call_id: String,
77 name: String,
78 args: serde_json::Value,
79 index: usize,
80 replay: Option<crate::llm::types::ProviderReplayMeta>,
81 ) -> CompletedModeToolCall {
82 let _ = self
83 .dispatch
84 .event_tx
85 .send(SessionEvent::ToolCallStart {
86 call_id: Some(call_id.clone()),
87 name: name.clone(),
88 args: args.clone(),
89 })
90 .await;
91 let tool_correlation_id = TurnActivityId::new(format!("tool:{call_id}"));
92 self.emit_turn_activity(
93 tool_correlation_id.clone(),
94 TurnEvent::ToolCallStarted {
95 call_id: Some(call_id.clone()),
96 name: name.clone(),
97 args: args.clone(),
98 },
99 )
100 .await;
101
102 let (progress_tx, mut progress_rx) =
103 tokio::sync::mpsc::unbounded_channel::<crate::SandboxMessage>();
104 let event_tx = self.dispatch.event_tx.clone();
105 let progress_handle = tokio::spawn(async move {
106 while let Some(sandbox_msg) = progress_rx.recv().await {
107 if sandbox_msg.kind != "lashlang_code" {
108 let _ = event_tx
109 .send(SessionEvent::Message {
110 text: sandbox_msg.text,
111 kind: sandbox_msg.kind,
112 })
113 .await;
114 }
115 }
116 });
117
118 let mut tool_context = ToolContext::new(
119 self.dispatch.session_id.clone(),
120 Arc::clone(&self.dispatch.host),
121 self.dispatch.turn_context.clone(),
122 Arc::clone(&self.dispatch.attachment_store),
123 Some(call_id.clone()),
124 );
125 tool_context.cancellation_token = self.cancellation_token.clone();
126 let mut outcome = dispatch_tool_call_with_execution_context(
127 &self.dispatch,
128 name,
129 args,
130 Some(&progress_tx),
131 tool_context,
132 )
133 .await;
134 outcome.record.call_id = Some(call_id.clone());
135 drop(progress_tx);
136 let _ = progress_handle.await;
137
138 let output = outcome.record.output.clone();
139 let model_return = match self
140 .dispatch
141 .plugins
142 .project_tool_result(crate::plugin::ToolResultProjectionContext {
143 session_id: self.dispatch.session_id.clone(),
144 tool_name: outcome.record.tool.clone(),
145 args: outcome.record.args.clone(),
146 output: output.clone(),
147 duration_ms: outcome.record.duration_ms,
148 call_id: call_id.clone(),
149 })
150 .await
151 {
152 Ok(projected) => projected,
153 Err(err) => ModelToolReturn::text(
154 call_id.clone(),
155 outcome.record.tool.clone(),
156 err.to_string(),
157 ),
158 };
159
160 self.emit_turn_activity(
161 tool_correlation_id,
162 TurnEvent::ToolCallCompleted {
163 call_id: Some(call_id.clone()),
164 name: outcome.record.tool.clone(),
165 args: outcome.record.args.clone(),
166 output: output.clone(),
167 duration_ms: outcome.record.duration_ms,
168 },
169 )
170 .await;
171
172 let record = ToolCallRecord {
173 call_id: Some(call_id.clone()),
174 tool: outcome.record.tool.clone(),
175 args: outcome.record.args.clone(),
176 output: output.clone(),
177 duration_ms: outcome.record.duration_ms,
178 };
179 CompletedModeToolCall {
180 index,
181 completed: crate::sansio::CompletedToolCall {
182 call_id,
183 tool_name: outcome.record.tool,
184 args: outcome.record.args,
185 output,
186 model_return,
187 duration_ms: outcome.record.duration_ms,
188 replay,
189 },
190 record,
191 }
192 }
193
194 pub async fn call_tool(
195 &self,
196 call_id: String,
197 name: String,
198 args: serde_json::Value,
199 index: usize,
200 ) -> ModeToolReply {
201 if name == "list_async_handles" {
202 let live_monitor_tasks = self.live_monitor_tasks().await;
203 return self.list_async_handles(live_monitor_tasks);
204 }
205 if name == "monitor" {
206 return self.start_monitor_handle_call(call_id, args, index).await;
207 }
208 let executed = self
209 .execute_tool_call(call_id, name, args, index, None)
210 .await;
211 let reply = ModeToolReply::from_output(executed.completed.output);
212 reply.with_record(executed.record)
213 }
214
215 pub async fn call_tool_batch(&self, calls: Vec<ModeToolBatchItem>) -> Vec<ModeToolReply> {
216 let indexed_calls = calls.into_iter().enumerate().collect::<Vec<_>>();
217 schedule_tool_batch(
218 indexed_calls,
219 |(index, _)| *index,
220 |(_, call)| self.tool_execution_mode(&call.name),
221 |(index, call)| {
222 let ctx = self.clone();
223 async move { ctx.call_tool(call.id, call.name, call.args, index).await }
224 },
225 )
226 .await
227 }
228
229 pub async fn start_tool_call(
230 &self,
231 call_id: String,
232 name: String,
233 args: serde_json::Value,
234 ) -> ModeToolReply {
235 if name == "monitor" {
236 return self.start_monitor_handle_call(call_id, args, 0).await;
237 }
238 self.start_async_tool_call(call_id, name, args).await
239 }
240
241 pub async fn await_tool_handle(
242 &self,
243 _call_id: String,
244 handle: serde_json::Value,
245 ) -> ModeToolReply {
246 self.await_async_tool_handle(handle).await
247 }
248
249 pub async fn cancel_tool_handle(
250 &self,
251 _call_id: String,
252 handle: serde_json::Value,
253 ) -> ModeToolReply {
254 self.cancel_async_tool_handle(handle).await
255 }
256}