rama_core/layer/limit/policy/
mod.rs

1//! Limit policies for [`super::Limit`]
2//! define how requests are handled when the limit is reached
3//! for a given request.
4//!
5//! [`Option`] can be used to disable a limit policy for some scenarios
6//! while enabling it for others.
7//!
8//! # Policy Maps
9//!
10//! A policy which applies a [`Policy`] based on a [`Matcher`].
11//! These can be made by using a Vec<([`Matcher`], [`Policy`])>.
12//! To avoid cloning you best use an Arc<...> around the most outer vec.
13//!
14//! The first matching policy is used.
15//! If no policy matches, the request is allowed to proceed as well.
16//! If you want to enforce a default policy, you can add a policy with a [`Matcher`] that always matches,
17//! such as the bool `true`.
18//!
19//! Note that the [`Matcher`]s will not receive the mutable [`Extensions`],
20//! as polices are not intended to keep track of what is matched on.
21//!
22//! It is this policy that you want to use in case you want to rate limit only
23//! external sockets or you want to rate limit specific domains/paths only for http requests.
24//! See the [`http_rate_limit.rs`] example for a use case.
25//!
26//! [`Matcher`]: crate::matcher::Matcher
27//! [`Extensions`]: crate::context::Extensions
28//! [`http_listener_hello.rs`]: https://github.com/plabayo/rama/blob/main/examples/http_rate_limit.rs
29
30use crate::error::BoxError;
31use crate::Context;
32use std::{convert::Infallible, fmt, sync::Arc};
33
34mod concurrent;
35#[doc(inline)]
36pub use concurrent::{ConcurrentCounter, ConcurrentPolicy, ConcurrentTracker, LimitReached};
37
38mod matcher;
39
40/// The full result of a limit policy.
41pub struct PolicyResult<State, Request, Guard, Error> {
42    /// The input context
43    pub ctx: Context<State>,
44    /// The input request
45    pub request: Request,
46    /// The output part of the limit policy.
47    pub output: PolicyOutput<Guard, Error>,
48}
49
50impl<State: fmt::Debug, Request: fmt::Debug, Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug
51    for PolicyResult<State, Request, Guard, Error>
52{
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("PolicyResult")
55            .field("ctx", &self.ctx)
56            .field("request", &self.request)
57            .field("output", &self.output)
58            .finish()
59    }
60}
61
62/// The output part of a limit policy.
63pub enum PolicyOutput<Guard, Error> {
64    /// The request is allowed to proceed,
65    /// and the guard is returned to release the limit when it is dropped,
66    /// which should be done after the request is completed.
67    Ready(Guard),
68    /// The request is not allowed to proceed, and should be aborted.
69    Abort(Error),
70    /// The request is not allowed to proceed, but should be retried.
71    Retry,
72}
73
74impl<Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug for PolicyOutput<Guard, Error> {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            Self::Ready(guard) => write!(f, "PolicyOutput::Ready({guard:?})"),
78            Self::Abort(error) => write!(f, "PolicyOutput::Abort({error:?})"),
79            Self::Retry => write!(f, "PolicyOutput::Retry"),
80        }
81    }
82}
83
84/// A limit [`Policy`] is used to determine whether a request is allowed to proceed,
85/// and if not, how to handle it.
86pub trait Policy<State, Request>: Send + Sync + 'static {
87    /// The guard type that is returned when the request is allowed to proceed.
88    ///
89    /// See [`PolicyOutput::Ready`].
90    type Guard: Send + 'static;
91    /// The error type that is returned when the request is not allowed to proceed,
92    /// and should be aborted.
93    ///
94    /// See [`PolicyOutput::Abort`].
95    type Error: Send + Sync + 'static;
96
97    /// Check whether the request is allowed to proceed.
98    ///
99    /// Optionally modify the request before it is passed to the inner service,
100    /// which can be used to add metadata to the request regarding how the request
101    /// was handled by this limit policy.
102    fn check(
103        &self,
104        ctx: Context<State>,
105        request: Request,
106    ) -> impl std::future::Future<Output = PolicyResult<State, Request, Self::Guard, Self::Error>>
107           + Send
108           + '_;
109}
110
111impl<State, Request, P> Policy<State, Request> for Option<P>
112where
113    P: Policy<State, Request>,
114    State: Clone + Send + Sync + 'static,
115    Request: Send + 'static,
116{
117    type Guard = Option<P::Guard>;
118    type Error = P::Error;
119
120    async fn check(
121        &self,
122        ctx: Context<State>,
123        request: Request,
124    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
125        match self {
126            Some(policy) => {
127                let result = policy.check(ctx, request).await;
128                match result.output {
129                    PolicyOutput::Ready(guard) => PolicyResult {
130                        ctx: result.ctx,
131                        request: result.request,
132                        output: PolicyOutput::Ready(Some(guard)),
133                    },
134                    PolicyOutput::Abort(err) => PolicyResult {
135                        ctx: result.ctx,
136                        request: result.request,
137                        output: PolicyOutput::Abort(err),
138                    },
139                    PolicyOutput::Retry => PolicyResult {
140                        ctx: result.ctx,
141                        request: result.request,
142                        output: PolicyOutput::Retry,
143                    },
144                }
145            }
146            None => PolicyResult {
147                ctx,
148                request,
149                output: PolicyOutput::Ready(None),
150            },
151        }
152    }
153}
154
155impl<State, Request, P> Policy<State, Request> for &'static P
156where
157    P: Policy<State, Request>,
158    State: Clone + Send + Sync + 'static,
159    Request: Send + 'static,
160{
161    type Guard = P::Guard;
162    type Error = P::Error;
163
164    async fn check(
165        &self,
166        ctx: Context<State>,
167        request: Request,
168    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
169        (**self).check(ctx, request).await
170    }
171}
172
173impl<State, Request, P> Policy<State, Request> for Arc<P>
174where
175    P: Policy<State, Request>,
176    State: Clone + Send + Sync + 'static,
177    Request: Send + 'static,
178{
179    type Guard = P::Guard;
180    type Error = P::Error;
181
182    async fn check(
183        &self,
184        ctx: Context<State>,
185        request: Request,
186    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
187        self.as_ref().check(ctx, request).await
188    }
189}
190
191impl<State, Request, P> Policy<State, Request> for Box<P>
192where
193    P: Policy<State, Request>,
194    State: Clone + Send + Sync + 'static,
195    Request: Send + 'static,
196{
197    type Guard = P::Guard;
198    type Error = P::Error;
199
200    async fn check(
201        &self,
202        ctx: Context<State>,
203        request: Request,
204    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
205        self.as_ref().check(ctx, request).await
206    }
207}
208
209#[derive(Debug, Clone, Default)]
210#[non_exhaustive]
211/// An unlimited policy that allows all requests to proceed.
212pub struct UnlimitedPolicy;
213
214impl UnlimitedPolicy {
215    /// Create a new [`UnlimitedPolicy`].
216    pub const fn new() -> Self {
217        UnlimitedPolicy
218    }
219}
220
221impl<State, Request> Policy<State, Request> for UnlimitedPolicy
222where
223    State: Clone + Send + Sync + 'static,
224    Request: Send + 'static,
225{
226    type Guard = ();
227    type Error = Infallible;
228
229    async fn check(
230        &self,
231        ctx: Context<State>,
232        request: Request,
233    ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
234        PolicyResult {
235            ctx,
236            request,
237            output: PolicyOutput::Ready(()),
238        }
239    }
240}
241
242macro_rules! impl_limit_policy_either {
243    ($id:ident, $($param:ident),+ $(,)?) => {
244        impl<$($param),+, State, Request> Policy<State, Request> for crate::combinators::$id<$($param),+>
245        where
246            $(
247                $param: Policy<State, Request>,
248                $param::Error: Into<BoxError>,
249            )+
250            Request: Send + 'static,
251            State: Clone + Send + Sync + 'static,
252        {
253            type Guard = crate::combinators::$id<$($param::Guard),+>;
254            type Error = BoxError;
255
256            async fn check(
257                &self,
258                ctx: Context<State>,
259                req: Request,
260            ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
261                match self {
262                    $(
263                        crate::combinators::$id::$param(policy) => {
264                            let result = policy.check(ctx, req).await;
265                            match result.output {
266                                PolicyOutput::Ready(guard) => PolicyResult {
267                                    ctx: result.ctx,
268                                    request: result.request,
269                                    output: PolicyOutput::Ready(crate::combinators::$id::$param(guard)),
270                                },
271                                PolicyOutput::Abort(err) => PolicyResult {
272                                    ctx: result.ctx,
273                                    request: result.request,
274                                    output: PolicyOutput::Abort(err.into()),
275                                },
276                                PolicyOutput::Retry => PolicyResult {
277                                    ctx: result.ctx,
278                                    request: result.request,
279                                    output: PolicyOutput::Retry,
280                                },
281                            }
282                        }
283                    )+
284                }
285            }
286        }
287    };
288}
289
290crate::combinators::impl_either!(impl_limit_policy_either);