rs_zero/resil/
concurrency.rs1use std::{future::Future, sync::Arc};
2
3use thiserror::Error;
4use tokio::sync::Semaphore;
5
6#[derive(Debug, Error, PartialEq, Eq)]
8#[error("concurrency limit reached")]
9pub struct ConcurrencyLimitError;
10
11#[derive(Debug, Clone)]
13pub struct ConcurrencyLimit {
14 semaphore: Arc<Semaphore>,
15}
16
17impl ConcurrencyLimit {
18 pub fn new(max: usize) -> Self {
20 Self {
21 semaphore: Arc::new(Semaphore::new(max)),
22 }
23 }
24
25 pub async fn run<F, T>(&self, future: F) -> Result<T, ConcurrencyLimitError>
27 where
28 F: Future<Output = T>,
29 {
30 let permit = self
31 .semaphore
32 .clone()
33 .try_acquire_owned()
34 .map_err(|_| ConcurrencyLimitError)?;
35
36 let result = future.await;
37 drop(permit);
38 Ok(result)
39 }
40}
41
42#[cfg(test)]
43mod tests {
44 use super::ConcurrencyLimit;
45 use std::time::Duration;
46
47 #[tokio::test]
48 async fn rejects_when_limit_is_exhausted() {
49 let limit = ConcurrencyLimit::new(1);
50 let held = limit.clone();
51
52 let handle = tokio::spawn(async move {
53 held.run(async {
54 tokio::time::sleep(Duration::from_millis(50)).await;
55 1
56 })
57 .await
58 });
59
60 tokio::time::sleep(Duration::from_millis(5)).await;
61 let second = limit.run(async { 2 }).await;
62
63 assert!(second.is_err());
64 assert_eq!(handle.await.expect("join").expect("first result"), 1);
65 }
66}