rama_core/layer/limit/policy/
mod.rs1use crate::Context;
31use crate::error::BoxError;
32use std::{convert::Infallible, fmt, sync::Arc};
33
34mod concurrent;
35#[doc(inline)]
36pub use concurrent::{ConcurrentCounter, ConcurrentPolicy, ConcurrentTracker, LimitReached};
37
38mod matcher;
39
40pub struct PolicyResult<State, Request, Guard, Error> {
42 pub ctx: Context<State>,
44 pub request: Request,
46 pub output: PolicyOutput<Guard, Error>,
48}
49
50impl<State: fmt::Debug, Request: fmt::Debug, Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug
51 for PolicyResult<State, Request, Guard, Error>
52{
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("PolicyResult")
55 .field("ctx", &self.ctx)
56 .field("request", &self.request)
57 .field("output", &self.output)
58 .finish()
59 }
60}
61
62pub enum PolicyOutput<Guard, Error> {
64 Ready(Guard),
68 Abort(Error),
70 Retry,
72}
73
74impl<Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug for PolicyOutput<Guard, Error> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Self::Ready(guard) => write!(f, "PolicyOutput::Ready({guard:?})"),
78 Self::Abort(error) => write!(f, "PolicyOutput::Abort({error:?})"),
79 Self::Retry => write!(f, "PolicyOutput::Retry"),
80 }
81 }
82}
83
84pub trait Policy<State, Request>: Send + Sync + 'static {
87 type Guard: Send + 'static;
91 type Error: Send + 'static;
96
97 fn check(
103 &self,
104 ctx: Context<State>,
105 request: Request,
106 ) -> impl Future<Output = PolicyResult<State, Request, Self::Guard, Self::Error>> + Send + '_;
107}
108
109impl<State, Request, P> Policy<State, Request> for Option<P>
110where
111 P: Policy<State, Request>,
112 State: Clone + Send + Sync + 'static,
113 Request: Send + 'static,
114{
115 type Guard = Option<P::Guard>;
116 type Error = P::Error;
117
118 async fn check(
119 &self,
120 ctx: Context<State>,
121 request: Request,
122 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
123 match self {
124 Some(policy) => {
125 let result = policy.check(ctx, request).await;
126 match result.output {
127 PolicyOutput::Ready(guard) => PolicyResult {
128 ctx: result.ctx,
129 request: result.request,
130 output: PolicyOutput::Ready(Some(guard)),
131 },
132 PolicyOutput::Abort(err) => PolicyResult {
133 ctx: result.ctx,
134 request: result.request,
135 output: PolicyOutput::Abort(err),
136 },
137 PolicyOutput::Retry => PolicyResult {
138 ctx: result.ctx,
139 request: result.request,
140 output: PolicyOutput::Retry,
141 },
142 }
143 }
144 None => PolicyResult {
145 ctx,
146 request,
147 output: PolicyOutput::Ready(None),
148 },
149 }
150 }
151}
152
153impl<State, Request, P> Policy<State, Request> for &'static P
154where
155 P: Policy<State, Request>,
156 State: Clone + Send + Sync + 'static,
157 Request: Send + 'static,
158{
159 type Guard = P::Guard;
160 type Error = P::Error;
161
162 async fn check(
163 &self,
164 ctx: Context<State>,
165 request: Request,
166 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
167 (**self).check(ctx, request).await
168 }
169}
170
171impl<State, Request, P> Policy<State, Request> for Arc<P>
172where
173 P: Policy<State, Request>,
174 State: Clone + Send + Sync + 'static,
175 Request: Send + 'static,
176{
177 type Guard = P::Guard;
178 type Error = P::Error;
179
180 async fn check(
181 &self,
182 ctx: Context<State>,
183 request: Request,
184 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
185 self.as_ref().check(ctx, request).await
186 }
187}
188
189impl<State, Request, P> Policy<State, Request> for Box<P>
190where
191 P: Policy<State, Request>,
192 State: Clone + Send + Sync + 'static,
193 Request: Send + 'static,
194{
195 type Guard = P::Guard;
196 type Error = P::Error;
197
198 async fn check(
199 &self,
200 ctx: Context<State>,
201 request: Request,
202 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
203 self.as_ref().check(ctx, request).await
204 }
205}
206
207#[derive(Debug, Clone, Default)]
208#[non_exhaustive]
209pub struct UnlimitedPolicy;
211
212impl UnlimitedPolicy {
213 pub const fn new() -> Self {
215 UnlimitedPolicy
216 }
217}
218
219impl<State, Request> Policy<State, Request> for UnlimitedPolicy
220where
221 State: Clone + Send + Sync + 'static,
222 Request: Send + 'static,
223{
224 type Guard = ();
225 type Error = Infallible;
226
227 async fn check(
228 &self,
229 ctx: Context<State>,
230 request: Request,
231 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
232 PolicyResult {
233 ctx,
234 request,
235 output: PolicyOutput::Ready(()),
236 }
237 }
238}
239
240macro_rules! impl_limit_policy_either {
241 ($id:ident, $($param:ident),+ $(,)?) => {
242 impl<$($param),+, State, Request> Policy<State, Request> for crate::combinators::$id<$($param),+>
243 where
244 $(
245 $param: Policy<State, Request>,
246 $param::Error: Into<BoxError>,
247 )+
248 Request: Send + 'static,
249 State: Clone + Send + Sync + 'static,
250 {
251 type Guard = crate::combinators::$id<$($param::Guard),+>;
252 type Error = BoxError;
253
254 async fn check(
255 &self,
256 ctx: Context<State>,
257 req: Request,
258 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
259 match self {
260 $(
261 crate::combinators::$id::$param(policy) => {
262 let result = policy.check(ctx, req).await;
263 match result.output {
264 PolicyOutput::Ready(guard) => PolicyResult {
265 ctx: result.ctx,
266 request: result.request,
267 output: PolicyOutput::Ready(crate::combinators::$id::$param(guard)),
268 },
269 PolicyOutput::Abort(err) => PolicyResult {
270 ctx: result.ctx,
271 request: result.request,
272 output: PolicyOutput::Abort(err.into()),
273 },
274 PolicyOutput::Retry => PolicyResult {
275 ctx: result.ctx,
276 request: result.request,
277 output: PolicyOutput::Retry,
278 },
279 }
280 }
281 )+
282 }
283 }
284 }
285 };
286}
287
288crate::combinators::impl_either!(impl_limit_policy_either);