Skip to main content

ntex_dispatcher/
lib.rs

1//! Framed transport dispatcher
2#![deny(clippy::pedantic)]
3#![allow(clippy::cast_possible_truncation)]
4use std::task::{Context, Poll, ready};
5use std::{cell::Cell, fmt, future::Future, io, pin::Pin, rc::Rc};
6
7use ntex_codec::{Decoder, Encoder};
8use ntex_io::{Decoded, IoBoxed, IoStatusUpdate, RecvError};
9use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
10use ntex_util::{future::Either, spawn, time::Seconds};
11
12type Response<U> = <U as Encoder>::Item;
13
14/// Dispatch item
15pub enum DispatchItem<U: Encoder + Decoder> {
16    /// Parsed item
17    Item(<U as Decoder>::Item),
18    /// Dispatcher control message
19    Control(Control),
20    /// Dispatcher is stopping
21    Stop(Reason<U>),
22}
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25/// Dispatcher control message
26pub enum Control {
27    /// Write back-pressure enabled
28    WBackPressureEnabled,
29    /// Write back-pressure disabled
30    WBackPressureDisabled,
31}
32
33/// Dispatcher disconnect message
34pub enum Reason<U: Encoder + Decoder> {
35    /// Socket has been disconnected
36    Io(Option<io::Error>),
37    /// Encoder error
38    Encoder(<U as Encoder>::Error),
39    /// Decoder error
40    Decoder(<U as Decoder>::Error),
41    /// Keep alive timeout
42    KeepAliveTimeout,
43    /// Frame read timeout
44    ReadTimeout,
45}
46
47pin_project_lite::pin_project! {
48    /// Dispatcher - is a future that reads frames from bytes stream
49    /// and pass then to the service.
50    pub struct Dispatcher<S, U>
51    where
52        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
53        U: Encoder,
54        U: Decoder,
55        U: 'static,
56    {
57        inner: DispatcherInner<S, U>,
58    }
59}
60
61bitflags::bitflags! {
62    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
63    struct Flags: u8  {
64        const READY_ERR     = 0b000_0001;
65        const IO_ERR        = 0b000_0010;
66        const KA_ENABLED    = 0b000_0100;
67        const KA_TIMEOUT    = 0b000_1000;
68        const READ_TIMEOUT  = 0b001_0000;
69        const IDLE          = 0b010_0000;
70    }
71}
72
73struct DispatcherInner<S, U>
74where
75    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
76    U: Encoder + Decoder + 'static,
77{
78    st: DispatcherState,
79    error: Option<S::Error>,
80    shared: Rc<DispatcherShared<S, U>>,
81    response: Option<PipelineCall<S, DispatchItem<U>>>,
82    read_remains: u32,
83    read_remains_prev: u32,
84    read_max_timeout: Seconds,
85}
86
87pub(crate) struct DispatcherShared<S, U>
88where
89    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
90    U: Encoder + Decoder,
91{
92    io: IoBoxed,
93    codec: U,
94    service: PipelineBinding<S, DispatchItem<U>>,
95    flags: Cell<Flags>,
96    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
97    inflight: Cell<u32>,
98}
99
100#[derive(Copy, Clone, Debug)]
101enum DispatcherState {
102    Processing,
103    Backpressure,
104    Stop,
105    Shutdown,
106}
107
108#[derive(Debug)]
109enum DispatcherError<S, U> {
110    Encoder(U),
111    Service(S),
112}
113
114enum PollService<U: Encoder + Decoder> {
115    Item(DispatchItem<U>),
116    ItemWait(DispatchItem<U>),
117    Continue,
118    Ready,
119}
120
121impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
122    fn from(err: Either<S, U>) -> Self {
123        match err {
124            Either::Left(err) => DispatcherError::Service(err),
125            Either::Right(err) => DispatcherError::Encoder(err),
126        }
127    }
128}
129
130impl<S, U> Dispatcher<S, U>
131where
132    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
133    U: Decoder + Encoder + 'static,
134{
135    /// Construct new `Dispatcher` instance.
136    pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
137    where
138        IoBoxed: From<Io>,
139        F: IntoService<S, DispatchItem<U>>,
140    {
141        let io = IoBoxed::from(io);
142        let flags = if io.cfg().keepalive_timeout().is_zero() {
143            Flags::empty()
144        } else {
145            Flags::KA_ENABLED
146        };
147
148        let shared = Rc::new(DispatcherShared {
149            io,
150            codec,
151            flags: Cell::new(flags),
152            error: Cell::new(None),
153            inflight: Cell::new(0),
154            service: Pipeline::new(service.into_service()).bind(),
155        });
156
157        Dispatcher {
158            inner: DispatcherInner {
159                shared,
160                response: None,
161                error: None,
162                read_remains: 0,
163                read_remains_prev: 0,
164                read_max_timeout: Seconds::ZERO,
165                st: DispatcherState::Processing,
166            },
167        }
168    }
169}
170
171impl<S, U> DispatcherShared<S, U>
172where
173    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
174    U: Encoder + Decoder,
175{
176    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
177        match item {
178            Ok(Some(val)) => {
179                if let Err(err) = io.encode(val, &self.codec) {
180                    self.error.set(Some(DispatcherError::Encoder(err)));
181                }
182            }
183            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
184            Ok(None) => (),
185        }
186        let inflight = self.inflight.get() - 1;
187        self.inflight.set(inflight);
188        if inflight == 0 {
189            self.insert_flags(Flags::IDLE);
190        }
191        if wake {
192            io.wake();
193        }
194    }
195
196    fn contains(&self, f: Flags) -> bool {
197        self.flags.get().intersects(f)
198    }
199
200    fn insert_flags(&self, f: Flags) {
201        let mut flags = self.flags.get();
202        flags.insert(f);
203        self.flags.set(flags);
204    }
205
206    fn remove_flags(&self, f: Flags) -> bool {
207        let mut flags = self.flags.get();
208        if flags.intersects(f) {
209            flags.remove(f);
210            self.flags.set(flags);
211            true
212        } else {
213            false
214        }
215    }
216}
217
218impl<S, U> Future for Dispatcher<S, U>
219where
220    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
221    U: Decoder + Encoder + 'static,
222{
223    type Output = Result<(), S::Error>;
224
225    #[allow(clippy::too_many_lines)]
226    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227        let this = self.as_mut().project();
228        let inner = this.inner;
229
230        // handle service response future
231        if let Some(fut) = inner.response.as_mut()
232            && let Poll::Ready(item) = Pin::new(fut).poll(cx)
233        {
234            inner.shared.handle_result(item, &inner.shared.io, false);
235            inner.response = None;
236        }
237
238        loop {
239            match inner.st {
240                DispatcherState::Processing => {
241                    let (item, nowait) = match ready!(inner.poll_service(cx)) {
242                        PollService::Ready => {
243                            // decode incoming bytes if buffer is ready
244                            match inner.shared.io.poll_recv_decode(&inner.shared.codec, cx)
245                            {
246                                Ok(decoded) => {
247                                    inner.update_timer(&decoded);
248                                    if let Some(el) = decoded.item {
249                                        (DispatchItem::Item(el), true)
250                                    } else {
251                                        return Poll::Pending;
252                                    }
253                                }
254                                Err(RecvError::KeepAlive) => {
255                                    if let Err(ctl) = inner.handle_timeout() {
256                                        inner.st = DispatcherState::Stop;
257                                        (DispatchItem::Stop(ctl), true)
258                                    } else {
259                                        continue;
260                                    }
261                                }
262                                Err(RecvError::WriteBackpressure) => {
263                                    // instruct write task to notify dispatcher when data is flushed
264                                    inner.st = DispatcherState::Backpressure;
265                                    (
266                                        DispatchItem::Control(
267                                            Control::WBackPressureEnabled,
268                                        ),
269                                        true,
270                                    )
271                                }
272                                Err(RecvError::Decoder(err)) => {
273                                    log::trace!(
274                                        "{}: Decoder error, stopping dispatcher: {:?}",
275                                        inner.shared.io.tag(),
276                                        err
277                                    );
278                                    inner.st = DispatcherState::Stop;
279                                    (DispatchItem::Stop(Reason::Decoder(err)), true)
280                                }
281                                Err(RecvError::PeerGone(err)) => {
282                                    log::trace!(
283                                        "{}: Peer is gone, stopping dispatcher: {:?}",
284                                        inner.shared.io.tag(),
285                                        err
286                                    );
287                                    inner.st = DispatcherState::Stop;
288                                    (DispatchItem::Stop(Reason::Io(err)), true)
289                                }
290                            }
291                        }
292                        PollService::Item(item) => (item, true),
293                        PollService::ItemWait(item) => (item, false),
294                        PollService::Continue => continue,
295                    };
296
297                    inner.call_service(cx, item, nowait);
298                }
299                // handle write back-pressure
300                DispatcherState::Backpressure => {
301                    match ready!(inner.poll_service(cx)) {
302                        PollService::Ready => (),
303                        PollService::Item(item) => inner.call_service(cx, item, true),
304                        PollService::ItemWait(item) => inner.call_service(cx, item, false),
305                        PollService::Continue => continue,
306                    }
307
308                    let item =
309                        if let Err(err) = ready!(inner.shared.io.poll_flush(cx, false)) {
310                            inner.st = DispatcherState::Stop;
311                            DispatchItem::Stop(Reason::Io(Some(err)))
312                        } else {
313                            inner.st = DispatcherState::Processing;
314                            DispatchItem::Control(Control::WBackPressureDisabled)
315                        };
316                    inner.call_service(cx, item, false);
317                }
318                // drain service responses and shutdown io
319                DispatcherState::Stop => {
320                    inner.shared.io.stop_timer();
321
322                    // service may relay on poll_ready for response results
323                    if !inner.shared.contains(Flags::READY_ERR)
324                        && let Poll::Ready(res) = inner.shared.service.poll_ready(cx)
325                        && res.is_err()
326                    {
327                        inner.shared.insert_flags(Flags::READY_ERR);
328                    }
329
330                    if inner.shared.inflight.get() == 0 {
331                        if inner.shared.io.poll_shutdown(cx).is_ready() {
332                            inner.st = DispatcherState::Shutdown;
333                            continue;
334                        }
335                    } else if !inner.shared.contains(Flags::IO_ERR) {
336                        match ready!(inner.shared.io.poll_status_update(cx)) {
337                            IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
338                                inner.shared.insert_flags(Flags::IO_ERR);
339                                continue;
340                            }
341                            IoStatusUpdate::WriteBackpressure => {
342                                if ready!(inner.shared.io.poll_flush(cx, true)).is_err() {
343                                    inner.shared.insert_flags(Flags::IO_ERR);
344                                }
345                                continue;
346                            }
347                        }
348                    } else {
349                        inner.shared.io.poll_dispatch(cx);
350                    }
351                    return Poll::Pending;
352                }
353                // shutdown service
354                DispatcherState::Shutdown => {
355                    return if inner.shared.service.poll_shutdown(cx).is_ready() {
356                        log::trace!(
357                            "{}: Service shutdown is completed, stop",
358                            inner.shared.io.tag()
359                        );
360
361                        Poll::Ready(if let Some(err) = inner.error.take() {
362                            Err(err)
363                        } else {
364                            Ok(())
365                        })
366                    } else {
367                        Poll::Pending
368                    };
369                }
370            }
371        }
372    }
373}
374
375impl<S, U> DispatcherInner<S, U>
376where
377    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
378    U: Decoder + Encoder + 'static,
379{
380    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>, nowait: bool) {
381        let mut fut = if nowait {
382            self.shared.service.call_nowait(item)
383        } else {
384            self.shared.service.call(item)
385        };
386        let inflight = self.shared.inflight.get() + 1;
387        self.shared.inflight.set(inflight);
388        if inflight == 1 {
389            self.shared.remove_flags(Flags::IDLE);
390        }
391
392        // optimize first call
393        if self.response.is_none() {
394            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
395                self.shared.handle_result(result, &self.shared.io, false);
396            } else {
397                self.response = Some(fut);
398            }
399        } else {
400            let shared = self.shared.clone();
401            spawn(async move {
402                let result = fut.await;
403                shared.handle_result(result, &shared.io, true);
404            });
405        }
406    }
407
408    fn check_error(&mut self) -> PollService<U> {
409        // check for errors
410        if let Some(err) = self.shared.error.take() {
411            log::trace!(
412                "{}: Error occured, stopping dispatcher",
413                self.shared.io.tag()
414            );
415            self.st = DispatcherState::Stop;
416
417            match err {
418                DispatcherError::Encoder(err) => {
419                    PollService::Item(DispatchItem::Stop(Reason::Encoder(err)))
420                }
421                DispatcherError::Service(err) => {
422                    self.error = Some(err);
423                    PollService::Continue
424                }
425            }
426        } else {
427            PollService::Ready
428        }
429    }
430
431    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
432        // wait until service becomes ready
433        match self.shared.service.poll_ready(cx) {
434            Poll::Ready(Ok(())) => Poll::Ready(self.check_error()),
435            // pause io read task
436            Poll::Pending => {
437                log::trace!(
438                    "{}: Service is not ready, register dispatcher",
439                    self.shared.io.tag()
440                );
441
442                // remove all timers
443                self.shared
444                    .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
445                self.shared.io.stop_timer();
446
447                match ready!(self.shared.io.poll_read_pause(cx)) {
448                    IoStatusUpdate::KeepAlive => {
449                        if self.shared.contains(Flags::KA_ENABLED) {
450                            log::trace!(
451                                "{}: Keep-alive error, stopping dispatcher during pause",
452                                self.shared.io.tag()
453                            );
454                            self.st = DispatcherState::Stop;
455                            Poll::Ready(PollService::ItemWait(DispatchItem::Stop(
456                                Reason::KeepAliveTimeout,
457                            )))
458                        } else {
459                            // ignore spurious DSP_TIMEOUT when keep-alive is disabled
460                            Poll::Ready(PollService::Continue)
461                        }
462                    }
463                    IoStatusUpdate::PeerGone(err) => {
464                        log::trace!(
465                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
466                            self.shared.io.tag(),
467                            err
468                        );
469                        self.st = DispatcherState::Stop;
470                        Poll::Ready(PollService::ItemWait(DispatchItem::Stop(Reason::Io(
471                            err,
472                        ))))
473                    }
474                    IoStatusUpdate::WriteBackpressure => {
475                        self.st = DispatcherState::Backpressure;
476                        Poll::Ready(PollService::ItemWait(DispatchItem::Control(
477                            Control::WBackPressureEnabled,
478                        )))
479                    }
480                }
481            }
482            // handle service readiness error
483            Poll::Ready(Err(err)) => {
484                log::trace!(
485                    "{}: Service readiness check failed, stopping",
486                    self.shared.io.tag()
487                );
488                self.st = DispatcherState::Stop;
489                self.error = Some(err);
490                self.shared.insert_flags(Flags::READY_ERR);
491                Poll::Ready(PollService::Continue)
492            }
493        }
494    }
495
496    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
497        // got parsed frame
498        if decoded.item.is_some() {
499            self.read_remains = 0;
500            self.shared
501                .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
502        } else if self.shared.contains(Flags::READ_TIMEOUT) {
503            // received new data but not enough for parsing complete frame
504            self.read_remains = decoded.remains as u32;
505        } else if self.read_remains == 0 && decoded.remains == 0 {
506            // no new data, start keep-alive timer
507            if self.shared.contains(Flags::KA_ENABLED)
508                && !self.shared.contains(Flags::KA_TIMEOUT)
509            {
510                log::trace!(
511                    "{}: Start keep-alive timer {:?}",
512                    self.shared.io.tag(),
513                    self.shared.io.cfg().keepalive_timeout()
514                );
515                self.shared.insert_flags(Flags::KA_TIMEOUT);
516                self.shared
517                    .io
518                    .start_timer(self.shared.io.cfg().keepalive_timeout());
519            }
520        } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
521            // we got new data but not enough to parse single frame
522            // start read timer
523            self.shared.insert_flags(Flags::READ_TIMEOUT);
524
525            self.read_remains = decoded.remains as u32;
526            self.read_remains_prev = 0;
527            self.read_max_timeout = params.max_timeout;
528            self.shared.io.start_timer(params.timeout);
529        }
530    }
531
532    fn handle_timeout(&mut self) -> Result<(), Reason<U>> {
533        // check read timer
534        if self.shared.contains(Flags::READ_TIMEOUT) {
535            if let Some(params) = self.shared.io.cfg().frame_read_rate() {
536                let total = self.read_remains - self.read_remains_prev;
537
538                // read rate, start timer for next period
539                if total > params.rate {
540                    self.read_remains_prev = self.read_remains;
541                    self.read_remains = 0;
542
543                    if !params.max_timeout.is_zero() {
544                        self.read_max_timeout = Seconds(
545                            self.read_max_timeout.0.saturating_sub(params.timeout.0),
546                        );
547                    }
548
549                    if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
550                        log::trace!(
551                            "{}: Frame read rate {:?}, extend timer",
552                            self.shared.io.tag(),
553                            total
554                        );
555                        self.shared.io.start_timer(params.timeout);
556                        return Ok(());
557                    }
558                    log::trace!(
559                        "{}: Max payload timeout has been reached",
560                        self.shared.io.tag()
561                    );
562                }
563                Err(Reason::ReadTimeout)
564            } else {
565                Ok(())
566            }
567        } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
568            log::trace!(
569                "{}: Keep-alive error, stopping dispatcher",
570                self.shared.io.tag()
571            );
572            Err(Reason::KeepAliveTimeout)
573        } else {
574            Ok(())
575        }
576    }
577}
578
579impl<U> fmt::Debug for DispatchItem<U>
580where
581    U: Encoder + Decoder,
582    <U as Decoder>::Item: fmt::Debug,
583{
584    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
585        match *self {
586            DispatchItem::Item(ref item) => {
587                write!(fmt, "DispatchItem::Item({item:?})")
588            }
589            DispatchItem::Control(ref e) => {
590                write!(fmt, "DispatchItem::Control({e:?})")
591            }
592            DispatchItem::Stop(ref e) => {
593                write!(fmt, "DispatchItem::Stop({e:?})")
594            }
595        }
596    }
597}
598
599impl<U> fmt::Debug for Reason<U>
600where
601    U: Encoder + Decoder,
602    <U as Encoder>::Error: fmt::Debug,
603    <U as Decoder>::Error: fmt::Debug,
604{
605    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
606        match *self {
607            Reason::Io(ref err) => {
608                write!(fmt, "Reason::Io({err:?})")
609            }
610            Reason::Encoder(ref err) => {
611                write!(fmt, "Reason::Encoder({err:?})")
612            }
613            Reason::Decoder(ref err) => {
614                write!(fmt, "Reason::Decoder({err:?})")
615            }
616            Reason::KeepAliveTimeout => {
617                write!(fmt, "Reason::KeepAliveTimeout")
618            }
619            Reason::ReadTimeout => {
620                write!(fmt, "Reason::ReadTimeout")
621            }
622        }
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
629    use std::{cell::RefCell, io};
630
631    use ntex_bytes::{Bytes, BytesMut};
632    use ntex_codec::BytesCodec;
633    use ntex_io::{Flags, Io, IoConfig, IoRef, testing::IoTest};
634    use ntex_service::{ServiceCtx, cfg::SharedCfg};
635    use ntex_util::{time::Millis, time::sleep};
636    use rand::Rng;
637
638    use super::*;
639
640    pub(crate) struct State(IoRef);
641
642    impl State {
643        fn flags(&self) -> Flags {
644            self.0.flags()
645        }
646
647        fn io(&self) -> &IoRef {
648            &self.0
649        }
650
651        fn close(&self) {
652            self.0.close();
653        }
654    }
655
656    #[derive(Copy, Clone)]
657    struct BCodec(usize);
658
659    impl Encoder for BCodec {
660        type Item = Bytes;
661        type Error = io::Error;
662
663        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
664            dst.extend_from_slice(&item[..]);
665            Ok(())
666        }
667    }
668
669    impl Decoder for BCodec {
670        type Item = Bytes;
671        type Error = io::Error;
672
673        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
674            if src.len() < self.0 {
675                Ok(None)
676            } else {
677                Ok(Some(src.split_to(self.0)))
678            }
679        }
680    }
681
682    impl<S, U> Dispatcher<S, U>
683    where
684        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
685        U: Decoder + Encoder + 'static,
686    {
687        /// Construct new `Dispatcher` instance
688        pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
689            let flags = if io.cfg().keepalive_timeout().is_zero() {
690                super::Flags::empty()
691            } else {
692                super::Flags::KA_ENABLED
693            };
694
695            let inner = State(io.get_ref());
696            io.start_timer(Seconds::ONE);
697
698            let shared = Rc::new(DispatcherShared {
699                codec,
700                io: io.into(),
701                flags: Cell::new(flags),
702                error: Cell::new(None),
703                inflight: Cell::new(0),
704                service: Pipeline::new(service).bind(),
705            });
706
707            (
708                Dispatcher {
709                    inner: DispatcherInner {
710                        shared,
711                        error: None,
712                        st: DispatcherState::Processing,
713                        response: None,
714                        read_remains: 0,
715                        read_remains_prev: 0,
716                        read_max_timeout: Seconds::ZERO,
717                    },
718                },
719                inner,
720            )
721        }
722    }
723
724    #[ntex::test]
725    async fn basics() {
726        let (client, server) = IoTest::create();
727        client.remote_buffer_cap(1024);
728        client.write("GET /test HTTP/1\r\n\r\n");
729
730        let (disp, _) = Dispatcher::debug(
731            Io::from(server),
732            BytesCodec,
733            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
734                sleep(Millis(50)).await;
735                if let DispatchItem::Item(msg) = msg {
736                    Ok::<_, ()>(Some(msg))
737                } else {
738                    panic!()
739                }
740            }),
741        );
742        spawn(async move {
743            let _ = disp.await;
744        });
745
746        sleep(Millis(25)).await;
747        let buf = client.read().await.unwrap();
748        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
749
750        client.write("GET /test HTTP/1\r\n\r\n");
751        let buf = client.read().await.unwrap();
752        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
753
754        client.close().await;
755        assert!(client.is_server_dropped());
756
757        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
758    }
759
760    #[ntex::test]
761    async fn sink() {
762        let (client, server) = IoTest::create();
763        client.remote_buffer_cap(1024);
764        client.write("GET /test HTTP/1\r\n\r\n");
765
766        let (disp, st) = Dispatcher::debug(
767            Io::from(server),
768            BytesCodec,
769            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
770                if let DispatchItem::Item(msg) = msg {
771                    Ok::<_, ()>(Some(msg))
772                } else if let DispatchItem::Stop(_) = msg {
773                    Ok(None)
774                } else {
775                    panic!()
776                }
777            }),
778        );
779        spawn(async move {
780            let _ = disp.await;
781        });
782
783        let buf = client.read().await.unwrap();
784        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
785
786        assert!(
787            st.io()
788                .encode(Bytes::from_static(b"test"), &BytesCodec)
789                .is_ok()
790        );
791        let buf = client.read().await.unwrap();
792        assert_eq!(buf, Bytes::from_static(b"test"));
793
794        st.close();
795        sleep(Millis(1500)).await;
796        assert!(client.is_server_dropped());
797    }
798
799    #[ntex::test]
800    async fn err_in_service() {
801        let (client, server) = IoTest::create();
802        client.remote_buffer_cap(0);
803        client.write("GET /test HTTP/1\r\n\r\n");
804
805        let (disp, state) = Dispatcher::debug(
806            Io::from(server),
807            BytesCodec,
808            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
809                Err::<Option<Bytes>, _>(())
810            }),
811        );
812        state
813            .io()
814            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
815            .unwrap();
816        spawn(async move {
817            let _ = disp.await;
818        });
819
820        // buffer should be flushed
821        client.remote_buffer_cap(1024);
822        let buf = client.read().await.unwrap();
823        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
824
825        // write side must be closed, dispatcher waiting for read side to close
826        sleep(Millis(250)).await;
827        assert!(client.is_closed());
828
829        // close read side
830        client.close().await;
831
832        // dispatcher is closed
833        assert!(client.is_server_dropped());
834    }
835
836    #[ntex::test]
837    #[allow(clippy::items_after_statements)]
838    async fn err_in_service_ready() {
839        let (client, server) = IoTest::create();
840        client.remote_buffer_cap(0);
841        client.write("GET /test HTTP/1\r\n\r\n");
842
843        let counter = Rc::new(Cell::new(0));
844
845        struct Srv(Rc<Cell<usize>>);
846
847        impl Service<DispatchItem<BytesCodec>> for Srv {
848            type Response = Option<Response<BytesCodec>>;
849            type Error = &'static str;
850
851            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
852                self.0.set(self.0.get() + 1);
853                Err("test")
854            }
855
856            async fn call(
857                &self,
858                _: DispatchItem<BytesCodec>,
859                _: ServiceCtx<'_, Self>,
860            ) -> Result<Self::Response, Self::Error> {
861                Ok(None)
862            }
863        }
864
865        let (disp, state) =
866            Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
867        spawn(async move {
868            let res = disp.await;
869            assert_eq!(res, Err("test"));
870        });
871
872        state
873            .io()
874            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
875            .unwrap();
876
877        // buffer should be flushed
878        client.remote_buffer_cap(1024);
879        let buf = client.read().await.unwrap();
880        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
881
882        // write side must be closed, dispatcher waiting for read side to close
883        sleep(Millis(250)).await;
884        assert!(client.is_closed());
885
886        // close read side
887        client.close().await;
888        assert!(client.is_server_dropped());
889
890        // service must be checked for readiness only once
891        assert_eq!(counter.get(), 1);
892    }
893
894    #[ntex::test]
895    async fn write_backpressure() {
896        let (client, server) = IoTest::create();
897        // do not allow to write to socket
898        client.remote_buffer_cap(0);
899        client.write("GET /test HTTP/1\r\n\r\n");
900
901        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
902        let data2 = data.clone();
903
904        let io = Io::new(
905            server,
906            SharedCfg::new("TEST").add(
907                IoConfig::new()
908                    .set_read_buf(8 * 1024, 1024, 16)
909                    .set_write_buf(16 * 1024, 1024, 16),
910            ),
911        );
912
913        let (disp, state) = Dispatcher::debug(
914            io,
915            BytesCodec,
916            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
917                let data = data2.clone();
918                async move {
919                    match msg {
920                        DispatchItem::Item(_) => {
921                            data.lock().unwrap().borrow_mut().push(0);
922                            let bytes = rand::rng()
923                                .sample_iter(&rand::distr::Alphanumeric)
924                                .take(65_536)
925                                .map(char::from)
926                                .collect::<String>();
927                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
928                        }
929                        DispatchItem::Control(Control::WBackPressureEnabled) => {
930                            data.lock().unwrap().borrow_mut().push(1);
931                        }
932                        DispatchItem::Control(Control::WBackPressureDisabled) => {
933                            data.lock().unwrap().borrow_mut().push(2);
934                        }
935                        _ => (),
936                    }
937                    Ok(None)
938                }
939            }),
940        );
941
942        spawn(async move {
943            let _ = disp.await;
944        });
945
946        let buf = client.read_any();
947        assert_eq!(buf, Bytes::from_static(b""));
948        client.write("GET /test HTTP/1\r\n\r\n");
949        sleep(Millis(25)).await;
950
951        // buf must be consumed
952        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
953
954        // response message
955        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
956
957        client.remote_buffer_cap(10240);
958        sleep(Millis(50)).await;
959        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
960
961        client.remote_buffer_cap(48056);
962        sleep(Millis(50)).await;
963        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
964
965        // backpressure disabled
966        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
967    }
968
969    #[ntex::test]
970    async fn disconnect_during_read_backpressure() {
971        let (client, server) = IoTest::create();
972        client.remote_buffer_cap(0);
973
974        let (disp, state) = Dispatcher::debug(
975            Io::new(
976                server,
977                SharedCfg::new("TEST").add(
978                    IoConfig::new()
979                        .set_keepalive_timeout(Seconds::ZERO)
980                        .set_read_buf(1024, 512, 16),
981                ),
982            ),
983            BytesCodec,
984            ntex_util::services::inflight::InFlightService::new(
985                1,
986                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
987                    if let DispatchItem::Item(_) = msg {
988                        sleep(Millis(500)).await;
989                        Ok::<_, ()>(None)
990                    } else {
991                        Ok(None)
992                    }
993                }),
994            ),
995        );
996
997        let (tx, rx) = ntex::channel::oneshot::channel();
998        ntex::rt::spawn(async move {
999            let _ = disp.await;
1000            let _ = tx.send(());
1001        });
1002
1003        let bytes = rand::rng()
1004            .sample_iter(&rand::distr::Alphanumeric)
1005            .take(1024)
1006            .map(char::from)
1007            .collect::<String>();
1008        client.write(bytes.clone());
1009        sleep(Millis(25)).await;
1010        client.write(bytes);
1011        sleep(Millis(25)).await;
1012
1013        // close read side
1014        state.close();
1015        let _ = rx.recv().await;
1016    }
1017
1018    #[ntex::test]
1019    async fn keepalive() {
1020        let (client, server) = IoTest::create();
1021        client.remote_buffer_cap(1024);
1022        client.write("GET /test HTTP/1\r\n\r\n");
1023
1024        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1025        let data2 = data.clone();
1026
1027        let cfg = SharedCfg::new("DBG").add(
1028            IoConfig::new()
1029                .set_disconnect_timeout(Seconds(1))
1030                .set_keepalive_timeout(Seconds(1)),
1031        );
1032
1033        let (disp, state) = Dispatcher::debug(
1034            Io::new(server, cfg),
1035            BytesCodec,
1036            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1037                let data = data2.clone();
1038                async move {
1039                    match msg {
1040                        DispatchItem::Item(bytes) => {
1041                            data.lock().unwrap().borrow_mut().push(0);
1042                            return Ok::<_, ()>(Some(bytes));
1043                        }
1044                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1045                            data.lock().unwrap().borrow_mut().push(1);
1046                        }
1047                        _ => (),
1048                    }
1049                    Ok(None)
1050                }
1051            }),
1052        );
1053        spawn(async move {
1054            let _ = disp.await;
1055        });
1056
1057        let buf = client.read().await.unwrap();
1058        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1059        sleep(Millis(2000)).await;
1060
1061        // write side must be closed, dispatcher should fail with keep-alive
1062        let flags = state.flags();
1063        assert!(flags.contains(Flags::IO_STOPPING));
1064        assert!(client.is_closed());
1065        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1066    }
1067
1068    #[ntex::test]
1069    async fn keepalive2() {
1070        let (client, server) = IoTest::create();
1071        client.remote_buffer_cap(1024);
1072
1073        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1074        let data2 = data.clone();
1075
1076        let cfg = SharedCfg::new("DBG").add(
1077            IoConfig::new()
1078                .set_keepalive_timeout(Seconds(1))
1079                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1080        );
1081
1082        let (disp, state) = Dispatcher::debug(
1083            Io::new(server, cfg),
1084            BCodec(8),
1085            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1086                let data = data2.clone();
1087                async move {
1088                    match msg {
1089                        DispatchItem::Item(bytes) => {
1090                            data.lock().unwrap().borrow_mut().push(0);
1091                            return Ok::<_, ()>(Some(bytes));
1092                        }
1093                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1094                            data.lock().unwrap().borrow_mut().push(1);
1095                        }
1096                        _ => (),
1097                    }
1098                    Ok(None)
1099                }
1100            }),
1101        );
1102        spawn(async move {
1103            let _ = disp.await;
1104        });
1105
1106        client.write("12345678");
1107        let buf = client.read().await.unwrap();
1108        assert_eq!(buf, Bytes::from_static(b"12345678"));
1109        sleep(Millis(2000)).await;
1110
1111        // write side must be closed, dispatcher should fail with keep-alive
1112        let flags = state.flags();
1113        assert!(flags.contains(Flags::IO_STOPPING));
1114        assert!(client.is_closed());
1115        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1116    }
1117
1118    /// Update keep-alive timer after receiving frame
1119    #[ntex::test]
1120    async fn keepalive3() {
1121        let (client, server) = IoTest::create();
1122        client.remote_buffer_cap(1024);
1123
1124        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1125        let data2 = data.clone();
1126
1127        let cfg = SharedCfg::new("DBG").add(
1128            IoConfig::new()
1129                .set_keepalive_timeout(Seconds(2))
1130                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1131        );
1132
1133        let (disp, _) = Dispatcher::debug(
1134            Io::new(server, cfg),
1135            BCodec(1),
1136            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1137                let data = data2.clone();
1138                async move {
1139                    match msg {
1140                        DispatchItem::Item(bytes) => {
1141                            data.lock().unwrap().borrow_mut().push(0);
1142                            return Ok::<_, ()>(Some(bytes));
1143                        }
1144                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1145                            data.lock().unwrap().borrow_mut().push(1);
1146                        }
1147                        _ => (),
1148                    }
1149                    Ok(None)
1150                }
1151            }),
1152        );
1153        spawn(async move {
1154            let _ = disp.await;
1155        });
1156
1157        client.write("1");
1158        let buf = client.read().await.unwrap();
1159        assert_eq!(buf, Bytes::from_static(b"1"));
1160        sleep(Millis(750)).await;
1161
1162        client.write("2");
1163        let buf = client.read().await.unwrap();
1164        assert_eq!(buf, Bytes::from_static(b"2"));
1165
1166        sleep(Millis(750)).await;
1167        client.write("3");
1168        let buf = client.read().await.unwrap();
1169        assert_eq!(buf, Bytes::from_static(b"3"));
1170
1171        sleep(Millis(750)).await;
1172        assert!(!client.is_closed());
1173        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1174    }
1175
1176    #[ntex::test]
1177    async fn read_timeout() {
1178        let (client, server) = IoTest::create();
1179        client.remote_buffer_cap(1024);
1180
1181        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1182        let data2 = data.clone();
1183
1184        let io = Io::new(
1185            server,
1186            SharedCfg::new("TEST").add(
1187                IoConfig::new()
1188                    .set_keepalive_timeout(Seconds::ZERO)
1189                    .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1190            ),
1191        );
1192
1193        let (disp, state) = Dispatcher::debug(
1194            io,
1195            BCodec(8),
1196            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1197                let data = data2.clone();
1198                async move {
1199                    match msg {
1200                        DispatchItem::Item(bytes) => {
1201                            data.lock().unwrap().borrow_mut().push(0);
1202                            return Ok::<_, ()>(Some(bytes));
1203                        }
1204                        DispatchItem::Stop(Reason::ReadTimeout) => {
1205                            data.lock().unwrap().borrow_mut().push(1);
1206                        }
1207                        _ => (),
1208                    }
1209                    Ok(None)
1210                }
1211            }),
1212        );
1213        spawn(async move {
1214            let _ = disp.await;
1215        });
1216
1217        client.write("12345678");
1218        let buf = client.read().await.unwrap();
1219        assert_eq!(buf, Bytes::from_static(b"12345678"));
1220
1221        client.write("1");
1222        sleep(Millis(1000)).await;
1223        assert!(!state.flags().contains(Flags::IO_STOPPING));
1224        client.write("23");
1225        sleep(Millis(1000)).await;
1226        assert!(!state.flags().contains(Flags::IO_STOPPING));
1227        client.write("4");
1228        sleep(Millis(2000)).await;
1229
1230        // write side must be closed, dispatcher should fail with keep-alive
1231        assert!(state.flags().contains(Flags::IO_STOPPING));
1232        assert!(client.is_closed());
1233        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1234    }
1235
1236    #[ntex::test]
1237    async fn idle_timeout() {
1238        let (client, server) = IoTest::create();
1239        client.remote_buffer_cap(1024);
1240
1241        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1242        let data2 = data.clone();
1243
1244        let io = Io::new(
1245            server,
1246            SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1247        );
1248        let ioref = io.get_ref();
1249
1250        let (disp, state) = Dispatcher::debug(
1251            io,
1252            BCodec(1),
1253            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1254                let ioref = ioref.clone();
1255                ntex::rt::spawn(async move {
1256                    sleep(Millis(500)).await;
1257                    ioref.notify_timeout();
1258                });
1259                let data = data2.clone();
1260                async move {
1261                    match msg {
1262                        DispatchItem::Item(bytes) => {
1263                            data.lock().unwrap().borrow_mut().push(0);
1264                            return Ok::<_, ()>(Some(bytes));
1265                        }
1266                        DispatchItem::Stop(Reason::ReadTimeout) => {
1267                            data.lock().unwrap().borrow_mut().push(1);
1268                        }
1269                        _ => (),
1270                    }
1271                    Ok(None)
1272                }
1273            }),
1274        );
1275        spawn(async move {
1276            let _ = disp.await;
1277        });
1278
1279        client.write("1");
1280        let buf = client.read().await.unwrap();
1281        assert_eq!(buf, Bytes::from_static(b"1"));
1282
1283        sleep(Millis(1000)).await;
1284        assert!(state.flags().contains(Flags::IO_STOPPING));
1285        assert!(client.is_closed());
1286    }
1287
1288    #[ntex::test]
1289    async fn unhandled_data() {
1290        let handled = Arc::new(AtomicBool::new(false));
1291        let handled2 = handled.clone();
1292
1293        let (client, server) = IoTest::create();
1294        client.remote_buffer_cap(1024);
1295        client.write("GET /test HTTP/1\r\n\r\n");
1296
1297        let (disp, _) = Dispatcher::debug(
1298            Io::from(server),
1299            BytesCodec,
1300            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1301                handled2.store(true, Relaxed);
1302                async move {
1303                    sleep(Millis(50)).await;
1304                    if let DispatchItem::Item(msg) = msg {
1305                        Ok::<_, ()>(Some(msg))
1306                    } else if let DispatchItem::Stop(_) = msg {
1307                        Ok::<_, ()>(None)
1308                    } else {
1309                        panic!()
1310                    }
1311                }
1312            }),
1313        );
1314        client.close().await;
1315        spawn(async move {
1316            let _ = disp.await;
1317        });
1318        sleep(Millis(50)).await;
1319
1320        assert!(handled.load(Relaxed));
1321    }
1322}