1use tower_async_service::Service;
6
7use crate::BoxError;
8
9pub mod policy;
10pub use policy::{Policy, PolicyOutput};
11
12mod layer;
13pub use layer::LimitLayer;
14
15#[derive(Debug)]
17pub struct Limit<T, P> {
18 inner: T,
19 policy: P,
20}
21
22impl<T, P> Limit<T, P> {
23 pub fn new(inner: T, policy: P) -> Self {
26 Limit { inner, policy }
27 }
28}
29
30impl<T, P> Clone for Limit<T, P>
31where
32 T: Clone,
33 P: Clone,
34{
35 fn clone(&self) -> Self {
36 Limit {
37 inner: self.inner.clone(),
38 policy: self.policy.clone(),
39 }
40 }
41}
42
43impl<T, P, Request> Service<Request> for Limit<T, P>
44where
45 T: Service<Request>,
46 T::Error: Into<BoxError>,
47 P: policy::Policy<Request>,
48 P::Error: Into<BoxError>,
49{
50 type Response = T::Response;
51 type Error = BoxError;
52
53 async fn call(&self, request: Request) -> Result<Self::Response, Self::Error> {
54 let mut request = request;
55 loop {
56 match self.policy.check(&mut request).await {
57 policy::PolicyOutput::Ready(guard) => {
58 let _ = guard;
59 return self.inner.call(request).await.map_err(Into::into);
60 }
61 policy::PolicyOutput::Abort(err) => return Err(err.into()),
62 policy::PolicyOutput::Retry => (),
63 }
64 }
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use std::convert::Infallible;
71
72 use crate::limit::policy::ConcurrentPolicy;
73 use crate::service_fn;
74
75 use super::*;
76
77 use futures_util::future::join_all;
78 use tower_async_layer::Layer;
79 use tower_async_service::Service;
80
81 #[tokio::test]
82 async fn test_limit() {
83 async fn handle_request<Request>(req: Request) -> Result<Request, Infallible> {
84 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
85 Ok(req)
86 }
87
88 let layer: LimitLayer<ConcurrentPolicy<()>> = LimitLayer::new(ConcurrentPolicy::new(1));
89
90 let service_1 = layer.layer(service_fn(handle_request));
91 let service_2 = layer.layer(service_fn(handle_request));
92
93 let future_1 = service_1.call("Hello");
94 let future_2 = service_2.call("Hello");
95
96 let mut results = join_all(vec![future_1, future_2]).await;
97 let result_1 = results.pop().unwrap();
98 let result_2 = results.pop().unwrap();
99
100 if result_1.is_err() {
102 assert_eq!(result_2.unwrap(), "Hello");
103 } else {
104 assert_eq!(result_1.unwrap(), "Hello");
105 assert!(result_2.is_err());
106 }
107 }
108}