rama_core/layer/limit/
mod.rs

1//! A middleware that limits the number of in-flight requests.
2//!
3//! See [`Limit`].
4
5use std::fmt;
6
7use crate::error::BoxError;
8use crate::{Context, Service};
9use into_response::{ErrorIntoResponse, ErrorIntoResponseFn};
10use rama_utils::macros::define_inner_service_accessors;
11
12pub mod policy;
13use policy::UnlimitedPolicy;
14pub use policy::{Policy, PolicyOutput};
15
16mod layer;
17#[doc(inline)]
18pub use layer::LimitLayer;
19
20mod into_response;
21
22/// Limit requests based on a [`Policy`].
23///
24/// [`Policy`]: crate::layer::limit::Policy
25pub struct Limit<S, P, F = ()> {
26    inner: S,
27    policy: P,
28    error_into_response: F,
29}
30
31impl<S, P> Limit<S, P, ()> {
32    /// Creates a new [`Limit`] from a limit [`Policy`],
33    /// wrapping the given [`Service`].
34    pub const fn new(inner: S, policy: P) -> Self {
35        Limit {
36            inner,
37            policy,
38            error_into_response: (),
39        }
40    }
41
42    /// Attach a function to this [`Limit`] to allow you to turn the Policy error
43    /// into a Result fully compatible with the inner `Service` Result.
44    pub fn with_error_into_response_fn<F>(self, f: F) -> Limit<S, P, ErrorIntoResponseFn<F>> {
45        Limit {
46            inner: self.inner,
47            policy: self.policy,
48            error_into_response: ErrorIntoResponseFn(f),
49        }
50    }
51
52    define_inner_service_accessors!();
53}
54
55impl<T> Limit<T, UnlimitedPolicy, ()> {
56    /// Creates a new [`Limit`] with an unlimited policy.
57    ///
58    /// Meaning that all requests are allowed to proceed.
59    pub const fn unlimited(inner: T) -> Self {
60        Limit {
61            inner,
62            policy: UnlimitedPolicy,
63            error_into_response: (),
64        }
65    }
66}
67
68impl<T: fmt::Debug, P: fmt::Debug, F: fmt::Debug> fmt::Debug for Limit<T, P, F> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("Limit")
71            .field("inner", &self.inner)
72            .field("policy", &self.policy)
73            .field("error_into_response", &self.error_into_response)
74            .finish()
75    }
76}
77
78impl<T, P, F> Clone for Limit<T, P, F>
79where
80    T: Clone,
81    P: Clone,
82    F: Clone,
83{
84    fn clone(&self) -> Self {
85        Limit {
86            inner: self.inner.clone(),
87            policy: self.policy.clone(),
88            error_into_response: self.error_into_response.clone(),
89        }
90    }
91}
92
93impl<T, P, State, Request> Service<State, Request> for Limit<T, P, ()>
94where
95    T: Service<State, Request, Error: Into<BoxError>>,
96    P: policy::Policy<State, Request, Error: Into<BoxError>>,
97    Request: Send + Sync + 'static,
98    State: Clone + Send + Sync + 'static,
99{
100    type Response = T::Response;
101    type Error = BoxError;
102
103    async fn serve(
104        &self,
105        mut ctx: Context<State>,
106        mut request: Request,
107    ) -> Result<Self::Response, Self::Error> {
108        loop {
109            let result = self.policy.check(ctx, request).await;
110            ctx = result.ctx;
111            request = result.request;
112
113            match result.output {
114                policy::PolicyOutput::Ready(guard) => {
115                    let _ = guard;
116                    return self.inner.serve(ctx, request).await.map_err(Into::into);
117                }
118                policy::PolicyOutput::Abort(err) => return Err(err.into()),
119                policy::PolicyOutput::Retry => (),
120            }
121        }
122    }
123}
124
125impl<T, P, F, State, Request, FnResponse, FnError> Service<State, Request>
126    for Limit<T, P, ErrorIntoResponseFn<F>>
127where
128    T: Service<State, Request>,
129    P: policy::Policy<State, Request>,
130    F: Fn(P::Error) -> Result<FnResponse, FnError> + Send + Sync + 'static,
131    FnResponse: Into<T::Response> + Send + 'static,
132    FnError: Into<T::Error> + Send + Sync + 'static,
133    Request: Send + Sync + 'static,
134    State: Clone + Send + Sync + 'static,
135{
136    type Response = T::Response;
137    type Error = T::Error;
138
139    async fn serve(
140        &self,
141        mut ctx: Context<State>,
142        mut request: Request,
143    ) -> Result<Self::Response, Self::Error> {
144        loop {
145            let result = self.policy.check(ctx, request).await;
146            ctx = result.ctx;
147            request = result.request;
148
149            match result.output {
150                policy::PolicyOutput::Ready(guard) => {
151                    let _ = guard;
152                    return self.inner.serve(ctx, request).await;
153                }
154                policy::PolicyOutput::Abort(err) => {
155                    return match self.error_into_response.error_into_response(err) {
156                        Ok(ok) => Ok(ok.into()),
157                        Err(err) => Err(err.into()),
158                    };
159                }
160                policy::PolicyOutput::Retry => (),
161            }
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::policy::ConcurrentPolicy;
169    use super::*;
170
171    use crate::{Context, Layer, Service, service::service_fn};
172    use std::convert::Infallible;
173
174    use futures_lite::future::zip;
175
176    #[tokio::test]
177    async fn test_limit() {
178        async fn handle_request<State, Request>(
179            _ctx: Context<State>,
180            req: Request,
181        ) -> Result<Request, Infallible> {
182            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
183            Ok(req)
184        }
185
186        let layer: LimitLayer<ConcurrentPolicy<_, _>> = LimitLayer::new(ConcurrentPolicy::max(1));
187
188        let service_1 = layer.layer(service_fn(handle_request));
189        let service_2 = layer.layer(service_fn(handle_request));
190
191        let future_1 = service_1.serve(Context::default(), "Hello");
192        let future_2 = service_2.serve(Context::default(), "Hello");
193
194        let (result_1, result_2) = zip(future_1, future_2).await;
195
196        // check that one request succeeded and the other failed
197        if result_1.is_err() {
198            assert_eq!(result_2.unwrap(), "Hello");
199        } else {
200            assert_eq!(result_1.unwrap(), "Hello");
201            assert!(result_2.is_err());
202        }
203    }
204
205    #[tokio::test]
206    async fn test_with_error_into_response_fn() {
207        async fn handle_request<State, Request>(
208            _ctx: Context<State>,
209            _req: Request,
210        ) -> Result<&'static str, Infallible> {
211            Ok("good")
212        }
213
214        let layer: LimitLayer<ConcurrentPolicy<_, _>, _> =
215            LimitLayer::new(ConcurrentPolicy::max(0))
216                .with_error_into_response_fn(|_| Ok::<_, Infallible>("bad"));
217
218        let service = layer.layer(service_fn(handle_request));
219
220        let resp = service.serve(Context::default(), "Hello").await.unwrap();
221        assert_eq!("bad", resp);
222    }
223
224    #[tokio::test]
225    async fn test_zero_limit() {
226        async fn handle_request<State, Request>(
227            _ctx: Context<State>,
228            req: Request,
229        ) -> Result<Request, Infallible> {
230            Ok(req)
231        }
232
233        let layer: LimitLayer<ConcurrentPolicy<_, _>> = LimitLayer::new(ConcurrentPolicy::max(0));
234
235        let service_1 = layer.layer(service_fn(handle_request));
236        let result_1 = service_1.serve(Context::default(), "Hello").await;
237        assert!(result_1.is_err());
238    }
239}