Skip to main content

bamboo_tools/
parallel.rs

1//! Parallel tool execution runtime.
2//!
3//! Inspired by Codex's `ToolCallRuntime`, this module provides a concurrency
4//! manager for tool calls using an RwLock strategy:
5//!
6//! - **Read-only tools** (Read, Grep, Glob, etc.) acquire a *read* lock and
7//!   can execute concurrently with other read-only tools.
8//! - **Mutating tools** (Write, Edit, Bash, etc.) acquire a *write* lock and
9//!   run exclusively — blocking other tools until they finish.
10//!
11//! This ensures that multiple reads can happen in parallel while mutations
12//! are safely serialized.
13
14use std::sync::Arc;
15use std::time::Instant;
16
17use tokio::sync::RwLock;
18
19use crate::orchestrator::ToolMutability;
20use bamboo_agent_core::{ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult};
21
22/// The parallel tool call runtime.
23///
24/// Wraps a `ToolExecutor` and adds concurrency control via RwLock.
25/// Clone is cheap — all state is behind Arc.
26#[derive(Clone)]
27pub struct ToolCallRuntime {
28    executor: Arc<dyn ToolExecutor>,
29    /// RwLock for concurrency control:
30    /// - Read lock = parallel-safe (multiple readers)
31    /// - Write lock = exclusive (single writer)
32    parallel_lock: Arc<RwLock<()>>,
33}
34
35/// Result of a tool call execution with timing metadata.
36#[derive(Debug)]
37pub struct ToolCallResult {
38    pub call_id: String,
39    pub tool_name: String,
40    pub result: Result<ToolResult, ToolError>,
41    pub elapsed_ms: u64,
42    pub was_parallel: bool,
43}
44
45impl ToolCallRuntime {
46    /// Create a new runtime wrapping the given executor.
47    pub fn new(executor: Arc<dyn ToolExecutor>) -> Self {
48        Self {
49            executor,
50            parallel_lock: Arc::new(RwLock::new(())),
51        }
52    }
53
54    /// Determine if a tool supports parallel execution.
55    pub fn supports_parallel(executor: &Arc<dyn ToolExecutor>, call: &ToolCall) -> bool {
56        // Compute mutability + concurrency-safety together so the executor parses
57        // the call's arguments once instead of once per classification (issue #17).
58        let (mutability, concurrency_safe) = executor.call_parallel_classification(call);
59        mutability == ToolMutability::ReadOnly && concurrency_safe
60    }
61
62    /// Execute a single tool call with appropriate concurrency control.
63    pub async fn execute(&self, call: &ToolCall, ctx: ToolExecutionContext<'_>) -> ToolCallResult {
64        let tool_name = call.function.name.trim().to_string();
65        let parallel = Self::supports_parallel(&self.executor, call);
66        let started = Instant::now();
67
68        let result = if parallel {
69            // Read lock — allows concurrent execution with other readers
70            let _guard = self.parallel_lock.read().await;
71            self.executor.execute_with_context(call, ctx).await
72        } else {
73            // Write lock — exclusive execution
74            let _guard = self.parallel_lock.write().await;
75            self.executor.execute_with_context(call, ctx).await
76        };
77
78        ToolCallResult {
79            call_id: call.id.clone(),
80            tool_name,
81            result,
82            elapsed_ms: started.elapsed().as_millis() as u64,
83            was_parallel: parallel,
84        }
85    }
86
87    /// Execute multiple tool calls with automatic parallel/sequential scheduling.
88    ///
89    /// Tool calls are partitioned into batches:
90    /// - Consecutive read-only tools run concurrently
91    /// - Mutating tools run one at a time
92    /// - Order is preserved (mutating tools act as barriers)
93    pub async fn execute_batch(
94        &self,
95        calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
96    ) -> Vec<ToolCallResult> {
97        if calls.is_empty() {
98            return Vec::new();
99        }
100
101        // Split into groups: consecutive parallel-safe calls are batched
102        let mut results = Vec::with_capacity(calls.len());
103        let mut parallel_batch: Vec<(ToolCall, ToolExecutionContext<'_>)> = Vec::new();
104
105        for (call, ctx) in calls {
106            if Self::supports_parallel(&self.executor, &call) {
107                parallel_batch.push((call, ctx));
108            } else {
109                // Flush any pending parallel batch first
110                if !parallel_batch.is_empty() {
111                    let batch_results = self.execute_parallel_group(parallel_batch).await;
112                    results.extend(batch_results);
113                    parallel_batch = Vec::new();
114                }
115                // Execute the mutating tool sequentially
116                let result = self.execute(&call, ctx).await;
117                results.push(result);
118            }
119        }
120
121        // Flush remaining parallel batch
122        if !parallel_batch.is_empty() {
123            let batch_results = self.execute_parallel_group(parallel_batch).await;
124            results.extend(batch_results);
125        }
126
127        results
128    }
129
130    /// Execute a group of parallel-safe tool calls concurrently.
131    async fn execute_parallel_group(
132        &self,
133        calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
134    ) -> Vec<ToolCallResult> {
135        let futures: Vec<_> = calls
136            .into_iter()
137            .map(|(call, ctx)| {
138                let runtime = self.clone();
139                async move { runtime.execute(&call, ctx).await }
140            })
141            .collect();
142
143        futures::future::join_all(futures).await
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use async_trait::async_trait;
151    use bamboo_agent_core::{FunctionCall, ToolSchema};
152    use std::sync::atomic::{AtomicUsize, Ordering};
153    use std::time::Duration;
154
155    fn make_call(name: &str) -> ToolCall {
156        ToolCall {
157            id: format!("call_{}", name),
158            tool_type: "function".to_string(),
159            function: FunctionCall {
160                name: name.to_string(),
161                arguments: "{}".to_string(),
162            },
163        }
164    }
165
166    struct CountingExecutor {
167        call_count: AtomicUsize,
168        max_concurrent: Arc<std::sync::Mutex<usize>>,
169        current_concurrent: Arc<AtomicUsize>,
170        delay: Duration,
171    }
172
173    impl CountingExecutor {
174        fn new(delay: Duration) -> Self {
175            Self {
176                call_count: AtomicUsize::new(0),
177                max_concurrent: Arc::new(std::sync::Mutex::new(0)),
178                current_concurrent: Arc::new(AtomicUsize::new(0)),
179                delay,
180            }
181        }
182    }
183
184    #[async_trait]
185    impl ToolExecutor for CountingExecutor {
186        async fn execute(&self, _call: &ToolCall) -> Result<ToolResult, ToolError> {
187            self.execute_with_context(_call, ToolExecutionContext::none("test"))
188                .await
189        }
190
191        async fn execute_with_context(
192            &self,
193            _call: &ToolCall,
194            _ctx: ToolExecutionContext<'_>,
195        ) -> Result<ToolResult, ToolError> {
196            self.call_count.fetch_add(1, Ordering::SeqCst);
197
198            // Track concurrency
199            let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
200            {
201                let mut max = self.max_concurrent.lock().unwrap();
202                if current > *max {
203                    *max = current;
204                }
205            }
206
207            if self.delay > Duration::ZERO {
208                tokio::time::sleep(self.delay).await;
209            }
210
211            self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
212
213            Ok(ToolResult {
214                success: true,
215                result: "ok".to_string(),
216                display_preference: None,
217                images: Vec::new(),
218            })
219        }
220
221        fn list_tools(&self) -> Vec<ToolSchema> {
222            vec![]
223        }
224    }
225
226    #[test]
227    fn test_supports_parallel() {
228        let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
229        assert!(ToolCallRuntime::supports_parallel(
230            &executor,
231            &make_call("Read")
232        ));
233        assert!(ToolCallRuntime::supports_parallel(
234            &executor,
235            &make_call("Grep")
236        ));
237        assert!(ToolCallRuntime::supports_parallel(
238            &executor,
239            &make_call("Glob")
240        ));
241        assert!(!ToolCallRuntime::supports_parallel(
242            &executor,
243            &make_call("Bash")
244        ));
245        assert!(!ToolCallRuntime::supports_parallel(
246            &executor,
247            &make_call("Write")
248        ));
249        assert!(!ToolCallRuntime::supports_parallel(
250            &executor,
251            &make_call("Edit")
252        ));
253    }
254
255    #[tokio::test]
256    async fn test_single_call_works() {
257        let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
258        let runtime = ToolCallRuntime::new(executor.clone());
259        let call = make_call("Read");
260        let ctx = ToolExecutionContext::none("test");
261
262        let result = runtime.execute(&call, ctx).await;
263        assert!(result.result.is_ok());
264        assert!(result.was_parallel);
265        assert_eq!(result.tool_name, "Read");
266        assert_eq!(executor.call_count.load(Ordering::SeqCst), 1);
267    }
268
269    #[tokio::test]
270    async fn test_mutating_call_is_sequential() {
271        let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
272        let runtime = ToolCallRuntime::new(executor.clone());
273        let call = make_call("Bash");
274        let ctx = ToolExecutionContext::none("test");
275
276        let result = runtime.execute(&call, ctx).await;
277        assert!(result.result.is_ok());
278        assert!(!result.was_parallel);
279    }
280
281    #[tokio::test]
282    async fn test_parallel_reads_are_concurrent() {
283        let executor = Arc::new(CountingExecutor::new(Duration::from_millis(50)));
284        let runtime = ToolCallRuntime::new(executor.clone());
285
286        // Execute 3 reads concurrently
287        let handles: Vec<_> = (0..3)
288            .map(|_| {
289                let rt = runtime.clone();
290                let call = make_call("Read");
291                tokio::spawn(
292                    async move { rt.execute(&call, ToolExecutionContext::none("test")).await },
293                )
294            })
295            .collect();
296
297        let results: Vec<_> = futures::future::join_all(handles)
298            .await
299            .into_iter()
300            .map(|r| r.unwrap())
301            .collect();
302
303        // All should succeed
304        assert!(results.iter().all(|r| r.result.is_ok()));
305        assert!(results.iter().all(|r| r.was_parallel));
306
307        // Max concurrency should be > 1 (parallel execution)
308        let max_conc = *executor.max_concurrent.lock().unwrap();
309        assert!(
310            max_conc >= 2,
311            "Expected parallel execution, got max_concurrent={}",
312            max_conc
313        );
314    }
315
316    #[tokio::test]
317    async fn test_batch_empty() {
318        let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
319        let runtime = ToolCallRuntime::new(executor);
320        let results = runtime.execute_batch(vec![]).await;
321        assert!(results.is_empty());
322    }
323
324    #[tokio::test]
325    async fn test_batch_mixed() {
326        let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
327        let runtime = ToolCallRuntime::new(executor);
328
329        let calls: Vec<_> = vec![
330            (make_call("Read"), ToolExecutionContext::none("test")),
331            (make_call("Grep"), ToolExecutionContext::none("test")),
332            (make_call("Bash"), ToolExecutionContext::none("test")), // barrier
333            (make_call("Glob"), ToolExecutionContext::none("test")),
334        ];
335
336        let results = runtime.execute_batch(calls).await;
337        assert_eq!(results.len(), 4);
338        assert!(results.iter().all(|r| r.result.is_ok()));
339    }
340}