Skip to main content

burger/
concurrency_limit.rs

1//! The [`ServiceExt::concurrency_limit`](crate::ServiceExt::concurrency_limit) combinator returns
2//! [`ConcurrencyLimit`] which restricts the number of inflight [calls](Service::call) to a specified
3//! value.
4//!
5//! # Example
6//!
7//! ```rust
8//! use burger::*;
9//! # use tokio::{join, time::sleep};
10//! # use std::time::Duration;
11//!
12//! # #[tokio::main]
13//! # async fn main() {
14//! let svc = service_fn(|x| async move {
15//!     sleep(Duration::from_secs(1)).await;
16//!     2 * x
17//! })
18//! .concurrency_limit(1);
19//! let (a, b) = join! {
20//!     svc.oneshot(6),
21//!     svc.oneshot(2)
22//! };
23//! # }
24//! ```
25//!
26//! # Load
27//!
28//! The [`Load::load`] on [ConcurrencyLimit] defers to the inner service.
29
30use std::fmt;
31
32use tokio::sync::{Semaphore, SemaphorePermit};
33
34use crate::{load::Load, Middleware, Service};
35
36/// A wrapper for the [`ServiceExt::concurrency_limit`](crate::ServiceExt::concurrency_limit)
37/// combinator.
38///
39/// See the [module](crate::concurrency_limit) for more information.
40#[derive(Debug)]
41pub struct ConcurrencyLimit<S> {
42    inner: S,
43    semaphore: Semaphore,
44}
45
46impl<S> ConcurrencyLimit<S> {
47    pub(crate) fn new(inner: S, n_permits: usize) -> Self {
48        Self {
49            inner,
50            semaphore: Semaphore::new(n_permits),
51        }
52    }
53}
54
55/// The [`Service::Permit`] type for [`ConcurrencyLimit`].
56pub struct ConcurrencyLimitPermit<'a, S, Request>
57where
58    S: Service<Request> + 'a,
59{
60    inner: S::Permit<'a>,
61    _semaphore_permit: SemaphorePermit<'a>,
62}
63
64impl<'a, S, Request> fmt::Debug for ConcurrencyLimitPermit<'a, S, Request>
65where
66    S: Service<Request>,
67    S::Permit<'a>: fmt::Debug,
68{
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("ConcurrencyLimitPermit")
71            .field("inner", &self.inner)
72            .field("_semaphore_permit", &self._semaphore_permit)
73            .finish()
74    }
75}
76
77impl<Request, S> Service<Request> for ConcurrencyLimit<S>
78where
79    S: Service<Request>,
80{
81    type Response = S::Response;
82    type Permit<'a> = ConcurrencyLimitPermit<'a, S, Request>
83    where
84        S: 'a;
85
86    async fn acquire(&self) -> Self::Permit<'_> {
87        ConcurrencyLimitPermit {
88            _semaphore_permit: self.semaphore.acquire().await.expect("not closed"),
89            inner: self.inner.acquire().await,
90        }
91    }
92
93    async fn call(permit: Self::Permit<'_>, request: Request) -> Self::Response {
94        S::call(permit.inner, request).await
95    }
96}
97
98impl<S> Load for ConcurrencyLimit<S>
99where
100    S: Load,
101{
102    type Metric = S::Metric;
103
104    fn load(&self) -> Self::Metric {
105        self.inner.load()
106    }
107}
108
109impl<S, T> Middleware<S> for ConcurrencyLimit<T>
110where
111    T: Middleware<S>,
112{
113    type Service = ConcurrencyLimit<T::Service>;
114
115    fn apply(self, svc: S) -> Self::Service {
116        let Self { inner, semaphore } = self;
117        ConcurrencyLimit {
118            inner: inner.apply(svc),
119            semaphore,
120        }
121    }
122}