1use std::collections::HashMap;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::Arc;
4
5use tokio::sync::{mpsc, Mutex};
6use tokio_util::sync::CancellationToken;
7
8use super::registry::ToolRegistry;
9use super::types::{ToolBatchResult, ToolContext, ToolRequest, ToolResult};
10use crate::controller::types::TurnId;
11
12pub struct ToolExecutor {
14 registry: Arc<ToolRegistry>,
15 tool_result_tx: mpsc::Sender<ToolResult>,
16 batch_result_tx: mpsc::Sender<ToolBatchResult>,
17 batch_counter: AtomicI64,
18}
19
20impl ToolExecutor {
21 pub fn new(
28 registry: Arc<ToolRegistry>,
29 tool_result_tx: mpsc::Sender<ToolResult>,
30 batch_result_tx: mpsc::Sender<ToolBatchResult>,
31 ) -> Self {
32 Self {
33 registry,
34 tool_result_tx,
35 batch_result_tx,
36 batch_counter: AtomicI64::new(0),
37 }
38 }
39
40 pub async fn execute_batch(
47 &self,
48 session_id: i64,
49 turn_id: Option<TurnId>,
50 requests: Vec<ToolRequest>,
51 cancel_token: CancellationToken,
52 ) -> i64 {
53 let batch_id = self.batch_counter.fetch_add(1, Ordering::SeqCst) + 1;
54 let expected_count = requests.len();
55
56 if expected_count == 0 {
57 let batch_result = ToolBatchResult {
59 batch_id,
60 session_id,
61 turn_id,
62 results: Vec::new(),
63 };
64 if let Err(e) = self.batch_result_tx.send(batch_result).await {
65 tracing::debug!("Failed to send empty batch result: {}", e);
66 }
67 return batch_id;
68 }
69
70 tracing::debug!(
71 batch_id,
72 session_id,
73 tool_count = expected_count,
74 "Starting tool batch execution"
75 );
76
77 let batch = Arc::new(ToolExecutorBatch {
79 batch_id,
80 session_id,
81 turn_id: turn_id.clone(),
82 tool_result_tx: self.tool_result_tx.clone(),
83 batch_result_tx: self.batch_result_tx.clone(),
84 requests: requests.clone(),
85 results: Mutex::new(HashMap::new()),
86 expected_count,
87 });
88
89 for request in requests {
91 let batch = batch.clone();
92 let registry = self.registry.clone();
93 let cancel = cancel_token.clone();
94 let turn_id = turn_id.clone();
95
96 tokio::spawn(async move {
97 batch
98 .run_tool(registry, request, turn_id, cancel)
99 .await;
100 });
101 }
102
103 batch_id
104 }
105
106 pub async fn execute(
108 &self,
109 session_id: i64,
110 turn_id: Option<TurnId>,
111 request: ToolRequest,
112 cancel_token: CancellationToken,
113 ) -> i64 {
114 self.execute_batch(session_id, turn_id, vec![request], cancel_token)
115 .await
116 }
117}
118
119struct ToolExecutorBatch {
121 batch_id: i64,
122 session_id: i64,
123 turn_id: Option<TurnId>,
124 tool_result_tx: mpsc::Sender<ToolResult>,
125 batch_result_tx: mpsc::Sender<ToolBatchResult>,
126 requests: Vec<ToolRequest>,
127 results: Mutex<HashMap<String, ToolResult>>,
128 expected_count: usize,
129}
130
131impl ToolExecutorBatch {
132 async fn run_tool(
134 &self,
135 registry: Arc<ToolRegistry>,
136 request: ToolRequest,
137 turn_id: Option<TurnId>,
138 cancel_token: CancellationToken,
139 ) {
140 let tool_use_id = request.tool_use_id.clone();
141 let tool_name = request.tool_name.clone();
142 let input = request.input.clone();
143
144 tracing::debug!(
145 batch_id = self.batch_id,
146 session_id = self.session_id,
147 tool_name = %tool_name,
148 tool_use_id = %tool_use_id,
149 "Starting tool execution"
150 );
151
152 let tool = registry.get(&tool_name).await;
154
155 let result = match tool {
156 None => {
157 tracing::warn!(
159 batch_id = self.batch_id,
160 tool_name = %tool_name,
161 "Tool not found in registry"
162 );
163 ToolResult::error(
164 self.session_id,
165 tool_name,
166 tool_use_id,
167 input,
168 format!("Tool not found: {}", request.tool_name),
169 turn_id,
170 )
171 }
172 Some(tool) => {
173 let display_name = Some(tool.display_config().display_name);
175
176 let context = ToolContext {
178 session_id: self.session_id,
179 tool_use_id: tool_use_id.clone(),
180 turn_id: turn_id.clone(),
181 };
182
183 tokio::select! {
185 exec_result = tool.execute(context, input.clone()) => {
186 match exec_result {
187 Ok(content) => {
188 tracing::info!(
189 batch_id = self.batch_id,
190 tool_name = %tool_name,
191 result_bytes = content.len(),
192 "Tool execution succeeded"
193 );
194 let compact_summary = Some(tool.compact_summary(&input, &content));
196 ToolResult::success(
197 self.session_id,
198 tool_name,
199 display_name,
200 tool_use_id,
201 input,
202 content,
203 turn_id,
204 compact_summary,
205 )
206 }
207 Err(error) => {
208 tracing::warn!(
209 batch_id = self.batch_id,
210 tool_name = %tool_name,
211 error = %error,
212 "Tool execution failed"
213 );
214 ToolResult::error(
215 self.session_id,
216 tool_name,
217 tool_use_id,
218 input,
219 error,
220 turn_id,
221 )
222 }
223 }
224 }
225 _ = cancel_token.cancelled() => {
226 tracing::warn!(
227 batch_id = self.batch_id,
228 tool_name = %tool_name,
229 "Tool execution cancelled"
230 );
231 ToolResult::timeout(
232 self.session_id,
233 tool_name,
234 tool_use_id,
235 input,
236 turn_id,
237 )
238 }
239 }
240 }
241 };
242
243 self.add_result(result).await;
244 }
245
246 async fn add_result(&self, result: ToolResult) {
248 if let Err(e) = self.tool_result_tx.send(result.clone()).await {
250 tracing::debug!("Failed to send tool result: {}", e);
251 }
252
253 let mut results = self.results.lock().await;
254 results.insert(result.tool_use_id.clone(), result);
255
256 tracing::debug!(
257 batch_id = self.batch_id,
258 completed = results.len(),
259 expected = self.expected_count,
260 "Tool completed in batch"
261 );
262
263 if results.len() == self.expected_count {
265 self.send_batch_result(&results).await;
266 }
267 }
268
269 async fn send_batch_result(&self, results: &HashMap<String, ToolResult>) {
271 let ordered_results: Vec<ToolResult> = self
273 .requests
274 .iter()
275 .filter_map(|req| results.get(&req.tool_use_id).cloned())
276 .collect();
277
278 let batch_result = ToolBatchResult {
279 batch_id: self.batch_id,
280 session_id: self.session_id,
281 turn_id: self.turn_id.clone(),
282 results: ordered_results,
283 };
284
285 tracing::debug!(
286 batch_id = self.batch_id,
287 session_id = self.session_id,
288 result_count = batch_result.results.len(),
289 "Sending batch result"
290 );
291
292 if let Err(e) = self.batch_result_tx.send(batch_result).await {
293 tracing::debug!("Failed to send batch result: {}", e);
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::controller::tools::types::{Executable, ToolResultStatus, ToolType};
302 use std::future::Future;
303 use std::pin::Pin;
304 use std::time::Duration;
305
306 struct EchoTool;
307
308 impl Executable for EchoTool {
309 fn name(&self) -> &str {
310 "echo"
311 }
312
313 fn description(&self) -> &str {
314 "Echoes input back"
315 }
316
317 fn input_schema(&self) -> &str {
318 r#"{"type":"object","properties":{"message":{"type":"string"}}}"#
319 }
320
321 fn tool_type(&self) -> ToolType {
322 ToolType::Custom
323 }
324
325 fn execute(
326 &self,
327 _context: ToolContext,
328 input: HashMap<String, serde_json::Value>,
329 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
330 let message = input
331 .get("message")
332 .and_then(|v| v.as_str())
333 .unwrap_or("no message")
334 .to_string();
335 Box::pin(async move { Ok(format!("Echo: {}", message)) })
336 }
337 }
338
339 struct SlowTool;
340
341 impl Executable for SlowTool {
342 fn name(&self) -> &str {
343 "slow"
344 }
345
346 fn description(&self) -> &str {
347 "A slow tool for testing timeouts"
348 }
349
350 fn input_schema(&self) -> &str {
351 r#"{"type":"object"}"#
352 }
353
354 fn tool_type(&self) -> ToolType {
355 ToolType::Custom
356 }
357
358 fn execute(
359 &self,
360 _context: ToolContext,
361 _input: HashMap<String, serde_json::Value>,
362 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
363 Box::pin(async {
364 tokio::time::sleep(Duration::from_secs(10)).await;
365 Ok("done".to_string())
366 })
367 }
368 }
369
370 #[tokio::test]
371 async fn test_execute_single_tool() {
372 let registry = Arc::new(ToolRegistry::new());
373 registry.register(Arc::new(EchoTool)).await.unwrap();
374
375 let (tool_tx, mut tool_rx) = mpsc::channel(10);
376 let (batch_tx, mut batch_rx) = mpsc::channel(10);
377
378 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
379
380 let mut input = HashMap::new();
381 input.insert(
382 "message".to_string(),
383 serde_json::Value::String("hello".to_string()),
384 );
385
386 let request = ToolRequest {
387 tool_use_id: "test_1".to_string(),
388 tool_name: "echo".to_string(),
389 input,
390 };
391
392 let cancel = CancellationToken::new();
393 executor.execute(1, None, request, cancel).await;
394
395 let result = tool_rx.recv().await.unwrap();
397 assert_eq!(result.status, ToolResultStatus::Success);
398 assert!(result.content.contains("Echo: hello"));
399
400 let batch = batch_rx.recv().await.unwrap();
402 assert_eq!(batch.results.len(), 1);
403 }
404
405 #[tokio::test]
406 async fn test_execute_batch() {
407 let registry = Arc::new(ToolRegistry::new());
408 registry.register(Arc::new(EchoTool)).await.unwrap();
409
410 let (tool_tx, mut tool_rx) = mpsc::channel(10);
411 let (batch_tx, mut batch_rx) = mpsc::channel(10);
412
413 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
414
415 let requests: Vec<ToolRequest> = (0..3)
416 .map(|i| {
417 let mut input = HashMap::new();
418 input.insert(
419 "message".to_string(),
420 serde_json::Value::String(format!("msg_{}", i)),
421 );
422 ToolRequest {
423 tool_use_id: format!("tool_{}", i),
424 tool_name: "echo".to_string(),
425 input,
426 }
427 })
428 .collect();
429
430 let cancel = CancellationToken::new();
431 executor.execute_batch(1, None, requests, cancel).await;
432
433 for _ in 0..3 {
435 let result = tool_rx.recv().await.unwrap();
436 assert_eq!(result.status, ToolResultStatus::Success);
437 }
438
439 let batch = batch_rx.recv().await.unwrap();
441 assert_eq!(batch.results.len(), 3);
442 }
443
444 #[tokio::test]
445 async fn test_tool_not_found() {
446 let registry = Arc::new(ToolRegistry::new());
447
448 let (tool_tx, mut tool_rx) = mpsc::channel(10);
449 let (batch_tx, _batch_rx) = mpsc::channel(10);
450
451 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
452
453 let request = ToolRequest {
454 tool_use_id: "test_1".to_string(),
455 tool_name: "nonexistent".to_string(),
456 input: HashMap::new(),
457 };
458
459 let cancel = CancellationToken::new();
460 executor.execute(1, None, request, cancel).await;
461
462 let result = tool_rx.recv().await.unwrap();
463 assert_eq!(result.status, ToolResultStatus::Error);
464 assert!(result.error.unwrap().contains("not found"));
465 }
466
467 #[tokio::test]
468 async fn test_tool_cancellation() {
469 let registry = Arc::new(ToolRegistry::new());
470 registry.register(Arc::new(SlowTool)).await.unwrap();
471
472 let (tool_tx, mut tool_rx) = mpsc::channel(10);
473 let (batch_tx, _batch_rx) = mpsc::channel(10);
474
475 let executor = ToolExecutor::new(registry, tool_tx, batch_tx);
476
477 let request = ToolRequest {
478 tool_use_id: "test_1".to_string(),
479 tool_name: "slow".to_string(),
480 input: HashMap::new(),
481 };
482
483 let cancel = CancellationToken::new();
484 let cancel_clone = cancel.clone();
485
486 executor.execute(1, None, request, cancel).await;
488
489 tokio::spawn(async move {
491 tokio::time::sleep(Duration::from_millis(50)).await;
492 cancel_clone.cancel();
493 });
494
495 let result = tool_rx.recv().await.unwrap();
497 assert_eq!(result.status, ToolResultStatus::Timeout);
498 }
499}