use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::AdkError;
use crate::context::{BackpressurePolicy, ToolConcurrencyConfig};
pub struct ConcurrencyPermit {
_global: Option<tokio::sync::OwnedSemaphorePermit>,
_per_tool: Option<tokio::sync::OwnedSemaphorePermit>,
}
pub struct ToolConcurrencyManager {
global_semaphore: Option<Arc<Semaphore>>,
per_tool_semaphores: HashMap<String, Arc<Semaphore>>,
backpressure: BackpressurePolicy,
}
impl ToolConcurrencyManager {
pub fn new(config: &ToolConcurrencyConfig) -> Self {
let global_semaphore = config.max_concurrency.map(|n| Arc::new(Semaphore::new(n)));
let per_tool_semaphores = config
.per_tool
.iter()
.map(|(name, &limit)| (name.clone(), Arc::new(Semaphore::new(limit))))
.collect();
Self { global_semaphore, per_tool_semaphores, backpressure: config.backpressure.clone() }
}
pub fn has_limits(&self) -> bool {
self.global_semaphore.is_some() || !self.per_tool_semaphores.is_empty()
}
pub async fn acquire(&self, tool_name: &str) -> Result<ConcurrencyPermit, AdkError> {
let has_per_tool = self.per_tool_semaphores.contains_key(tool_name);
let per_tool_permit = if has_per_tool {
let sem = self.per_tool_semaphores[tool_name].clone();
Some(self.acquire_permit(sem, tool_name).await?)
} else {
None
};
let global_permit = if !has_per_tool {
match &self.global_semaphore {
Some(sem) => Some(self.acquire_permit(sem.clone(), tool_name).await?),
None => None,
}
} else {
None
};
Ok(ConcurrencyPermit { _global: global_permit, _per_tool: per_tool_permit })
}
async fn acquire_permit(
&self,
semaphore: Arc<Semaphore>,
tool_name: &str,
) -> Result<tokio::sync::OwnedSemaphorePermit, AdkError> {
match self.backpressure {
BackpressurePolicy::Queue => semaphore
.acquire_owned()
.await
.map_err(|_| AdkError::tool(format!("concurrency semaphore closed: {tool_name}"))),
BackpressurePolicy::Fail => semaphore
.try_acquire_owned()
.map_err(|_| AdkError::tool(format!("concurrency limit reached: {tool_name}"))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_unlimited_concurrency() {
let config = ToolConcurrencyConfig::default();
let manager = ToolConcurrencyManager::new(&config);
assert!(!manager.has_limits());
let permit = manager.acquire("any_tool").await;
assert!(permit.is_ok());
}
#[tokio::test]
async fn test_global_limit_queue_policy() {
let config = ToolConcurrencyConfig {
max_concurrency: Some(2),
backpressure: BackpressurePolicy::Queue,
..Default::default()
};
let manager = ToolConcurrencyManager::new(&config);
assert!(manager.has_limits());
let _p1 = manager.acquire("tool_a").await.unwrap();
let _p2 = manager.acquire("tool_b").await.unwrap();
}
#[tokio::test]
async fn test_global_limit_fail_policy() {
let config = ToolConcurrencyConfig {
max_concurrency: Some(1),
backpressure: BackpressurePolicy::Fail,
..Default::default()
};
let manager = ToolConcurrencyManager::new(&config);
let _p1 = manager.acquire("tool_a").await.unwrap();
let result = manager.acquire("tool_b").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_per_tool_override() {
let config = ToolConcurrencyConfig {
max_concurrency: Some(10),
per_tool: HashMap::from([("limited_tool".to_string(), 1)]),
backpressure: BackpressurePolicy::Fail,
};
let manager = ToolConcurrencyManager::new(&config);
let _p1 = manager.acquire("limited_tool").await.unwrap();
let result = manager.acquire("limited_tool").await;
assert!(result.is_err());
let _p2 = manager.acquire("other_tool").await.unwrap();
assert!(_p2._global.is_some());
}
#[tokio::test]
async fn test_permit_release_on_drop() {
let config = ToolConcurrencyConfig {
max_concurrency: Some(1),
backpressure: BackpressurePolicy::Fail,
..Default::default()
};
let manager = ToolConcurrencyManager::new(&config);
let permit = manager.acquire("tool").await.unwrap();
drop(permit);
let result = manager.acquire("tool").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_per_tool_permit_release_on_drop() {
let config = ToolConcurrencyConfig {
per_tool: HashMap::from([("special".to_string(), 1)]),
backpressure: BackpressurePolicy::Fail,
..Default::default()
};
let manager = ToolConcurrencyManager::new(&config);
let permit = manager.acquire("special").await.unwrap();
drop(permit);
let result = manager.acquire("special").await;
assert!(result.is_ok());
}
}