ntex_mqtt/
inflight.rs

1//! Service that limits number of in-flight async requests.
2use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll};
3
4use ntex_service::{Service, ServiceCtx};
5use ntex_util::{future::join, task::LocalWaker};
6
7/// Trait for types that could be sized
8pub trait SizedRequest {
9    fn size(&self) -> u32;
10
11    fn is_publish(&self) -> bool;
12
13    fn is_chunk(&self) -> bool;
14}
15
16pub struct InFlightServiceImpl<S> {
17    count: Counter,
18    service: S,
19    publish: Cell<bool>,
20}
21
22impl<S> InFlightServiceImpl<S> {
23    pub fn new(max_cap: u16, max_size: usize, service: S) -> Self {
24        InFlightServiceImpl {
25            service,
26            publish: Cell::new(false),
27            count: Counter::new(max_cap, max_size),
28        }
29    }
30}
31
32impl<S, R> Service<R> for InFlightServiceImpl<S>
33where
34    S: Service<R>,
35    R: SizedRequest + 'static,
36{
37    type Response = S::Response;
38    type Error = S::Error;
39
40    #[inline]
41    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> {
42        if self.publish.get() || self.count.is_available() {
43            ctx.ready(&self.service).await
44        } else {
45            join(self.count.available(), ctx.ready(&self.service)).await.1
46        }
47    }
48
49    #[inline]
50    async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
51        // process payload chunks
52        if self.publish.get() && !req.is_chunk() {
53            self.publish.set(false);
54        }
55        if req.is_publish() {
56            self.publish.set(true);
57        }
58
59        let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
60        let task_guard = self.count.get(size);
61        let result = ctx.call(&self.service, req).await;
62        drop(task_guard);
63        result
64    }
65
66    ntex_service::forward_poll!(service);
67    ntex_service::forward_shutdown!(service);
68}
69
70struct Counter(Rc<CounterInner>);
71
72struct CounterInner {
73    max_cap: u16,
74    cur_cap: Cell<u16>,
75    max_size: usize,
76    cur_size: Cell<usize>,
77    task: LocalWaker,
78}
79
80impl Counter {
81    fn new(max_cap: u16, max_size: usize) -> Self {
82        Counter(Rc::new(CounterInner {
83            max_cap,
84            max_size,
85            cur_cap: Cell::new(0),
86            cur_size: Cell::new(0),
87            task: LocalWaker::new(),
88        }))
89    }
90
91    fn get(&self, size: u32) -> CounterGuard {
92        CounterGuard::new(size, self.0.clone())
93    }
94
95    fn is_available(&self) -> bool {
96        (self.0.max_cap == 0 || self.0.cur_cap.get() < self.0.max_cap)
97            && (self.0.max_size == 0 || self.0.cur_size.get() <= self.0.max_size)
98    }
99
100    async fn available(&self) {
101        poll_fn(|cx| {
102            if self.0.available(cx) {
103                Poll::Ready(())
104            } else {
105                Poll::Pending
106            }
107        })
108        .await
109    }
110}
111
112struct CounterGuard(u32, Rc<CounterInner>);
113
114impl CounterGuard {
115    fn new(size: u32, inner: Rc<CounterInner>) -> Self {
116        inner.inc(size);
117        CounterGuard(size, inner)
118    }
119}
120
121impl Unpin for CounterGuard {}
122
123impl Drop for CounterGuard {
124    fn drop(&mut self) {
125        self.1.dec(self.0);
126    }
127}
128
129impl CounterInner {
130    fn inc(&self, size: u32) {
131        let cur_cap = self.cur_cap.get() + 1;
132        self.cur_cap.set(cur_cap);
133        let cur_size = self.cur_size.get() + size as usize;
134        self.cur_size.set(cur_size);
135
136        if cur_cap == self.max_cap || cur_size >= self.max_size {
137            self.task.wake();
138        }
139    }
140
141    fn dec(&self, size: u32) {
142        let num = self.cur_cap.get();
143        self.cur_cap.set(num - 1);
144
145        let cur_size = self.cur_size.get();
146        let new_size = cur_size - (size as usize);
147        self.cur_size.set(new_size);
148
149        if num == self.max_cap || (cur_size > self.max_size && new_size <= self.max_size) {
150            self.task.wake();
151        }
152    }
153
154    fn available(&self, cx: &Context<'_>) -> bool {
155        self.task.register(cx.waker());
156        (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
157            && (self.max_size == 0 || self.cur_size.get() <= self.max_size)
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::{future::poll_fn, time::Duration};
164
165    use ntex_service::Pipeline;
166    use ntex_util::{future::lazy, task::LocalWaker, time::sleep};
167
168    use super::*;
169
170    struct SleepService(Duration);
171
172    impl Service<()> for SleepService {
173        type Response = ();
174        type Error = ();
175
176        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
177            let fut = sleep(self.0);
178            let _ = fut.await;
179            Ok::<_, ()>(())
180        }
181    }
182
183    impl SizedRequest for () {
184        fn size(&self) -> u32 {
185            12
186        }
187
188        fn is_publish(&self) -> bool {
189            false
190        }
191
192        fn is_chunk(&self) -> bool {
193            false
194        }
195    }
196
197    #[ntex::test]
198    async fn test_inflight() {
199        let wait_time = Duration::from_millis(50);
200
201        let srv = Pipeline::new(InFlightServiceImpl::new(1, 0, SleepService(wait_time))).bind();
202        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
203
204        let srv2 = srv.clone();
205        ntex_util::spawn(async move {
206            let _ = srv2.call(()).await;
207        });
208        ntex_util::time::sleep(Duration::from_millis(25)).await;
209        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
210
211        ntex_util::time::sleep(Duration::from_millis(50)).await;
212        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
213        assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
214    }
215
216    #[ntex::test]
217    async fn test_inflight2() {
218        let wait_time = Duration::from_millis(50);
219
220        let srv =
221            Pipeline::new(InFlightServiceImpl::new(0, 10, SleepService(wait_time))).bind();
222        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
223
224        let srv2 = srv.clone();
225        ntex_util::spawn(async move {
226            let _ = srv2.call(()).await;
227        });
228        ntex_util::time::sleep(Duration::from_millis(25)).await;
229        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
230
231        ntex_util::time::sleep(Duration::from_millis(100)).await;
232        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
233    }
234
235    struct Srv2 {
236        dur: Duration,
237        cnt: Cell<bool>,
238        waker: LocalWaker,
239    }
240
241    impl Service<()> for Srv2 {
242        type Response = ();
243        type Error = ();
244
245        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), ()> {
246            poll_fn(|cx| {
247                if !self.cnt.get() {
248                    Poll::Ready(Ok(()))
249                } else {
250                    self.waker.register(cx.waker());
251                    Poll::Pending
252                }
253            })
254            .await
255        }
256
257        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
258            let fut = sleep(self.dur);
259            self.cnt.set(true);
260            self.waker.wake();
261
262            let _ = fut.await;
263            self.cnt.set(false);
264            self.waker.wake();
265            Ok::<_, ()>(())
266        }
267    }
268
269    /// InflightService::poll_ready() must always register waker,
270    /// otherwise it can lose wake up if inner service's poll_ready
271    /// does not wakes dispatcher.
272    #[ntex::test]
273    async fn test_inflight3() {
274        let wait_time = Duration::from_millis(50);
275
276        let srv = Pipeline::new(InFlightServiceImpl::new(
277            1,
278            10,
279            Srv2 { dur: wait_time, cnt: Cell::new(false), waker: LocalWaker::new() },
280        ))
281        .bind();
282        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
283
284        let srv2 = srv.clone();
285        ntex_util::spawn(async move {
286            let _ = srv2.call(()).await;
287        });
288        ntex_util::time::sleep(Duration::from_millis(25)).await;
289        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
290
291        let srv2 = srv.clone();
292        let (tx, rx) = ntex_util::channel::oneshot::channel();
293        ntex_util::spawn(async move {
294            let _ = poll_fn(|cx| srv2.poll_ready(cx)).await;
295            let _ = tx.send(());
296        });
297        assert_eq!(poll_fn(|cx| srv.poll_ready(cx)).await, Ok(()));
298
299        let _ = rx.await;
300    }
301}