1use std::fmt;
6
7use crate::error::BoxError;
8use crate::{Context, Service};
9use into_response::{ErrorIntoResponse, ErrorIntoResponseFn};
10use rama_utils::macros::define_inner_service_accessors;
11
12pub mod policy;
13use policy::UnlimitedPolicy;
14pub use policy::{Policy, PolicyOutput};
15
16mod layer;
17#[doc(inline)]
18pub use layer::LimitLayer;
19
20mod into_response;
21
22pub struct Limit<S, P, F = ()> {
26 inner: S,
27 policy: P,
28 error_into_response: F,
29}
30
31impl<S, P> Limit<S, P, ()> {
32 pub const fn new(inner: S, policy: P) -> Self {
35 Limit {
36 inner,
37 policy,
38 error_into_response: (),
39 }
40 }
41
42 pub fn with_error_into_response_fn<F>(self, f: F) -> Limit<S, P, ErrorIntoResponseFn<F>> {
45 Limit {
46 inner: self.inner,
47 policy: self.policy,
48 error_into_response: ErrorIntoResponseFn(f),
49 }
50 }
51
52 define_inner_service_accessors!();
53}
54
55impl<T> Limit<T, UnlimitedPolicy, ()> {
56 pub const fn unlimited(inner: T) -> Self {
60 Limit {
61 inner,
62 policy: UnlimitedPolicy,
63 error_into_response: (),
64 }
65 }
66}
67
68impl<T: fmt::Debug, P: fmt::Debug, F: fmt::Debug> fmt::Debug for Limit<T, P, F> {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("Limit")
71 .field("inner", &self.inner)
72 .field("policy", &self.policy)
73 .field("error_into_response", &self.error_into_response)
74 .finish()
75 }
76}
77
78impl<T, P, F> Clone for Limit<T, P, F>
79where
80 T: Clone,
81 P: Clone,
82 F: Clone,
83{
84 fn clone(&self) -> Self {
85 Limit {
86 inner: self.inner.clone(),
87 policy: self.policy.clone(),
88 error_into_response: self.error_into_response.clone(),
89 }
90 }
91}
92
93impl<T, P, State, Request> Service<State, Request> for Limit<T, P, ()>
94where
95 T: Service<State, Request, Error: Into<BoxError>>,
96 P: policy::Policy<State, Request, Error: Into<BoxError>>,
97 Request: Send + Sync + 'static,
98 State: Clone + Send + Sync + 'static,
99{
100 type Response = T::Response;
101 type Error = BoxError;
102
103 async fn serve(
104 &self,
105 mut ctx: Context<State>,
106 mut request: Request,
107 ) -> Result<Self::Response, Self::Error> {
108 loop {
109 let result = self.policy.check(ctx, request).await;
110 ctx = result.ctx;
111 request = result.request;
112
113 match result.output {
114 policy::PolicyOutput::Ready(guard) => {
115 let _ = guard;
116 return self.inner.serve(ctx, request).await.map_err(Into::into);
117 }
118 policy::PolicyOutput::Abort(err) => return Err(err.into()),
119 policy::PolicyOutput::Retry => (),
120 }
121 }
122 }
123}
124
125impl<T, P, F, State, Request, FnResponse, FnError> Service<State, Request>
126 for Limit<T, P, ErrorIntoResponseFn<F>>
127where
128 T: Service<State, Request>,
129 P: policy::Policy<State, Request>,
130 F: Fn(P::Error) -> Result<FnResponse, FnError> + Send + Sync + 'static,
131 FnResponse: Into<T::Response> + Send + 'static,
132 FnError: Into<T::Error> + Send + Sync + 'static,
133 Request: Send + Sync + 'static,
134 State: Clone + Send + Sync + 'static,
135{
136 type Response = T::Response;
137 type Error = T::Error;
138
139 async fn serve(
140 &self,
141 mut ctx: Context<State>,
142 mut request: Request,
143 ) -> Result<Self::Response, Self::Error> {
144 loop {
145 let result = self.policy.check(ctx, request).await;
146 ctx = result.ctx;
147 request = result.request;
148
149 match result.output {
150 policy::PolicyOutput::Ready(guard) => {
151 let _ = guard;
152 return self.inner.serve(ctx, request).await;
153 }
154 policy::PolicyOutput::Abort(err) => {
155 return match self.error_into_response.error_into_response(err) {
156 Ok(ok) => Ok(ok.into()),
157 Err(err) => Err(err.into()),
158 };
159 }
160 policy::PolicyOutput::Retry => (),
161 }
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::policy::ConcurrentPolicy;
169 use super::*;
170
171 use crate::{Context, Layer, Service, service::service_fn};
172 use std::convert::Infallible;
173
174 use futures_lite::future::zip;
175
176 #[tokio::test]
177 async fn test_limit() {
178 async fn handle_request<State, Request>(
179 _ctx: Context<State>,
180 req: Request,
181 ) -> Result<Request, Infallible> {
182 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
183 Ok(req)
184 }
185
186 let layer: LimitLayer<ConcurrentPolicy<_, _>> = LimitLayer::new(ConcurrentPolicy::max(1));
187
188 let service_1 = layer.layer(service_fn(handle_request));
189 let service_2 = layer.layer(service_fn(handle_request));
190
191 let future_1 = service_1.serve(Context::default(), "Hello");
192 let future_2 = service_2.serve(Context::default(), "Hello");
193
194 let (result_1, result_2) = zip(future_1, future_2).await;
195
196 if result_1.is_err() {
198 assert_eq!(result_2.unwrap(), "Hello");
199 } else {
200 assert_eq!(result_1.unwrap(), "Hello");
201 assert!(result_2.is_err());
202 }
203 }
204
205 #[tokio::test]
206 async fn test_with_error_into_response_fn() {
207 async fn handle_request<State, Request>(
208 _ctx: Context<State>,
209 _req: Request,
210 ) -> Result<&'static str, Infallible> {
211 Ok("good")
212 }
213
214 let layer: LimitLayer<ConcurrentPolicy<_, _>, _> =
215 LimitLayer::new(ConcurrentPolicy::max(0))
216 .with_error_into_response_fn(|_| Ok::<_, Infallible>("bad"));
217
218 let service = layer.layer(service_fn(handle_request));
219
220 let resp = service.serve(Context::default(), "Hello").await.unwrap();
221 assert_eq!("bad", resp);
222 }
223
224 #[tokio::test]
225 async fn test_zero_limit() {
226 async fn handle_request<State, Request>(
227 _ctx: Context<State>,
228 req: Request,
229 ) -> Result<Request, Infallible> {
230 Ok(req)
231 }
232
233 let layer: LimitLayer<ConcurrentPolicy<_, _>> = LimitLayer::new(ConcurrentPolicy::max(0));
234
235 let service_1 = layer.layer(service_fn(handle_request));
236 let result_1 = service_1.serve(Context::default(), "Hello").await;
237 assert!(result_1.is_err());
238 }
239}