use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore};
pub struct ConcurrentPermit {
#[allow(dead_code)]
permit: OwnedSemaphorePermit,
}
#[derive(Clone)]
struct ClientSemaphore {
semaphore: Arc<Semaphore>,
max_concurrent: usize,
}
impl ClientSemaphore {
fn new(max_concurrent: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
max_concurrent,
}
}
fn try_acquire_owned(self) -> Option<OwnedSemaphorePermit> {
Arc::clone(&self.semaphore).try_acquire_owned().ok()
}
fn available(&self) -> usize {
self.semaphore.available_permits()
}
fn max(&self) -> usize {
self.max_concurrent
}
}
pub struct ConcurrentLimiter {
semaphores: Arc<RwLock<HashMap<String, ClientSemaphore>>>,
}
impl ConcurrentLimiter {
pub fn new() -> Self {
Self {
semaphores: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn try_acquire(
&self,
key: &str,
max_concurrent: usize,
) -> Option<ConcurrentPermit> {
let semaphore = {
let mut semaphores = self.semaphores.write().await;
semaphores
.entry(key.to_string())
.or_insert_with(|| ClientSemaphore::new(max_concurrent))
.clone()
};
semaphore.try_acquire_owned().map(|permit| ConcurrentPermit { permit })
}
pub async fn available_permits(&self, key: &str) -> Option<usize> {
let semaphores = self.semaphores.read().await;
semaphores.get(key).map(|s| s.available())
}
pub async fn max_concurrent(&self, key: &str) -> Option<usize> {
let semaphores = self.semaphores.read().await;
semaphores.get(key).map(|s| s.max())
}
pub async fn client_count(&self) -> usize {
self.semaphores.read().await.len()
}
pub async fn cleanup(&self) {
let mut semaphores = self.semaphores.write().await;
semaphores.retain(|_, sem| sem.available() != sem.max());
}
}
impl Default for ConcurrentLimiter {
fn default() -> Self {
Self::new()
}
}
impl Clone for ConcurrentLimiter {
fn clone(&self) -> Self {
Self {
semaphores: Arc::clone(&self.semaphores),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_concurrent_limiter_new() {
let limiter = ConcurrentLimiter::new();
assert_eq!(limiter.client_count().await, 0);
}
#[tokio::test]
async fn test_try_acquire_first_permit() {
let limiter = ConcurrentLimiter::new();
let permit = limiter.try_acquire("user1", 10).await;
assert!(permit.is_some());
assert_eq!(limiter.available_permits("user1").await, Some(9));
}
#[tokio::test]
async fn test_permit_released_on_drop() {
let limiter = ConcurrentLimiter::new();
{
let _permit = limiter.try_acquire("user1", 10).await;
assert_eq!(limiter.available_permits("user1").await, Some(9));
}
assert_eq!(limiter.available_permits("user1").await, Some(10));
}
#[tokio::test]
async fn test_concurrent_limit_enforced() {
let limiter = ConcurrentLimiter::new();
let mut permits = Vec::new();
for _ in 0..5 {
let permit = limiter.try_acquire("user1", 5).await;
assert!(permit.is_some(), "Should acquire permit");
permits.push(permit);
}
let permit = limiter.try_acquire("user1", 5).await;
assert!(permit.is_none(), "Should not acquire 6th permit");
permits.pop();
let permit = limiter.try_acquire("user1", 5).await;
assert!(permit.is_some(), "Should acquire after release");
}
#[tokio::test]
async fn test_different_clients_separate_limits() {
let limiter = ConcurrentLimiter::new();
let mut user1_permits = Vec::new();
for _ in 0..3 {
let permit = limiter.try_acquire("user1", 3).await;
user1_permits.push(permit);
}
assert!(limiter.try_acquire("user1", 3).await.is_none());
let user2_permit = limiter.try_acquire("user2", 3).await;
assert!(user2_permit.is_some());
}
#[tokio::test]
async fn test_available_permits() {
let limiter = ConcurrentLimiter::new();
assert_eq!(limiter.available_permits("user1").await, None);
let _permit1 = limiter.try_acquire("user1", 5).await;
assert_eq!(limiter.available_permits("user1").await, Some(4));
let _permit2 = limiter.try_acquire("user1", 5).await;
assert_eq!(limiter.available_permits("user1").await, Some(3));
}
#[tokio::test]
async fn test_max_concurrent() {
let limiter = ConcurrentLimiter::new();
let _permit = limiter.try_acquire("user1", 10).await;
assert_eq!(limiter.max_concurrent("user1").await, Some(10));
let _permit = limiter.try_acquire("user2", 50).await;
assert_eq!(limiter.max_concurrent("user2").await, Some(50));
}
#[tokio::test]
async fn test_concurrent_limiter_clone() {
let limiter1 = ConcurrentLimiter::new();
let limiter2 = limiter1.clone();
let _permit = limiter1.try_acquire("user1", 10).await;
assert_eq!(limiter2.available_permits("user1").await, Some(9));
}
#[tokio::test]
async fn test_concurrent_requests_simulation() {
let limiter = Arc::new(ConcurrentLimiter::new());
let max_concurrent = 3;
let mut handles = Vec::new();
for i in 0..10 {
let limiter_clone = Arc::clone(&limiter);
let handle = tokio::spawn(async move {
if let Some(_permit) = limiter_clone.try_acquire("user1", max_concurrent).await {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
Ok::<_, ()>(i)
} else {
Err(())
}
});
handles.push(handle);
}
let results = futures::future::join_all(handles).await;
let successes = results.iter().filter(|r| {
if let Ok(Ok(_)) = r {
true
} else {
false
}
}).count();
let failures = results.iter().filter(|r| {
if let Ok(Err(_)) = r {
true
} else {
false
}
}).count();
assert!(successes > 0, "Some requests should succeed");
assert!(failures > 0, "Some requests should fail due to concurrency limit");
assert_eq!(successes + failures, 10);
}
#[tokio::test]
async fn test_cleanup() {
let limiter = ConcurrentLimiter::new();
{
let _permit = limiter.try_acquire("user1", 5).await;
let _permit = limiter.try_acquire("user2", 5).await;
}
assert_eq!(limiter.client_count().await, 2);
limiter.cleanup().await;
assert_eq!(limiter.client_count().await, 0);
}
}