1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::Arc;
4
5use tokio::sync::{mpsc, Mutex};
6use tokio_util::sync::CancellationToken;
7
8use super::registry::ToolRegistry;
9use super::types::{ToolBatchResult, ToolContext, ToolRequest, ToolResult};
10use crate::controller::types::TurnId;
11
12pub struct ToolExecutor {
14 registry: Arc<ToolRegistry>,
15 tool_result_tx: mpsc::Sender<ToolResult>,
16 batch_result_tx: mpsc::Sender<ToolBatchResult>,
17 batch_counter: AtomicI64,
18}
19
20impl ToolExecutor {
21 pub fn new(
28 registry: Arc<ToolRegistry>,
29 tool_result_tx: mpsc::Sender<ToolResult>,
30 batch_result_tx: mpsc::Sender<ToolBatchResult>,
31 ) -> Self {
32 Self {
33 registry,
34 tool_result_tx,
35 batch_result_tx,
36 batch_counter: AtomicI64::new(0),
37 }
38 }
39
40 pub async fn execute_batch(
47 &self,
48 session_id: i64,
49 turn_id: Option<TurnId>,
50 requests: Vec<ToolRequest>,
51 cancel_token: CancellationToken,
52 ) -> i64 {
53 let batch_id = self.batch_counter.fetch_add(1, Ordering::SeqCst) + 1;
54 let expected_count = requests.len();
55
56 if expected_count == 0 {
57 let batch_result = ToolBatchResult {
59 batch_id,
60 session_id,
61 turn_id,
62 results: Vec::new(),
63 };
64 let _ = self.batch_result_tx.send(batch_result).await;
65 return batch_id;
66 }
67
68 tracing::debug!(
69 batch_id,
70 session_id,
71 tool_count = expected_count,
72 "Starting tool batch execution"
73 );
74
75 let batch = Arc::new(ToolExecutorBatch {
77 batch_id,
78 session_id,
79 turn_id: turn_id.clone(),
80 tool_result_tx: self.tool_result_tx.clone(),
81 batch_result_tx: self.batch_result_tx.clone(),
82 requests: requests.clone(),
83 results: Mutex::new(HashMap::new()),
84 expected_count,
85 });
86
87 for request in requests {
89 let batch = batch.clone();
90 let registry = self.registry.clone();
91 let cancel = cancel_token.clone();
92 let turn_id = turn_id.clone();
93
94 tokio::spawn(async move {
95 batch
96 .run_tool(registry, request, turn_id, cancel)
97 .await;
98 });
99 }
100
101 batch_id
102 }
103
104 pub async fn execute(
106 &self,
107 session_id: i64,
108 turn_id: Option<TurnId>,
109 request: ToolRequest,
110 cancel_token: CancellationToken,
111 ) -> i64 {
112 self.execute_batch(session_id, turn_id, vec![request], cancel_token)
113 .await
114 }
115}
116
117struct ToolExecutorBatch {
119 batch_id: i64,
120 session_id: i64,
121 turn_id: Option<TurnId>,
122 tool_result_tx: mpsc::Sender<ToolResult>,
123 batch_result_tx: mpsc::Sender<ToolBatchResult>,
124 requests: Vec<ToolRequest>,
125 results: Mutex<HashMap<String, ToolResult>>,
126 expected_count: usize,
127}
128
129impl ToolExecutorBatch {
130 async fn run_tool(
132 &self,
133 registry: Arc<ToolRegistry>,
134 request: ToolRequest,
135 turn_id: Option<TurnId>,
136 cancel_token: CancellationToken,
137 ) {
138 let tool_use_id = request.tool_use_id.clone();
139 let tool_name = request.tool_name.clone();
140 let input = request.input.clone();
141
142 tracing::debug!(
143 batch_id = self.batch_id,
144 session_id = self.session_id,
145 tool_name = %tool_name,
146 tool_use_id = %tool_use_id,
147 "Starting tool execution"
148 );
149
150 let tool = registry.get(&tool_name).await;
152
153 let result = match tool {
154 None => {
155 tracing::warn!(
157 batch_id = self.batch_id,
158 tool_name = %tool_name,
159 "Tool not found in registry"
160 );
161 ToolResult::error(
162 self.session_id,
163 tool_name,
164 tool_use_id,
165 input,
166 format!("Tool not found: {}", request.tool_name),
167 turn_id,
168 )
169 }
170 Some(tool) => {
171 let display_name = Some(tool.display_config().display_name);
173
174 let context = ToolContext {
176 session_id: self.session_id,
177 tool_use_id: tool_use_id.clone(),
178 turn_id: turn_id.clone(),
179 };
180
181 tokio::select! {
183 exec_result = tool.execute(context, input.clone()) => {
184 match exec_result {
185 Ok(content) => {
186 tracing::info!(
187 batch_id = self.batch_id,
188 tool_name = %tool_name,
189 result_bytes = content.len(),
190 "Tool execution succeeded"
191 );
192 let compact_summary = Some(tool.compact_summary(&input, &content));
194 ToolResult::success(
195 self.session_id,
196 tool_name,
197 display_name,
198 tool_use_id,
199 input,
200 content,
201 turn_id,
202 compact_summary,
203 )
204 }
205 Err(error) => {
206 tracing::warn!(
207 batch_id = self.batch_id,
208 tool_name = %tool_name,
209 error = %error,
210 "Tool execution failed"
211 );
212 ToolResult::error(
213 self.session_id,
214 tool_name,
215 tool_use_id,
216 input,
217 error,
218 turn_id,
219 )
220 }
221 }
222 }
223 _ = cancel_token.cancelled() => {
224 tracing::warn!(
225 batch_id = self.batch_id,
226 tool_name = %tool_name,
227 "Tool execution cancelled"
228 );
229 ToolResult::timeout(
230 self.session_id,
231 tool_name,
232 tool_use_id,
233 input,
234 turn_id,
235 )
236 }
237 }
238 }
239 };
240
241 self.add_result(result).await;
242 }
243
244 async fn add_result(&self, result: ToolResult) {
246 let _ = self.tool_result_tx.send(result.clone()).await;
248
249 let mut results = self.results.lock().await;
250 results.insert(result.tool_use_id.clone(), result);
251
252 tracing::debug!(
253 batch_id = self.batch_id,
254 completed = results.len(),
255 expected = self.expected_count,
256 "Tool completed in batch"
257 );
258
259 if results.len() == self.expected_count {
261 self.send_batch_result(&results).await;
262 }
263 }
264
265 async fn send_batch_result(&self, results: &HashMap<String, ToolResult>) {
267 let ordered_results: Vec<ToolResult> = self
269 .requests
270 .iter()
271 .filter_map(|req| results.get(&req.tool_use_id).cloned())
272 .collect();
273
274 let batch_result = ToolBatchResult {
275 batch_id: self.batch_id,
276 session_id: self.session_id,
277 turn_id: self.turn_id.clone(),
278 results: ordered_results,
279 };
280
281 tracing::debug!(
282 batch_id = self.batch_id,
283 session_id = self.session_id,
284 result_count = batch_result.results.len(),
285 "Sending batch result"
286 );
287
288 let _ = self.batch_result_tx.send(batch_result).await;
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::controller::tools::types::{Executable, ToolResultStatus, ToolType};
296 use std::future::Future;
297 use std::pin::Pin;
298 use std::time::Duration;
299
300 struct EchoTool;
301
302 impl Executable for EchoTool {
303 fn name(&self) -> &str {
304 "echo"
305 }
306
307 fn description(&self) -> &str {
308 "Echoes input back"
309 }
310
311 fn input_schema(&self) -> &str {
312 r#"{"type":"object","properties":{"message":{"type":"string"}}}"#
313 }
314
315 fn tool_type(&self) -> ToolType {
316 ToolType::Custom
317 }
318
319 fn execute(
320 &self,
321 _context: ToolContext,
322 input: HashMap<String, serde_json::Value>,
323 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
324 let message = input
325 .get("message")
326 .and_then(|v| v.as_str())
327 .unwrap_or("no message")
328 .to_string();
329 Box::pin(async move { Ok(format!("Echo: {}", message)) })
330 }
331 }
332
333 #[allow(dead_code)]
334 struct SlowTool;
335
336 impl Executable for SlowTool {
337 fn name(&self) -> &str {
338 "slow"
339 }
340
341 fn description(&self) -> &str {
342 "A slow tool for testing timeouts"
343 }
344
345 fn input_schema(&self) -> &str {
346 r#"{"type":"object"}"#
347 }
348
349 fn tool_type(&self) -> ToolType {
350 ToolType::Custom
351 }
352
353 fn execute(
354 &self,
355 _context: ToolContext,
356 _input: HashMap<String, serde_json::Value>,
357 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
358 Box::pin(async {
359 tokio::time::sleep(Duration::from_secs(10)).await;
360 Ok("done".to_string())
361 })
362 }
363 }
364
365 #[tokio::test]
366 async fn test_execute_single_tool() {
367 let registry = Arc::new(ToolRegistry::new());
368 registry.register(Arc::new(EchoTool)).await.unwrap();
369
370 let (tool_tx, mut tool_rx) = mpsc::channel(10);
371 let (batch_tx, mut batch_rx) = mpsc::channel(10);
372
373 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
374
375 let mut input = HashMap::new();
376 input.insert(
377 "message".to_string(),
378 serde_json::Value::String("hello".to_string()),
379 );
380
381 let request = ToolRequest {
382 tool_use_id: "test_1".to_string(),
383 tool_name: "echo".to_string(),
384 input,
385 };
386
387 let cancel = CancellationToken::new();
388 executor.execute(1, None, request, cancel).await;
389
390 let result = tool_rx.recv().await.unwrap();
392 assert_eq!(result.status, ToolResultStatus::Success);
393 assert!(result.content.contains("Echo: hello"));
394
395 let batch = batch_rx.recv().await.unwrap();
397 assert_eq!(batch.results.len(), 1);
398 }
399
400 #[tokio::test]
401 async fn test_execute_batch() {
402 let registry = Arc::new(ToolRegistry::new());
403 registry.register(Arc::new(EchoTool)).await.unwrap();
404
405 let (tool_tx, mut tool_rx) = mpsc::channel(10);
406 let (batch_tx, mut batch_rx) = mpsc::channel(10);
407
408 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
409
410 let requests: Vec<ToolRequest> = (0..3)
411 .map(|i| {
412 let mut input = HashMap::new();
413 input.insert(
414 "message".to_string(),
415 serde_json::Value::String(format!("msg_{}", i)),
416 );
417 ToolRequest {
418 tool_use_id: format!("tool_{}", i),
419 tool_name: "echo".to_string(),
420 input,
421 }
422 })
423 .collect();
424
425 let cancel = CancellationToken::new();
426 executor.execute_batch(1, None, requests, cancel).await;
427
428 for _ in 0..3 {
430 let result = tool_rx.recv().await.unwrap();
431 assert_eq!(result.status, ToolResultStatus::Success);
432 }
433
434 let batch = batch_rx.recv().await.unwrap();
436 assert_eq!(batch.results.len(), 3);
437 }
438
439 #[tokio::test]
440 async fn test_tool_not_found() {
441 let registry = Arc::new(ToolRegistry::new());
442
443 let (tool_tx, mut tool_rx) = mpsc::channel(10);
444 let (batch_tx, _batch_rx) = mpsc::channel(10);
445
446 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
447
448 let request = ToolRequest {
449 tool_use_id: "test_1".to_string(),
450 tool_name: "nonexistent".to_string(),
451 input: HashMap::new(),
452 };
453
454 let cancel = CancellationToken::new();
455 executor.execute(1, None, request, cancel).await;
456
457 let result = tool_rx.recv().await.unwrap();
458 assert_eq!(result.status, ToolResultStatus::Error);
459 assert!(result.error.unwrap().contains("not found"));
460 }
461}