Skip to main content

bamboo_tools/
parallel.rs

1//! Parallel tool execution runtime.
2//!
3//! Inspired by Codex's `ToolCallRuntime`, this module provides a concurrency
4//! manager for tool calls using an RwLock strategy:
5//!
6//! - **Read-only tools** (Read, Grep, Glob, etc.) acquire a *read* lock and
7//!   can execute concurrently with other read-only tools.
8//! - **Mutating tools** (Write, Edit, Bash, etc.) acquire a *write* lock and
9//!   run exclusively — blocking other tools until they finish.
10//!
11//! This ensures that multiple reads can happen in parallel while mutations
12//! are safely serialized.
13
14use std::sync::Arc;
15use std::time::Instant;
16
17use tokio::sync::RwLock;
18
19use crate::orchestrator::ToolMutability;
20use bamboo_agent_core::{ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult};
21
22/// The parallel tool call runtime.
23///
24/// Wraps a `ToolExecutor` and adds concurrency control via RwLock.
25/// Clone is cheap — all state is behind Arc.
26#[derive(Clone)]
27pub struct ToolCallRuntime {
28    executor: Arc<dyn ToolExecutor>,
29    /// RwLock for concurrency control:
30    /// - Read lock = parallel-safe (multiple readers)
31    /// - Write lock = exclusive (single writer)
32    parallel_lock: Arc<RwLock<()>>,
33}
34
35/// Result of a tool call execution with timing metadata.
36#[derive(Debug)]
37pub struct ToolCallResult {
38    pub call_id: String,
39    pub tool_name: String,
40    pub result: Result<ToolResult, ToolError>,
41    pub elapsed_ms: u64,
42    pub was_parallel: bool,
43}
44
45impl ToolCallRuntime {
46    /// Create a new runtime wrapping the given executor.
47    pub fn new(executor: Arc<dyn ToolExecutor>) -> Self {
48        Self {
49            executor,
50            parallel_lock: Arc::new(RwLock::new(())),
51        }
52    }
53
54    /// Determine if a tool supports parallel execution.
55    pub fn supports_parallel(executor: &Arc<dyn ToolExecutor>, call: &ToolCall) -> bool {
56        executor.call_mutability(call) == ToolMutability::ReadOnly
57            && executor.call_concurrency_safe(call)
58    }
59
60    /// Execute a single tool call with appropriate concurrency control.
61    pub async fn execute(&self, call: &ToolCall, ctx: ToolExecutionContext<'_>) -> ToolCallResult {
62        let tool_name = call.function.name.trim().to_string();
63        let parallel = Self::supports_parallel(&self.executor, call);
64        let started = Instant::now();
65
66        let result = if parallel {
67            // Read lock — allows concurrent execution with other readers
68            let _guard = self.parallel_lock.read().await;
69            self.executor.execute_with_context(call, ctx).await
70        } else {
71            // Write lock — exclusive execution
72            let _guard = self.parallel_lock.write().await;
73            self.executor.execute_with_context(call, ctx).await
74        };
75
76        ToolCallResult {
77            call_id: call.id.clone(),
78            tool_name,
79            result,
80            elapsed_ms: started.elapsed().as_millis() as u64,
81            was_parallel: parallel,
82        }
83    }
84
85    /// Execute multiple tool calls with automatic parallel/sequential scheduling.
86    ///
87    /// Tool calls are partitioned into batches:
88    /// - Consecutive read-only tools run concurrently
89    /// - Mutating tools run one at a time
90    /// - Order is preserved (mutating tools act as barriers)
91    pub async fn execute_batch(
92        &self,
93        calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
94    ) -> Vec<ToolCallResult> {
95        if calls.is_empty() {
96            return Vec::new();
97        }
98
99        // Split into groups: consecutive parallel-safe calls are batched
100        let mut results = Vec::with_capacity(calls.len());
101        let mut parallel_batch: Vec<(ToolCall, ToolExecutionContext<'_>)> = Vec::new();
102
103        for (call, ctx) in calls {
104            if Self::supports_parallel(&self.executor, &call) {
105                parallel_batch.push((call, ctx));
106            } else {
107                // Flush any pending parallel batch first
108                if !parallel_batch.is_empty() {
109                    let batch_results = self.execute_parallel_group(parallel_batch).await;
110                    results.extend(batch_results);
111                    parallel_batch = Vec::new();
112                }
113                // Execute the mutating tool sequentially
114                let result = self.execute(&call, ctx).await;
115                results.push(result);
116            }
117        }
118
119        // Flush remaining parallel batch
120        if !parallel_batch.is_empty() {
121            let batch_results = self.execute_parallel_group(parallel_batch).await;
122            results.extend(batch_results);
123        }
124
125        results
126    }
127
128    /// Execute a group of parallel-safe tool calls concurrently.
129    async fn execute_parallel_group(
130        &self,
131        calls: Vec<(ToolCall, ToolExecutionContext<'_>)>,
132    ) -> Vec<ToolCallResult> {
133        let futures: Vec<_> = calls
134            .into_iter()
135            .map(|(call, ctx)| {
136                let runtime = self.clone();
137                async move { runtime.execute(&call, ctx).await }
138            })
139            .collect();
140
141        futures::future::join_all(futures).await
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use async_trait::async_trait;
149    use bamboo_agent_core::{FunctionCall, ToolSchema};
150    use std::sync::atomic::{AtomicUsize, Ordering};
151    use std::time::Duration;
152
153    fn make_call(name: &str) -> ToolCall {
154        ToolCall {
155            id: format!("call_{}", name),
156            tool_type: "function".to_string(),
157            function: FunctionCall {
158                name: name.to_string(),
159                arguments: "{}".to_string(),
160            },
161        }
162    }
163
164    struct CountingExecutor {
165        call_count: AtomicUsize,
166        max_concurrent: Arc<std::sync::Mutex<usize>>,
167        current_concurrent: Arc<AtomicUsize>,
168        delay: Duration,
169    }
170
171    impl CountingExecutor {
172        fn new(delay: Duration) -> Self {
173            Self {
174                call_count: AtomicUsize::new(0),
175                max_concurrent: Arc::new(std::sync::Mutex::new(0)),
176                current_concurrent: Arc::new(AtomicUsize::new(0)),
177                delay,
178            }
179        }
180    }
181
182    #[async_trait]
183    impl ToolExecutor for CountingExecutor {
184        async fn execute(&self, _call: &ToolCall) -> Result<ToolResult, ToolError> {
185            self.execute_with_context(_call, ToolExecutionContext::none("test"))
186                .await
187        }
188
189        async fn execute_with_context(
190            &self,
191            _call: &ToolCall,
192            _ctx: ToolExecutionContext<'_>,
193        ) -> Result<ToolResult, ToolError> {
194            self.call_count.fetch_add(1, Ordering::SeqCst);
195
196            // Track concurrency
197            let current = self.current_concurrent.fetch_add(1, Ordering::SeqCst) + 1;
198            {
199                let mut max = self.max_concurrent.lock().unwrap();
200                if current > *max {
201                    *max = current;
202                }
203            }
204
205            if self.delay > Duration::ZERO {
206                tokio::time::sleep(self.delay).await;
207            }
208
209            self.current_concurrent.fetch_sub(1, Ordering::SeqCst);
210
211            Ok(ToolResult {
212                success: true,
213                result: "ok".to_string(),
214                display_preference: None,
215            })
216        }
217
218        fn list_tools(&self) -> Vec<ToolSchema> {
219            vec![]
220        }
221    }
222
223    #[test]
224    fn test_supports_parallel() {
225        let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
226        assert!(ToolCallRuntime::supports_parallel(
227            &executor,
228            &make_call("Read")
229        ));
230        assert!(ToolCallRuntime::supports_parallel(
231            &executor,
232            &make_call("Grep")
233        ));
234        assert!(ToolCallRuntime::supports_parallel(
235            &executor,
236            &make_call("Glob")
237        ));
238        assert!(!ToolCallRuntime::supports_parallel(
239            &executor,
240            &make_call("Bash")
241        ));
242        assert!(!ToolCallRuntime::supports_parallel(
243            &executor,
244            &make_call("Write")
245        ));
246        assert!(!ToolCallRuntime::supports_parallel(
247            &executor,
248            &make_call("Edit")
249        ));
250    }
251
252    #[tokio::test]
253    async fn test_single_call_works() {
254        let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
255        let runtime = ToolCallRuntime::new(executor.clone());
256        let call = make_call("Read");
257        let ctx = ToolExecutionContext::none("test");
258
259        let result = runtime.execute(&call, ctx).await;
260        assert!(result.result.is_ok());
261        assert!(result.was_parallel);
262        assert_eq!(result.tool_name, "Read");
263        assert_eq!(executor.call_count.load(Ordering::SeqCst), 1);
264    }
265
266    #[tokio::test]
267    async fn test_mutating_call_is_sequential() {
268        let executor = Arc::new(CountingExecutor::new(Duration::ZERO));
269        let runtime = ToolCallRuntime::new(executor.clone());
270        let call = make_call("Bash");
271        let ctx = ToolExecutionContext::none("test");
272
273        let result = runtime.execute(&call, ctx).await;
274        assert!(result.result.is_ok());
275        assert!(!result.was_parallel);
276    }
277
278    #[tokio::test]
279    async fn test_parallel_reads_are_concurrent() {
280        let executor = Arc::new(CountingExecutor::new(Duration::from_millis(50)));
281        let runtime = ToolCallRuntime::new(executor.clone());
282
283        // Execute 3 reads concurrently
284        let handles: Vec<_> = (0..3)
285            .map(|_| {
286                let rt = runtime.clone();
287                let call = make_call("Read");
288                tokio::spawn(
289                    async move { rt.execute(&call, ToolExecutionContext::none("test")).await },
290                )
291            })
292            .collect();
293
294        let results: Vec<_> = futures::future::join_all(handles)
295            .await
296            .into_iter()
297            .map(|r| r.unwrap())
298            .collect();
299
300        // All should succeed
301        assert!(results.iter().all(|r| r.result.is_ok()));
302        assert!(results.iter().all(|r| r.was_parallel));
303
304        // Max concurrency should be > 1 (parallel execution)
305        let max_conc = *executor.max_concurrent.lock().unwrap();
306        assert!(
307            max_conc >= 2,
308            "Expected parallel execution, got max_concurrent={}",
309            max_conc
310        );
311    }
312
313    #[tokio::test]
314    async fn test_batch_empty() {
315        let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
316        let runtime = ToolCallRuntime::new(executor);
317        let results = runtime.execute_batch(vec![]).await;
318        assert!(results.is_empty());
319    }
320
321    #[tokio::test]
322    async fn test_batch_mixed() {
323        let executor: Arc<dyn ToolExecutor> = Arc::new(CountingExecutor::new(Duration::ZERO));
324        let runtime = ToolCallRuntime::new(executor);
325
326        let calls: Vec<_> = vec![
327            (make_call("Read"), ToolExecutionContext::none("test")),
328            (make_call("Grep"), ToolExecutionContext::none("test")),
329            (make_call("Bash"), ToolExecutionContext::none("test")), // barrier
330            (make_call("Glob"), ToolExecutionContext::none("test")),
331        ];
332
333        let results = runtime.execute_batch(calls).await;
334        assert_eq!(results.len(), 4);
335        assert!(results.iter().all(|r| r.result.is_ok()));
336    }
337}