adk_core/tool_concurrency.rs
1//! Semaphore-based concurrency control for parallel tool execution.
2//!
3//! This module provides [`ToolConcurrencyManager`] which enforces configurable
4//! concurrency limits on tool calls, supporting both global limits and per-tool
5//! overrides. When limits are reached, the [`BackpressurePolicy`] determines
6//! whether excess calls queue or fail immediately.
7//!
8//! # Architecture
9//!
10//! The manager holds an optional global semaphore and a map of per-tool semaphores.
11//! When a tool call is requested via [`ToolConcurrencyManager::acquire`], the manager
12//! checks for a per-tool semaphore first; if none exists, it falls back to the global
13//! semaphore. The returned [`ConcurrencyPermit`] is an RAII guard that releases the
14//! semaphore on drop.
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use adk_core::{
20//! BackpressurePolicy, ToolConcurrencyConfig, ToolConcurrencyManager,
21//! };
22//! use std::collections::HashMap;
23//!
24//! let config = ToolConcurrencyConfig {
25//! max_concurrency: Some(5),
26//! per_tool: HashMap::from([("web_scraper".to_string(), 2)]),
27//! backpressure: BackpressurePolicy::Queue,
28//! };
29//!
30//! let manager = ToolConcurrencyManager::new(&config);
31//!
32//! // Acquire a permit — blocks if limit reached (Queue policy)
33//! let permit = manager.acquire("web_scraper").await.unwrap();
34//! // ... execute tool ...
35//! drop(permit); // releases the semaphore
36//! ```
37
38use std::collections::HashMap;
39use std::sync::Arc;
40
41use tokio::sync::Semaphore;
42
43use crate::AdkError;
44use crate::context::{BackpressurePolicy, ToolConcurrencyConfig};
45
46/// RAII guard that releases semaphore permits on drop.
47///
48/// A `ConcurrencyPermit` holds at most one global permit and one per-tool permit.
49/// When the permit is dropped, the underlying semaphore slots are released,
50/// allowing queued tool calls to proceed.
51///
52/// # Example
53///
54/// ```rust,ignore
55/// use adk_core::{ToolConcurrencyConfig, ToolConcurrencyManager};
56///
57/// let config = ToolConcurrencyConfig {
58/// max_concurrency: Some(3),
59/// ..Default::default()
60/// };
61/// let manager = ToolConcurrencyManager::new(&config);
62///
63/// let permit = manager.acquire("my_tool").await.unwrap();
64/// // Tool executes while permit is held...
65/// drop(permit); // Semaphore slot released
66/// ```
67pub struct ConcurrencyPermit {
68 _global: Option<tokio::sync::OwnedSemaphorePermit>,
69 _per_tool: Option<tokio::sync::OwnedSemaphorePermit>,
70}
71
72/// Manages semaphores for tool concurrency enforcement.
73///
74/// Created from a [`ToolConcurrencyConfig`], the manager pre-allocates semaphores
75/// for the global limit and each per-tool override. Use [`acquire`](Self::acquire)
76/// to obtain a [`ConcurrencyPermit`] before executing a tool.
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use adk_core::{
82/// BackpressurePolicy, ToolConcurrencyConfig, ToolConcurrencyManager,
83/// };
84/// use std::collections::HashMap;
85///
86/// let config = ToolConcurrencyConfig {
87/// max_concurrency: Some(5),
88/// per_tool: HashMap::from([("expensive_tool".to_string(), 1)]),
89/// backpressure: BackpressurePolicy::Queue,
90/// };
91///
92/// let manager = ToolConcurrencyManager::new(&config);
93///
94/// // Only 1 "expensive_tool" can run at a time
95/// let permit = manager.acquire("expensive_tool").await.unwrap();
96/// // ... run tool ...
97/// drop(permit);
98///
99/// // Other tools use the global limit of 5
100/// let permit = manager.acquire("cheap_tool").await.unwrap();
101/// drop(permit);
102/// ```
103pub struct ToolConcurrencyManager {
104 global_semaphore: Option<Arc<Semaphore>>,
105 per_tool_semaphores: HashMap<String, Arc<Semaphore>>,
106 backpressure: BackpressurePolicy,
107}
108
109impl ToolConcurrencyManager {
110 /// Create a new manager from the given configuration.
111 ///
112 /// Allocates semaphores based on the config:
113 /// - A global semaphore with `max_concurrency` permits (if set)
114 /// - Per-tool semaphores for each entry in `per_tool`
115 ///
116 /// # Example
117 ///
118 /// ```rust
119 /// use adk_core::{ToolConcurrencyConfig, ToolConcurrencyManager};
120 ///
121 /// let config = ToolConcurrencyConfig {
122 /// max_concurrency: Some(10),
123 /// ..Default::default()
124 /// };
125 /// let manager = ToolConcurrencyManager::new(&config);
126 /// ```
127 pub fn new(config: &ToolConcurrencyConfig) -> Self {
128 let global_semaphore = config.max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
129
130 let per_tool_semaphores = config
131 .per_tool
132 .iter()
133 .map(|(name, &limit)| (name.clone(), Arc::new(Semaphore::new(limit))))
134 .collect();
135
136 Self { global_semaphore, per_tool_semaphores, backpressure: config.backpressure.clone() }
137 }
138
139 /// Returns `true` if this manager has any concurrency limits configured.
140 ///
141 /// When no limits are configured (no global limit and no per-tool overrides),
142 /// calling [`acquire`](Self::acquire) always succeeds immediately with no
143 /// semaphore enforcement.
144 pub fn has_limits(&self) -> bool {
145 self.global_semaphore.is_some() || !self.per_tool_semaphores.is_empty()
146 }
147
148 /// Acquire a permit for the named tool.
149 ///
150 /// If a per-tool override exists for `tool_name`, the per-tool semaphore is used.
151 /// Otherwise, the global semaphore is used (if configured). When neither a per-tool
152 /// override nor a global limit is configured, a permit is returned immediately with
153 /// no semaphore enforcement.
154 ///
155 /// # Errors
156 ///
157 /// Returns `AdkError` when [`BackpressurePolicy::Fail`] is configured and no
158 /// permit is immediately available.
159 ///
160 /// # Example
161 ///
162 /// ```rust,ignore
163 /// use adk_core::{
164 /// BackpressurePolicy, ToolConcurrencyConfig, ToolConcurrencyManager,
165 /// };
166 ///
167 /// let config = ToolConcurrencyConfig {
168 /// max_concurrency: Some(1),
169 /// backpressure: BackpressurePolicy::Fail,
170 /// ..Default::default()
171 /// };
172 /// let manager = ToolConcurrencyManager::new(&config);
173 ///
174 /// // First acquire succeeds
175 /// let permit1 = manager.acquire("tool_a").await.unwrap();
176 ///
177 /// // Second acquire fails immediately (Fail policy)
178 /// let result = manager.acquire("tool_b").await;
179 /// assert!(result.is_err());
180 ///
181 /// drop(permit1);
182 /// ```
183 pub async fn acquire(&self, tool_name: &str) -> Result<ConcurrencyPermit, AdkError> {
184 // Determine which semaphore to use: per-tool takes precedence
185 let has_per_tool = self.per_tool_semaphores.contains_key(tool_name);
186
187 let per_tool_permit = if has_per_tool {
188 let sem = self.per_tool_semaphores[tool_name].clone();
189 Some(self.acquire_permit(sem, tool_name).await?)
190 } else {
191 None
192 };
193
194 // If there's no per-tool override, use the global semaphore
195 let global_permit = if !has_per_tool {
196 match &self.global_semaphore {
197 Some(sem) => Some(self.acquire_permit(sem.clone(), tool_name).await?),
198 None => None,
199 }
200 } else {
201 None
202 };
203
204 Ok(ConcurrencyPermit { _global: global_permit, _per_tool: per_tool_permit })
205 }
206
207 /// Acquire a single permit from the given semaphore, respecting backpressure policy.
208 async fn acquire_permit(
209 &self,
210 semaphore: Arc<Semaphore>,
211 tool_name: &str,
212 ) -> Result<tokio::sync::OwnedSemaphorePermit, AdkError> {
213 match self.backpressure {
214 BackpressurePolicy::Queue => semaphore
215 .acquire_owned()
216 .await
217 .map_err(|_| AdkError::tool(format!("concurrency semaphore closed: {tool_name}"))),
218 BackpressurePolicy::Fail => semaphore
219 .try_acquire_owned()
220 .map_err(|_| AdkError::tool(format!("concurrency limit reached: {tool_name}"))),
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[tokio::test]
230 async fn test_unlimited_concurrency() {
231 let config = ToolConcurrencyConfig::default();
232 let manager = ToolConcurrencyManager::new(&config);
233
234 assert!(!manager.has_limits());
235 let permit = manager.acquire("any_tool").await;
236 assert!(permit.is_ok());
237 }
238
239 #[tokio::test]
240 async fn test_global_limit_queue_policy() {
241 let config = ToolConcurrencyConfig {
242 max_concurrency: Some(2),
243 backpressure: BackpressurePolicy::Queue,
244 ..Default::default()
245 };
246 let manager = ToolConcurrencyManager::new(&config);
247
248 assert!(manager.has_limits());
249 let _p1 = manager.acquire("tool_a").await.unwrap();
250 let _p2 = manager.acquire("tool_b").await.unwrap();
251 }
252
253 #[tokio::test]
254 async fn test_global_limit_fail_policy() {
255 let config = ToolConcurrencyConfig {
256 max_concurrency: Some(1),
257 backpressure: BackpressurePolicy::Fail,
258 ..Default::default()
259 };
260 let manager = ToolConcurrencyManager::new(&config);
261
262 let _p1 = manager.acquire("tool_a").await.unwrap();
263 let result = manager.acquire("tool_b").await;
264 assert!(result.is_err());
265 }
266
267 #[tokio::test]
268 async fn test_per_tool_override() {
269 let config = ToolConcurrencyConfig {
270 max_concurrency: Some(10),
271 per_tool: HashMap::from([("limited_tool".to_string(), 1)]),
272 backpressure: BackpressurePolicy::Fail,
273 };
274 let manager = ToolConcurrencyManager::new(&config);
275
276 // Per-tool limit of 1
277 let _p1 = manager.acquire("limited_tool").await.unwrap();
278 let result = manager.acquire("limited_tool").await;
279 assert!(result.is_err());
280
281 // Other tools use global limit of 10
282 let _p2 = manager.acquire("other_tool").await.unwrap();
283 assert!(_p2._global.is_some());
284 }
285
286 #[tokio::test]
287 async fn test_permit_release_on_drop() {
288 let config = ToolConcurrencyConfig {
289 max_concurrency: Some(1),
290 backpressure: BackpressurePolicy::Fail,
291 ..Default::default()
292 };
293 let manager = ToolConcurrencyManager::new(&config);
294
295 let permit = manager.acquire("tool").await.unwrap();
296 drop(permit);
297
298 // After drop, we can acquire again
299 let result = manager.acquire("tool").await;
300 assert!(result.is_ok());
301 }
302
303 #[tokio::test]
304 async fn test_per_tool_permit_release_on_drop() {
305 let config = ToolConcurrencyConfig {
306 per_tool: HashMap::from([("special".to_string(), 1)]),
307 backpressure: BackpressurePolicy::Fail,
308 ..Default::default()
309 };
310 let manager = ToolConcurrencyManager::new(&config);
311
312 let permit = manager.acquire("special").await.unwrap();
313 drop(permit);
314
315 // After drop, we can acquire again
316 let result = manager.acquire("special").await;
317 assert!(result.is_ok());
318 }
319}