Skip to main content

rs_zero/resil/
concurrency.rs

1use std::{future::Future, sync::Arc};
2
3use thiserror::Error;
4use tokio::sync::Semaphore;
5
6/// Error returned when concurrency capacity is exhausted.
7#[derive(Debug, Error, PartialEq, Eq)]
8#[error("concurrency limit reached")]
9pub struct ConcurrencyLimitError;
10
11/// Async concurrency limiter backed by a semaphore.
12#[derive(Debug, Clone)]
13pub struct ConcurrencyLimit {
14    semaphore: Arc<Semaphore>,
15}
16
17impl ConcurrencyLimit {
18    /// Creates a concurrency limiter with `max` in-flight operations.
19    pub fn new(max: usize) -> Self {
20        Self {
21            semaphore: Arc::new(Semaphore::new(max)),
22        }
23    }
24
25    /// Runs the future after acquiring capacity.
26    pub async fn run<F, T>(&self, future: F) -> Result<T, ConcurrencyLimitError>
27    where
28        F: Future<Output = T>,
29    {
30        let permit = self
31            .semaphore
32            .clone()
33            .try_acquire_owned()
34            .map_err(|_| ConcurrencyLimitError)?;
35
36        let result = future.await;
37        drop(permit);
38        Ok(result)
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use super::ConcurrencyLimit;
45    use std::time::Duration;
46
47    #[tokio::test]
48    async fn rejects_when_limit_is_exhausted() {
49        let limit = ConcurrencyLimit::new(1);
50        let held = limit.clone();
51
52        let handle = tokio::spawn(async move {
53            held.run(async {
54                tokio::time::sleep(Duration::from_millis(50)).await;
55                1
56            })
57            .await
58        });
59
60        tokio::time::sleep(Duration::from_millis(5)).await;
61        let second = limit.run(async { 2 }).await;
62
63        assert!(second.is_err());
64        assert_eq!(handle.await.expect("join").expect("first result"), 1);
65    }
66}