use std::{future::Future, sync::Arc};
use thiserror::Error;
use tokio::sync::Semaphore;
#[derive(Debug, Error, PartialEq, Eq)]
#[error("concurrency limit reached")]
pub struct ConcurrencyLimitError;
#[derive(Debug, Clone)]
pub struct ConcurrencyLimit {
semaphore: Arc<Semaphore>,
}
impl ConcurrencyLimit {
pub fn new(max: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max)),
}
}
pub async fn run<F, T>(&self, future: F) -> Result<T, ConcurrencyLimitError>
where
F: Future<Output = T>,
{
let permit = self
.semaphore
.clone()
.try_acquire_owned()
.map_err(|_| ConcurrencyLimitError)?;
let result = future.await;
drop(permit);
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::ConcurrencyLimit;
use std::time::Duration;
#[tokio::test]
async fn rejects_when_limit_is_exhausted() {
let limit = ConcurrencyLimit::new(1);
let held = limit.clone();
let handle = tokio::spawn(async move {
held.run(async {
tokio::time::sleep(Duration::from_millis(50)).await;
1
})
.await
});
tokio::time::sleep(Duration::from_millis(5)).await;
let second = limit.run(async { 2 }).await;
assert!(second.is_err());
assert_eq!(handle.await.expect("join").expect("first result"), 1);
}
}