1use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll};
3
4use ntex_service::{Service, ServiceCtx};
5use ntex_util::{future::join, task::LocalWaker};
6
7pub trait SizedRequest {
9 fn size(&self) -> u32;
10
11 fn is_publish(&self) -> bool;
12
13 fn is_chunk(&self) -> bool;
14}
15
16pub struct InFlightServiceImpl<S> {
17 count: Counter,
18 service: S,
19 publish: Cell<bool>,
20}
21
22impl<S> InFlightServiceImpl<S> {
23 pub fn new(max_cap: u16, max_size: usize, service: S) -> Self {
24 InFlightServiceImpl {
25 service,
26 publish: Cell::new(false),
27 count: Counter::new(max_cap, max_size),
28 }
29 }
30}
31
32impl<S, R> Service<R> for InFlightServiceImpl<S>
33where
34 S: Service<R>,
35 R: SizedRequest + 'static,
36{
37 type Response = S::Response;
38 type Error = S::Error;
39
40 #[inline]
41 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> {
42 if self.publish.get() || self.count.is_available() {
43 ctx.ready(&self.service).await
44 } else {
45 join(self.count.available(), ctx.ready(&self.service)).await.1
46 }
47 }
48
49 #[inline]
50 async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
51 if self.publish.get() && !req.is_chunk() {
53 self.publish.set(false);
54 }
55 if req.is_publish() {
56 self.publish.set(true);
57 }
58
59 let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
60 let task_guard = self.count.get(size);
61 let result = ctx.call(&self.service, req).await;
62 drop(task_guard);
63 result
64 }
65
66 ntex_service::forward_poll!(service);
67 ntex_service::forward_shutdown!(service);
68}
69
70struct Counter(Rc<CounterInner>);
71
72struct CounterInner {
73 max_cap: u16,
74 cur_cap: Cell<u16>,
75 max_size: usize,
76 cur_size: Cell<usize>,
77 task: LocalWaker,
78}
79
80impl Counter {
81 fn new(max_cap: u16, max_size: usize) -> Self {
82 Counter(Rc::new(CounterInner {
83 max_cap,
84 max_size,
85 cur_cap: Cell::new(0),
86 cur_size: Cell::new(0),
87 task: LocalWaker::new(),
88 }))
89 }
90
91 fn get(&self, size: u32) -> CounterGuard {
92 CounterGuard::new(size, self.0.clone())
93 }
94
95 fn is_available(&self) -> bool {
96 (self.0.max_cap == 0 || self.0.cur_cap.get() < self.0.max_cap)
97 && (self.0.max_size == 0 || self.0.cur_size.get() <= self.0.max_size)
98 }
99
100 async fn available(&self) {
101 poll_fn(|cx| {
102 if self.0.available(cx) {
103 Poll::Ready(())
104 } else {
105 Poll::Pending
106 }
107 })
108 .await
109 }
110}
111
112struct CounterGuard(u32, Rc<CounterInner>);
113
114impl CounterGuard {
115 fn new(size: u32, inner: Rc<CounterInner>) -> Self {
116 inner.inc(size);
117 CounterGuard(size, inner)
118 }
119}
120
121impl Unpin for CounterGuard {}
122
123impl Drop for CounterGuard {
124 fn drop(&mut self) {
125 self.1.dec(self.0);
126 }
127}
128
129impl CounterInner {
130 fn inc(&self, size: u32) {
131 let cur_cap = self.cur_cap.get() + 1;
132 self.cur_cap.set(cur_cap);
133 let cur_size = self.cur_size.get() + size as usize;
134 self.cur_size.set(cur_size);
135
136 if cur_cap == self.max_cap || cur_size >= self.max_size {
137 self.task.wake();
138 }
139 }
140
141 fn dec(&self, size: u32) {
142 let num = self.cur_cap.get();
143 self.cur_cap.set(num - 1);
144
145 let cur_size = self.cur_size.get();
146 let new_size = cur_size - (size as usize);
147 self.cur_size.set(new_size);
148
149 if num == self.max_cap || (cur_size > self.max_size && new_size <= self.max_size) {
150 self.task.wake();
151 }
152 }
153
154 fn available(&self, cx: &Context<'_>) -> bool {
155 self.task.register(cx.waker());
156 (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
157 && (self.max_size == 0 || self.cur_size.get() <= self.max_size)
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use std::{future::poll_fn, time::Duration};
164
165 use ntex_service::Pipeline;
166 use ntex_util::{future::lazy, task::LocalWaker, time::sleep};
167
168 use super::*;
169
170 struct SleepService(Duration);
171
172 impl Service<()> for SleepService {
173 type Response = ();
174 type Error = ();
175
176 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
177 let fut = sleep(self.0);
178 let _ = fut.await;
179 Ok::<_, ()>(())
180 }
181 }
182
183 impl SizedRequest for () {
184 fn size(&self) -> u32 {
185 12
186 }
187
188 fn is_publish(&self) -> bool {
189 false
190 }
191
192 fn is_chunk(&self) -> bool {
193 false
194 }
195 }
196
197 #[ntex::test]
198 async fn test_inflight() {
199 let wait_time = Duration::from_millis(50);
200
201 let srv = Pipeline::new(InFlightServiceImpl::new(1, 0, SleepService(wait_time))).bind();
202 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
203
204 let srv2 = srv.clone();
205 ntex_util::spawn(async move {
206 let _ = srv2.call(()).await;
207 });
208 ntex_util::time::sleep(Duration::from_millis(25)).await;
209 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
210
211 ntex_util::time::sleep(Duration::from_millis(50)).await;
212 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
213 assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
214 }
215
216 #[ntex::test]
217 async fn test_inflight2() {
218 let wait_time = Duration::from_millis(50);
219
220 let srv =
221 Pipeline::new(InFlightServiceImpl::new(0, 10, SleepService(wait_time))).bind();
222 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
223
224 let srv2 = srv.clone();
225 ntex_util::spawn(async move {
226 let _ = srv2.call(()).await;
227 });
228 ntex_util::time::sleep(Duration::from_millis(25)).await;
229 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
230
231 ntex_util::time::sleep(Duration::from_millis(100)).await;
232 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
233 }
234
235 struct Srv2 {
236 dur: Duration,
237 cnt: Cell<bool>,
238 waker: LocalWaker,
239 }
240
241 impl Service<()> for Srv2 {
242 type Response = ();
243 type Error = ();
244
245 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), ()> {
246 poll_fn(|cx| {
247 if !self.cnt.get() {
248 Poll::Ready(Ok(()))
249 } else {
250 self.waker.register(cx.waker());
251 Poll::Pending
252 }
253 })
254 .await
255 }
256
257 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
258 let fut = sleep(self.dur);
259 self.cnt.set(true);
260 self.waker.wake();
261
262 let _ = fut.await;
263 self.cnt.set(false);
264 self.waker.wake();
265 Ok::<_, ()>(())
266 }
267 }
268
269 #[ntex::test]
273 async fn test_inflight3() {
274 let wait_time = Duration::from_millis(50);
275
276 let srv = Pipeline::new(InFlightServiceImpl::new(
277 1,
278 10,
279 Srv2 { dur: wait_time, cnt: Cell::new(false), waker: LocalWaker::new() },
280 ))
281 .bind();
282 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
283
284 let srv2 = srv.clone();
285 ntex_util::spawn(async move {
286 let _ = srv2.call(()).await;
287 });
288 ntex_util::time::sleep(Duration::from_millis(25)).await;
289 assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
290
291 let srv2 = srv.clone();
292 let (tx, rx) = ntex_util::channel::oneshot::channel();
293 ntex_util::spawn(async move {
294 let _ = poll_fn(|cx| srv2.poll_ready(cx)).await;
295 let _ = tx.send(());
296 });
297 assert_eq!(poll_fn(|cx| srv.poll_ready(cx)).await, Ok(()));
298
299 let _ = rx.await;
300 }
301}