1use std::sync::Arc;
15use std::time::Instant;
16
17use tokio::sync::RwLock;
18
19use crate::orchestrator::ToolMutability;
20use bamboo_agent_core::{ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult};
21
22#[derive(Clone)]
27pub struct ToolCallRuntime {
28 executor: Arc<dyn ToolExecutor>,
29 parallel_lock: Arc<RwLock<()>>,
33}
34
35#[derive(Debug)]
37pub struct ToolCallResult {
38 pub call_id: String,
39 pub tool_name: String,
40 pub result: Result<ToolResult, ToolError>,
41 pub elapsed_ms: u64,
42 pub was_parallel: bool,
43}
44
45impl ToolCallRuntime {
46 pub fn new(executor: Arc<dyn ToolExecutor>) -> Self {
48 Self {
49 executor,
50 parallel_lock: Arc::new(RwLock::new(())),
51 }
52 }
53
54 pub fn supports_parallel(executor: &Arc<dyn ToolExecutor>, call: &ToolCall) -> bool {
56 executor.call_mutability(call) == ToolMutability::ReadOnly
57 && executor.call_concurrency_safe(call)
58 }
59
60 pub async fn execute(&self, call: &ToolCall, ctx: ToolExecutionContext<'_>) -> ToolCallResult {
62 let tool_name = call.function.name.trim().to_string();
63 let parallel = Self::supports_parallel(&self.executor, call);
64 let started = Instant::now();
65
66 let result = if parallel {
67 let _guard = self.parallel_lock.read().await;
69 self.executor.execute_with_context(call, ctx).await
70 } else {
71 let _guard = self.parallel_lock.write().await;
73 self.executor.execute_with_context(call, ctx).await
74 };
75
76 ToolCallResult {
77 call_id: call.id.clone(),
78 tool_name,
79 result,
80 elapsed_ms: started.elapsed().as_millis() as u64,
81 was_parallel: parallel,
82 }
83 }
84
85 pub async fn execute_batch(
92 &self,
93 calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
94 ) -> Vec<ToolCallResult> {
95 if calls.is_empty() {
96 return Vec::new();
97 }
98
99 let mut results = Vec::with_capacity(calls.len());
101 let mut parallel_batch: Vec<(ToolCall, ToolExecutionContext<'_>)> = Vec::new();
102
103 for (call, ctx) in calls {
104 if Self::supports_parallel(&self.executor, &call) {
105 parallel_batch.push((call, ctx));
106 } else {
107 if !parallel_batch.is_empty() {
109 let batch_results = self.execute_parallel_group(parallel_batch).await;
110 results.extend(batch_results);
111 parallel_batch = Vec::new();
112 }
113 let result = self.execute(&call, ctx).await;
115 results.push(result);
116 }
117 }
118
119 if !parallel_batch.is_empty() {
121 let batch_results = self.execute_parallel_group(parallel_batch).await;
122 results.extend(batch_results);
123 }
124
125 results
126 }
127
128 async fn execute_parallel_group(
130 &self,
131 calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
132 ) -> Vec<ToolCallResult> {
133 let futures: Vec<_> = calls
134 .into_iter()
135 .map(|(call, ctx)| {
136 let runtime = self.clone();
137 async move { runtime.execute(&call, ctx).await }
138 })
139 .collect();
140
141 futures::future::join_all(futures).await
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use async_trait::async_trait;
149 use bamboo_agent_core::{FunctionCall, ToolSchema};
150 use std::sync::atomic::{AtomicUsize, Ordering};
151 use std::time::Duration;
152
153 fn make_call(name: &str) -> ToolCall {
154 ToolCall {
155 id: format!("call_{}", name),
156 tool_type: "function".to_string(),
157 function: FunctionCall {
158 name: name.to_string(),
159 arguments: "{}".to_string(),
160 },
161 }
162 }
163
164 struct CountingExecutor {
165 call_count: AtomicUsize,
166 max_concurrent: Arc<std::sync::Mutex<usize>>,
167 current_concurrent: Arc<AtomicUsize>,
168 delay: Duration,
169 }
170
171 impl CountingExecutor {
172 fn new(delay: Duration) -> Self {
173 Self {
174 call_count: AtomicUsize::new(0),
175 max_concurrent: Arc::new(std::sync::Mutex::new(0)),
176 current_concurrent: Arc::new(AtomicUsize::new(0)),
177 delay,
178 }
179 }
180 }
181
182 #[async_trait]
183 impl ToolExecutor for CountingExecutor {
184 async fn execute(&self, _call: &ToolCall) -> Result<ToolResult, ToolError> {
185 self.execute_with_context(_call, ToolExecutionContext::none("test"))
186 .await
187 }
188
189 async fn execute_with_context(
190 &self,
191 _call: &ToolCall,
192 _ctx: ToolExecutionContext<'_>,
193 ) -> Result<ToolResult, ToolError> {
194 self.call_count.fetch_add(1, Ordering::SeqCst);
195
196 let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
198 {
199 let mut max = self.max_concurrent.lock().unwrap();
200 if current > *max {
201 *max = current;
202 }
203 }
204
205 if self.delay > Duration::ZERO {
206 tokio::time::sleep(self.delay).await;
207 }
208
209 self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
210
211 Ok(ToolResult {
212 success: true,
213 result: "ok".to_string(),
214 display_preference: None,
215 })
216 }
217
218 fn list_tools(&self) -> Vec<ToolSchema> {
219 vec![]
220 }
221 }
222
223 #[test]
224 fn test_supports_parallel() {
225 let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
226 assert!(ToolCallRuntime::supports_parallel(
227 &executor,
228 &make_call("Read")
229 ));
230 assert!(ToolCallRuntime::supports_parallel(
231 &executor,
232 &make_call("Grep")
233 ));
234 assert!(ToolCallRuntime::supports_parallel(
235 &executor,
236 &make_call("Glob")
237 ));
238 assert!(!ToolCallRuntime::supports_parallel(
239 &executor,
240 &make_call("Bash")
241 ));
242 assert!(!ToolCallRuntime::supports_parallel(
243 &executor,
244 &make_call("Write")
245 ));
246 assert!(!ToolCallRuntime::supports_parallel(
247 &executor,
248 &make_call("Edit")
249 ));
250 }
251
252 #[tokio::test]
253 async fn test_single_call_works() {
254 let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
255 let runtime = ToolCallRuntime::new(executor.clone());
256 let call = make_call("Read");
257 let ctx = ToolExecutionContext::none("test");
258
259 let result = runtime.execute(&call, ctx).await;
260 assert!(result.result.is_ok());
261 assert!(result.was_parallel);
262 assert_eq!(result.tool_name, "Read");
263 assert_eq!(executor.call_count.load(Ordering::SeqCst), 1);
264 }
265
266 #[tokio::test]
267 async fn test_mutating_call_is_sequential() {
268 let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
269 let runtime = ToolCallRuntime::new(executor.clone());
270 let call = make_call("Bash");
271 let ctx = ToolExecutionContext::none("test");
272
273 let result = runtime.execute(&call, ctx).await;
274 assert!(result.result.is_ok());
275 assert!(!result.was_parallel);
276 }
277
278 #[tokio::test]
279 async fn test_parallel_reads_are_concurrent() {
280 let executor = Arc::new(CountingExecutor::new(Duration::from_millis(50)));
281 let runtime = ToolCallRuntime::new(executor.clone());
282
283 let handles: Vec<_> = (0..3)
285 .map(|_| {
286 let rt = runtime.clone();
287 let call = make_call("Read");
288 tokio::spawn(
289 async move { rt.execute(&call, ToolExecutionContext::none("test")).await },
290 )
291 })
292 .collect();
293
294 let results: Vec<_> = futures::future::join_all(handles)
295 .await
296 .into_iter()
297 .map(|r| r.unwrap())
298 .collect();
299
300 assert!(results.iter().all(|r| r.result.is_ok()));
302 assert!(results.iter().all(|r| r.was_parallel));
303
304 let max_conc = *executor.max_concurrent.lock().unwrap();
306 assert!(
307 max_conc >= 2,
308 "Expected parallel execution, got max_concurrent={}",
309 max_conc
310 );
311 }
312
313 #[tokio::test]
314 async fn test_batch_empty() {
315 let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
316 let runtime = ToolCallRuntime::new(executor);
317 let results = runtime.execute_batch(vec![]).await;
318 assert!(results.is_empty());
319 }
320
321 #[tokio::test]
322 async fn test_batch_mixed() {
323 let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
324 let runtime = ToolCallRuntime::new(executor);
325
326 let calls: Vec<_> = vec![
327 (make_call("Read"), ToolExecutionContext::none("test")),
328 (make_call("Grep"), ToolExecutionContext::none("test")),
329 (make_call("Bash"), ToolExecutionContext::none("test")), (make_call("Glob"), ToolExecutionContext::none("test")),
331 ];
332
333 let results = runtime.execute_batch(calls).await;
334 assert_eq!(results.len(), 4);
335 assert!(results.iter().all(|r| r.result.is_ok()));
336 }
337}