ntex_io/
dispatcher.rs

1//! Framed transport dispatcher
2#![allow(clippy::let_underscore_future)]
3use std::task::{ready, Context, Poll};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14#[derive(Clone, Debug)]
15/// Shared dispatcher configuration
16pub struct DispatcherConfig(Rc<DispatcherConfigInner>);
17
18#[derive(Debug)]
19struct DispatcherConfigInner {
20    keepalive_timeout: Cell<Seconds>,
21    disconnect_timeout: Cell<Seconds>,
22    frame_read_enabled: Cell<bool>,
23    frame_read_rate: Cell<u16>,
24    frame_read_timeout: Cell<Seconds>,
25    frame_read_max_timeout: Cell<Seconds>,
26}
27
28impl Default for DispatcherConfig {
29    fn default() -> Self {
30        DispatcherConfig(Rc::new(DispatcherConfigInner {
31            keepalive_timeout: Cell::new(Seconds(30)),
32            disconnect_timeout: Cell::new(Seconds(1)),
33            frame_read_rate: Cell::new(0),
34            frame_read_enabled: Cell::new(false),
35            frame_read_timeout: Cell::new(Seconds::ZERO),
36            frame_read_max_timeout: Cell::new(Seconds::ZERO),
37        }))
38    }
39}
40
41impl DispatcherConfig {
42    #[inline]
43    /// Get keep-alive timeout
44    pub fn keepalive_timeout(&self) -> Seconds {
45        self.0.keepalive_timeout.get()
46    }
47
48    #[inline]
49    /// Get disconnect timeout
50    pub fn disconnect_timeout(&self) -> Seconds {
51        self.0.disconnect_timeout.get()
52    }
53
54    #[inline]
55    /// Get frame read rate
56    pub fn frame_read_rate(&self) -> Option<(Seconds, Seconds, u16)> {
57        if self.0.frame_read_enabled.get() {
58            Some((
59                self.0.frame_read_timeout.get(),
60                self.0.frame_read_max_timeout.get(),
61                self.0.frame_read_rate.get(),
62            ))
63        } else {
64            None
65        }
66    }
67
68    /// Set keep-alive timeout in seconds.
69    ///
70    /// To disable timeout set value to 0.
71    ///
72    /// By default keep-alive timeout is set to 30 seconds.
73    pub fn set_keepalive_timeout(&self, timeout: Seconds) -> &Self {
74        self.0.keepalive_timeout.set(timeout);
75        self
76    }
77
78    /// Set connection disconnect timeout.
79    ///
80    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
81    /// within this time, the connection get dropped.
82    ///
83    /// To disable timeout set value to 0.
84    ///
85    /// By default disconnect timeout is set to 1 seconds.
86    pub fn set_disconnect_timeout(&self, timeout: Seconds) -> &Self {
87        self.0.disconnect_timeout.set(timeout);
88        self
89    }
90
91    /// Set read rate parameters for single frame.
92    ///
93    /// Set read timeout, max timeout and rate for reading payload. If the client
94    /// sends `rate` amount of data within `timeout` period of time, extend timeout by `timeout` seconds.
95    /// But no more than `max_timeout` timeout.
96    ///
97    /// By default frame read rate is disabled.
98    pub fn set_frame_read_rate(
99        &self,
100        timeout: Seconds,
101        max_timeout: Seconds,
102        rate: u16,
103    ) -> &Self {
104        self.0.frame_read_enabled.set(!timeout.is_zero());
105        self.0.frame_read_timeout.set(timeout);
106        self.0.frame_read_max_timeout.set(max_timeout);
107        self.0.frame_read_rate.set(rate);
108        self
109    }
110}
111
112pin_project_lite::pin_project! {
113    /// Dispatcher - is a future that reads frames from bytes stream
114    /// and pass then to the service.
115    pub struct Dispatcher<S, U>
116    where
117        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
118        U: Encoder,
119        U: Decoder,
120        U: 'static,
121    {
122        inner: DispatcherInner<S, U>,
123    }
124}
125
126bitflags::bitflags! {
127    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
128    struct Flags: u8  {
129        const READY_ERR     = 0b0000001;
130        const IO_ERR        = 0b0000010;
131        const KA_ENABLED    = 0b0000100;
132        const KA_TIMEOUT    = 0b0001000;
133        const READ_TIMEOUT  = 0b0010000;
134    }
135}
136
137struct DispatcherInner<S, U>
138where
139    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
140    U: Encoder + Decoder + 'static,
141{
142    st: DispatcherState,
143    error: Option<S::Error>,
144    flags: Flags,
145    shared: Rc<DispatcherShared<S, U>>,
146    response: Option<PipelineCall<S, DispatchItem<U>>>,
147    cfg: DispatcherConfig,
148    read_remains: u32,
149    read_remains_prev: u32,
150    read_max_timeout: Seconds,
151}
152
153pub(crate) struct DispatcherShared<S, U>
154where
155    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
156    U: Encoder + Decoder,
157{
158    io: IoBoxed,
159    codec: U,
160    service: PipelineBinding<S, DispatchItem<U>>,
161    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
162    inflight: Cell<u32>,
163}
164
165#[derive(Copy, Clone, Debug)]
166enum DispatcherState {
167    Processing,
168    Backpressure,
169    Stop,
170    Shutdown,
171}
172
173#[derive(Debug)]
174enum DispatcherError<S, U> {
175    Encoder(U),
176    Service(S),
177}
178
179enum PollService<U: Encoder + Decoder> {
180    Item(DispatchItem<U>),
181    Continue,
182    Ready,
183}
184
185impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
186    fn from(err: Either<S, U>) -> Self {
187        match err {
188            Either::Left(err) => DispatcherError::Service(err),
189            Either::Right(err) => DispatcherError::Encoder(err),
190        }
191    }
192}
193
194impl<S, U> Dispatcher<S, U>
195where
196    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
197    U: Decoder + Encoder + 'static,
198{
199    /// Construct new `Dispatcher` instance.
200    pub fn new<Io, F>(
201        io: Io,
202        codec: U,
203        service: F,
204        cfg: &DispatcherConfig,
205    ) -> Dispatcher<S, U>
206    where
207        IoBoxed: From<Io>,
208        F: IntoService<S, DispatchItem<U>>,
209    {
210        let io = IoBoxed::from(io);
211        io.set_disconnect_timeout(cfg.disconnect_timeout());
212
213        let flags = if cfg.keepalive_timeout().is_zero() {
214            Flags::empty()
215        } else {
216            Flags::KA_ENABLED
217        };
218
219        let shared = Rc::new(DispatcherShared {
220            io,
221            codec,
222            error: Cell::new(None),
223            inflight: Cell::new(0),
224            service: Pipeline::new(service.into_service()).bind(),
225        });
226
227        Dispatcher {
228            inner: DispatcherInner {
229                shared,
230                flags,
231                cfg: cfg.clone(),
232                response: None,
233                error: None,
234                read_remains: 0,
235                read_remains_prev: 0,
236                read_max_timeout: Seconds::ZERO,
237                st: DispatcherState::Processing,
238            },
239        }
240    }
241}
242
243impl<S, U> DispatcherShared<S, U>
244where
245    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
246    U: Encoder + Decoder,
247{
248    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
249        match item {
250            Ok(Some(val)) => {
251                if let Err(err) = io.encode(val, &self.codec) {
252                    self.error.set(Some(DispatcherError::Encoder(err)))
253                }
254            }
255            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
256            Ok(None) => (),
257        }
258        self.inflight.set(self.inflight.get() - 1);
259        if wake {
260            io.wake();
261        }
262    }
263}
264
265impl<S, U> Future for Dispatcher<S, U>
266where
267    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
268    U: Decoder + Encoder + 'static,
269{
270    type Output = Result<(), S::Error>;
271
272    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273        let mut this = self.as_mut().project();
274        let slf = &mut this.inner;
275
276        // handle service response future
277        if let Some(fut) = slf.response.as_mut() {
278            if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
279                slf.shared.handle_result(item, &slf.shared.io, false);
280                slf.response = None;
281            }
282        }
283
284        loop {
285            match slf.st {
286                DispatcherState::Processing => {
287                    let item = match ready!(slf.poll_service(cx)) {
288                        PollService::Ready => {
289                            // decode incoming bytes if buffer is ready
290                            match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
291                                Ok(decoded) => {
292                                    slf.update_timer(&decoded);
293                                    if let Some(el) = decoded.item {
294                                        DispatchItem::Item(el)
295                                    } else {
296                                        return Poll::Pending;
297                                    }
298                                }
299                                Err(RecvError::KeepAlive) => {
300                                    if let Err(err) = slf.handle_timeout() {
301                                        slf.st = DispatcherState::Stop;
302                                        err
303                                    } else {
304                                        continue;
305                                    }
306                                }
307                                Err(RecvError::Stop) => {
308                                    log::trace!(
309                                        "{}: Dispatcher is instructed to stop",
310                                        slf.shared.io.tag()
311                                    );
312                                    slf.st = DispatcherState::Stop;
313                                    continue;
314                                }
315                                Err(RecvError::WriteBackpressure) => {
316                                    // instruct write task to notify dispatcher when data is flushed
317                                    slf.st = DispatcherState::Backpressure;
318                                    DispatchItem::WBackPressureEnabled
319                                }
320                                Err(RecvError::Decoder(err)) => {
321                                    log::trace!(
322                                        "{}: Decoder error, stopping dispatcher: {:?}",
323                                        slf.shared.io.tag(),
324                                        err
325                                    );
326                                    slf.st = DispatcherState::Stop;
327                                    DispatchItem::DecoderError(err)
328                                }
329                                Err(RecvError::PeerGone(err)) => {
330                                    log::trace!(
331                                        "{}: Peer is gone, stopping dispatcher: {:?}",
332                                        slf.shared.io.tag(),
333                                        err
334                                    );
335                                    slf.st = DispatcherState::Stop;
336                                    DispatchItem::Disconnect(err)
337                                }
338                            }
339                        }
340                        PollService::Item(item) => item,
341                        PollService::Continue => continue,
342                    };
343
344                    slf.call_service(cx, item);
345                }
346                // handle write back-pressure
347                DispatcherState::Backpressure => {
348                    match ready!(slf.poll_service(cx)) {
349                        PollService::Ready => (),
350                        PollService::Item(item) => slf.call_service(cx, item),
351                        PollService::Continue => continue,
352                    };
353
354                    let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
355                    {
356                        slf.st = DispatcherState::Stop;
357                        DispatchItem::Disconnect(Some(err))
358                    } else {
359                        slf.st = DispatcherState::Processing;
360                        DispatchItem::WBackPressureDisabled
361                    };
362                    slf.call_service(cx, item);
363                }
364                // drain service responses and shutdown io
365                DispatcherState::Stop => {
366                    slf.shared.io.stop_timer();
367
368                    // service may relay on poll_ready for response results
369                    if !slf.flags.contains(Flags::READY_ERR) {
370                        if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
371                            if res.is_err() {
372                                slf.flags.insert(Flags::READY_ERR);
373                            }
374                        }
375                    }
376
377                    if slf.shared.inflight.get() == 0 {
378                        if slf.shared.io.poll_shutdown(cx).is_ready() {
379                            slf.st = DispatcherState::Shutdown;
380                            continue;
381                        }
382                    } else if !slf.flags.contains(Flags::IO_ERR) {
383                        match ready!(slf.shared.io.poll_status_update(cx)) {
384                            IoStatusUpdate::PeerGone(_)
385                            | IoStatusUpdate::Stop
386                            | IoStatusUpdate::KeepAlive => {
387                                slf.flags.insert(Flags::IO_ERR);
388                                continue;
389                            }
390                            IoStatusUpdate::WriteBackpressure => {
391                                if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
392                                    slf.flags.insert(Flags::IO_ERR);
393                                }
394                                continue;
395                            }
396                        }
397                    } else {
398                        slf.shared.io.poll_dispatch(cx);
399                    }
400                    return Poll::Pending;
401                }
402                // shutdown service
403                DispatcherState::Shutdown => {
404                    return if slf.shared.service.poll_shutdown(cx).is_ready() {
405                        log::trace!(
406                            "{}: Service shutdown is completed, stop",
407                            slf.shared.io.tag()
408                        );
409
410                        Poll::Ready(if let Some(err) = slf.error.take() {
411                            Err(err)
412                        } else {
413                            Ok(())
414                        })
415                    } else {
416                        Poll::Pending
417                    };
418                }
419            }
420        }
421    }
422}
423
424impl<S, U> DispatcherInner<S, U>
425where
426    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
427    U: Decoder + Encoder + 'static,
428{
429    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
430        let mut fut = self.shared.service.call_nowait(item);
431        self.shared.inflight.set(self.shared.inflight.get() + 1);
432
433        // optimize first call
434        if self.response.is_none() {
435            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
436                self.shared.handle_result(result, &self.shared.io, false);
437            } else {
438                self.response = Some(fut);
439            }
440        } else {
441            let shared = self.shared.clone();
442            spawn(async move {
443                let result = fut.await;
444                shared.handle_result(result, &shared.io, true);
445            });
446        }
447    }
448
449    fn check_error(&mut self) -> PollService<U> {
450        // check for errors
451        if let Some(err) = self.shared.error.take() {
452            log::trace!(
453                "{}: Error occured, stopping dispatcher",
454                self.shared.io.tag()
455            );
456            self.st = DispatcherState::Stop;
457
458            match err {
459                DispatcherError::Encoder(err) => {
460                    PollService::Item(DispatchItem::EncoderError(err))
461                }
462                DispatcherError::Service(err) => {
463                    self.error = Some(err);
464                    PollService::Continue
465                }
466            }
467        } else {
468            PollService::Ready
469        }
470    }
471
472    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
473        // wait until service becomes ready
474        match self.shared.service.poll_ready(cx) {
475            Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
476            // pause io read task
477            Poll::Pending => {
478                log::trace!(
479                    "{}: Service is not ready, register dispatcher",
480                    self.shared.io.tag()
481                );
482
483                // remove all timers
484                self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
485                self.shared.io.stop_timer();
486
487                match ready!(self.shared.io.poll_read_pause(cx)) {
488                    IoStatusUpdate::KeepAlive => {
489                        log::trace!(
490                            "{}: Keep-alive error, stopping dispatcher during pause",
491                            self.shared.io.tag()
492                        );
493                        self.st = DispatcherState::Stop;
494                        Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
495                    }
496                    IoStatusUpdate::Stop => {
497                        log::trace!(
498                            "{}: Dispatcher is instructed to stop during pause",
499                            self.shared.io.tag()
500                        );
501                        self.st = DispatcherState::Stop;
502                        Poll::Ready(PollService::Continue)
503                    }
504                    IoStatusUpdate::PeerGone(err) => {
505                        log::trace!(
506                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
507                            self.shared.io.tag(),
508                            err
509                        );
510                        self.st = DispatcherState::Stop;
511                        Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
512                    }
513                    IoStatusUpdate::WriteBackpressure => {
514                        self.st = DispatcherState::Backpressure;
515                        Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
516                    }
517                }
518            }
519            // handle service readiness error
520            Poll::Ready(Err(err)) => {
521                log::trace!(
522                    "{}: Service readiness check failed, stopping",
523                    self.shared.io.tag()
524                );
525                self.st = DispatcherState::Stop;
526                self.error = Some(err);
527                self.flags.insert(Flags::READY_ERR);
528                Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
529            }
530        }
531    }
532
533    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
534        // got parsed frame
535        if decoded.item.is_some() {
536            self.read_remains = 0;
537            self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
538        } else if self.flags.contains(Flags::READ_TIMEOUT) {
539            // received new data but not enough for parsing complete frame
540            self.read_remains = decoded.remains as u32;
541        } else if self.read_remains == 0 && decoded.remains == 0 {
542            // no new data, start keep-alive timer
543            if self.flags.contains(Flags::KA_ENABLED)
544                && !self.flags.contains(Flags::KA_TIMEOUT)
545            {
546                log::trace!(
547                    "{}: Start keep-alive timer {:?}",
548                    self.shared.io.tag(),
549                    self.cfg.keepalive_timeout()
550                );
551                self.flags.insert(Flags::KA_TIMEOUT);
552                self.shared.io.start_timer(self.cfg.keepalive_timeout());
553            }
554        } else if let Some((timeout, max, _)) = self.cfg.frame_read_rate() {
555            // we got new data but not enough to parse single frame
556            // start read timer
557            self.flags.insert(Flags::READ_TIMEOUT);
558
559            self.read_remains = decoded.remains as u32;
560            self.read_remains_prev = 0;
561            self.read_max_timeout = max;
562            self.shared.io.start_timer(timeout);
563        }
564    }
565
566    fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
567        // check read timer
568        if self.flags.contains(Flags::READ_TIMEOUT) {
569            if let Some((timeout, max, rate)) = self.cfg.frame_read_rate() {
570                let total = (self.read_remains - self.read_remains_prev)
571                    .try_into()
572                    .unwrap_or(u16::MAX);
573
574                // read rate, start timer for next period
575                if total > rate {
576                    self.read_remains_prev = self.read_remains;
577                    self.read_remains = 0;
578
579                    if !max.is_zero() {
580                        self.read_max_timeout =
581                            Seconds(self.read_max_timeout.0.saturating_sub(timeout.0));
582                    }
583
584                    if max.is_zero() || !self.read_max_timeout.is_zero() {
585                        log::trace!(
586                            "{}: Frame read rate {:?}, extend timer",
587                            self.shared.io.tag(),
588                            total
589                        );
590                        self.shared.io.start_timer(timeout);
591                        return Ok(());
592                    }
593                    log::trace!(
594                        "{}: Max payload timeout has been reached",
595                        self.shared.io.tag()
596                    );
597                }
598                Err(DispatchItem::ReadTimeout)
599            } else {
600                Ok(())
601            }
602        } else if self.flags.contains(Flags::KA_TIMEOUT) {
603            log::trace!(
604                "{}: Keep-alive error, stopping dispatcher",
605                self.shared.io.tag()
606            );
607            Err(DispatchItem::KeepAliveTimeout)
608        } else {
609            Ok(())
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
617    use std::{cell::RefCell, io};
618
619    use ntex_bytes::{Bytes, BytesMut, PoolId, PoolRef};
620    use ntex_codec::BytesCodec;
621    use ntex_service::ServiceCtx;
622    use ntex_util::{time::sleep, time::Millis};
623    use rand::Rng;
624
625    use super::*;
626    use crate::{testing::IoTest, Flags, Io, IoRef, IoStream};
627
628    pub(crate) struct State(IoRef);
629
630    impl State {
631        fn flags(&self) -> Flags {
632            self.0.flags()
633        }
634
635        fn io(&self) -> &IoRef {
636            &self.0
637        }
638
639        fn close(&self) {
640            self.0 .0.insert_flags(Flags::DSP_STOP);
641            self.0 .0.dispatch_task.wake();
642        }
643
644        fn set_memory_pool(&self, pool: PoolRef) {
645            self.0 .0.pool.set(pool);
646        }
647    }
648
649    #[derive(Copy, Clone)]
650    struct BCodec(usize);
651
652    impl Encoder for BCodec {
653        type Item = Bytes;
654        type Error = io::Error;
655
656        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
657            dst.extend_from_slice(&item[..]);
658            Ok(())
659        }
660    }
661
662    impl Decoder for BCodec {
663        type Item = BytesMut;
664        type Error = io::Error;
665
666        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
667            if src.len() < self.0 {
668                Ok(None)
669            } else {
670                Ok(Some(src.split_to(self.0)))
671            }
672        }
673    }
674
675    impl<S, U> Dispatcher<S, U>
676    where
677        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
678        U: Decoder + Encoder + 'static,
679    {
680        /// Construct new `Dispatcher` instance
681        pub(crate) fn debug<T: IoStream>(io: T, codec: U, service: S) -> (Self, State) {
682            let cfg = DispatcherConfig::default()
683                .set_keepalive_timeout(Seconds(1))
684                .clone();
685            Self::debug_cfg(io, codec, service, cfg)
686        }
687
688        /// Construct new `Dispatcher` instance
689        pub(crate) fn debug_cfg<T: IoStream>(
690            io: T,
691            codec: U,
692            service: S,
693            cfg: DispatcherConfig,
694        ) -> (Self, State) {
695            let state = Io::new(io);
696            state.set_disconnect_timeout(cfg.disconnect_timeout());
697            state.set_tag("DBG");
698
699            let flags = if cfg.keepalive_timeout().is_zero() {
700                super::Flags::empty()
701            } else {
702                super::Flags::KA_ENABLED
703            };
704
705            let inner = State(state.get_ref());
706            state.start_timer(Seconds::ONE);
707
708            let shared = Rc::new(DispatcherShared {
709                codec,
710                io: state.into(),
711                error: Cell::new(None),
712                inflight: Cell::new(0),
713                service: Pipeline::new(service).bind(),
714            });
715
716            (
717                Dispatcher {
718                    inner: DispatcherInner {
719                        error: None,
720                        st: DispatcherState::Processing,
721                        response: None,
722                        read_remains: 0,
723                        read_remains_prev: 0,
724                        read_max_timeout: Seconds::ZERO,
725                        shared,
726                        cfg,
727                        flags,
728                    },
729                },
730                inner,
731            )
732        }
733    }
734
735    #[ntex::test]
736    async fn test_basic() {
737        let (client, server) = IoTest::create();
738        client.remote_buffer_cap(1024);
739        client.write("GET /test HTTP/1\r\n\r\n");
740
741        let (disp, _) = Dispatcher::debug(
742            server,
743            BytesCodec,
744            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
745                sleep(Millis(50)).await;
746                if let DispatchItem::Item(msg) = msg {
747                    Ok::<_, ()>(Some(msg.freeze()))
748                } else {
749                    panic!()
750                }
751            }),
752        );
753        spawn(async move {
754            let _ = disp.await;
755        });
756
757        sleep(Millis(25)).await;
758        let buf = client.read().await.unwrap();
759        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
760
761        client.write("GET /test HTTP/1\r\n\r\n");
762        let buf = client.read().await.unwrap();
763        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
764
765        client.close().await;
766        assert!(client.is_server_dropped());
767
768        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
769    }
770
771    #[ntex::test]
772    async fn test_sink() {
773        let (client, server) = IoTest::create();
774        client.remote_buffer_cap(1024);
775        client.write("GET /test HTTP/1\r\n\r\n");
776
777        let (disp, st) = Dispatcher::debug(
778            server,
779            BytesCodec,
780            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
781                if let DispatchItem::Item(msg) = msg {
782                    Ok::<_, ()>(Some(msg.freeze()))
783                } else {
784                    panic!()
785                }
786            }),
787        );
788        spawn(async move {
789            let _ = disp.await;
790        });
791
792        let buf = client.read().await.unwrap();
793        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
794
795        assert!(st
796            .io()
797            .encode(Bytes::from_static(b"test"), &BytesCodec)
798            .is_ok());
799        let buf = client.read().await.unwrap();
800        assert_eq!(buf, Bytes::from_static(b"test"));
801
802        st.close();
803        sleep(Millis(1500)).await;
804        assert!(client.is_server_dropped());
805    }
806
807    #[ntex::test]
808    async fn test_err_in_service() {
809        let (client, server) = IoTest::create();
810        client.remote_buffer_cap(0);
811        client.write("GET /test HTTP/1\r\n\r\n");
812
813        let (disp, state) = Dispatcher::debug(
814            server,
815            BytesCodec,
816            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
817                Err::<Option<Bytes>, _>(())
818            }),
819        );
820        state
821            .io()
822            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
823            .unwrap();
824        spawn(async move {
825            let _ = disp.await;
826        });
827
828        // buffer should be flushed
829        client.remote_buffer_cap(1024);
830        let buf = client.read().await.unwrap();
831        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
832
833        // write side must be closed, dispatcher waiting for read side to close
834        sleep(Millis(250)).await;
835        assert!(client.is_closed());
836
837        // close read side
838        client.close().await;
839
840        // dispatcher is closed
841        assert!(client.is_server_dropped());
842    }
843
844    #[ntex::test]
845    async fn test_err_in_service_ready() {
846        let (client, server) = IoTest::create();
847        client.remote_buffer_cap(0);
848        client.write("GET /test HTTP/1\r\n\r\n");
849
850        let counter = Rc::new(Cell::new(0));
851
852        struct Srv(Rc<Cell<usize>>);
853
854        impl Service<DispatchItem<BytesCodec>> for Srv {
855            type Response = Option<Response<BytesCodec>>;
856            type Error = &'static str;
857
858            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
859                self.0.set(self.0.get() + 1);
860                Err("test")
861            }
862
863            async fn call(
864                &self,
865                _: DispatchItem<BytesCodec>,
866                _: ServiceCtx<'_, Self>,
867            ) -> Result<Self::Response, Self::Error> {
868                Ok(None)
869            }
870        }
871
872        let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
873        spawn(async move {
874            let res = disp.await;
875            assert_eq!(res, Err("test"));
876        });
877
878        state
879            .io()
880            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
881            .unwrap();
882
883        // buffer should be flushed
884        client.remote_buffer_cap(1024);
885        let buf = client.read().await.unwrap();
886        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
887
888        // write side must be closed, dispatcher waiting for read side to close
889        sleep(Millis(250)).await;
890        assert!(client.is_closed());
891
892        // close read side
893        client.close().await;
894        assert!(client.is_server_dropped());
895
896        // service must be checked for readiness only once
897        assert_eq!(counter.get(), 1);
898    }
899
900    #[ntex::test]
901    async fn test_write_backpressure() {
902        let (client, server) = IoTest::create();
903        // do not allow to write to socket
904        client.remote_buffer_cap(0);
905        client.write("GET /test HTTP/1\r\n\r\n");
906
907        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
908        let data2 = data.clone();
909
910        let (disp, state) = Dispatcher::debug(
911            server,
912            BytesCodec,
913            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
914                let data = data2.clone();
915                async move {
916                    match msg {
917                        DispatchItem::Item(_) => {
918                            data.lock().unwrap().borrow_mut().push(0);
919                            let bytes = rand::rng()
920                                .sample_iter(&rand::distr::Alphanumeric)
921                                .take(65_536)
922                                .map(char::from)
923                                .collect::<String>();
924                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
925                        }
926                        DispatchItem::WBackPressureEnabled => {
927                            data.lock().unwrap().borrow_mut().push(1);
928                        }
929                        DispatchItem::WBackPressureDisabled => {
930                            data.lock().unwrap().borrow_mut().push(2);
931                        }
932                        _ => (),
933                    }
934                    Ok(None)
935                }
936            }),
937        );
938        let pool = PoolId::P10.pool_ref();
939        pool.set_read_params(8 * 1024, 1024);
940        pool.set_write_params(16 * 1024, 1024);
941        state.set_memory_pool(pool);
942
943        spawn(async move {
944            let _ = disp.await;
945        });
946
947        let buf = client.read_any();
948        assert_eq!(buf, Bytes::from_static(b""));
949        client.write("GET /test HTTP/1\r\n\r\n");
950        sleep(Millis(25)).await;
951
952        // buf must be consumed
953        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
954
955        // response message
956        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
957
958        client.remote_buffer_cap(10240);
959        sleep(Millis(50)).await;
960        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
961
962        client.remote_buffer_cap(45056);
963        sleep(Millis(50)).await;
964        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 10240);
965
966        // backpressure disabled
967        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
968    }
969
970    #[ntex::test]
971    async fn test_disconnect_during_read_backpressure() {
972        let (client, server) = IoTest::create();
973        client.remote_buffer_cap(0);
974
975        let (disp, state) = Dispatcher::debug(
976            server,
977            BytesCodec,
978            ntex_util::services::inflight::InFlightService::new(
979                1,
980                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
981                    if let DispatchItem::Item(_) = msg {
982                        sleep(Millis(500)).await;
983                        Ok::<_, ()>(None)
984                    } else {
985                        Ok(None)
986                    }
987                }),
988            ),
989        );
990        disp.inner.cfg.set_keepalive_timeout(Seconds::ZERO);
991        let pool = PoolId::P10.pool_ref();
992        pool.set_read_params(1024, 512);
993        state.set_memory_pool(pool);
994
995        let (tx, rx) = ntex::channel::oneshot::channel();
996        ntex::rt::spawn(async move {
997            let _ = disp.await;
998            let _ = tx.send(());
999        });
1000
1001        let bytes = rand::rng()
1002            .sample_iter(&rand::distr::Alphanumeric)
1003            .take(1024)
1004            .map(char::from)
1005            .collect::<String>();
1006        client.write(bytes.clone());
1007        sleep(Millis(25)).await;
1008        client.write(bytes);
1009        sleep(Millis(25)).await;
1010
1011        // close read side
1012        state.close();
1013        let _ = rx.recv().await;
1014    }
1015
1016    #[ntex::test]
1017    async fn test_keepalive() {
1018        let (client, server) = IoTest::create();
1019        client.remote_buffer_cap(1024);
1020        client.write("GET /test HTTP/1\r\n\r\n");
1021
1022        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1023        let data2 = data.clone();
1024
1025        let cfg = DispatcherConfig::default()
1026            .set_disconnect_timeout(Seconds(1))
1027            .set_keepalive_timeout(Seconds(1))
1028            .clone();
1029
1030        let (disp, state) = Dispatcher::debug_cfg(
1031            server,
1032            BytesCodec,
1033            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1034                let data = data2.clone();
1035                async move {
1036                    match msg {
1037                        DispatchItem::Item(bytes) => {
1038                            data.lock().unwrap().borrow_mut().push(0);
1039                            return Ok::<_, ()>(Some(bytes.freeze()));
1040                        }
1041                        DispatchItem::KeepAliveTimeout => {
1042                            data.lock().unwrap().borrow_mut().push(1);
1043                        }
1044                        _ => (),
1045                    }
1046                    Ok(None)
1047                }
1048            }),
1049            cfg,
1050        );
1051        spawn(async move {
1052            let _ = disp.await;
1053        });
1054
1055        let buf = client.read().await.unwrap();
1056        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1057        sleep(Millis(2000)).await;
1058
1059        // write side must be closed, dispatcher should fail with keep-alive
1060        let flags = state.flags();
1061        assert!(flags.contains(Flags::IO_STOPPING));
1062        assert!(client.is_closed());
1063        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1064    }
1065
1066    #[ntex::test]
1067    async fn test_keepalive2() {
1068        let (client, server) = IoTest::create();
1069        client.remote_buffer_cap(1024);
1070
1071        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1072        let data2 = data.clone();
1073
1074        let cfg = DispatcherConfig::default()
1075            .set_keepalive_timeout(Seconds(1))
1076            .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1077            .clone();
1078
1079        let (disp, state) = Dispatcher::debug_cfg(
1080            server,
1081            BCodec(8),
1082            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1083                let data = data2.clone();
1084                async move {
1085                    match msg {
1086                        DispatchItem::Item(bytes) => {
1087                            data.lock().unwrap().borrow_mut().push(0);
1088                            return Ok::<_, ()>(Some(bytes.freeze()));
1089                        }
1090                        DispatchItem::KeepAliveTimeout => {
1091                            data.lock().unwrap().borrow_mut().push(1);
1092                        }
1093                        _ => (),
1094                    }
1095                    Ok(None)
1096                }
1097            }),
1098            cfg,
1099        );
1100        spawn(async move {
1101            let _ = disp.await;
1102        });
1103
1104        client.write("12345678");
1105        let buf = client.read().await.unwrap();
1106        assert_eq!(buf, Bytes::from_static(b"12345678"));
1107        sleep(Millis(2000)).await;
1108
1109        // write side must be closed, dispatcher should fail with keep-alive
1110        let flags = state.flags();
1111        assert!(flags.contains(Flags::IO_STOPPING));
1112        assert!(client.is_closed());
1113        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1114    }
1115
1116    /// Update keep-alive timer after receiving frame
1117    #[ntex::test]
1118    async fn test_keepalive3() {
1119        let (client, server) = IoTest::create();
1120        client.remote_buffer_cap(1024);
1121
1122        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1123        let data2 = data.clone();
1124
1125        let cfg = DispatcherConfig::default()
1126            .set_keepalive_timeout(Seconds(2))
1127            .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1128            .clone();
1129
1130        let (disp, _) = Dispatcher::debug_cfg(
1131            server,
1132            BCodec(1),
1133            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1134                let data = data2.clone();
1135                async move {
1136                    match msg {
1137                        DispatchItem::Item(bytes) => {
1138                            data.lock().unwrap().borrow_mut().push(0);
1139                            return Ok::<_, ()>(Some(bytes.freeze()));
1140                        }
1141                        DispatchItem::KeepAliveTimeout => {
1142                            data.lock().unwrap().borrow_mut().push(1);
1143                        }
1144                        _ => (),
1145                    }
1146                    Ok(None)
1147                }
1148            }),
1149            cfg,
1150        );
1151        spawn(async move {
1152            let _ = disp.await;
1153        });
1154
1155        client.write("1");
1156        let buf = client.read().await.unwrap();
1157        assert_eq!(buf, Bytes::from_static(b"1"));
1158        sleep(Millis(750)).await;
1159
1160        client.write("2");
1161        let buf = client.read().await.unwrap();
1162        assert_eq!(buf, Bytes::from_static(b"2"));
1163
1164        sleep(Millis(750)).await;
1165        client.write("3");
1166        let buf = client.read().await.unwrap();
1167        assert_eq!(buf, Bytes::from_static(b"3"));
1168
1169        sleep(Millis(750)).await;
1170        assert!(!client.is_closed());
1171        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1172    }
1173
1174    #[ntex::test]
1175    async fn test_read_timeout() {
1176        let (client, server) = IoTest::create();
1177        client.remote_buffer_cap(1024);
1178
1179        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1180        let data2 = data.clone();
1181
1182        let (disp, state) = Dispatcher::debug(
1183            server,
1184            BCodec(8),
1185            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1186                let data = data2.clone();
1187                async move {
1188                    match msg {
1189                        DispatchItem::Item(bytes) => {
1190                            data.lock().unwrap().borrow_mut().push(0);
1191                            return Ok::<_, ()>(Some(bytes.freeze()));
1192                        }
1193                        DispatchItem::ReadTimeout => {
1194                            data.lock().unwrap().borrow_mut().push(1);
1195                        }
1196                        _ => (),
1197                    }
1198                    Ok(None)
1199                }
1200            }),
1201        );
1202        spawn(async move {
1203            disp.inner
1204                .cfg
1205                .set_keepalive_timeout(Seconds::ZERO)
1206                .set_frame_read_rate(Seconds(1), Seconds(2), 2);
1207            let _ = disp.await;
1208        });
1209
1210        client.write("12345678");
1211        let buf = client.read().await.unwrap();
1212        assert_eq!(buf, Bytes::from_static(b"12345678"));
1213
1214        client.write("1");
1215        sleep(Millis(1000)).await;
1216        assert!(!state.flags().contains(Flags::IO_STOPPING));
1217        client.write("23");
1218        sleep(Millis(1000)).await;
1219        assert!(!state.flags().contains(Flags::IO_STOPPING));
1220        client.write("4");
1221        sleep(Millis(2000)).await;
1222
1223        // write side must be closed, dispatcher should fail with keep-alive
1224        assert!(state.flags().contains(Flags::IO_STOPPING));
1225        assert!(client.is_closed());
1226        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1227    }
1228
1229    #[ntex::test]
1230    async fn test_unhandled_data() {
1231        let handled = Arc::new(AtomicBool::new(false));
1232        let handled2 = handled.clone();
1233
1234        let (client, server) = IoTest::create();
1235        client.remote_buffer_cap(1024);
1236        client.write("GET /test HTTP/1\r\n\r\n");
1237
1238        let (disp, _) = Dispatcher::debug(
1239            server,
1240            BytesCodec,
1241            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1242                handled2.store(true, Relaxed);
1243                async move {
1244                    sleep(Millis(50)).await;
1245                    if let DispatchItem::Item(msg) = msg {
1246                        Ok::<_, ()>(Some(msg.freeze()))
1247                    } else if let DispatchItem::Disconnect(_) = msg {
1248                        Ok::<_, ()>(None)
1249                    } else {
1250                        panic!()
1251                    }
1252                }
1253            }),
1254        );
1255        client.close().await;
1256        spawn(async move {
1257            let _ = disp.await;
1258        });
1259        sleep(Millis(50)).await;
1260
1261        assert!(handled.load(Relaxed));
1262    }
1263}