rama_core/layer/limit/policy/
mod.rs

1//! Limit policies for [`super::Limit`]
2//! define how requests are handled when the limit is reached
3//! for a given request.
4//!
5//! [`Option`] can be used to disable a limit policy for some scenarios
6//! while enabling it for others.
7//!
8//! # Policy Maps
9//!
10//! A policy which applies a [`Policy`] based on a [`Matcher`].
11//! These can be made by using a Vec<([`Matcher`], [`Policy`])>.
12//! To avoid cloning you best use an Arc<...> around the most outer vec.
13//!
14//! The first matching policy is used.
15//! If no policy matches, the request is allowed to proceed as well.
16//! If you want to enforce a default policy, you can add a policy with a [`Matcher`] that always matches,
17//! such as the bool `true`.
18//!
19//! Note that the [`Matcher`]s will not receive the mutable [`Extensions`],
20//! as polices are not intended to keep track of what is matched on.
21//!
22//! It is this policy that you want to use in case you want to rate limit only
23//! external sockets or you want to rate limit specific domains/paths only for http requests.
24//! See the [`http_rate_limit.rs`] example for a use case.
25//!
26//! [`Matcher`]: crate::matcher::Matcher
27//! [`Extensions`]: crate::context::Extensions
28//! [`http_listener_hello.rs`]: https://github.com/plabayo/rama/blob/main/examples/http_rate_limit.rs
29
30use crate::Context;
31use crate::error::BoxError;
32use std::{convert::Infallible, fmt, sync::Arc};
33
34mod concurrent;
35#[doc(inline)]
36pub use concurrent::{ConcurrentCounter, ConcurrentPolicy, ConcurrentTracker, LimitReached};
37
38mod matcher;
39
40/// The full result of a limit policy.
41pub struct PolicyResult<State, Request, Guard, Error> {
42    /// The input context
43    pub ctx: Context<State>,
44    /// The input request
45    pub request: Request,
46    /// The output part of the limit policy.
47    pub output: PolicyOutput<Guard, Error>,
48}
49
50impl<State: fmt::Debug, Request: fmt::Debug, Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug
51    for PolicyResult<State, Request, Guard, Error>
52{
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("PolicyResult")
55            .field("ctx", &self.ctx)
56            .field("request", &self.request)
57            .field("output", &self.output)
58            .finish()
59    }
60}
61
62/// The output part of a limit policy.
63pub enum PolicyOutput<Guard, Error> {
64    /// The request is allowed to proceed,
65    /// and the guard is returned to release the limit when it is dropped,
66    /// which should be done after the request is completed.
67    Ready(Guard),
68    /// The request is not allowed to proceed, and should be aborted.
69    Abort(Error),
70    /// The request is not allowed to proceed, but should be retried.
71    Retry,
72}
73
74impl<Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug for PolicyOutput<Guard, Error> {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            Self::Ready(guard) => write!(f, "PolicyOutput::Ready({guard:?})"),
78            Self::Abort(error) => write!(f, "PolicyOutput::Abort({error:?})"),
79            Self::Retry => write!(f, "PolicyOutput::Retry"),
80        }
81    }
82}
83
84/// A limit [`Policy`] is used to determine whether a request is allowed to proceed,
85/// and if not, how to handle it.
86pub trait Policy<State, Request>: Send + Sync + 'static {
87    /// The guard type that is returned when the request is allowed to proceed.
88    ///
89    /// See [`PolicyOutput::Ready`].
90    type Guard: Send + 'static;
91    /// The error type that is returned when the request is not allowed to proceed,
92    /// and should be aborted.
93    ///
94    /// See [`PolicyOutput::Abort`].
95    type Error: Send + 'static;
96
97    /// Check whether the request is allowed to proceed.
98    ///
99    /// Optionally modify the request before it is passed to the inner service,
100    /// which can be used to add metadata to the request regarding how the request
101    /// was handled by this limit policy.
102    fn check(
103        &self,
104        ctx: Context<State>,
105        request: Request,
106    ) -> impl Future<Output = PolicyResult<State, Request, Self::Guard, Self::Error>> + Send + '_;
107}
108
109impl<State, Request, P> Policy<State, Request> for Option<P>
110where
111    P: Policy<State, Request>,
112    State: Clone + Send + Sync + 'static,
113    Request: Send + 'static,
114{
115    type Guard = Option<P::Guard>;
116    type Error = P::Error;
117
118    async fn check(
119        &self,
120        ctx: Context<State>,
121        request: Request,
122    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
123        match self {
124            Some(policy) => {
125                let result = policy.check(ctx, request).await;
126                match result.output {
127                    PolicyOutput::Ready(guard) => PolicyResult {
128                        ctx: result.ctx,
129                        request: result.request,
130                        output: PolicyOutput::Ready(Some(guard)),
131                    },
132                    PolicyOutput::Abort(err) => PolicyResult {
133                        ctx: result.ctx,
134                        request: result.request,
135                        output: PolicyOutput::Abort(err),
136                    },
137                    PolicyOutput::Retry => PolicyResult {
138                        ctx: result.ctx,
139                        request: result.request,
140                        output: PolicyOutput::Retry,
141                    },
142                }
143            }
144            None => PolicyResult {
145                ctx,
146                request,
147                output: PolicyOutput::Ready(None),
148            },
149        }
150    }
151}
152
153impl<State, Request, P> Policy<State, Request> for &'static P
154where
155    P: Policy<State, Request>,
156    State: Clone + Send + Sync + 'static,
157    Request: Send + 'static,
158{
159    type Guard = P::Guard;
160    type Error = P::Error;
161
162    async fn check(
163        &self,
164        ctx: Context<State>,
165        request: Request,
166    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
167        (**self).check(ctx, request).await
168    }
169}
170
171impl<State, Request, P> Policy<State, Request> for Arc<P>
172where
173    P: Policy<State, Request>,
174    State: Clone + Send + Sync + 'static,
175    Request: Send + 'static,
176{
177    type Guard = P::Guard;
178    type Error = P::Error;
179
180    async fn check(
181        &self,
182        ctx: Context<State>,
183        request: Request,
184    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
185        self.as_ref().check(ctx, request).await
186    }
187}
188
189impl<State, Request, P> Policy<State, Request> for Box<P>
190where
191    P: Policy<State, Request>,
192    State: Clone + Send + Sync + 'static,
193    Request: Send + 'static,
194{
195    type Guard = P::Guard;
196    type Error = P::Error;
197
198    async fn check(
199        &self,
200        ctx: Context<State>,
201        request: Request,
202    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
203        self.as_ref().check(ctx, request).await
204    }
205}
206
207#[derive(Debug, Clone, Default)]
208#[non_exhaustive]
209/// An unlimited policy that allows all requests to proceed.
210pub struct UnlimitedPolicy;
211
212impl UnlimitedPolicy {
213    /// Create a new [`UnlimitedPolicy`].
214    pub const fn new() -> Self {
215        UnlimitedPolicy
216    }
217}
218
219impl<State, Request> Policy<State, Request> for UnlimitedPolicy
220where
221    State: Clone + Send + Sync + 'static,
222    Request: Send + 'static,
223{
224    type Guard = ();
225    type Error = Infallible;
226
227    async fn check(
228        &self,
229        ctx: Context<State>,
230        request: Request,
231    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
232        PolicyResult {
233            ctx,
234            request,
235            output: PolicyOutput::Ready(()),
236        }
237    }
238}
239
240macro_rules! impl_limit_policy_either {
241    ($id:ident, $($param:ident),+ $(,)?) => {
242        impl<$($param),+, State, Request> Policy<State, Request> for crate::combinators::$id<$($param),+>
243        where
244            $(
245                $param: Policy<State, Request>,
246                $param::Error: Into<BoxError>,
247            )+
248            Request: Send + 'static,
249            State: Clone + Send + Sync + 'static,
250        {
251            type Guard = crate::combinators::$id<$($param::Guard),+>;
252            type Error = BoxError;
253
254            async fn check(
255                &self,
256                ctx: Context<State>,
257                req: Request,
258            ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
259                match self {
260                    $(
261                        crate::combinators::$id::$param(policy) => {
262                            let result = policy.check(ctx, req).await;
263                            match result.output {
264                                PolicyOutput::Ready(guard) => PolicyResult {
265                                    ctx: result.ctx,
266                                    request: result.request,
267                                    output: PolicyOutput::Ready(crate::combinators::$id::$param(guard)),
268                                },
269                                PolicyOutput::Abort(err) => PolicyResult {
270                                    ctx: result.ctx,
271                                    request: result.request,
272                                    output: PolicyOutput::Abort(err.into()),
273                                },
274                                PolicyOutput::Retry => PolicyResult {
275                                    ctx: result.ctx,
276                                    request: result.request,
277                                    output: PolicyOutput::Retry,
278                                },
279                            }
280                        }
281                    )+
282                }
283            }
284        }
285    };
286}
287
288crate::combinators::impl_either!(impl_limit_policy_either);