use std::sync::Arc;
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{debug, trace};
use super::config::ConcurrencyConfig;
use super::rate_limiter::RateLimiter;
#[derive(Clone)]
pub struct ConcurrencyController {
semaphore: Arc<Semaphore>,
rate_limiter: Option<Arc<RateLimiter>>,
config: ConcurrencyConfig,
}
impl ConcurrencyController {
pub fn new(config: ConcurrencyConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_concurrent_requests));
let rate_limiter = if config.enabled {
Some(Arc::new(RateLimiter::new(config.requests_per_minute)))
} else {
None
};
Self {
semaphore,
rate_limiter,
config,
}
}
pub fn with_defaults() -> Self {
Self::new(ConcurrencyConfig::default())
}
pub fn high_throughput() -> Self {
Self::new(ConcurrencyConfig::high_throughput())
}
pub fn conservative() -> Self {
Self::new(ConcurrencyConfig::conservative())
}
pub fn unlimited() -> Self {
Self::new(ConcurrencyConfig::unlimited())
}
pub async fn acquire(&self) -> Option<SemaphorePermit<'_>> {
if let Some(ref limiter) = self.rate_limiter {
trace!("Waiting for rate limiter");
limiter.acquire().await;
debug!("Rate limiter: token acquired");
}
if self.config.semaphore_enabled {
trace!("Waiting for semaphore permit");
let permit = self.semaphore.acquire().await.unwrap();
debug!(
"Semaphore: permit acquired (available: {})",
self.semaphore.available_permits()
);
Some(permit)
} else {
None
}
}
pub fn try_acquire(&self) -> Option<SemaphorePermit<'_>> {
if let Some(ref limiter) = self.rate_limiter {
if !limiter.try_acquire() {
return None;
}
}
if self.config.semaphore_enabled {
self.semaphore.try_acquire().ok()
} else {
None
}
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
pub fn config(&self) -> &ConcurrencyConfig {
&self.config
}
pub fn rate_limiter(&self) -> Option<&RateLimiter> {
self.rate_limiter.as_deref()
}
}
impl std::fmt::Debug for ConcurrencyController {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConcurrencyController")
.field(
"max_concurrent_requests",
&self.config.max_concurrent_requests,
)
.field("requests_per_minute", &self.config.requests_per_minute)
.field("rate_limiting_enabled", &self.config.enabled)
.field("semaphore_enabled", &self.config.semaphore_enabled)
.field("available_permits", &self.semaphore.available_permits())
.finish()
}
}
impl Default for ConcurrencyController {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_controller_acquire() {
let controller = ConcurrencyController::new(ConcurrencyConfig {
max_concurrent_requests: 2,
requests_per_minute: 100,
enabled: false, semaphore_enabled: true,
});
let permit1 = controller.acquire().await;
assert!(permit1.is_some());
assert_eq!(controller.available_permits(), 1);
let permit2 = controller.acquire().await;
assert!(permit2.is_some());
assert_eq!(controller.available_permits(), 0);
drop(permit1);
assert_eq!(controller.available_permits(), 1);
}
#[test]
fn test_controller_creation() {
let controller = ConcurrencyController::with_defaults();
assert!(controller.available_permits() > 0);
}
}