ntex_mqtt/
inflight.rs

1//! Service that limits number of in-flight async requests.
2use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll};
3
4use ntex_service::{Middleware, Service, ServiceCtx};
5use ntex_util::{future::join, task::LocalWaker};
6
7/// Trait for types that could be sized
8pub trait SizedRequest {
9    fn size(&self) -> u32;
10
11    fn is_publish(&self) -> bool;
12
13    fn is_chunk(&self) -> bool;
14}
15
16/// Service that can limit number of in-flight async requests.
17///
18/// Default is 16 in-flight messages and 64kb size
19pub struct InFlightService {
20    max_receive: u16,
21    max_receive_size: usize,
22}
23
24impl Default for InFlightService {
25    fn default() -> Self {
26        Self { max_receive: 16, max_receive_size: 65535 }
27    }
28}
29
30impl InFlightService {
31    /// Create new `InFlightService` middleware
32    ///
33    /// By default max receive is 16 and max size is 64kb
34    pub fn new(max_receive: u16, max_receive_size: usize) -> Self {
35        Self { max_receive, max_receive_size }
36    }
37
38    /// Number of inbound in-flight concurrent messages.
39    ///
40    /// By default max receive number is set to 16 messages
41    pub fn max_receive(mut self, val: u16) -> Self {
42        self.max_receive = val;
43        self
44    }
45
46    /// Total size of inbound in-flight messages.
47    ///
48    /// By default total inbound in-flight size is set to 64Kb
49    pub fn max_receive_size(mut self, val: usize) -> Self {
50        self.max_receive_size = val;
51        self
52    }
53}
54
55impl<S> Middleware<S> for InFlightService {
56    type Service = InFlightServiceImpl<S>;
57
58    #[inline]
59    fn create(&self, service: S) -> Self::Service {
60        InFlightServiceImpl {
61            service,
62            publish: Cell::new(false),
63            count: Counter::new(self.max_receive, self.max_receive_size),
64        }
65    }
66}
67
68pub struct InFlightServiceImpl<S> {
69    count: Counter,
70    service: S,
71    publish: Cell<bool>,
72}
73
74impl<S, R> Service<R> for InFlightServiceImpl<S>
75where
76    S: Service<R>,
77    R: SizedRequest + 'static,
78{
79    type Response = S::Response;
80    type Error = S::Error;
81
82    #[inline]
83    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> {
84        if self.publish.get() || self.count.is_available() {
85            ctx.ready(&self.service).await
86        } else {
87            join(self.count.available(), ctx.ready(&self.service)).await.1
88        }
89    }
90
91    #[inline]
92    async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
93        // process payload chunks
94        if self.publish.get() && !req.is_chunk() {
95            self.publish.set(false);
96        }
97        if req.is_publish() {
98            self.publish.set(true);
99        }
100
101        let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
102        let task_guard = self.count.get(size);
103        let result = ctx.call(&self.service, req).await;
104        drop(task_guard);
105        result
106    }
107
108    ntex_service::forward_poll!(service);
109    ntex_service::forward_shutdown!(service);
110}
111
112struct Counter(Rc<CounterInner>);
113
114struct CounterInner {
115    max_cap: u16,
116    cur_cap: Cell<u16>,
117    max_size: usize,
118    cur_size: Cell<usize>,
119    task: LocalWaker,
120}
121
122impl Counter {
123    fn new(max_cap: u16, max_size: usize) -> Self {
124        Counter(Rc::new(CounterInner {
125            max_cap,
126            max_size,
127            cur_cap: Cell::new(0),
128            cur_size: Cell::new(0),
129            task: LocalWaker::new(),
130        }))
131    }
132
133    fn get(&self, size: u32) -> CounterGuard {
134        CounterGuard::new(size, self.0.clone())
135    }
136
137    fn is_available(&self) -> bool {
138        (self.0.max_cap == 0 || self.0.cur_cap.get() < self.0.max_cap)
139            && (self.0.max_size == 0 || self.0.cur_size.get() <= self.0.max_size)
140    }
141
142    async fn available(&self) {
143        poll_fn(|cx| if self.0.available(cx) { Poll::Ready(()) } else { Poll::Pending }).await
144    }
145}
146
147struct CounterGuard(u32, Rc<CounterInner>);
148
149impl CounterGuard {
150    fn new(size: u32, inner: Rc<CounterInner>) -> Self {
151        inner.inc(size);
152        CounterGuard(size, inner)
153    }
154}
155
156impl Unpin for CounterGuard {}
157
158impl Drop for CounterGuard {
159    fn drop(&mut self) {
160        self.1.dec(self.0);
161    }
162}
163
164impl CounterInner {
165    fn inc(&self, size: u32) {
166        let cur_cap = self.cur_cap.get() + 1;
167        self.cur_cap.set(cur_cap);
168        let cur_size = self.cur_size.get() + size as usize;
169        self.cur_size.set(cur_size);
170
171        if cur_cap == self.max_cap || cur_size >= self.max_size {
172            self.task.wake();
173        }
174    }
175
176    fn dec(&self, size: u32) {
177        let num = self.cur_cap.get();
178        self.cur_cap.set(num - 1);
179
180        let cur_size = self.cur_size.get();
181        let new_size = cur_size - (size as usize);
182        self.cur_size.set(new_size);
183
184        if num == self.max_cap || (cur_size > self.max_size && new_size <= self.max_size) {
185            self.task.wake();
186        }
187    }
188
189    fn available(&self, cx: &Context<'_>) -> bool {
190        self.task.register(cx.waker());
191        (self.max_cap == 0 || self.cur_cap.get() < self.max_cap)
192            && (self.max_size == 0 || self.cur_size.get() <= self.max_size)
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use std::{future::poll_fn, time::Duration};
199
200    use ntex_service::Pipeline;
201    use ntex_util::{future::lazy, task::LocalWaker, time::sleep};
202
203    use super::*;
204
205    struct SleepService(Duration);
206
207    impl Service<()> for SleepService {
208        type Response = ();
209        type Error = ();
210
211        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
212            let fut = sleep(self.0);
213            let _ = fut.await;
214            Ok::<_, ()>(())
215        }
216    }
217
218    impl SizedRequest for () {
219        fn size(&self) -> u32 {
220            12
221        }
222
223        fn is_publish(&self) -> bool {
224            false
225        }
226
227        fn is_chunk(&self) -> bool {
228            false
229        }
230    }
231
232    #[ntex_macros::rt_test]
233    async fn test_inflight() {
234        let wait_time = Duration::from_millis(50);
235
236        let srv =
237            Pipeline::new(InFlightService::new(1, 0).create(SleepService(wait_time))).bind();
238        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
239
240        let srv2 = srv.clone();
241        ntex_util::spawn(async move {
242            let _ = srv2.call(()).await;
243        });
244        ntex_util::time::sleep(Duration::from_millis(25)).await;
245        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
246
247        ntex_util::time::sleep(Duration::from_millis(50)).await;
248        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
249        assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready());
250    }
251
252    #[ntex_macros::rt_test]
253    async fn test_inflight2() {
254        let wait_time = Duration::from_millis(50);
255
256        let srv =
257            Pipeline::new(InFlightService::new(0, 10).create(SleepService(wait_time))).bind();
258        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
259
260        let srv2 = srv.clone();
261        ntex_util::spawn(async move {
262            let _ = srv2.call(()).await;
263        });
264        ntex_util::time::sleep(Duration::from_millis(25)).await;
265        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
266
267        ntex_util::time::sleep(Duration::from_millis(100)).await;
268        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
269    }
270
271    struct Srv2 {
272        dur: Duration,
273        cnt: Cell<bool>,
274        waker: LocalWaker,
275    }
276
277    impl Service<()> for Srv2 {
278        type Response = ();
279        type Error = ();
280
281        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), ()> {
282            poll_fn(|cx| {
283                if !self.cnt.get() {
284                    Poll::Ready(Ok(()))
285                } else {
286                    self.waker.register(cx.waker());
287                    Poll::Pending
288                }
289            })
290            .await
291        }
292
293        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
294            let fut = sleep(self.dur);
295            self.cnt.set(true);
296            self.waker.wake();
297
298            let _ = fut.await;
299            self.cnt.set(false);
300            self.waker.wake();
301            Ok::<_, ()>(())
302        }
303    }
304
305    /// InflightService::poll_ready() must always register waker,
306    /// otherwise it can lose wake up if inner service's poll_ready
307    /// does not wakes dispatcher.
308    #[ntex_macros::rt_test]
309    async fn test_inflight3() {
310        let wait_time = Duration::from_millis(50);
311
312        let srv = Pipeline::new(InFlightService::new(1, 10).create(Srv2 {
313            dur: wait_time,
314            cnt: Cell::new(false),
315            waker: LocalWaker::new(),
316        }))
317        .bind();
318        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
319
320        let srv2 = srv.clone();
321        ntex_util::spawn(async move {
322            let _ = srv2.call(()).await;
323        });
324        ntex_util::time::sleep(Duration::from_millis(25)).await;
325        assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending);
326
327        let srv2 = srv.clone();
328        let (tx, rx) = ntex_util::channel::oneshot::channel();
329        ntex_util::spawn(async move {
330            let _ = poll_fn(|cx| srv2.poll_ready(cx)).await;
331            let _ = tx.send(());
332        });
333        assert_eq!(poll_fn(|cx| srv.poll_ready(cx)).await, Ok(()));
334
335        let _ = rx.await;
336    }
337}