1use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll};
3
4use ntex_service::{Middleware, Service, ServiceCtx};
5use ntex_util::{future::join, task::LocalWaker};
6
7pub trait SizedRequest {
9 fn size(&self) -> u32;
10
11 fn is_publish(&self) -> bool;
12
13 fn is_chunk(&self) -> bool;
14}
15
16pub struct InFlightService {
20 max_receive: u16,
21 max_receive_size: usize,
22}
23
24impl Default for InFlightService {
25 fn default() -> Self {
26 Self { max_receive: 16, max_receive_size: 65535 }
27 }
28}
29
30impl InFlightService {
31 pub fn new(max_receive: u16, max_receive_size: usize) -> Self {
35 Self { max_receive, max_receive_size }
36 }
37
38 pub fn max_receive(mut self, val: u16) -> Self {
42 self.max_receive = val;
43 self
44 }
45
46 pub fn max_receive_size(mut self, val: usize) -> Self {
50 self.max_receive_size = val;
51 self
52 }
53}
54
55impl<S> Middleware<S> for InFlightService {
56 type Service = InFlightServiceImpl<S>;
57
58 #[inline]
59 fn create(&self, service: S) -> Self::Service {
60 InFlightServiceImpl {
61 service,
62 publish: Cell::new(false),
63 count: Counter::new(self.max_receive, self.max_receive_size),
64 }
65 }
66}
67
68pub struct InFlightServiceImpl<S> {
69 count: Counter,
70 service: S,
71 publish: Cell<bool>,
72}
73
74impl<S, R> Service<R> for InFlightServiceImpl<S>
75where
76 S: Service<R>,
77 R: SizedRequest + 'static,
78{
79 type Response = S::Response;
80 type Error = S::Error;
81
82 #[inline]
83 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> {
84 if self.publish.get() || self.count.is_available() {
85 ctx.ready(&self.service).await
86 } else {
87 join(self.count.available(), ctx.ready(&self.service)).await.1
88 }
89 }
90
91 #[inline]
92 async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
93 if self.publish.get() && !req.is_chunk() {
95 self.publish.set(false);
96 }
97 if req.is_publish() {
98 self.publish.set(true);
99 }
100
101 let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
102 let task_guard = self.count.get(size);
103 let result = ctx.call(&self.service, req).await;
104 drop(task_guard);
105 result
106 }
107
108 ntex_service::forward_poll!(service);
109 ntex_service::forward_shutdown!(service);
110}
111
112struct Counter(Rc<CounterInner>);
113
114struct CounterInner {
115 max_cap: u16,
116 cur_cap: Cell<u16>,
117 max_size: usize,
118 cur_size: Cell<usize>,
119 task: LocalWaker,
120}
121
122impl Counter {
123 fn new(max_cap: u16, max_size: usize) -> Self {
124 Counter(Rc::new(CounterInner {
125 max_cap,
126 max_size,
127 cur_cap: Cell::new(0),
128 cur_size: Cell::new(0),
129 task: LocalWaker::new(),
130 }))
131 }
132
133 fn get(&self, size: u32) -> CounterGuard {
134 CounterGuard::new(size, self.0.clone())
135 }
136
137 fn is_available(&self) -> bool {
138 (self.0.max_cap == 0 || self.0.cur_cap.get() < self.0.max_cap)
139 && (self.0.max_size == 0 || self.0.cur_size.get() <= self.0.max_size)
140 }
141
142 async fn available(&self) {
143 poll_fn(|cx| if self.0.available(cx) { Poll::Ready(()) } else { Poll::Pending }).await
144 }
145}
146
147struct CounterGuard(u32, Rc<CounterInner>);
148
149impl CounterGuard {
150 fn new(size: u32, inner: Rc<CounterInner>) -> Self {
151 inner.inc(size);
152 CounterGuard(size, inner)
153 }
154}
155
156impl Unpin for CounterGuard {}
157
158impl Drop for CounterGuard {
159 fn drop(&mut self) {
160 self.1.dec(self.0);
161 }
162}
163
164impl CounterInner {
165 fn inc(&self, size: u32) {
166 let cur_cap = self.cur_cap.get() + 1;
167 self.cur_cap.set(cur_cap);
168 let cur_size = self.cur_size.get() + size as usize;
169 self.cur_size.set(cur_size);
170
171 if cur_cap == self.max_cap || cur_size >= self.max_size {
172 self.task.wake();
173 }
174 }
175
176 fn dec(&self, size: u32) {
177 let num = self.cur_cap.get();
178 self.cur_cap.set(num - 1);
179
180 let cur_size = self.cur_size.get();
181 let new_size = cur_size - (size as usize);
182 self.cur_size.set(new_size);
183
184 if num == self.max_cap || (cur_size > self.max_size && new_size <= self.max_size) {
185 self.task.wake();
186 }
187 }
188
189 fn available(&self, cx: &Context<'_>) -> bool {
190 self.task.register(cx.waker());
191 (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
192 && (self.max_size == 0 || self.cur_size.get() <= self.max_size)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use std::{future::poll_fn, time::Duration};
199
200 use ntex_service::Pipeline;
201 use ntex_util::{future::lazy, task::LocalWaker, time::sleep};
202
203 use super::*;
204
205 struct SleepService(Duration);
206
207 impl Service<()> for SleepService {
208 type Response = ();
209 type Error = ();
210
211 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
212 let fut = sleep(self.0);
213 let _ = fut.await;
214 Ok::<_, ()>(())
215 }
216 }
217
218 impl SizedRequest for () {
219 fn size(&self) -> u32 {
220 12
221 }
222
223 fn is_publish(&self) -> bool {
224 false
225 }
226
227 fn is_chunk(&self) -> bool {
228 false
229 }
230 }
231
232 #[ntex_macros::rt_test]
233 async fn test_inflight() {
234 let wait_time = Duration::from_millis(50);
235
236 let srv =
237 Pipeline::new(InFlightService::new(1, 0).create(SleepService(wait_time))).bind();
238 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
239
240 let srv2 = srv.clone();
241 ntex_util::spawn(async move {
242 let _ = srv2.call(()).await;
243 });
244 ntex_util::time::sleep(Duration::from_millis(25)).await;
245 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
246
247 ntex_util::time::sleep(Duration::from_millis(50)).await;
248 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
249 assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
250 }
251
252 #[ntex_macros::rt_test]
253 async fn test_inflight2() {
254 let wait_time = Duration::from_millis(50);
255
256 let srv =
257 Pipeline::new(InFlightService::new(0, 10).create(SleepService(wait_time))).bind();
258 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
259
260 let srv2 = srv.clone();
261 ntex_util::spawn(async move {
262 let _ = srv2.call(()).await;
263 });
264 ntex_util::time::sleep(Duration::from_millis(25)).await;
265 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
266
267 ntex_util::time::sleep(Duration::from_millis(100)).await;
268 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
269 }
270
271 struct Srv2 {
272 dur: Duration,
273 cnt: Cell<bool>,
274 waker: LocalWaker,
275 }
276
277 impl Service<()> for Srv2 {
278 type Response = ();
279 type Error = ();
280
281 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), ()> {
282 poll_fn(|cx| {
283 if !self.cnt.get() {
284 Poll::Ready(Ok(()))
285 } else {
286 self.waker.register(cx.waker());
287 Poll::Pending
288 }
289 })
290 .await
291 }
292
293 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
294 let fut = sleep(self.dur);
295 self.cnt.set(true);
296 self.waker.wake();
297
298 let _ = fut.await;
299 self.cnt.set(false);
300 self.waker.wake();
301 Ok::<_, ()>(())
302 }
303 }
304
305 #[ntex_macros::rt_test]
309 async fn test_inflight3() {
310 let wait_time = Duration::from_millis(50);
311
312 let srv = Pipeline::new(InFlightService::new(1, 10).create(Srv2 {
313 dur: wait_time,
314 cnt: Cell::new(false),
315 waker: LocalWaker::new(),
316 }))
317 .bind();
318 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
319
320 let srv2 = srv.clone();
321 ntex_util::spawn(async move {
322 let _ = srv2.call(()).await;
323 });
324 ntex_util::time::sleep(Duration::from_millis(25)).await;
325 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
326
327 let srv2 = srv.clone();
328 let (tx, rx) = ntex_util::channel::oneshot::channel();
329 ntex_util::spawn(async move {
330 let _ = poll_fn(|cx| srv2.poll_ready(cx)).await;
331 let _ = tx.send(());
332 });
333 assert_eq!(poll_fn(|cx| srv.poll_ready(cx)).await, Ok(()));
334
335 let _ = rx.await;
336 }
337}