1use std::sync::Arc;
15use std::time::Instant;
16
17use tokio::sync::RwLock;
18
19use crate::orchestrator::ToolMutability;
20use bamboo_agent_core::{ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult};
21
22#[derive(Clone)]
27pub struct ToolCallRuntime {
28 executor: Arc<dyn ToolExecutor>,
29 parallel_lock: Arc<RwLock<()>>,
33}
34
35#[derive(Debug)]
37pub struct ToolCallResult {
38 pub call_id: String,
39 pub tool_name: String,
40 pub result: Result<ToolResult, ToolError>,
41 pub elapsed_ms: u64,
42 pub was_parallel: bool,
43}
44
45impl ToolCallRuntime {
46 pub fn new(executor: Arc<dyn ToolExecutor>) -> Self {
48 Self {
49 executor,
50 parallel_lock: Arc::new(RwLock::new(())),
51 }
52 }
53
54 pub fn supports_parallel(executor: &Arc<dyn ToolExecutor>, call: &ToolCall) -> bool {
56 let (mutability, concurrency_safe) = executor.call_parallel_classification(call);
59 mutability == ToolMutability::ReadOnly && concurrency_safe
60 }
61
62 pub async fn execute(&self, call: &ToolCall, ctx: ToolExecutionContext<'_>) -> ToolCallResult {
64 let tool_name = call.function.name.trim().to_string();
65 let parallel = Self::supports_parallel(&self.executor, call);
66 let started = Instant::now();
67
68 let result = if parallel {
69 let _guard = self.parallel_lock.read().await;
71 self.executor.execute_with_context(call, ctx).await
72 } else {
73 let _guard = self.parallel_lock.write().await;
75 self.executor.execute_with_context(call, ctx).await
76 };
77
78 ToolCallResult {
79 call_id: call.id.clone(),
80 tool_name,
81 result,
82 elapsed_ms: started.elapsed().as_millis() as u64,
83 was_parallel: parallel,
84 }
85 }
86
87 pub async fn execute_batch(
94 &self,
95 calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
96 ) -> Vec<ToolCallResult> {
97 if calls.is_empty() {
98 return Vec::new();
99 }
100
101 let mut results = Vec::with_capacity(calls.len());
103 let mut parallel_batch: Vec<(ToolCall, ToolExecutionContext<'_>)> = Vec::new();
104
105 for (call, ctx) in calls {
106 if Self::supports_parallel(&self.executor, &call) {
107 parallel_batch.push((call, ctx));
108 } else {
109 if !parallel_batch.is_empty() {
111 let batch_results = self.execute_parallel_group(parallel_batch).await;
112 results.extend(batch_results);
113 parallel_batch = Vec::new();
114 }
115 let result = self.execute(&call, ctx).await;
117 results.push(result);
118 }
119 }
120
121 if !parallel_batch.is_empty() {
123 let batch_results = self.execute_parallel_group(parallel_batch).await;
124 results.extend(batch_results);
125 }
126
127 results
128 }
129
130 async fn execute_parallel_group(
132 &self,
133 calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
134 ) -> Vec<ToolCallResult> {
135 let futures: Vec<_> = calls
136 .into_iter()
137 .map(|(call, ctx)| {
138 let runtime = self.clone();
139 async move { runtime.execute(&call, ctx).await }
140 })
141 .collect();
142
143 futures::future::join_all(futures).await
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use async_trait::async_trait;
151 use bamboo_agent_core::{FunctionCall, ToolSchema};
152 use std::sync::atomic::{AtomicUsize, Ordering};
153 use std::time::Duration;
154
155 fn make_call(name: &str) -> ToolCall {
156 ToolCall {
157 id: format!("call_{}", name),
158 tool_type: "function".to_string(),
159 function: FunctionCall {
160 name: name.to_string(),
161 arguments: "{}".to_string(),
162 },
163 }
164 }
165
166 struct CountingExecutor {
167 call_count: AtomicUsize,
168 max_concurrent: Arc<std::sync::Mutex<usize>>,
169 current_concurrent: Arc<AtomicUsize>,
170 delay: Duration,
171 }
172
173 impl CountingExecutor {
174 fn new(delay: Duration) -> Self {
175 Self {
176 call_count: AtomicUsize::new(0),
177 max_concurrent: Arc::new(std::sync::Mutex::new(0)),
178 current_concurrent: Arc::new(AtomicUsize::new(0)),
179 delay,
180 }
181 }
182 }
183
184 #[async_trait]
185 impl ToolExecutor for CountingExecutor {
186 async fn execute(&self, _call: &ToolCall) -> Result<ToolResult, ToolError> {
187 self.execute_with_context(_call, ToolExecutionContext::none("test"))
188 .await
189 }
190
191 async fn execute_with_context(
192 &self,
193 _call: &ToolCall,
194 _ctx: ToolExecutionContext<'_>,
195 ) -> Result<ToolResult, ToolError> {
196 self.call_count.fetch_add(1, Ordering::SeqCst);
197
198 let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
200 {
201 let mut max = self.max_concurrent.lock().unwrap();
202 if current > *max {
203 *max = current;
204 }
205 }
206
207 if self.delay > Duration::ZERO {
208 tokio::time::sleep(self.delay).await;
209 }
210
211 self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
212
213 Ok(ToolResult {
214 success: true,
215 result: "ok".to_string(),
216 display_preference: None,
217 images: Vec::new(),
218 })
219 }
220
221 fn list_tools(&self) -> Vec<ToolSchema> {
222 vec![]
223 }
224 }
225
226 #[test]
227 fn test_supports_parallel() {
228 let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
229 assert!(ToolCallRuntime::supports_parallel(
230 &executor,
231 &make_call("Read")
232 ));
233 assert!(ToolCallRuntime::supports_parallel(
234 &executor,
235 &make_call("Grep")
236 ));
237 assert!(ToolCallRuntime::supports_parallel(
238 &executor,
239 &make_call("Glob")
240 ));
241 assert!(!ToolCallRuntime::supports_parallel(
242 &executor,
243 &make_call("Bash")
244 ));
245 assert!(!ToolCallRuntime::supports_parallel(
246 &executor,
247 &make_call("Write")
248 ));
249 assert!(!ToolCallRuntime::supports_parallel(
250 &executor,
251 &make_call("Edit")
252 ));
253 }
254
255 #[tokio::test]
256 async fn test_single_call_works() {
257 let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
258 let runtime = ToolCallRuntime::new(executor.clone());
259 let call = make_call("Read");
260 let ctx = ToolExecutionContext::none("test");
261
262 let result = runtime.execute(&call, ctx).await;
263 assert!(result.result.is_ok());
264 assert!(result.was_parallel);
265 assert_eq!(result.tool_name, "Read");
266 assert_eq!(executor.call_count.load(Ordering::SeqCst), 1);
267 }
268
269 #[tokio::test]
270 async fn test_mutating_call_is_sequential() {
271 let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
272 let runtime = ToolCallRuntime::new(executor.clone());
273 let call = make_call("Bash");
274 let ctx = ToolExecutionContext::none("test");
275
276 let result = runtime.execute(&call, ctx).await;
277 assert!(result.result.is_ok());
278 assert!(!result.was_parallel);
279 }
280
281 #[tokio::test]
282 async fn test_parallel_reads_are_concurrent() {
283 let executor = Arc::new(CountingExecutor::new(Duration::from_millis(50)));
284 let runtime = ToolCallRuntime::new(executor.clone());
285
286 let handles: Vec<_> = (0..3)
288 .map(|_| {
289 let rt = runtime.clone();
290 let call = make_call("Read");
291 tokio::spawn(
292 async move { rt.execute(&call, ToolExecutionContext::none("test")).await },
293 )
294 })
295 .collect();
296
297 let results: Vec<_> = futures::future::join_all(handles)
298 .await
299 .into_iter()
300 .map(|r| r.unwrap())
301 .collect();
302
303 assert!(results.iter().all(|r| r.result.is_ok()));
305 assert!(results.iter().all(|r| r.was_parallel));
306
307 let max_conc = *executor.max_concurrent.lock().unwrap();
309 assert!(
310 max_conc >= 2,
311 "Expected parallel execution, got max_concurrent={}",
312 max_conc
313 );
314 }
315
316 #[tokio::test]
317 async fn test_batch_empty() {
318 let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
319 let runtime = ToolCallRuntime::new(executor);
320 let results = runtime.execute_batch(vec![]).await;
321 assert!(results.is_empty());
322 }
323
324 #[tokio::test]
325 async fn test_batch_mixed() {
326 let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
327 let runtime = ToolCallRuntime::new(executor);
328
329 let calls: Vec<_> = vec![
330 (make_call("Read"), ToolExecutionContext::none("test")),
331 (make_call("Grep"), ToolExecutionContext::none("test")),
332 (make_call("Bash"), ToolExecutionContext::none("test")), (make_call("Glob"), ToolExecutionContext::none("test")),
334 ];
335
336 let results = runtime.execute_batch(calls).await;
337 assert_eq!(results.len(), 4);
338 assert!(results.iter().all(|r| r.result.is_ok()));
339 }
340}