rama_core/layer/limit/policy/
mod.rs1use crate::error::BoxError;
31use crate::Context;
32use std::{convert::Infallible, fmt, sync::Arc};
33
34mod concurrent;
35#[doc(inline)]
36pub use concurrent::{ConcurrentCounter, ConcurrentPolicy, ConcurrentTracker, LimitReached};
37
38mod matcher;
39
40pub struct PolicyResult<State, Request, Guard, Error> {
42 pub ctx: Context<State>,
44 pub request: Request,
46 pub output: PolicyOutput<Guard, Error>,
48}
49
50impl<State: fmt::Debug, Request: fmt::Debug, Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug
51 for PolicyResult<State, Request, Guard, Error>
52{
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("PolicyResult")
55 .field("ctx", &self.ctx)
56 .field("request", &self.request)
57 .field("output", &self.output)
58 .finish()
59 }
60}
61
62pub enum PolicyOutput<Guard, Error> {
64 Ready(Guard),
68 Abort(Error),
70 Retry,
72}
73
74impl<Guard: fmt::Debug, Error: fmt::Debug> std::fmt::Debug for PolicyOutput<Guard, Error> {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Self::Ready(guard) => write!(f, "PolicyOutput::Ready({guard:?})"),
78 Self::Abort(error) => write!(f, "PolicyOutput::Abort({error:?})"),
79 Self::Retry => write!(f, "PolicyOutput::Retry"),
80 }
81 }
82}
83
84pub trait Policy<State, Request>: Send + Sync + 'static {
87 type Guard: Send + 'static;
91 type Error: Send + Sync + 'static;
96
97 fn check(
103 &self,
104 ctx: Context<State>,
105 request: Request,
106 ) -> impl std::future::Future<Output = PolicyResult<State, Request, Self::Guard, Self::Error>>
107 + Send
108 + '_;
109}
110
111impl<State, Request, P> Policy<State, Request> for Option<P>
112where
113 P: Policy<State, Request>,
114 State: Clone + Send + Sync + 'static,
115 Request: Send + 'static,
116{
117 type Guard = Option<P::Guard>;
118 type Error = P::Error;
119
120 async fn check(
121 &self,
122 ctx: Context<State>,
123 request: Request,
124 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
125 match self {
126 Some(policy) => {
127 let result = policy.check(ctx, request).await;
128 match result.output {
129 PolicyOutput::Ready(guard) => PolicyResult {
130 ctx: result.ctx,
131 request: result.request,
132 output: PolicyOutput::Ready(Some(guard)),
133 },
134 PolicyOutput::Abort(err) => PolicyResult {
135 ctx: result.ctx,
136 request: result.request,
137 output: PolicyOutput::Abort(err),
138 },
139 PolicyOutput::Retry => PolicyResult {
140 ctx: result.ctx,
141 request: result.request,
142 output: PolicyOutput::Retry,
143 },
144 }
145 }
146 None => PolicyResult {
147 ctx,
148 request,
149 output: PolicyOutput::Ready(None),
150 },
151 }
152 }
153}
154
155impl<State, Request, P> Policy<State, Request> for &'static P
156where
157 P: Policy<State, Request>,
158 State: Clone + Send + Sync + 'static,
159 Request: Send + 'static,
160{
161 type Guard = P::Guard;
162 type Error = P::Error;
163
164 async fn check(
165 &self,
166 ctx: Context<State>,
167 request: Request,
168 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
169 (**self).check(ctx, request).await
170 }
171}
172
173impl<State, Request, P> Policy<State, Request> for Arc<P>
174where
175 P: Policy<State, Request>,
176 State: Clone + Send + Sync + 'static,
177 Request: Send + 'static,
178{
179 type Guard = P::Guard;
180 type Error = P::Error;
181
182 async fn check(
183 &self,
184 ctx: Context<State>,
185 request: Request,
186 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
187 self.as_ref().check(ctx, request).await
188 }
189}
190
191impl<State, Request, P> Policy<State, Request> for Box<P>
192where
193 P: Policy<State, Request>,
194 State: Clone + Send + Sync + 'static,
195 Request: Send + 'static,
196{
197 type Guard = P::Guard;
198 type Error = P::Error;
199
200 async fn check(
201 &self,
202 ctx: Context<State>,
203 request: Request,
204 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
205 self.as_ref().check(ctx, request).await
206 }
207}
208
209#[derive(Debug, Clone, Default)]
210#[non_exhaustive]
211pub struct UnlimitedPolicy;
213
214impl UnlimitedPolicy {
215 pub const fn new() -> Self {
217 UnlimitedPolicy
218 }
219}
220
221impl<State, Request> Policy<State, Request> for UnlimitedPolicy
222where
223 State: Clone + Send + Sync + 'static,
224 Request: Send + 'static,
225{
226 type Guard = ();
227 type Error = Infallible;
228
229 async fn check(
230 &self,
231 ctx: Context<State>,
232 request: Request,
233 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
234 PolicyResult {
235 ctx,
236 request,
237 output: PolicyOutput::Ready(()),
238 }
239 }
240}
241
242macro_rules! impl_limit_policy_either {
243 ($id:ident, $($param:ident),+ $(,)?) => {
244 impl<$($param),+, State, Request> Policy<State, Request> for crate::combinators::$id<$($param),+>
245 where
246 $(
247 $param: Policy<State, Request>,
248 $param::Error: Into<BoxError>,
249 )+
250 Request: Send + 'static,
251 State: Clone + Send + Sync + 'static,
252 {
253 type Guard = crate::combinators::$id<$($param::Guard),+>;
254 type Error = BoxError;
255
256 async fn check(
257 &self,
258 ctx: Context<State>,
259 req: Request,
260 ) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
261 match self {
262 $(
263 crate::combinators::$id::$param(policy) => {
264 let result = policy.check(ctx, req).await;
265 match result.output {
266 PolicyOutput::Ready(guard) => PolicyResult {
267 ctx: result.ctx,
268 request: result.request,
269 output: PolicyOutput::Ready(crate::combinators::$id::$param(guard)),
270 },
271 PolicyOutput::Abort(err) => PolicyResult {
272 ctx: result.ctx,
273 request: result.request,
274 output: PolicyOutput::Abort(err.into()),
275 },
276 PolicyOutput::Retry => PolicyResult {
277 ctx: result.ctx,
278 request: result.request,
279 output: PolicyOutput::Retry,
280 },
281 }
282 }
283 )+
284 }
285 }
286 }
287 };
288}
289
290crate::combinators::impl_either!(impl_limit_policy_either);