Skip to main content

adk_core/
tool_concurrency.rs

1//! Semaphore-based concurrency control for parallel tool execution.
2//!
3//! This module provides [`ToolConcurrencyManager`] which enforces configurable
4//! concurrency limits on tool calls, supporting both global limits and per-tool
5//! overrides. When limits are reached, the [`BackpressurePolicy`] determines
6//! whether excess calls queue or fail immediately.
7//!
8//! # Architecture
9//!
10//! The manager holds an optional global semaphore and a map of per-tool semaphores.
11//! When a tool call is requested via [`ToolConcurrencyManager::acquire`], the manager
12//! checks for a per-tool semaphore first; if none exists, it falls back to the global
13//! semaphore. The returned [`ConcurrencyPermit`] is an RAII guard that releases the
14//! semaphore on drop.
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use adk_core::{
20//!     BackpressurePolicy, ToolConcurrencyConfig, ToolConcurrencyManager,
21//! };
22//! use std::collections::HashMap;
23//!
24//! let config = ToolConcurrencyConfig {
25//!     max_concurrency: Some(5),
26//!     per_tool: HashMap::from([("web_scraper".to_string(), 2)]),
27//!     backpressure: BackpressurePolicy::Queue,
28//! };
29//!
30//! let manager = ToolConcurrencyManager::new(&config);
31//!
32//! // Acquire a permit — blocks if limit reached (Queue policy)
33//! let permit = manager.acquire("web_scraper").await.unwrap();
34//! // ... execute tool ...
35//! drop(permit); // releases the semaphore
36//! ```
37
38use std::collections::HashMap;
39use std::sync::Arc;
40
41use tokio::sync::Semaphore;
42
43use crate::AdkError;
44use crate::context::{BackpressurePolicy, ToolConcurrencyConfig};
45
46/// RAII guard that releases semaphore permits on drop.
47///
48/// A `ConcurrencyPermit` holds at most one global permit and one per-tool permit.
49/// When the permit is dropped, the underlying semaphore slots are released,
50/// allowing queued tool calls to proceed.
51///
52/// # Example
53///
54/// ```rust,ignore
55/// use adk_core::{ToolConcurrencyConfig, ToolConcurrencyManager};
56///
57/// let config = ToolConcurrencyConfig {
58///     max_concurrency: Some(3),
59///     ..Default::default()
60/// };
61/// let manager = ToolConcurrencyManager::new(&config);
62///
63/// let permit = manager.acquire("my_tool").await.unwrap();
64/// // Tool executes while permit is held...
65/// drop(permit); // Semaphore slot released
66/// ```
67pub struct ConcurrencyPermit {
68    _global: Option<tokio::sync::OwnedSemaphorePermit>,
69    _per_tool: Option<tokio::sync::OwnedSemaphorePermit>,
70}
71
72/// Manages semaphores for tool concurrency enforcement.
73///
74/// Created from a [`ToolConcurrencyConfig`], the manager pre-allocates semaphores
75/// for the global limit and each per-tool override. Use [`acquire`](Self::acquire)
76/// to obtain a [`ConcurrencyPermit`] before executing a tool.
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use adk_core::{
82///     BackpressurePolicy, ToolConcurrencyConfig, ToolConcurrencyManager,
83/// };
84/// use std::collections::HashMap;
85///
86/// let config = ToolConcurrencyConfig {
87///     max_concurrency: Some(5),
88///     per_tool: HashMap::from([("expensive_tool".to_string(), 1)]),
89///     backpressure: BackpressurePolicy::Queue,
90/// };
91///
92/// let manager = ToolConcurrencyManager::new(&config);
93///
94/// // Only 1 "expensive_tool" can run at a time
95/// let permit = manager.acquire("expensive_tool").await.unwrap();
96/// // ... run tool ...
97/// drop(permit);
98///
99/// // Other tools use the global limit of 5
100/// let permit = manager.acquire("cheap_tool").await.unwrap();
101/// drop(permit);
102/// ```
103pub struct ToolConcurrencyManager {
104    global_semaphore: Option<Arc<Semaphore>>,
105    per_tool_semaphores: HashMap<String, Arc<Semaphore>>,
106    backpressure: BackpressurePolicy,
107}
108
109impl ToolConcurrencyManager {
110    /// Create a new manager from the given configuration.
111    ///
112    /// Allocates semaphores based on the config:
113    /// - A global semaphore with `max_concurrency` permits (if set)
114    /// - Per-tool semaphores for each entry in `per_tool`
115    ///
116    /// # Example
117    ///
118    /// ```rust
119    /// use adk_core::{ToolConcurrencyConfig, ToolConcurrencyManager};
120    ///
121    /// let config = ToolConcurrencyConfig {
122    ///     max_concurrency: Some(10),
123    ///     ..Default::default()
124    /// };
125    /// let manager = ToolConcurrencyManager::new(&config);
126    /// ```
127    pub fn new(config: &ToolConcurrencyConfig) -> Self {
128        let global_semaphore = config.max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
129
130        let per_tool_semaphores = config
131            .per_tool
132            .iter()
133            .map(|(name, &limit)| (name.clone(), Arc::new(Semaphore::new(limit))))
134            .collect();
135
136        Self { global_semaphore, per_tool_semaphores, backpressure: config.backpressure.clone() }
137    }
138
139    /// Returns `true` if this manager has any concurrency limits configured.
140    ///
141    /// When no limits are configured (no global limit and no per-tool overrides),
142    /// calling [`acquire`](Self::acquire) always succeeds immediately with no
143    /// semaphore enforcement.
144    pub fn has_limits(&self) -> bool {
145        self.global_semaphore.is_some() || !self.per_tool_semaphores.is_empty()
146    }
147
148    /// Acquire a permit for the named tool.
149    ///
150    /// If a per-tool override exists for `tool_name`, the per-tool semaphore is used.
151    /// Otherwise, the global semaphore is used (if configured). When neither a per-tool
152    /// override nor a global limit is configured, a permit is returned immediately with
153    /// no semaphore enforcement.
154    ///
155    /// # Errors
156    ///
157    /// Returns `AdkError` when [`BackpressurePolicy::Fail`] is configured and no
158    /// permit is immediately available.
159    ///
160    /// # Example
161    ///
162    /// ```rust,ignore
163    /// use adk_core::{
164    ///     BackpressurePolicy, ToolConcurrencyConfig, ToolConcurrencyManager,
165    /// };
166    ///
167    /// let config = ToolConcurrencyConfig {
168    ///     max_concurrency: Some(1),
169    ///     backpressure: BackpressurePolicy::Fail,
170    ///     ..Default::default()
171    /// };
172    /// let manager = ToolConcurrencyManager::new(&config);
173    ///
174    /// // First acquire succeeds
175    /// let permit1 = manager.acquire("tool_a").await.unwrap();
176    ///
177    /// // Second acquire fails immediately (Fail policy)
178    /// let result = manager.acquire("tool_b").await;
179    /// assert!(result.is_err());
180    ///
181    /// drop(permit1);
182    /// ```
183    pub async fn acquire(&self, tool_name: &str) -> Result<ConcurrencyPermit, AdkError> {
184        // Determine which semaphore to use: per-tool takes precedence
185        let has_per_tool = self.per_tool_semaphores.contains_key(tool_name);
186
187        let per_tool_permit = if has_per_tool {
188            let sem = self.per_tool_semaphores[tool_name].clone();
189            Some(self.acquire_permit(sem, tool_name).await?)
190        } else {
191            None
192        };
193
194        // If there's no per-tool override, use the global semaphore
195        let global_permit = if !has_per_tool {
196            match &self.global_semaphore {
197                Some(sem) => Some(self.acquire_permit(sem.clone(), tool_name).await?),
198                None => None,
199            }
200        } else {
201            None
202        };
203
204        Ok(ConcurrencyPermit { _global: global_permit, _per_tool: per_tool_permit })
205    }
206
207    /// Acquire a single permit from the given semaphore, respecting backpressure policy.
208    async fn acquire_permit(
209        &self,
210        semaphore: Arc<Semaphore>,
211        tool_name: &str,
212    ) -> Result<tokio::sync::OwnedSemaphorePermit, AdkError> {
213        match self.backpressure {
214            BackpressurePolicy::Queue => semaphore
215                .acquire_owned()
216                .await
217                .map_err(|_| AdkError::tool(format!("concurrency semaphore closed: {tool_name}"))),
218            BackpressurePolicy::Fail => semaphore
219                .try_acquire_owned()
220                .map_err(|_| AdkError::tool(format!("concurrency limit reached: {tool_name}"))),
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[tokio::test]
230    async fn test_unlimited_concurrency() {
231        let config = ToolConcurrencyConfig::default();
232        let manager = ToolConcurrencyManager::new(&config);
233
234        assert!(!manager.has_limits());
235        let permit = manager.acquire("any_tool").await;
236        assert!(permit.is_ok());
237    }
238
239    #[tokio::test]
240    async fn test_global_limit_queue_policy() {
241        let config = ToolConcurrencyConfig {
242            max_concurrency: Some(2),
243            backpressure: BackpressurePolicy::Queue,
244            ..Default::default()
245        };
246        let manager = ToolConcurrencyManager::new(&config);
247
248        assert!(manager.has_limits());
249        let _p1 = manager.acquire("tool_a").await.unwrap();
250        let _p2 = manager.acquire("tool_b").await.unwrap();
251    }
252
253    #[tokio::test]
254    async fn test_global_limit_fail_policy() {
255        let config = ToolConcurrencyConfig {
256            max_concurrency: Some(1),
257            backpressure: BackpressurePolicy::Fail,
258            ..Default::default()
259        };
260        let manager = ToolConcurrencyManager::new(&config);
261
262        let _p1 = manager.acquire("tool_a").await.unwrap();
263        let result = manager.acquire("tool_b").await;
264        assert!(result.is_err());
265    }
266
267    #[tokio::test]
268    async fn test_per_tool_override() {
269        let config = ToolConcurrencyConfig {
270            max_concurrency: Some(10),
271            per_tool: HashMap::from([("limited_tool".to_string(), 1)]),
272            backpressure: BackpressurePolicy::Fail,
273        };
274        let manager = ToolConcurrencyManager::new(&config);
275
276        // Per-tool limit of 1
277        let _p1 = manager.acquire("limited_tool").await.unwrap();
278        let result = manager.acquire("limited_tool").await;
279        assert!(result.is_err());
280
281        // Other tools use global limit of 10
282        let _p2 = manager.acquire("other_tool").await.unwrap();
283        assert!(_p2._global.is_some());
284    }
285
286    #[tokio::test]
287    async fn test_permit_release_on_drop() {
288        let config = ToolConcurrencyConfig {
289            max_concurrency: Some(1),
290            backpressure: BackpressurePolicy::Fail,
291            ..Default::default()
292        };
293        let manager = ToolConcurrencyManager::new(&config);
294
295        let permit = manager.acquire("tool").await.unwrap();
296        drop(permit);
297
298        // After drop, we can acquire again
299        let result = manager.acquire("tool").await;
300        assert!(result.is_ok());
301    }
302
303    #[tokio::test]
304    async fn test_per_tool_permit_release_on_drop() {
305        let config = ToolConcurrencyConfig {
306            per_tool: HashMap::from([("special".to_string(), 1)]),
307            backpressure: BackpressurePolicy::Fail,
308            ..Default::default()
309        };
310        let manager = ToolConcurrencyManager::new(&config);
311
312        let permit = manager.acquire("special").await.unwrap();
313        drop(permit);
314
315        // After drop, we can acquire again
316        let result = manager.acquire("special").await;
317        assert!(result.is_ok());
318    }
319}