requiem_utils/
inflight.rs1use std::convert::Infallible;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use requiem_service::{IntoService, Service, Transform};
7use futures::future::{ok, Ready};
8
9use super::counter::{Counter, CounterGuard};
10
11pub struct InFlight {
16 max_inflight: usize,
17}
18
19impl InFlight {
20 pub fn new(max: usize) -> Self {
21 Self { max_inflight: max }
22 }
23}
24
25impl Default for InFlight {
26 fn default() -> Self {
27 Self::new(15)
28 }
29}
30
31impl<S> Transform<S> for InFlight
32where
33 S: Service,
34{
35 type Request = S::Request;
36 type Response = S::Response;
37 type Error = S::Error;
38 type InitError = Infallible;
39 type Transform = InFlightService<S>;
40 type Future = Ready<Result<Self::Transform, Self::InitError>>;
41
42 fn new_transform(&self, service: S) -> Self::Future {
43 ok(InFlightService::new(self.max_inflight, service))
44 }
45}
46
47pub struct InFlightService<S> {
48 count: Counter,
49 service: S,
50}
51
52impl<S> InFlightService<S>
53where
54 S: Service,
55{
56 pub fn new<U>(max: usize, service: U) -> Self
57 where
58 U: IntoService<S>,
59 {
60 Self {
61 count: Counter::new(max),
62 service: service.into_service(),
63 }
64 }
65}
66
67impl<T> Service for InFlightService<T>
68where
69 T: Service,
70{
71 type Request = T::Request;
72 type Response = T::Response;
73 type Error = T::Error;
74 type Future = InFlightServiceResponse<T>;
75
76 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77 if let Poll::Pending = self.service.poll_ready(cx)? {
78 Poll::Pending
79 } else if !self.count.available(cx) {
80 log::trace!("InFlight limit exceeded");
81 Poll::Pending
82 } else {
83 Poll::Ready(Ok(()))
84 }
85 }
86
87 fn call(&mut self, req: T::Request) -> Self::Future {
88 InFlightServiceResponse {
89 fut: self.service.call(req),
90 _guard: self.count.get(),
91 }
92 }
93}
94
95#[doc(hidden)]
96#[pin_project::pin_project]
97pub struct InFlightServiceResponse<T: Service> {
98 #[pin]
99 fut: T::Future,
100 _guard: CounterGuard,
101}
102
103impl<T: Service> Future for InFlightServiceResponse<T> {
104 type Output = Result<T::Response, T::Error>;
105
106 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107 self.project().fut.poll(cx)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113
114 use std::task::{Context, Poll};
115 use std::time::Duration;
116
117 use super::*;
118 use requiem_service::{apply, fn_factory, Service, ServiceFactory};
119 use futures::future::{lazy, ok, FutureExt, LocalBoxFuture};
120
121 struct SleepService(Duration);
122
123 impl Service for SleepService {
124 type Request = ();
125 type Response = ();
126 type Error = ();
127 type Future = LocalBoxFuture<'static, Result<(), ()>>;
128
129 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130 Poll::Ready(Ok(()))
131 }
132
133 fn call(&mut self, _: ()) -> Self::Future {
134 requiem_rt::time::delay_for(self.0)
135 .then(|_| ok::<_, ()>(()))
136 .boxed_local()
137 }
138 }
139
140 #[requiem_rt::test]
141 async fn test_transform() {
142 let wait_time = Duration::from_millis(50);
143
144 let mut srv = InFlightService::new(1, SleepService(wait_time));
145 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
146
147 let res = srv.call(());
148 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
149
150 let _ = res.await;
151 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
152 }
153
154 #[requiem_rt::test]
155 async fn test_newtransform() {
156 let wait_time = Duration::from_millis(50);
157
158 let srv = apply(InFlight::new(1), fn_factory(|| ok(SleepService(wait_time))));
159
160 let mut srv = srv.new_service(&()).await.unwrap();
161 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
162
163 let res = srv.call(());
164 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
165
166 let _ = res.await;
167 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
168 }
169}