ntex_io/
dispatcher.rs

1//! Framed transport dispatcher
2#![allow(clippy::let_underscore_future)]
3use std::task::{ready, Context, Poll};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14#[derive(Clone, Debug)]
15/// Shared dispatcher configuration
16pub struct DispatcherConfig(Rc<DispatcherConfigInner>);
17
18#[derive(Debug)]
19struct DispatcherConfigInner {
20    keepalive_timeout: Cell<Seconds>,
21    disconnect_timeout: Cell<Seconds>,
22    frame_read_enabled: Cell<bool>,
23    frame_read_rate: Cell<u16>,
24    frame_read_timeout: Cell<Seconds>,
25    frame_read_max_timeout: Cell<Seconds>,
26}
27
28impl Default for DispatcherConfig {
29    fn default() -> Self {
30        DispatcherConfig(Rc::new(DispatcherConfigInner {
31            keepalive_timeout: Cell::new(Seconds(30)),
32            disconnect_timeout: Cell::new(Seconds(1)),
33            frame_read_rate: Cell::new(0),
34            frame_read_enabled: Cell::new(false),
35            frame_read_timeout: Cell::new(Seconds::ZERO),
36            frame_read_max_timeout: Cell::new(Seconds::ZERO),
37        }))
38    }
39}
40
41impl DispatcherConfig {
42    #[inline]
43    /// Get keep-alive timeout
44    pub fn keepalive_timeout(&self) -> Seconds {
45        self.0.keepalive_timeout.get()
46    }
47
48    #[inline]
49    /// Get disconnect timeout
50    pub fn disconnect_timeout(&self) -> Seconds {
51        self.0.disconnect_timeout.get()
52    }
53
54    #[inline]
55    /// Get frame read rate
56    pub fn frame_read_rate(&self) -> Option<(Seconds, Seconds, u16)> {
57        if self.0.frame_read_enabled.get() {
58            Some((
59                self.0.frame_read_timeout.get(),
60                self.0.frame_read_max_timeout.get(),
61                self.0.frame_read_rate.get(),
62            ))
63        } else {
64            None
65        }
66    }
67
68    /// Set keep-alive timeout in seconds.
69    ///
70    /// To disable timeout set value to 0.
71    ///
72    /// By default keep-alive timeout is set to 30 seconds.
73    pub fn set_keepalive_timeout(&self, timeout: Seconds) -> &Self {
74        self.0.keepalive_timeout.set(timeout);
75        self
76    }
77
78    /// Set connection disconnect timeout.
79    ///
80    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
81    /// within this time, the connection get dropped.
82    ///
83    /// To disable timeout set value to 0.
84    ///
85    /// By default disconnect timeout is set to 1 seconds.
86    pub fn set_disconnect_timeout(&self, timeout: Seconds) -> &Self {
87        self.0.disconnect_timeout.set(timeout);
88        self
89    }
90
91    /// Set read rate parameters for single frame.
92    ///
93    /// Set read timeout, max timeout and rate for reading payload. If the client
94    /// sends `rate` amount of data within `timeout` period of time, extend timeout by `timeout` seconds.
95    /// But no more than `max_timeout` timeout.
96    ///
97    /// By default frame read rate is disabled.
98    pub fn set_frame_read_rate(
99        &self,
100        timeout: Seconds,
101        max_timeout: Seconds,
102        rate: u16,
103    ) -> &Self {
104        self.0.frame_read_enabled.set(!timeout.is_zero());
105        self.0.frame_read_timeout.set(timeout);
106        self.0.frame_read_max_timeout.set(max_timeout);
107        self.0.frame_read_rate.set(rate);
108        self
109    }
110}
111
112pin_project_lite::pin_project! {
113    /// Dispatcher - is a future that reads frames from bytes stream
114    /// and pass then to the service.
115    pub struct Dispatcher<S, U>
116    where
117        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
118        U: Encoder,
119        U: Decoder,
120        U: 'static,
121    {
122        inner: DispatcherInner<S, U>,
123    }
124}
125
126bitflags::bitflags! {
127    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
128    struct Flags: u8  {
129        const READY_ERR     = 0b0000001;
130        const IO_ERR        = 0b0000010;
131        const KA_ENABLED    = 0b0000100;
132        const KA_TIMEOUT    = 0b0001000;
133        const READ_TIMEOUT  = 0b0010000;
134        const IDLE          = 0b0100000;
135    }
136}
137
138struct DispatcherInner<S, U>
139where
140    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
141    U: Encoder + Decoder + 'static,
142{
143    st: DispatcherState,
144    error: Option<S::Error>,
145    shared: Rc<DispatcherShared<S, U>>,
146    response: Option<PipelineCall<S, DispatchItem<U>>>,
147    cfg: DispatcherConfig,
148    read_remains: u32,
149    read_remains_prev: u32,
150    read_max_timeout: Seconds,
151}
152
153pub(crate) struct DispatcherShared<S, U>
154where
155    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
156    U: Encoder + Decoder,
157{
158    io: IoBoxed,
159    codec: U,
160    service: PipelineBinding<S, DispatchItem<U>>,
161    flags: Cell<Flags>,
162    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
163    inflight: Cell<u32>,
164}
165
166#[derive(Copy, Clone, Debug)]
167enum DispatcherState {
168    Processing,
169    Backpressure,
170    Stop,
171    Shutdown,
172}
173
174#[derive(Debug)]
175enum DispatcherError<S, U> {
176    Encoder(U),
177    Service(S),
178}
179
180enum PollService<U: Encoder + Decoder> {
181    Item(DispatchItem<U>),
182    Continue,
183    Ready,
184}
185
186impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
187    fn from(err: Either<S, U>) -> Self {
188        match err {
189            Either::Left(err) => DispatcherError::Service(err),
190            Either::Right(err) => DispatcherError::Encoder(err),
191        }
192    }
193}
194
195impl<S, U> Dispatcher<S, U>
196where
197    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
198    U: Decoder + Encoder + 'static,
199{
200    /// Construct new `Dispatcher` instance.
201    pub fn new<Io, F>(
202        io: Io,
203        codec: U,
204        service: F,
205        cfg: &DispatcherConfig,
206    ) -> Dispatcher<S, U>
207    where
208        IoBoxed: From<Io>,
209        F: IntoService<S, DispatchItem<U>>,
210    {
211        let io = IoBoxed::from(io);
212        io.set_disconnect_timeout(cfg.disconnect_timeout());
213
214        let flags = if cfg.keepalive_timeout().is_zero() {
215            Flags::empty()
216        } else {
217            Flags::KA_ENABLED
218        };
219
220        let shared = Rc::new(DispatcherShared {
221            io,
222            codec,
223            flags: Cell::new(flags),
224            error: Cell::new(None),
225            inflight: Cell::new(0),
226            service: Pipeline::new(service.into_service()).bind(),
227        });
228
229        Dispatcher {
230            inner: DispatcherInner {
231                shared,
232                cfg: cfg.clone(),
233                response: None,
234                error: None,
235                read_remains: 0,
236                read_remains_prev: 0,
237                read_max_timeout: Seconds::ZERO,
238                st: DispatcherState::Processing,
239            },
240        }
241    }
242}
243
244impl<S, U> DispatcherShared<S, U>
245where
246    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
247    U: Encoder + Decoder,
248{
249    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
250        match item {
251            Ok(Some(val)) => {
252                if let Err(err) = io.encode(val, &self.codec) {
253                    self.error.set(Some(DispatcherError::Encoder(err)))
254                }
255            }
256            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
257            Ok(None) => (),
258        }
259        let inflight = self.inflight.get() - 1;
260        self.inflight.set(inflight);
261        if inflight == 0 {
262            self.insert_flags(Flags::IDLE);
263        }
264        if wake {
265            io.wake();
266        }
267    }
268
269    fn contains(&self, f: Flags) -> bool {
270        self.flags.get().intersects(f)
271    }
272
273    fn insert_flags(&self, f: Flags) {
274        let mut flags = self.flags.get();
275        flags.insert(f);
276        self.flags.set(flags);
277    }
278
279    fn remove_flags(&self, f: Flags) -> bool {
280        let mut flags = self.flags.get();
281        if flags.intersects(f) {
282            flags.remove(f);
283            self.flags.set(flags);
284            true
285        } else {
286            false
287        }
288    }
289}
290
291impl<S, U> Future for Dispatcher<S, U>
292where
293    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
294    U: Decoder + Encoder + 'static,
295{
296    type Output = Result<(), S::Error>;
297
298    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299        let mut this = self.as_mut().project();
300        let slf = &mut this.inner;
301
302        // handle service response future
303        if let Some(fut) = slf.response.as_mut() {
304            if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
305                slf.shared.handle_result(item, &slf.shared.io, false);
306                slf.response = None;
307            }
308        }
309
310        loop {
311            match slf.st {
312                DispatcherState::Processing => {
313                    let item = match ready!(slf.poll_service(cx)) {
314                        PollService::Ready => {
315                            // decode incoming bytes if buffer is ready
316                            match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
317                                Ok(decoded) => {
318                                    slf.update_timer(&decoded);
319                                    if let Some(el) = decoded.item {
320                                        DispatchItem::Item(el)
321                                    } else {
322                                        return Poll::Pending;
323                                    }
324                                }
325                                Err(RecvError::KeepAlive) => {
326                                    if let Err(err) = slf.handle_timeout() {
327                                        slf.st = DispatcherState::Stop;
328                                        err
329                                    } else {
330                                        continue;
331                                    }
332                                }
333                                Err(RecvError::Stop) => {
334                                    log::trace!(
335                                        "{}: Dispatcher is instructed to stop",
336                                        slf.shared.io.tag()
337                                    );
338                                    slf.st = DispatcherState::Stop;
339                                    continue;
340                                }
341                                Err(RecvError::WriteBackpressure) => {
342                                    // instruct write task to notify dispatcher when data is flushed
343                                    slf.st = DispatcherState::Backpressure;
344                                    DispatchItem::WBackPressureEnabled
345                                }
346                                Err(RecvError::Decoder(err)) => {
347                                    log::trace!(
348                                        "{}: Decoder error, stopping dispatcher: {:?}",
349                                        slf.shared.io.tag(),
350                                        err
351                                    );
352                                    slf.st = DispatcherState::Stop;
353                                    DispatchItem::DecoderError(err)
354                                }
355                                Err(RecvError::PeerGone(err)) => {
356                                    log::trace!(
357                                        "{}: Peer is gone, stopping dispatcher: {:?}",
358                                        slf.shared.io.tag(),
359                                        err
360                                    );
361                                    slf.st = DispatcherState::Stop;
362                                    DispatchItem::Disconnect(err)
363                                }
364                            }
365                        }
366                        PollService::Item(item) => item,
367                        PollService::Continue => continue,
368                    };
369
370                    slf.call_service(cx, item);
371                }
372                // handle write back-pressure
373                DispatcherState::Backpressure => {
374                    match ready!(slf.poll_service(cx)) {
375                        PollService::Ready => (),
376                        PollService::Item(item) => slf.call_service(cx, item),
377                        PollService::Continue => continue,
378                    };
379
380                    let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
381                    {
382                        slf.st = DispatcherState::Stop;
383                        DispatchItem::Disconnect(Some(err))
384                    } else {
385                        slf.st = DispatcherState::Processing;
386                        DispatchItem::WBackPressureDisabled
387                    };
388                    slf.call_service(cx, item);
389                }
390                // drain service responses and shutdown io
391                DispatcherState::Stop => {
392                    slf.shared.io.stop_timer();
393
394                    // service may relay on poll_ready for response results
395                    if !slf.shared.contains(Flags::READY_ERR) {
396                        if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
397                            if res.is_err() {
398                                slf.shared.insert_flags(Flags::READY_ERR);
399                            }
400                        }
401                    }
402
403                    if slf.shared.inflight.get() == 0 {
404                        if slf.shared.io.poll_shutdown(cx).is_ready() {
405                            slf.st = DispatcherState::Shutdown;
406                            continue;
407                        }
408                    } else if !slf.shared.contains(Flags::IO_ERR) {
409                        match ready!(slf.shared.io.poll_status_update(cx)) {
410                            IoStatusUpdate::PeerGone(_)
411                            | IoStatusUpdate::Stop
412                            | IoStatusUpdate::KeepAlive => {
413                                slf.shared.insert_flags(Flags::IO_ERR);
414                                continue;
415                            }
416                            IoStatusUpdate::WriteBackpressure => {
417                                if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
418                                    slf.shared.insert_flags(Flags::IO_ERR);
419                                }
420                                continue;
421                            }
422                        }
423                    } else {
424                        slf.shared.io.poll_dispatch(cx);
425                    }
426                    return Poll::Pending;
427                }
428                // shutdown service
429                DispatcherState::Shutdown => {
430                    return if slf.shared.service.poll_shutdown(cx).is_ready() {
431                        log::trace!(
432                            "{}: Service shutdown is completed, stop",
433                            slf.shared.io.tag()
434                        );
435
436                        Poll::Ready(if let Some(err) = slf.error.take() {
437                            Err(err)
438                        } else {
439                            Ok(())
440                        })
441                    } else {
442                        Poll::Pending
443                    };
444                }
445            }
446        }
447    }
448}
449
450impl<S, U> DispatcherInner<S, U>
451where
452    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
453    U: Decoder + Encoder + 'static,
454{
455    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
456        let mut fut = self.shared.service.call_nowait(item);
457        let inflight = self.shared.inflight.get() + 1;
458        self.shared.inflight.set(inflight);
459        if inflight == 1 {
460            self.shared.remove_flags(Flags::IDLE);
461        }
462
463        // optimize first call
464        if self.response.is_none() {
465            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
466                self.shared.handle_result(result, &self.shared.io, false);
467            } else {
468                self.response = Some(fut);
469            }
470        } else {
471            let shared = self.shared.clone();
472            spawn(async move {
473                let result = fut.await;
474                shared.handle_result(result, &shared.io, true);
475            });
476        }
477    }
478
479    fn check_error(&mut self) -> PollService<U> {
480        // check for errors
481        if let Some(err) = self.shared.error.take() {
482            log::trace!(
483                "{}: Error occured, stopping dispatcher",
484                self.shared.io.tag()
485            );
486            self.st = DispatcherState::Stop;
487
488            match err {
489                DispatcherError::Encoder(err) => {
490                    PollService::Item(DispatchItem::EncoderError(err))
491                }
492                DispatcherError::Service(err) => {
493                    self.error = Some(err);
494                    PollService::Continue
495                }
496            }
497        } else {
498            PollService::Ready
499        }
500    }
501
502    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
503        // wait until service becomes ready
504        match self.shared.service.poll_ready(cx) {
505            Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
506            // pause io read task
507            Poll::Pending => {
508                log::trace!(
509                    "{}: Service is not ready, register dispatcher",
510                    self.shared.io.tag()
511                );
512
513                // remove all timers
514                self.shared
515                    .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
516                self.shared.io.stop_timer();
517
518                match ready!(self.shared.io.poll_read_pause(cx)) {
519                    IoStatusUpdate::KeepAlive => {
520                        log::trace!(
521                            "{}: Keep-alive error, stopping dispatcher during pause",
522                            self.shared.io.tag()
523                        );
524                        self.st = DispatcherState::Stop;
525                        Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
526                    }
527                    IoStatusUpdate::Stop => {
528                        log::trace!(
529                            "{}: Dispatcher is instructed to stop during pause",
530                            self.shared.io.tag()
531                        );
532                        self.st = DispatcherState::Stop;
533                        Poll::Ready(PollService::Continue)
534                    }
535                    IoStatusUpdate::PeerGone(err) => {
536                        log::trace!(
537                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
538                            self.shared.io.tag(),
539                            err
540                        );
541                        self.st = DispatcherState::Stop;
542                        Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
543                    }
544                    IoStatusUpdate::WriteBackpressure => {
545                        self.st = DispatcherState::Backpressure;
546                        Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
547                    }
548                }
549            }
550            // handle service readiness error
551            Poll::Ready(Err(err)) => {
552                log::trace!(
553                    "{}: Service readiness check failed, stopping",
554                    self.shared.io.tag()
555                );
556                self.st = DispatcherState::Stop;
557                self.error = Some(err);
558                self.shared.insert_flags(Flags::READY_ERR);
559                Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
560            }
561        }
562    }
563
564    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
565        // got parsed frame
566        if decoded.item.is_some() {
567            self.read_remains = 0;
568            self.shared
569                .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
570        } else if self.shared.contains(Flags::READ_TIMEOUT) {
571            // received new data but not enough for parsing complete frame
572            self.read_remains = decoded.remains as u32;
573        } else if self.read_remains == 0 && decoded.remains == 0 {
574            // no new data, start keep-alive timer
575            if self.shared.contains(Flags::KA_ENABLED)
576                && !self.shared.contains(Flags::KA_TIMEOUT)
577            {
578                log::trace!(
579                    "{}: Start keep-alive timer {:?}",
580                    self.shared.io.tag(),
581                    self.cfg.keepalive_timeout()
582                );
583                self.shared.insert_flags(Flags::KA_TIMEOUT);
584                self.shared.io.start_timer(self.cfg.keepalive_timeout());
585            }
586        } else if let Some((timeout, max, _)) = self.cfg.frame_read_rate() {
587            // we got new data but not enough to parse single frame
588            // start read timer
589            self.shared.insert_flags(Flags::READ_TIMEOUT);
590
591            self.read_remains = decoded.remains as u32;
592            self.read_remains_prev = 0;
593            self.read_max_timeout = max;
594            self.shared.io.start_timer(timeout);
595        }
596    }
597
598    fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
599        // check read timer
600        if self.shared.contains(Flags::READ_TIMEOUT) {
601            if let Some((timeout, max, rate)) = self.cfg.frame_read_rate() {
602                let total = (self.read_remains - self.read_remains_prev)
603                    .try_into()
604                    .unwrap_or(u16::MAX);
605
606                // read rate, start timer for next period
607                if total > rate {
608                    self.read_remains_prev = self.read_remains;
609                    self.read_remains = 0;
610
611                    if !max.is_zero() {
612                        self.read_max_timeout =
613                            Seconds(self.read_max_timeout.0.saturating_sub(timeout.0));
614                    }
615
616                    if max.is_zero() || !self.read_max_timeout.is_zero() {
617                        log::trace!(
618                            "{}: Frame read rate {:?}, extend timer",
619                            self.shared.io.tag(),
620                            total
621                        );
622                        self.shared.io.start_timer(timeout);
623                        return Ok(());
624                    }
625                    log::trace!(
626                        "{}: Max payload timeout has been reached",
627                        self.shared.io.tag()
628                    );
629                }
630                Err(DispatchItem::ReadTimeout)
631            } else {
632                Ok(())
633            }
634        } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
635            log::trace!(
636                "{}: Keep-alive error, stopping dispatcher",
637                self.shared.io.tag()
638            );
639            Err(DispatchItem::KeepAliveTimeout)
640        } else {
641            Ok(())
642        }
643    }
644}
645
646#[cfg(test)]
647mod tests {
648    use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
649    use std::{cell::RefCell, io};
650
651    use ntex_bytes::{Bytes, BytesMut, PoolId, PoolRef};
652    use ntex_codec::BytesCodec;
653    use ntex_service::ServiceCtx;
654    use ntex_util::{time::sleep, time::Millis};
655    use rand::Rng;
656
657    use super::*;
658    use crate::{testing::IoTest, Flags, Io, IoRef};
659
660    pub(crate) struct State(IoRef);
661
662    impl State {
663        fn flags(&self) -> Flags {
664            self.0.flags()
665        }
666
667        fn io(&self) -> &IoRef {
668            &self.0
669        }
670
671        fn close(&self) {
672            self.0 .0.insert_flags(Flags::DSP_STOP);
673            self.0 .0.dispatch_task.wake();
674        }
675
676        fn set_memory_pool(&self, pool: PoolRef) {
677            self.0 .0.pool.set(pool);
678        }
679    }
680
681    #[derive(Copy, Clone)]
682    struct BCodec(usize);
683
684    impl Encoder for BCodec {
685        type Item = Bytes;
686        type Error = io::Error;
687
688        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
689            dst.extend_from_slice(&item[..]);
690            Ok(())
691        }
692    }
693
694    impl Decoder for BCodec {
695        type Item = BytesMut;
696        type Error = io::Error;
697
698        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
699            if src.len() < self.0 {
700                Ok(None)
701            } else {
702                Ok(Some(src.split_to(self.0)))
703            }
704        }
705    }
706
707    impl<S, U> Dispatcher<S, U>
708    where
709        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
710        U: Decoder + Encoder + 'static,
711    {
712        /// Construct new `Dispatcher` instance
713        pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
714            let cfg = DispatcherConfig::default()
715                .set_keepalive_timeout(Seconds(1))
716                .clone();
717            Self::debug_cfg(io, codec, service, cfg)
718        }
719
720        /// Construct new `Dispatcher` instance
721        pub(crate) fn debug_cfg(
722            state: Io,
723            codec: U,
724            service: S,
725            cfg: DispatcherConfig,
726        ) -> (Self, State) {
727            state.set_disconnect_timeout(cfg.disconnect_timeout());
728            state.set_tag("DBG");
729
730            let flags = if cfg.keepalive_timeout().is_zero() {
731                super::Flags::empty()
732            } else {
733                super::Flags::KA_ENABLED
734            };
735
736            let inner = State(state.get_ref());
737            state.start_timer(Seconds::ONE);
738
739            let shared = Rc::new(DispatcherShared {
740                codec,
741                io: state.into(),
742                flags: Cell::new(flags),
743                error: Cell::new(None),
744                inflight: Cell::new(0),
745                service: Pipeline::new(service).bind(),
746            });
747
748            (
749                Dispatcher {
750                    inner: DispatcherInner {
751                        cfg,
752                        shared,
753                        error: None,
754                        st: DispatcherState::Processing,
755                        response: None,
756                        read_remains: 0,
757                        read_remains_prev: 0,
758                        read_max_timeout: Seconds::ZERO,
759                    },
760                },
761                inner,
762            )
763        }
764    }
765
766    #[ntex::test]
767    async fn test_basic() {
768        let (client, server) = IoTest::create();
769        client.remote_buffer_cap(1024);
770        client.write("GET /test HTTP/1\r\n\r\n");
771
772        let (disp, _) = Dispatcher::debug(
773            Io::new(server),
774            BytesCodec,
775            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
776                sleep(Millis(50)).await;
777                if let DispatchItem::Item(msg) = msg {
778                    Ok::<_, ()>(Some(msg.freeze()))
779                } else {
780                    panic!()
781                }
782            }),
783        );
784        spawn(async move {
785            let _ = disp.await;
786        });
787
788        sleep(Millis(25)).await;
789        let buf = client.read().await.unwrap();
790        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
791
792        client.write("GET /test HTTP/1\r\n\r\n");
793        let buf = client.read().await.unwrap();
794        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
795
796        client.close().await;
797        assert!(client.is_server_dropped());
798
799        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
800    }
801
802    #[ntex::test]
803    async fn test_sink() {
804        let (client, server) = IoTest::create();
805        client.remote_buffer_cap(1024);
806        client.write("GET /test HTTP/1\r\n\r\n");
807
808        let (disp, st) = Dispatcher::debug(
809            Io::new(server),
810            BytesCodec,
811            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
812                if let DispatchItem::Item(msg) = msg {
813                    Ok::<_, ()>(Some(msg.freeze()))
814                } else {
815                    panic!()
816                }
817            }),
818        );
819        spawn(async move {
820            let _ = disp.await;
821        });
822
823        let buf = client.read().await.unwrap();
824        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
825
826        assert!(st
827            .io()
828            .encode(Bytes::from_static(b"test"), &BytesCodec)
829            .is_ok());
830        let buf = client.read().await.unwrap();
831        assert_eq!(buf, Bytes::from_static(b"test"));
832
833        st.close();
834        sleep(Millis(1500)).await;
835        assert!(client.is_server_dropped());
836    }
837
838    #[ntex::test]
839    async fn test_err_in_service() {
840        let (client, server) = IoTest::create();
841        client.remote_buffer_cap(0);
842        client.write("GET /test HTTP/1\r\n\r\n");
843
844        let (disp, state) = Dispatcher::debug(
845            Io::new(server),
846            BytesCodec,
847            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
848                Err::<Option<Bytes>, _>(())
849            }),
850        );
851        state
852            .io()
853            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
854            .unwrap();
855        spawn(async move {
856            let _ = disp.await;
857        });
858
859        // buffer should be flushed
860        client.remote_buffer_cap(1024);
861        let buf = client.read().await.unwrap();
862        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
863
864        // write side must be closed, dispatcher waiting for read side to close
865        sleep(Millis(250)).await;
866        assert!(client.is_closed());
867
868        // close read side
869        client.close().await;
870
871        // dispatcher is closed
872        assert!(client.is_server_dropped());
873    }
874
875    #[ntex::test]
876    async fn test_err_in_service_ready() {
877        let (client, server) = IoTest::create();
878        client.remote_buffer_cap(0);
879        client.write("GET /test HTTP/1\r\n\r\n");
880
881        let counter = Rc::new(Cell::new(0));
882
883        struct Srv(Rc<Cell<usize>>);
884
885        impl Service<DispatchItem<BytesCodec>> for Srv {
886            type Response = Option<Response<BytesCodec>>;
887            type Error = &'static str;
888
889            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
890                self.0.set(self.0.get() + 1);
891                Err("test")
892            }
893
894            async fn call(
895                &self,
896                _: DispatchItem<BytesCodec>,
897                _: ServiceCtx<'_, Self>,
898            ) -> Result<Self::Response, Self::Error> {
899                Ok(None)
900            }
901        }
902
903        let (disp, state) =
904            Dispatcher::debug(Io::new(server), BytesCodec, Srv(counter.clone()));
905        spawn(async move {
906            let res = disp.await;
907            assert_eq!(res, Err("test"));
908        });
909
910        state
911            .io()
912            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
913            .unwrap();
914
915        // buffer should be flushed
916        client.remote_buffer_cap(1024);
917        let buf = client.read().await.unwrap();
918        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
919
920        // write side must be closed, dispatcher waiting for read side to close
921        sleep(Millis(250)).await;
922        assert!(client.is_closed());
923
924        // close read side
925        client.close().await;
926        assert!(client.is_server_dropped());
927
928        // service must be checked for readiness only once
929        assert_eq!(counter.get(), 1);
930    }
931
932    #[ntex::test]
933    async fn test_write_backpressure() {
934        let (client, server) = IoTest::create();
935        // do not allow to write to socket
936        client.remote_buffer_cap(0);
937        client.write("GET /test HTTP/1\r\n\r\n");
938
939        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
940        let data2 = data.clone();
941
942        let (disp, state) = Dispatcher::debug(
943            Io::new(server),
944            BytesCodec,
945            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
946                let data = data2.clone();
947                async move {
948                    match msg {
949                        DispatchItem::Item(_) => {
950                            data.lock().unwrap().borrow_mut().push(0);
951                            let bytes = rand::rng()
952                                .sample_iter(&rand::distr::Alphanumeric)
953                                .take(65_536)
954                                .map(char::from)
955                                .collect::<String>();
956                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
957                        }
958                        DispatchItem::WBackPressureEnabled => {
959                            data.lock().unwrap().borrow_mut().push(1);
960                        }
961                        DispatchItem::WBackPressureDisabled => {
962                            data.lock().unwrap().borrow_mut().push(2);
963                        }
964                        _ => (),
965                    }
966                    Ok(None)
967                }
968            }),
969        );
970        let pool = PoolId::P10.pool_ref();
971        pool.set_read_params(8 * 1024, 1024);
972        pool.set_write_params(16 * 1024, 1024);
973        state.set_memory_pool(pool);
974
975        spawn(async move {
976            let _ = disp.await;
977        });
978
979        let buf = client.read_any();
980        assert_eq!(buf, Bytes::from_static(b""));
981        client.write("GET /test HTTP/1\r\n\r\n");
982        sleep(Millis(25)).await;
983
984        // buf must be consumed
985        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
986
987        // response message
988        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
989
990        client.remote_buffer_cap(10240);
991        sleep(Millis(50)).await;
992        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
993
994        client.remote_buffer_cap(45056);
995        sleep(Millis(50)).await;
996        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 10240);
997
998        // backpressure disabled
999        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
1000    }
1001
1002    #[ntex::test]
1003    async fn test_disconnect_during_read_backpressure() {
1004        let (client, server) = IoTest::create();
1005        client.remote_buffer_cap(0);
1006
1007        let (disp, state) = Dispatcher::debug(
1008            Io::new(server),
1009            BytesCodec,
1010            ntex_util::services::inflight::InFlightService::new(
1011                1,
1012                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
1013                    if let DispatchItem::Item(_) = msg {
1014                        sleep(Millis(500)).await;
1015                        Ok::<_, ()>(None)
1016                    } else {
1017                        Ok(None)
1018                    }
1019                }),
1020            ),
1021        );
1022        disp.inner.cfg.set_keepalive_timeout(Seconds::ZERO);
1023        let pool = PoolId::P10.pool_ref();
1024        pool.set_read_params(1024, 512);
1025        state.set_memory_pool(pool);
1026
1027        let (tx, rx) = ntex::channel::oneshot::channel();
1028        ntex::rt::spawn(async move {
1029            let _ = disp.await;
1030            let _ = tx.send(());
1031        });
1032
1033        let bytes = rand::rng()
1034            .sample_iter(&rand::distr::Alphanumeric)
1035            .take(1024)
1036            .map(char::from)
1037            .collect::<String>();
1038        client.write(bytes.clone());
1039        sleep(Millis(25)).await;
1040        client.write(bytes);
1041        sleep(Millis(25)).await;
1042
1043        // close read side
1044        state.close();
1045        let _ = rx.recv().await;
1046    }
1047
1048    #[ntex::test]
1049    async fn test_keepalive() {
1050        let (client, server) = IoTest::create();
1051        client.remote_buffer_cap(1024);
1052        client.write("GET /test HTTP/1\r\n\r\n");
1053
1054        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1055        let data2 = data.clone();
1056
1057        let cfg = DispatcherConfig::default()
1058            .set_disconnect_timeout(Seconds(1))
1059            .set_keepalive_timeout(Seconds(1))
1060            .clone();
1061
1062        let (disp, state) = Dispatcher::debug_cfg(
1063            Io::new(server),
1064            BytesCodec,
1065            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1066                let data = data2.clone();
1067                async move {
1068                    match msg {
1069                        DispatchItem::Item(bytes) => {
1070                            data.lock().unwrap().borrow_mut().push(0);
1071                            return Ok::<_, ()>(Some(bytes.freeze()));
1072                        }
1073                        DispatchItem::KeepAliveTimeout => {
1074                            data.lock().unwrap().borrow_mut().push(1);
1075                        }
1076                        _ => (),
1077                    }
1078                    Ok(None)
1079                }
1080            }),
1081            cfg,
1082        );
1083        spawn(async move {
1084            let _ = disp.await;
1085        });
1086
1087        let buf = client.read().await.unwrap();
1088        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1089        sleep(Millis(2000)).await;
1090
1091        // write side must be closed, dispatcher should fail with keep-alive
1092        let flags = state.flags();
1093        assert!(flags.contains(Flags::IO_STOPPING));
1094        assert!(client.is_closed());
1095        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1096    }
1097
1098    #[ntex::test]
1099    async fn test_keepalive2() {
1100        let (client, server) = IoTest::create();
1101        client.remote_buffer_cap(1024);
1102
1103        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1104        let data2 = data.clone();
1105
1106        let cfg = DispatcherConfig::default()
1107            .set_keepalive_timeout(Seconds(1))
1108            .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1109            .clone();
1110
1111        let (disp, state) = Dispatcher::debug_cfg(
1112            Io::new(server),
1113            BCodec(8),
1114            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1115                let data = data2.clone();
1116                async move {
1117                    match msg {
1118                        DispatchItem::Item(bytes) => {
1119                            data.lock().unwrap().borrow_mut().push(0);
1120                            return Ok::<_, ()>(Some(bytes.freeze()));
1121                        }
1122                        DispatchItem::KeepAliveTimeout => {
1123                            data.lock().unwrap().borrow_mut().push(1);
1124                        }
1125                        _ => (),
1126                    }
1127                    Ok(None)
1128                }
1129            }),
1130            cfg,
1131        );
1132        spawn(async move {
1133            let _ = disp.await;
1134        });
1135
1136        client.write("12345678");
1137        let buf = client.read().await.unwrap();
1138        assert_eq!(buf, Bytes::from_static(b"12345678"));
1139        sleep(Millis(2000)).await;
1140
1141        // write side must be closed, dispatcher should fail with keep-alive
1142        let flags = state.flags();
1143        assert!(flags.contains(Flags::IO_STOPPING));
1144        assert!(client.is_closed());
1145        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1146    }
1147
1148    /// Update keep-alive timer after receiving frame
1149    #[ntex::test]
1150    async fn test_keepalive3() {
1151        let (client, server) = IoTest::create();
1152        client.remote_buffer_cap(1024);
1153
1154        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1155        let data2 = data.clone();
1156
1157        let cfg = DispatcherConfig::default()
1158            .set_keepalive_timeout(Seconds(2))
1159            .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1160            .clone();
1161
1162        let (disp, _) = Dispatcher::debug_cfg(
1163            Io::new(server),
1164            BCodec(1),
1165            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1166                let data = data2.clone();
1167                async move {
1168                    match msg {
1169                        DispatchItem::Item(bytes) => {
1170                            data.lock().unwrap().borrow_mut().push(0);
1171                            return Ok::<_, ()>(Some(bytes.freeze()));
1172                        }
1173                        DispatchItem::KeepAliveTimeout => {
1174                            data.lock().unwrap().borrow_mut().push(1);
1175                        }
1176                        _ => (),
1177                    }
1178                    Ok(None)
1179                }
1180            }),
1181            cfg,
1182        );
1183        spawn(async move {
1184            let _ = disp.await;
1185        });
1186
1187        client.write("1");
1188        let buf = client.read().await.unwrap();
1189        assert_eq!(buf, Bytes::from_static(b"1"));
1190        sleep(Millis(750)).await;
1191
1192        client.write("2");
1193        let buf = client.read().await.unwrap();
1194        assert_eq!(buf, Bytes::from_static(b"2"));
1195
1196        sleep(Millis(750)).await;
1197        client.write("3");
1198        let buf = client.read().await.unwrap();
1199        assert_eq!(buf, Bytes::from_static(b"3"));
1200
1201        sleep(Millis(750)).await;
1202        assert!(!client.is_closed());
1203        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1204    }
1205
1206    #[ntex::test]
1207    async fn test_read_timeout() {
1208        let (client, server) = IoTest::create();
1209        client.remote_buffer_cap(1024);
1210
1211        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1212        let data2 = data.clone();
1213
1214        let (disp, state) = Dispatcher::debug(
1215            Io::new(server),
1216            BCodec(8),
1217            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1218                let data = data2.clone();
1219                async move {
1220                    match msg {
1221                        DispatchItem::Item(bytes) => {
1222                            data.lock().unwrap().borrow_mut().push(0);
1223                            return Ok::<_, ()>(Some(bytes.freeze()));
1224                        }
1225                        DispatchItem::ReadTimeout => {
1226                            data.lock().unwrap().borrow_mut().push(1);
1227                        }
1228                        _ => (),
1229                    }
1230                    Ok(None)
1231                }
1232            }),
1233        );
1234        spawn(async move {
1235            disp.inner
1236                .cfg
1237                .set_keepalive_timeout(Seconds::ZERO)
1238                .set_frame_read_rate(Seconds(1), Seconds(2), 2);
1239            let _ = disp.await;
1240        });
1241
1242        client.write("12345678");
1243        let buf = client.read().await.unwrap();
1244        assert_eq!(buf, Bytes::from_static(b"12345678"));
1245
1246        client.write("1");
1247        sleep(Millis(1000)).await;
1248        assert!(!state.flags().contains(Flags::IO_STOPPING));
1249        client.write("23");
1250        sleep(Millis(1000)).await;
1251        assert!(!state.flags().contains(Flags::IO_STOPPING));
1252        client.write("4");
1253        sleep(Millis(2000)).await;
1254
1255        // write side must be closed, dispatcher should fail with keep-alive
1256        assert!(state.flags().contains(Flags::IO_STOPPING));
1257        assert!(client.is_closed());
1258        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1259    }
1260
1261    #[ntex::test]
1262    async fn test_idle_timeout() {
1263        let (client, server) = IoTest::create();
1264        client.remote_buffer_cap(1024);
1265
1266        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1267        let data2 = data.clone();
1268
1269        let io = Io::new(server);
1270        let ioref = io.get_ref();
1271
1272        let (disp, state) = Dispatcher::debug_cfg(
1273            io,
1274            BCodec(1),
1275            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1276                let ioref = ioref.clone();
1277                ntex::rt::spawn(async move {
1278                    sleep(Millis(500)).await;
1279                    ioref.notify_timeout();
1280                });
1281                let data = data2.clone();
1282                async move {
1283                    match msg {
1284                        DispatchItem::Item(bytes) => {
1285                            data.lock().unwrap().borrow_mut().push(0);
1286                            return Ok::<_, ()>(Some(bytes.freeze()));
1287                        }
1288                        DispatchItem::ReadTimeout => {
1289                            data.lock().unwrap().borrow_mut().push(1);
1290                        }
1291                        _ => (),
1292                    }
1293                    Ok(None)
1294                }
1295            }),
1296            DispatcherConfig::default()
1297                .set_keepalive_timeout(Seconds::ZERO)
1298                .clone(),
1299        );
1300        spawn(async move {
1301            let _ = disp.await;
1302        });
1303
1304        client.write("1");
1305        let buf = client.read().await.unwrap();
1306        assert_eq!(buf, Bytes::from_static(b"1"));
1307
1308        sleep(Millis(1000)).await;
1309        assert!(state.flags().contains(Flags::IO_STOPPING));
1310        assert!(client.is_closed());
1311    }
1312
1313    #[ntex::test]
1314    async fn test_unhandled_data() {
1315        let handled = Arc::new(AtomicBool::new(false));
1316        let handled2 = handled.clone();
1317
1318        let (client, server) = IoTest::create();
1319        client.remote_buffer_cap(1024);
1320        client.write("GET /test HTTP/1\r\n\r\n");
1321
1322        let (disp, _) = Dispatcher::debug(
1323            Io::new(server),
1324            BytesCodec,
1325            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1326                handled2.store(true, Relaxed);
1327                async move {
1328                    sleep(Millis(50)).await;
1329                    if let DispatchItem::Item(msg) = msg {
1330                        Ok::<_, ()>(Some(msg.freeze()))
1331                    } else if let DispatchItem::Disconnect(_) = msg {
1332                        Ok::<_, ()>(None)
1333                    } else {
1334                        panic!()
1335                    }
1336                }
1337            }),
1338        );
1339        client.close().await;
1340        spawn(async move {
1341            let _ = disp.await;
1342        });
1343        sleep(Millis(50)).await;
1344
1345        assert!(handled.load(Relaxed));
1346    }
1347}