Skip to main content

ntex_mqtt/
io.rs

1//! Framed transport dispatcher
2use std::task::{Context, Poll, ready};
3use std::{cell::Cell, cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc};
4
5use ntex_codec::{Decoder, Encoder};
6use ntex_dispatcher::{Control, DispatchItem, Reason};
7use ntex_io::{Decoded, IoBoxed, IoRef, IoStatusUpdate, RecvError};
8use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
9use ntex_util::channel::condition::Condition;
10use ntex_util::{future::Either, future::select, spawn, task::LocalWaker, time::Seconds};
11
12type Response<U> = <U as Encoder>::Item;
13type Queue<T, E> = RefCell<VecDeque<ServiceResult<Result<T, E>>>>;
14
15pin_project_lite::pin_project! {
16    /// Dispatcher for mqtt protocol
17    pub(crate) struct Dispatcher<S, U>
18    where
19        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20        S: 'static,
21        U: Encoder,
22        U: Decoder,
23        U: 'static,
24    {
25        inner: DispatcherInner<S, U>
26    }
27}
28
29bitflags::bitflags! {
30    #[derive(Copy, Clone, Eq, PartialEq, Debug)]
31    struct Flags: u8  {
32        const READY_ERR     = 0b0000_0001;
33        const IO_ERR        = 0b0000_0010;
34        const KA_ENABLED    = 0b0000_0100;
35        const KA_TIMEOUT    = 0b0000_1000;
36        const READ_TIMEOUT  = 0b0001_0000;
37        const READY         = 0b0010_0000;
38        const READY_TASK    = 0b0100_0000;
39        const RESPONSE_STOP = 0b1000_0000;
40    }
41}
42
43struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
44    io: IoBoxed,
45    flags: Flags,
46    codec: U,
47    service: PipelineBinding<S, DispatchItem<U>>,
48    st: IoDispatcherState,
49    state: Rc<DispatcherState<S, U>>,
50    read_remains: u32,
51    read_remains_prev: u32,
52    read_max_timeout: Seconds,
53    keepalive_timeout: Seconds,
54}
55
56struct DispatcherState<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
57    error: Cell<Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>>,
58    base: Cell<usize>,
59    queue: Queue<S::Response, S::Error>,
60    waker: LocalWaker,
61    stopping: Condition,
62    response: Cell<ResponseCall<S, U>>,
63    response_idx: Cell<usize>,
64}
65
66#[derive(Default)]
67enum ResponseCall<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
68    Call(PipelineCall<S, DispatchItem<U>>),
69    Canceled,
70    #[default]
71    Empty,
72}
73
74enum ServiceResult<T> {
75    Pending,
76    Ready(T),
77}
78
79impl<T> ServiceResult<T> {
80    fn take(&mut self) -> Option<T> {
81        let this = std::mem::replace(self, ServiceResult::Pending);
82        match this {
83            ServiceResult::Pending => None,
84            ServiceResult::Ready(result) => Some(result),
85        }
86    }
87}
88
89#[derive(Copy, Clone, Debug)]
90enum IoDispatcherState {
91    Processing,
92    Backpressure,
93    Stop,
94    Shutdown,
95}
96
97pub(crate) enum IoDispatcherError<S, U> {
98    Encoder(U),
99    Service(S),
100}
101
102enum PollService<U: Encoder + Decoder> {
103    Item(DispatchItem<U>),
104    ItemWait(DispatchItem<U>),
105    Continue,
106    Ready,
107}
108
109impl<S, U> From<S> for IoDispatcherError<S, U> {
110    fn from(err: S) -> Self {
111        IoDispatcherError::Service(err)
112    }
113}
114
115impl<S, U> Dispatcher<S, U>
116where
117    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
118    U: Decoder + Encoder + Clone + 'static,
119    <U as Encoder>::Item: 'static,
120{
121    /// Construct new `Dispatcher` instance with outgoing messages stream.
122    pub(crate) fn new<F: IntoService<S, DispatchItem<U>>>(
123        io: IoBoxed,
124        codec: U,
125        service: F,
126    ) -> Self {
127        let state = Rc::new(DispatcherState {
128            error: Cell::new(None),
129            base: Cell::new(0),
130            queue: RefCell::new(VecDeque::new()),
131            waker: LocalWaker::default(),
132            response: Cell::new(ResponseCall::Empty),
133            response_idx: Cell::new(0),
134            stopping: Condition::new(),
135        });
136        let keepalive_timeout = io.cfg().keepalive_timeout();
137
138        Dispatcher {
139            inner: DispatcherInner {
140                io,
141                codec,
142                state,
143                keepalive_timeout,
144                flags: if keepalive_timeout.is_zero() {
145                    Flags::KA_ENABLED
146                } else {
147                    Flags::empty()
148                },
149                service: Pipeline::new(service.into_service()).bind(),
150                st: IoDispatcherState::Processing,
151                read_remains: 0,
152                read_remains_prev: 0,
153                read_max_timeout: Seconds::ZERO,
154            },
155        }
156    }
157
158    /// Set keep-alive timeout in seconds.
159    ///
160    /// To disable timeout set value to 0.
161    ///
162    /// By default keep-alive timeout is set to 30 seconds.
163    pub(crate) fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
164        self.inner.keepalive_timeout = timeout;
165        if timeout.is_zero() {
166            self.inner.flags.remove(Flags::KA_ENABLED);
167        } else {
168            self.inner.flags.insert(Flags::KA_ENABLED);
169        }
170        self
171    }
172}
173
174impl<S, U> DispatcherState<S, U>
175where
176    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
177    U: Encoder + Decoder,
178    <U as Encoder>::Item: 'static,
179{
180    fn handle_result(
181        &self,
182        item: Result<S::Response, S::Error>,
183        response_idx: usize,
184        io: &IoRef,
185        codec: &U,
186        stop: bool,
187    ) -> bool {
188        let mut queue = self.queue.borrow_mut();
189
190        if stop {
191            self.stopping.notify();
192
193            // remove in-place message handler
194            let resp = self.response.take();
195            if matches!(resp, ResponseCall::Call(_) | ResponseCall::Canceled) {
196                self.response.set(ResponseCall::Canceled);
197            }
198        }
199
200        let idx = response_idx.wrapping_sub(self.base.get());
201
202        // handle first response
203        if idx == 0 {
204            let _ = queue.pop_front();
205            self.base.set(self.base.get().wrapping_add(1));
206            match item {
207                Err(err) => {
208                    self.error.set(Some(err.into()));
209                }
210                Ok(Some(item)) => {
211                    if let Err(err) = io.encode(item, codec) {
212                        self.error.set(Some(IoDispatcherError::Encoder(err)));
213                    }
214                }
215                Ok(None) => (),
216            }
217
218            // check remaining response
219            while let Some(item) = queue.front_mut().and_then(ServiceResult::take) {
220                let _ = queue.pop_front();
221                self.base.set(self.base.get().wrapping_add(1));
222                match item {
223                    Err(err) => {
224                        self.error.set(Some(err.into()));
225                    }
226                    Ok(Some(item)) => {
227                        if let Err(err) = io.encode(item, codec) {
228                            self.error.set(Some(IoDispatcherError::Encoder(err)));
229                        }
230                    }
231                    Ok(None) => (),
232                }
233            }
234
235            queue.is_empty()
236        } else {
237            queue[idx] = ServiceResult::Ready(item);
238            false
239        }
240    }
241}
242
243impl<S, U> Future for Dispatcher<S, U>
244where
245    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
246    U: Decoder + Encoder + Clone + 'static,
247    <U as Encoder>::Item: 'static,
248{
249    type Output = Result<(), S::Error>;
250
251    #[allow(clippy::too_many_lines)]
252    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
253        let this = self.as_mut().project();
254        let inner = this.inner;
255
256        inner.state.waker.register(cx.waker());
257
258        // handle service response future
259        match inner.state.response.take() {
260            ResponseCall::Call(mut fut) => {
261                if let Poll::Ready(item) = Pin::new(&mut fut).poll(cx) {
262                    let stop = if inner.flags.contains(Flags::RESPONSE_STOP) {
263                        inner.flags.remove(Flags::RESPONSE_STOP);
264                        true
265                    } else {
266                        false
267                    };
268                    inner.state.handle_result(
269                        item,
270                        inner.state.response_idx.get(),
271                        inner.io.as_ref(),
272                        &inner.codec,
273                        stop,
274                    );
275                } else {
276                    inner.state.response.set(ResponseCall::Call(fut));
277                }
278            }
279            ResponseCall::Canceled => {
280                inner.state.handle_result(
281                    Ok(None),
282                    inner.state.response_idx.get(),
283                    inner.io.as_ref(),
284                    &inner.codec,
285                    true,
286                );
287            }
288            ResponseCall::Empty => {}
289        }
290
291        loop {
292            match inner.st {
293                IoDispatcherState::Processing => {
294                    let (item, nowait, stop) = match ready!(inner.poll_service(cx)) {
295                        PollService::Ready => {
296                            // decode incoming bytes stream
297                            match inner.io.poll_recv_decode(&inner.codec, cx) {
298                                Ok(decoded) => {
299                                    inner.update_timer(&decoded);
300                                    if let Some(el) = decoded.item {
301                                        (DispatchItem::Item(el), true, false)
302                                    } else {
303                                        return Poll::Pending;
304                                    }
305                                }
306                                Err(RecvError::KeepAlive) => {
307                                    if let Err(err) = inner.handle_timeout() {
308                                        inner.stop();
309                                        (DispatchItem::Stop(err), true, true)
310                                    } else {
311                                        continue;
312                                    }
313                                }
314                                Err(RecvError::WriteBackpressure) => {
315                                    inner.st = IoDispatcherState::Backpressure;
316                                    (
317                                        DispatchItem::Control(Control::WBackPressureEnabled),
318                                        true,
319                                        false,
320                                    )
321                                }
322                                Err(RecvError::Decoder(err)) => {
323                                    inner.stop();
324                                    (DispatchItem::Stop(Reason::Decoder(err)), true, true)
325                                }
326                                Err(RecvError::PeerGone(err)) => {
327                                    inner.stop();
328                                    (DispatchItem::Stop(Reason::Io(err)), true, true)
329                                }
330                            }
331                        }
332                        PollService::Item(item) => (item, true, false),
333                        PollService::ItemWait(item) => (item, false, false),
334                        PollService::Continue => continue,
335                    };
336
337                    inner.call_service(cx, item, nowait, stop);
338                }
339                // handle write back-pressure
340                IoDispatcherState::Backpressure => {
341                    match ready!(inner.poll_service(cx)) {
342                        PollService::Ready => (),
343                        PollService::Item(item) => inner.call_service(cx, item, true, false),
344                        PollService::ItemWait(item) => {
345                            inner.call_service(cx, item, false, false);
346                        }
347                        PollService::Continue => continue,
348                    }
349
350                    let item = if let Err(err) = ready!(inner.io.poll_flush(cx, false)) {
351                        inner.stop();
352                        DispatchItem::Stop(Reason::Io(Some(err)))
353                    } else {
354                        inner.st = IoDispatcherState::Processing;
355                        DispatchItem::Control(Control::WBackPressureDisabled)
356                    };
357                    inner.call_service(cx, item, false, false);
358                }
359
360                // drain service responses and shutdown io
361                IoDispatcherState::Stop => {
362                    inner.io.stop_timer();
363
364                    // service may relay on poll_ready for response results
365                    if !inner.flags.contains(Flags::READY_ERR)
366                        && let Poll::Ready(res) = inner.service.poll_ready(cx)
367                        && res.is_err()
368                    {
369                        inner.flags.insert(Flags::READY_ERR);
370                    }
371
372                    if inner.state.queue.borrow().is_empty() {
373                        if inner.io.poll_shutdown(cx).is_ready() {
374                            log::trace!("{}: io shutdown completed", inner.io.tag());
375                            inner.st = IoDispatcherState::Shutdown;
376                            continue;
377                        }
378                    } else if !inner.flags.contains(Flags::IO_ERR) {
379                        match ready!(inner.io.poll_status_update(cx)) {
380                            IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
381                                inner.flags.insert(Flags::IO_ERR);
382                                continue;
383                            }
384                            IoStatusUpdate::WriteBackpressure => {
385                                if ready!(inner.io.poll_flush(cx, true)).is_err() {
386                                    inner.flags.insert(Flags::IO_ERR);
387                                }
388                                continue;
389                            }
390                        }
391                    } else {
392                        inner.io.poll_dispatch(cx);
393                    }
394                    return Poll::Pending;
395                }
396                // shutdown service
397                IoDispatcherState::Shutdown => {
398                    return if inner.service.poll_shutdown(cx).is_ready() {
399                        log::trace!("{}: Service shutdown is completed, stop", inner.io.tag());
400
401                        Poll::Ready(
402                            if let Some(IoDispatcherError::Service(err)) =
403                                inner.state.error.take()
404                            {
405                                Err(err)
406                            } else {
407                                Ok(())
408                            },
409                        )
410                    } else {
411                        Poll::Pending
412                    };
413                }
414            }
415        }
416    }
417}
418
419impl<S, U> DispatcherInner<S, U>
420where
421    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
422    U: Decoder + Encoder + Clone + 'static,
423    <U as Encoder>::Item: 'static,
424{
425    fn stop(&mut self) {
426        self.st = IoDispatcherState::Stop;
427    }
428
429    fn call_service(
430        &mut self,
431        cx: &mut Context<'_>,
432        item: DispatchItem<U>,
433        nowait: bool,
434        stop: bool,
435    ) {
436        let mut fut = if nowait {
437            self.service.call_nowait(item)
438        } else {
439            self.service.call(item)
440        };
441        let mut queue = self.state.queue.borrow_mut();
442
443        // optimize first call
444        let resp = self.state.response.take();
445        if matches!(resp, ResponseCall::Call(_) | ResponseCall::Canceled) {
446            // first call is running
447            self.state.response.set(resp);
448
449            let response_idx = self.state.base.get().wrapping_add(queue.len());
450            queue.push_back(ServiceResult::Pending);
451
452            let st = self.io.get_ref();
453            let codec = self.codec.clone();
454            let state = self.state.clone();
455
456            spawn(async move {
457                let empty_q = match select(fut, state.stopping.wait()).await {
458                    Either::Left(item) => {
459                        state.handle_result(item, response_idx, &st, &codec, stop)
460                    }
461                    Either::Right(()) => {
462                        state.handle_result(Ok(None), response_idx, &st, &codec, stop)
463                    }
464                };
465                if empty_q || stop {
466                    st.wake();
467                }
468            });
469        } else if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) {
470            // check if current result is only response
471            if queue.is_empty() {
472                match res {
473                    Err(err) => {
474                        self.state.error.set(Some(err.into()));
475                    }
476                    Ok(Some(item)) => {
477                        if let Err(err) = self.io.encode(item, &self.codec) {
478                            self.state.error.set(Some(IoDispatcherError::Encoder(err)));
479                        }
480                    }
481                    Ok(None) => (),
482                }
483            } else {
484                if stop {
485                    self.state.stopping.notify();
486                }
487                queue.push_back(ServiceResult::Ready(res));
488                self.state.response_idx.set(self.state.base.get().wrapping_add(queue.len()));
489            }
490        } else {
491            if stop {
492                self.flags.insert(Flags::RESPONSE_STOP);
493            }
494            self.state.response.set(ResponseCall::Call(fut));
495            self.state.response_idx.set(self.state.base.get().wrapping_add(queue.len()));
496            queue.push_back(ServiceResult::Pending);
497        }
498    }
499
500    fn check_error(&mut self) -> PollService<U> {
501        // check for errors
502        if let Some(err) = self.state.error.take() {
503            log::trace!("{}: Error occured, stopping dispatcher", self.io.tag());
504            self.stop();
505            match err {
506                IoDispatcherError::Encoder(err) => {
507                    PollService::Item(DispatchItem::Stop(Reason::Encoder(err)))
508                }
509                IoDispatcherError::Service(err) => {
510                    self.state.error.set(Some(IoDispatcherError::Service(err)));
511                    PollService::Continue
512                }
513            }
514        } else {
515            PollService::Ready
516        }
517    }
518
519    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
520        match self.service.poll_ready(cx) {
521            Poll::Ready(Ok(())) => Poll::Ready(self.check_error()),
522            // pause io read task
523            Poll::Pending => {
524                log::trace!("{}: Service is not ready, pause read task", self.io.tag());
525
526                // remove timers
527                self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
528                self.io.stop_timer();
529
530                match ready!(self.io.poll_read_pause(cx)) {
531                    IoStatusUpdate::KeepAlive => {
532                        log::trace!(
533                            "{}: Keep-alive error, stopping dispatcher during pause",
534                            self.io.tag()
535                        );
536                        self.stop();
537                        Poll::Ready(PollService::ItemWait(DispatchItem::Stop(
538                            Reason::KeepAliveTimeout,
539                        )))
540                    }
541                    IoStatusUpdate::PeerGone(err) => {
542                        log::trace!(
543                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
544                            self.io.tag(),
545                            err
546                        );
547                        self.stop();
548                        Poll::Ready(PollService::ItemWait(DispatchItem::Stop(Reason::Io(err))))
549                    }
550                    IoStatusUpdate::WriteBackpressure => {
551                        self.st = IoDispatcherState::Backpressure;
552                        Poll::Ready(PollService::ItemWait(DispatchItem::Control(
553                            Control::WBackPressureEnabled,
554                        )))
555                    }
556                }
557            }
558            // handle service readiness error
559            Poll::Ready(Err(err)) => {
560                log::error!("{}: Service readiness check failed, stopping", self.io.tag());
561                self.stop();
562                self.flags.insert(Flags::READY_ERR);
563                self.state.error.set(Some(IoDispatcherError::Service(err)));
564                Poll::Ready(PollService::Continue)
565            }
566        }
567    }
568
569    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
570        // got parsed frame
571        if decoded.item.is_some() {
572            self.read_remains = 0;
573            self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
574        } else if self.flags.contains(Flags::READ_TIMEOUT) {
575            // received new data but not enough for parsing complete frame
576            self.read_remains = decoded.remains as u32;
577        } else if self.read_remains == 0 && decoded.remains == 0 {
578            // no new data, start keep-alive timer
579            if self.flags.contains(Flags::KA_ENABLED) && !self.flags.contains(Flags::KA_TIMEOUT)
580            {
581                log::trace!(
582                    "{}: Start keep-alive timer {:?}",
583                    self.io.tag(),
584                    self.keepalive_timeout
585                );
586                self.flags.insert(Flags::KA_TIMEOUT);
587                self.io.start_timer(self.keepalive_timeout);
588            }
589        } else if let Some(params) = self.io.cfg().frame_read_rate() {
590            // we got new data but not enough to parse single frame
591            // start read timer
592            self.flags.insert(Flags::READ_TIMEOUT);
593
594            self.read_remains = decoded.remains as u32;
595            self.read_remains_prev = 0;
596            self.read_max_timeout = params.max_timeout;
597            self.io.start_timer(params.timeout);
598
599            log::trace!("{}: Start frame read timer {:?}", self.io.tag(), params.timeout);
600        }
601    }
602
603    fn handle_timeout(&mut self) -> Result<(), Reason<U>> {
604        // check read timer
605        if self.flags.contains(Flags::READ_TIMEOUT) {
606            if let Some(params) = self.io.cfg().frame_read_rate() {
607                let total = self.read_remains - self.read_remains_prev;
608
609                // read rate, start timer for next period
610                if total > params.rate {
611                    self.read_remains_prev = self.read_remains;
612                    self.read_remains = 0;
613
614                    if !params.max_timeout.is_zero() {
615                        self.read_max_timeout =
616                            Seconds(self.read_max_timeout.0.saturating_sub(params.timeout.0));
617                    }
618
619                    if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
620                        log::trace!(
621                            "{}: Frame read rate {:?}, extend timer",
622                            self.io.tag(),
623                            total
624                        );
625                        self.io.start_timer(params.timeout);
626                        return Ok(());
627                    }
628                }
629                log::trace!("{}: Max payload timeout has been reached", self.io.tag());
630                return Err(Reason::ReadTimeout);
631            }
632        } else if self.flags.contains(Flags::KA_TIMEOUT) {
633            log::trace!("{}: Keep-alive error, stopping dispatcher", self.io.tag());
634            return Err(Reason::KeepAliveTimeout);
635        }
636        Ok(())
637    }
638}
639
640#[cfg(test)]
641#[allow(clippy::items_after_statements)]
642mod tests {
643    use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering};
644    use std::{cell::Cell, io};
645
646    use ntex_bytes::{Bytes, BytesMut};
647    use ntex_codec::BytesCodec;
648    use ntex_io::{self as nio, IoConfig, testing::IoTest as Io};
649    use ntex_service::{ServiceCtx, cfg::SharedCfg};
650    use ntex_util::channel::condition::Condition;
651    use ntex_util::time::{Millis, sleep};
652    use rand::Rng;
653
654    use super::*;
655
656    impl<S, U> Dispatcher<S, U>
657    where
658        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
659        S::Error: 'static,
660        U: Decoder + Encoder + 'static,
661        <U as Encoder>::Item: 'static,
662    {
663        /// Construct new `Dispatcher` instance
664        pub(crate) fn new_debug<F: IntoService<S, DispatchItem<U>>>(
665            io: nio::Io,
666            codec: U,
667            service: F,
668        ) -> (Self, nio::IoRef) {
669            let keepalive_timeout = io.cfg().keepalive_timeout();
670            let rio = io.get_ref();
671
672            let state = Rc::new(DispatcherState {
673                error: Cell::new(None),
674                base: Cell::new(0),
675                waker: LocalWaker::default(),
676                queue: RefCell::new(VecDeque::new()),
677                stopping: Condition::new(),
678                response: Cell::new(ResponseCall::Empty),
679                response_idx: Cell::new(0),
680            });
681
682            (
683                Dispatcher {
684                    inner: DispatcherInner {
685                        codec,
686                        state,
687                        keepalive_timeout,
688                        service: Pipeline::new(service.into_service()).bind(),
689                        io: IoBoxed::from(io),
690                        st: IoDispatcherState::Processing,
691                        flags: if keepalive_timeout.is_zero() {
692                            Flags::empty()
693                        } else {
694                            Flags::KA_ENABLED
695                        },
696                        read_remains: 0,
697                        read_remains_prev: 0,
698                        read_max_timeout: Seconds::ZERO,
699                    },
700                },
701                rio,
702            )
703        }
704    }
705
706    #[ntex::test]
707    async fn test_basic() {
708        let (client, server) = Io::create();
709        client.remote_buffer_cap(1024);
710        client.write("GET /test HTTP/1\r\n\r\n");
711
712        let (disp, _) = Dispatcher::new_debug(
713            nio::Io::new(server, SharedCfg::new("DBG")),
714            BytesCodec,
715            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
716                sleep(Millis(50)).await;
717                if let DispatchItem::Item(msg) = msg {
718                    Ok::<_, ()>(Some(msg))
719                } else {
720                    panic!()
721                }
722            }),
723        );
724        ntex_util::spawn(async move {
725            let _ = disp.await;
726        });
727        sleep(Millis(25)).await;
728        client.write("GET /test HTTP/1\r\n\r\n");
729
730        let buf = client.read().await.unwrap();
731        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
732
733        let buf = client.read().await.unwrap();
734        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
735
736        client.close().await;
737        assert!(client.is_server_dropped());
738    }
739
740    #[ntex::test]
741    async fn test_drop_connection() {
742        let (client, server) = Io::create();
743        client.remote_buffer_cap(1024);
744        client.write("test");
745
746        #[derive(Clone)]
747        struct OnDrop(Rc<Cell<bool>>);
748        impl Drop for OnDrop {
749            fn drop(&mut self) {
750                if Rc::strong_count(&self.0) == 2 {
751                    self.0.set(true);
752                }
753            }
754        }
755        let ops = Rc::new(Cell::new(false));
756        let on_drop = OnDrop(ops.clone());
757
758        let (disp, _) = Dispatcher::new_debug(
759            nio::Io::new(server, SharedCfg::new("DBG")),
760            BytesCodec,
761            ntex_service::fn_service(async move |msg: DispatchItem<BytesCodec>| {
762                let _on_drop = on_drop.clone();
763                if let DispatchItem::Item(msg) = msg {
764                    if msg == "test" {
765                        sleep(Millis(500)).await;
766                    }
767                    Ok::<_, ()>(Some(msg))
768                } else {
769                    Ok::<_, ()>(None)
770                }
771            }),
772        );
773        ntex_util::spawn(async move {
774            let _ = disp.await;
775        });
776        sleep(Millis(25)).await;
777        client.write("pl1");
778        client.close().await;
779        assert!(client.is_server_dropped());
780        // service dropped?
781        assert!(ops.get());
782    }
783
784    #[ntex::test]
785    async fn test_ordering() {
786        let (client, server) = Io::create();
787        client.remote_buffer_cap(1024);
788        client.write("test");
789
790        let condition = Condition::new();
791        let waiter = condition.wait();
792
793        let (disp, _) = Dispatcher::new_debug(
794            nio::Io::new(server, SharedCfg::new("DBG")),
795            BytesCodec,
796            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
797                let waiter = waiter.clone();
798                async move {
799                    waiter.await;
800                    if let DispatchItem::Item(msg) = msg {
801                        Ok::<_, ()>(Some(msg))
802                    } else if matches!(msg, DispatchItem::Stop(Reason::Io(_))) {
803                        Ok(None)
804                    } else {
805                        panic!()
806                    }
807                }
808            }),
809        );
810        ntex_util::spawn(async move {
811            let _ = disp.await;
812        });
813        sleep(Millis(50)).await;
814
815        client.write("test");
816        sleep(Millis(50)).await;
817        client.write("test");
818        sleep(Millis(50)).await;
819        condition.notify();
820
821        let buf = client.read().await.unwrap();
822        assert_eq!(buf, Bytes::from_static(b"testtesttest"));
823
824        client.close().await;
825        assert!(client.is_server_dropped());
826    }
827
828    /// On disconnect, call control service and after call completion
829    /// drop in-flight publish handlers
830    #[ntex::test]
831    async fn test_disconnect_ordering() {
832        #[derive(Debug, Copy, Clone, PartialEq, Eq)]
833        enum Info {
834            Publish,
835            PublishDrop,
836            Disconnect,
837        }
838
839        struct OnDrop(Rc<RefCell<Vec<Info>>>);
840        impl Drop for OnDrop {
841            fn drop(&mut self) {
842                self.0.borrow_mut().push(Info::PublishDrop);
843            }
844        }
845
846        let condition = Condition::new();
847        let waiter = condition.wait();
848        let ops = Rc::new(RefCell::new(Vec::new()));
849        let ops2 = ops.clone();
850
851        let run_server = async || -> Io {
852            let (client, server) = Io::create();
853            client.remote_buffer_cap(1024);
854
855            let (disp, _) = Dispatcher::new_debug(
856                nio::Io::new(server, SharedCfg::new("DBG")),
857                BytesCodec,
858                ntex_service::fn_service(async move |msg: DispatchItem<BytesCodec>| {
859                    if let DispatchItem::Item(msg) = msg {
860                        if msg == b"1" {
861                            sleep(Millis(75)).await;
862                        } else {
863                            ops2.borrow_mut().push(Info::Publish);
864                            let on_drop = OnDrop(ops2.clone());
865                            waiter.clone().await;
866                            drop(on_drop);
867                        }
868                        Ok::<_, ()>(Some(msg))
869                    } else if matches!(msg, DispatchItem::Stop(Reason::Io(_))) {
870                        sleep(Millis(25)).await;
871                        ops2.borrow_mut().push(Info::Disconnect);
872                        Ok(None)
873                    } else {
874                        panic!()
875                    }
876                }),
877            );
878            ntex_util::spawn(async move {
879                let _ = disp.await;
880            });
881            sleep(Millis(50)).await;
882
883            client
884        };
885        let client = run_server.clone()().await;
886
887        client.write("test");
888        sleep(Millis(50)).await;
889        client.write("test");
890        sleep(Millis(50)).await;
891        client.close().await;
892        assert!(client.is_server_dropped());
893        sleep(Millis(150)).await;
894
895        assert_eq!(
896            &[
897                Info::Publish,
898                Info::Publish,
899                Info::Disconnect,
900                Info::PublishDrop,
901                Info::PublishDrop
902            ][..],
903            &*ops.borrow()
904        );
905
906        // different options
907        ops.borrow_mut().clear();
908        let client = run_server().await;
909
910        client.write("1");
911        sleep(Millis(50)).await;
912
913        client.write("test");
914        sleep(Millis(50)).await;
915        client.write("test");
916        sleep(Millis(50)).await;
917        client.close().await;
918        assert!(client.is_server_dropped());
919        sleep(Millis(150)).await;
920
921        assert_eq!(
922            &[
923                Info::Publish,
924                Info::Publish,
925                Info::Disconnect,
926                Info::PublishDrop,
927                Info::PublishDrop
928            ][..],
929            &*ops.borrow()
930        );
931    }
932
933    #[ntex::test]
934    async fn test_sink() {
935        let (client, server) = Io::create();
936        client.remote_buffer_cap(1024);
937        client.write("GET /test HTTP/1\r\n\r\n");
938
939        let (disp, io) = Dispatcher::new_debug(
940            nio::Io::new(server, SharedCfg::new("DBG")),
941            BytesCodec,
942            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
943                if let DispatchItem::Item(msg) = msg {
944                    Ok::<_, ()>(Some(msg))
945                } else if let DispatchItem::Stop(Reason::Io(_)) = msg {
946                    Ok(None)
947                } else {
948                    panic!()
949                }
950            }),
951        );
952        ntex_util::spawn(async move {
953            let _ = disp.await;
954        });
955
956        let buf = client.read().await.unwrap();
957        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
958
959        assert!(io.encode(Bytes::from_static(b"test"), &BytesCodec).is_ok());
960        let buf = client.read().await.unwrap();
961        assert_eq!(buf, Bytes::from_static(b"test"));
962
963        io.close();
964        sleep(Millis(150)).await;
965        assert!(client.is_server_dropped());
966    }
967
968    #[ntex::test]
969    async fn test_err_in_service() {
970        let (client, server) = Io::create();
971        client.remote_buffer_cap(0);
972        client.write("GET /test HTTP/1\r\n\r\n");
973
974        let (disp, io) = Dispatcher::new_debug(
975            nio::Io::new(server, SharedCfg::new("DBG")),
976            BytesCodec,
977            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
978                Err::<Option<Bytes>, _>(())
979            }),
980        );
981        ntex_util::spawn(async move {
982            let _ = disp.await;
983        });
984
985        io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap();
986
987        // buffer should be flushed
988        client.remote_buffer_cap(1024);
989        let buf = client.read().await.unwrap();
990        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
991
992        // write side must be closed, dispatcher waiting for read side to close
993        sleep(Millis(50)).await;
994        assert!(client.is_closed());
995
996        // close read side
997        client.close().await;
998        assert!(client.is_server_dropped());
999    }
1000
1001    #[ntex::test]
1002    async fn test_err_in_service_ready() {
1003        struct Srv(Rc<Cell<usize>>);
1004
1005        impl Service<DispatchItem<BytesCodec>> for Srv {
1006            type Response = Option<Response<BytesCodec>>;
1007            type Error = ();
1008
1009            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), ()> {
1010                self.0.set(self.0.get() + 1);
1011                Err(())
1012            }
1013
1014            async fn call(
1015                &self,
1016                _: DispatchItem<BytesCodec>,
1017                _: ServiceCtx<'_, Self>,
1018            ) -> Result<Option<Response<BytesCodec>>, ()> {
1019                Ok(None)
1020            }
1021        }
1022
1023        let (client, server) = Io::create();
1024        client.remote_buffer_cap(0);
1025        client.write("GET /test HTTP/1\r\n\r\n");
1026
1027        let counter = Rc::new(Cell::new(0));
1028
1029        let (disp, io) = Dispatcher::new_debug(
1030            nio::Io::new(server, SharedCfg::new("DBG")),
1031            BytesCodec,
1032            Srv(counter.clone()),
1033        );
1034        io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap();
1035        ntex_util::spawn(async move {
1036            let _ = disp.await;
1037        });
1038
1039        // buffer should be flushed
1040        client.remote_buffer_cap(1024);
1041        let buf = client.read().await.unwrap();
1042        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1043
1044        // write side must be closed, dispatcher waiting for read side to close
1045        sleep(Millis(50)).await;
1046        assert!(client.is_closed());
1047
1048        // close read side
1049        client.close().await;
1050        assert!(client.is_server_dropped());
1051
1052        // service must be checked for readiness only once
1053        assert_eq!(counter.get(), 1);
1054    }
1055
1056    #[ntex::test]
1057    async fn test_write_backpressure() {
1058        let (client, server) = Io::create();
1059        // do not allow to write to socket
1060        client.remote_buffer_cap(0);
1061        client.write("GET /test HTTP/1\r\n\r\n");
1062
1063        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1064        let data2 = data.clone();
1065
1066        let config = SharedCfg::new("DBG").add(
1067            IoConfig::new().set_read_buf(8 * 1024, 1024, 16).set_write_buf(32 * 1024, 1024, 16),
1068        );
1069
1070        let (disp, io) = Dispatcher::new_debug(
1071            nio::Io::new(server, config),
1072            BytesCodec,
1073            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1074                let data = data2.clone();
1075                async move {
1076                    match msg {
1077                        DispatchItem::Item(_) => {
1078                            data.lock().unwrap().borrow_mut().push(0);
1079                            let bytes = rand::thread_rng()
1080                                .sample_iter(&rand::distributions::Alphanumeric)
1081                                .take(65_536)
1082                                .map(char::from)
1083                                .collect::<String>();
1084                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
1085                        }
1086                        DispatchItem::Control(Control::WBackPressureEnabled) => {
1087                            data.lock().unwrap().borrow_mut().push(1);
1088                        }
1089                        DispatchItem::Control(Control::WBackPressureDisabled) => {
1090                            data.lock().unwrap().borrow_mut().push(2);
1091                        }
1092                        _ => (),
1093                    }
1094                    Ok(None)
1095                }
1096            }),
1097        );
1098
1099        ntex_util::spawn(async move {
1100            let _ = disp.await;
1101        });
1102
1103        let buf = client.read_any();
1104        assert_eq!(buf, Bytes::from_static(b""));
1105        client.write("GET /test HTTP/1\r\n\r\n");
1106        sleep(Millis(25)).await;
1107
1108        // buf must be consumed
1109        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
1110
1111        // response message
1112        assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 65536);
1113
1114        client.remote_buffer_cap(10240);
1115        sleep(Millis(50)).await;
1116        assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 55296);
1117
1118        client.remote_buffer_cap(45056);
1119        sleep(Millis(50)).await;
1120        assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 10240);
1121
1122        // backpressure disabled
1123        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
1124    }
1125
1126    #[ntex::test]
1127    async fn test_shutdown_dispatcher_waker() {
1128        let (client, server) = Io::create();
1129        let server = nio::Io::new(server, SharedCfg::new("DBG"));
1130        client.remote_buffer_cap(1024);
1131
1132        let flag = Rc::new(Cell::new(true));
1133        let flag2 = flag.clone();
1134        let server_ref = server.get_ref();
1135
1136        let (disp, _io) = Dispatcher::new_debug(
1137            server,
1138            BytesCodec,
1139            ntex_service::fn_service(async move |item: DispatchItem<BytesCodec>| {
1140                let first = flag2.get();
1141                flag2.set(false);
1142                if let DispatchItem::Item(b) = item {
1143                    if !first {
1144                        sleep(Millis(500)).await;
1145                    }
1146                    Ok(Some(b))
1147                } else {
1148                    server_ref.close();
1149                    Ok::<_, ()>(None)
1150                }
1151            }),
1152        );
1153        let (tx, rx) = ntex_util::channel::oneshot::channel();
1154        ntex_util::spawn(async move {
1155            let _ = disp.await;
1156            let _ = tx.send(());
1157        });
1158
1159        // send first message
1160        client.write(b"msg1");
1161        sleep(Millis(25)).await;
1162
1163        // send second message
1164        client.write(b"msg2");
1165
1166        // receive response to first message
1167        sleep(Millis(150)).await;
1168        let buf = client.read().await.unwrap();
1169        assert_eq!(buf, Bytes::from_static(b"msg1"));
1170
1171        // close read side
1172        client.close().await;
1173        let _ = rx.recv().await;
1174    }
1175
1176    /// Update keep-alive timer after receiving frame
1177    #[ntex::test]
1178    async fn test_keepalive() {
1179        let (client, server) = Io::create();
1180        client.remote_buffer_cap(1024);
1181
1182        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1183        let data2 = data.clone();
1184
1185        let (disp, _) = Dispatcher::new_debug(
1186            nio::Io::new(server, SharedCfg::new("DBG")),
1187            BytesCodec,
1188            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1189                let data = data2.clone();
1190                async move {
1191                    match msg {
1192                        DispatchItem::Item(bytes) => {
1193                            data.lock().unwrap().borrow_mut().push(0);
1194                            return Ok::<_, ()>(Some(bytes));
1195                        }
1196                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1197                            data.lock().unwrap().borrow_mut().push(1);
1198                        }
1199                        _ => (),
1200                    }
1201                    Ok(None)
1202                }
1203            }),
1204        );
1205        ntex_util::spawn(async move {
1206            let _ = disp.keepalive_timeout(Seconds(2)).await;
1207        });
1208
1209        client.write("1");
1210        let buf = client.read().await.unwrap();
1211        assert_eq!(buf, Bytes::from_static(b"1"));
1212        sleep(Millis(750)).await;
1213
1214        client.write("2");
1215        let buf = client.read().await.unwrap();
1216        assert_eq!(buf, Bytes::from_static(b"2"));
1217
1218        sleep(Millis(750)).await;
1219        client.write("3");
1220        let buf = client.read().await.unwrap();
1221        assert_eq!(buf, Bytes::from_static(b"3"));
1222
1223        sleep(Millis(750)).await;
1224        assert!(!client.is_closed());
1225        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1226    }
1227
1228    #[derive(Debug, Copy, Clone)]
1229    struct BytesLenCodec(usize);
1230
1231    impl Encoder for BytesLenCodec {
1232        type Item = Bytes;
1233        type Error = io::Error;
1234
1235        #[inline]
1236        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
1237            dst.extend_from_slice(&item[..]);
1238            Ok(())
1239        }
1240    }
1241
1242    impl Decoder for BytesLenCodec {
1243        type Item = Bytes;
1244        type Error = io::Error;
1245
1246        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
1247            if src.len() >= self.0 {
1248                Ok(Some(src.split_to(self.0)))
1249            } else {
1250                Ok(None)
1251            }
1252        }
1253    }
1254
1255    /// Do not use keep-alive timer if not configured
1256    #[ntex::test]
1257    async fn test_no_keepalive_err_after_frame_timeout() {
1258        let (client, server) = Io::create();
1259        client.remote_buffer_cap(1024);
1260
1261        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1262        let data2 = data.clone();
1263
1264        let config = SharedCfg::new("BDG").add(
1265            IoConfig::new().set_keepalive_timeout(Seconds(0)).set_frame_read_rate(
1266                Seconds(1),
1267                Seconds(2),
1268                2,
1269            ),
1270        );
1271
1272        let (disp, _) = Dispatcher::new_debug(
1273            nio::Io::new(server, config),
1274            BytesLenCodec(2),
1275            ntex_service::fn_service(move |msg: DispatchItem<BytesLenCodec>| {
1276                let data = data2.clone();
1277                async move {
1278                    match msg {
1279                        DispatchItem::Item(bytes) => {
1280                            data.lock().unwrap().borrow_mut().push(0);
1281                            return Ok::<_, ()>(Some(bytes));
1282                        }
1283                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1284                            data.lock().unwrap().borrow_mut().push(1);
1285                        }
1286                        _ => (),
1287                    }
1288                    Ok(None)
1289                }
1290            }),
1291        );
1292        ntex_util::spawn(async move {
1293            let _ = disp.await;
1294        });
1295
1296        client.write("1");
1297        sleep(Millis(250)).await;
1298        client.write("2");
1299        let buf = client.read().await.unwrap();
1300        assert_eq!(buf, Bytes::from_static(b"12"));
1301        sleep(Millis(2000)).await;
1302
1303        assert_eq!(&data.lock().unwrap().borrow()[..], &[0]);
1304    }
1305
1306    #[ntex::test]
1307    async fn test_read_timeout() {
1308        let (client, server) = Io::create();
1309        client.remote_buffer_cap(1024);
1310
1311        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1312        let data2 = data.clone();
1313
1314        let config = SharedCfg::new("DBG").add(
1315            IoConfig::new().set_keepalive_timeout(Seconds::ZERO).set_frame_read_rate(
1316                Seconds(1),
1317                Seconds(2),
1318                2,
1319            ),
1320        );
1321
1322        let (disp, state) = Dispatcher::new_debug(
1323            nio::Io::new(server, config),
1324            BytesLenCodec(8),
1325            ntex_service::fn_service(move |msg: DispatchItem<BytesLenCodec>| {
1326                let data = data2.clone();
1327                async move {
1328                    match msg {
1329                        DispatchItem::Item(bytes) => {
1330                            data.lock().unwrap().borrow_mut().push(0);
1331                            return Ok::<_, ()>(Some(bytes));
1332                        }
1333                        DispatchItem::Stop(Reason::ReadTimeout) => {
1334                            data.lock().unwrap().borrow_mut().push(1);
1335                        }
1336                        _ => (),
1337                    }
1338                    Ok(None)
1339                }
1340            }),
1341        );
1342        ntex_util::spawn(async move {
1343            let _ = disp.await;
1344        });
1345
1346        client.write("12345678");
1347        let buf = client.read().await.unwrap();
1348        assert_eq!(buf, Bytes::from_static(b"12345678"));
1349
1350        client.write("1");
1351        sleep(Millis(1000)).await;
1352        assert!(!state.flags().contains(nio::Flags::IO_STOPPING));
1353        client.write("23");
1354        sleep(Millis(1000)).await;
1355        assert!(!state.flags().contains(nio::Flags::IO_STOPPING));
1356        client.write("4");
1357        sleep(Millis(2000)).await;
1358
1359        // write side must be closed, dispatcher should fail with keep-alive
1360        assert!(state.flags().contains(nio::Flags::IO_STOPPING));
1361        assert!(client.is_closed());
1362        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1363    }
1364
1365    /// Do not use keep-alive timer if not configured
1366    #[ntex::test]
1367    async fn cancel_on_stop() {
1368        #[derive(Clone)]
1369        struct OnDrop(Arc<AtomicBool>);
1370        impl Drop for OnDrop {
1371            fn drop(&mut self) {
1372                self.0.store(true, Ordering::Relaxed);
1373            }
1374        }
1375
1376        let (client, server) = Io::create();
1377        client.remote_buffer_cap(1024);
1378
1379        let data = Arc::new(AtomicBool::new(false));
1380        let data2 = OnDrop(data.clone());
1381
1382        let config = SharedCfg::new("DBG").add(
1383            IoConfig::new().set_keepalive_timeout(Seconds(0)).set_frame_read_rate(
1384                Seconds(1),
1385                Seconds(2),
1386                2,
1387            ),
1388        );
1389
1390        let (disp, _) = Dispatcher::new_debug(
1391            nio::Io::new(server, config),
1392            BytesLenCodec(2),
1393            ntex_service::fn_service(move |msg: DispatchItem<BytesLenCodec>| {
1394                let data = data2.clone();
1395                async move {
1396                    if let DispatchItem::Item(bytes) = msg {
1397                        sleep(Millis(99_9999)).await;
1398                        drop(data);
1399                        return Ok::<_, ()>(Some(bytes));
1400                    }
1401                    Ok(None)
1402                }
1403            }),
1404        );
1405        ntex_util::spawn(async move {
1406            let _ = disp.await;
1407        });
1408
1409        client.write("1");
1410        client.close().await;
1411        sleep(Millis(250)).await;
1412
1413        assert!(&data.load(Ordering::Relaxed));
1414    }
1415}