Skip to main content

ntex_dispatcher/
lib.rs

1//! Framed transport dispatcher
2#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, fmt, future::Future, io, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_io::{Decoded, IoBoxed, IoStatusUpdate, RecvError};
8use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
9use ntex_util::{future::Either, spawn, time::Seconds};
10
11type Response<U> = <U as Encoder>::Item;
12
13/// Dispatch item
14pub enum DispatchItem<U: Encoder + Decoder> {
15    /// Parsed item
16    Item(<U as Decoder>::Item),
17    /// Dispatcher control message
18    Control(Control),
19    /// Dispatcher is stopping
20    Stop(Reason<U>),
21}
22
23#[derive(Copy, Clone, Debug, PartialEq, Eq)]
24/// Dispatcher control message
25pub enum Control {
26    /// Write back-pressure enabled
27    WBackPressureEnabled,
28    /// Write back-pressure disabled
29    WBackPressureDisabled,
30}
31
32/// Dispatcher disconnect message
33pub enum Reason<U: Encoder + Decoder> {
34    /// Socket has been disconnected
35    Io(Option<io::Error>),
36    /// Encoder error
37    Encoder(<U as Encoder>::Error),
38    /// Decoder error
39    Decoder(<U as Decoder>::Error),
40    /// Keep alive timeout
41    KeepAliveTimeout,
42    /// Frame read timeout
43    ReadTimeout,
44}
45
46pin_project_lite::pin_project! {
47    /// Dispatcher - is a future that reads frames from bytes stream
48    /// and pass then to the service.
49    pub struct Dispatcher<S, U>
50    where
51        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
52        U: Encoder,
53        U: Decoder,
54        U: 'static,
55    {
56        inner: DispatcherInner<S, U>,
57    }
58}
59
60bitflags::bitflags! {
61    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
62    struct Flags: u8  {
63        const READY_ERR     = 0b0000001;
64        const IO_ERR        = 0b0000010;
65        const KA_ENABLED    = 0b0000100;
66        const KA_TIMEOUT    = 0b0001000;
67        const READ_TIMEOUT  = 0b0010000;
68        const IDLE          = 0b0100000;
69    }
70}
71
72struct DispatcherInner<S, U>
73where
74    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
75    U: Encoder + Decoder + 'static,
76{
77    st: DispatcherState,
78    error: Option<S::Error>,
79    shared: Rc<DispatcherShared<S, U>>,
80    response: Option<PipelineCall<S, DispatchItem<U>>>,
81    read_remains: u32,
82    read_remains_prev: u32,
83    read_max_timeout: Seconds,
84}
85
86pub(crate) struct DispatcherShared<S, U>
87where
88    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
89    U: Encoder + Decoder,
90{
91    io: IoBoxed,
92    codec: U,
93    service: PipelineBinding<S, DispatchItem<U>>,
94    flags: Cell<Flags>,
95    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
96    inflight: Cell<u32>,
97}
98
99#[derive(Copy, Clone, Debug)]
100enum DispatcherState {
101    Processing,
102    Backpressure,
103    Stop,
104    Shutdown,
105}
106
107#[derive(Debug)]
108enum DispatcherError<S, U> {
109    Encoder(U),
110    Service(S),
111}
112
113enum PollService<U: Encoder + Decoder> {
114    Item(DispatchItem<U>),
115    ItemWait(DispatchItem<U>),
116    Continue,
117    Ready,
118}
119
120impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
121    fn from(err: Either<S, U>) -> Self {
122        match err {
123            Either::Left(err) => DispatcherError::Service(err),
124            Either::Right(err) => DispatcherError::Encoder(err),
125        }
126    }
127}
128
129impl<S, U> Dispatcher<S, U>
130where
131    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
132    U: Decoder + Encoder + 'static,
133{
134    /// Construct new `Dispatcher` instance.
135    pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
136    where
137        IoBoxed: From<Io>,
138        F: IntoService<S, DispatchItem<U>>,
139    {
140        let io = IoBoxed::from(io);
141        let flags = if io.cfg().keepalive_timeout().is_zero() {
142            Flags::empty()
143        } else {
144            Flags::KA_ENABLED
145        };
146
147        let shared = Rc::new(DispatcherShared {
148            io,
149            codec,
150            flags: Cell::new(flags),
151            error: Cell::new(None),
152            inflight: Cell::new(0),
153            service: Pipeline::new(service.into_service()).bind(),
154        });
155
156        Dispatcher {
157            inner: DispatcherInner {
158                shared,
159                response: None,
160                error: None,
161                read_remains: 0,
162                read_remains_prev: 0,
163                read_max_timeout: Seconds::ZERO,
164                st: DispatcherState::Processing,
165            },
166        }
167    }
168}
169
170impl<S, U> DispatcherShared<S, U>
171where
172    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
173    U: Encoder + Decoder,
174{
175    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
176        match item {
177            Ok(Some(val)) => {
178                if let Err(err) = io.encode(val, &self.codec) {
179                    self.error.set(Some(DispatcherError::Encoder(err)))
180                }
181            }
182            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
183            Ok(None) => (),
184        }
185        let inflight = self.inflight.get() - 1;
186        self.inflight.set(inflight);
187        if inflight == 0 {
188            self.insert_flags(Flags::IDLE);
189        }
190        if wake {
191            io.wake();
192        }
193    }
194
195    fn contains(&self, f: Flags) -> bool {
196        self.flags.get().intersects(f)
197    }
198
199    fn insert_flags(&self, f: Flags) {
200        let mut flags = self.flags.get();
201        flags.insert(f);
202        self.flags.set(flags);
203    }
204
205    fn remove_flags(&self, f: Flags) -> bool {
206        let mut flags = self.flags.get();
207        if flags.intersects(f) {
208            flags.remove(f);
209            self.flags.set(flags);
210            true
211        } else {
212            false
213        }
214    }
215}
216
217impl<S, U> Future for Dispatcher<S, U>
218where
219    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
220    U: Decoder + Encoder + 'static,
221{
222    type Output = Result<(), S::Error>;
223
224    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225        let mut this = self.as_mut().project();
226        let slf = &mut this.inner;
227
228        // handle service response future
229        if let Some(fut) = slf.response.as_mut()
230            && let Poll::Ready(item) = Pin::new(fut).poll(cx)
231        {
232            slf.shared.handle_result(item, &slf.shared.io, false);
233            slf.response = None;
234        }
235
236        loop {
237            match slf.st {
238                DispatcherState::Processing => {
239                    let (item, nowait) = match ready!(slf.poll_service(cx)) {
240                        PollService::Ready => {
241                            // decode incoming bytes if buffer is ready
242                            match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
243                                Ok(decoded) => {
244                                    slf.update_timer(&decoded);
245                                    if let Some(el) = decoded.item {
246                                        (DispatchItem::Item(el), true)
247                                    } else {
248                                        return Poll::Pending;
249                                    }
250                                }
251                                Err(RecvError::KeepAlive) => {
252                                    if let Err(ctl) = slf.handle_timeout() {
253                                        slf.st = DispatcherState::Stop;
254                                        (DispatchItem::Stop(ctl), true)
255                                    } else {
256                                        continue;
257                                    }
258                                }
259                                Err(RecvError::WriteBackpressure) => {
260                                    // instruct write task to notify dispatcher when data is flushed
261                                    slf.st = DispatcherState::Backpressure;
262                                    (
263                                        DispatchItem::Control(
264                                            Control::WBackPressureEnabled,
265                                        ),
266                                        true,
267                                    )
268                                }
269                                Err(RecvError::Decoder(err)) => {
270                                    log::trace!(
271                                        "{}: Decoder error, stopping dispatcher: {:?}",
272                                        slf.shared.io.tag(),
273                                        err
274                                    );
275                                    slf.st = DispatcherState::Stop;
276                                    (DispatchItem::Stop(Reason::Decoder(err)), true)
277                                }
278                                Err(RecvError::PeerGone(err)) => {
279                                    log::trace!(
280                                        "{}: Peer is gone, stopping dispatcher: {:?}",
281                                        slf.shared.io.tag(),
282                                        err
283                                    );
284                                    slf.st = DispatcherState::Stop;
285                                    (DispatchItem::Stop(Reason::Io(err)), true)
286                                }
287                            }
288                        }
289                        PollService::Item(item) => (item, true),
290                        PollService::ItemWait(item) => (item, false),
291                        PollService::Continue => continue,
292                    };
293
294                    slf.call_service(cx, item, nowait);
295                }
296                // handle write back-pressure
297                DispatcherState::Backpressure => {
298                    match ready!(slf.poll_service(cx)) {
299                        PollService::Ready => (),
300                        PollService::Item(item) => slf.call_service(cx, item, true),
301                        PollService::ItemWait(item) => slf.call_service(cx, item, false),
302                        PollService::Continue => continue,
303                    };
304
305                    let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
306                    {
307                        slf.st = DispatcherState::Stop;
308                        DispatchItem::Stop(Reason::Io(Some(err)))
309                    } else {
310                        slf.st = DispatcherState::Processing;
311                        DispatchItem::Control(Control::WBackPressureDisabled)
312                    };
313                    slf.call_service(cx, item, false);
314                }
315                // drain service responses and shutdown io
316                DispatcherState::Stop => {
317                    slf.shared.io.stop_timer();
318
319                    // service may relay on poll_ready for response results
320                    if !slf.shared.contains(Flags::READY_ERR)
321                        && let Poll::Ready(res) = slf.shared.service.poll_ready(cx)
322                        && res.is_err()
323                    {
324                        slf.shared.insert_flags(Flags::READY_ERR);
325                    }
326
327                    if slf.shared.inflight.get() == 0 {
328                        if slf.shared.io.poll_shutdown(cx).is_ready() {
329                            slf.st = DispatcherState::Shutdown;
330                            continue;
331                        }
332                    } else if !slf.shared.contains(Flags::IO_ERR) {
333                        match ready!(slf.shared.io.poll_status_update(cx)) {
334                            IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
335                                slf.shared.insert_flags(Flags::IO_ERR);
336                                continue;
337                            }
338                            IoStatusUpdate::WriteBackpressure => {
339                                if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
340                                    slf.shared.insert_flags(Flags::IO_ERR);
341                                }
342                                continue;
343                            }
344                        }
345                    } else {
346                        slf.shared.io.poll_dispatch(cx);
347                    }
348                    return Poll::Pending;
349                }
350                // shutdown service
351                DispatcherState::Shutdown => {
352                    return if slf.shared.service.poll_shutdown(cx).is_ready() {
353                        log::trace!(
354                            "{}: Service shutdown is completed, stop",
355                            slf.shared.io.tag()
356                        );
357
358                        Poll::Ready(if let Some(err) = slf.error.take() {
359                            Err(err)
360                        } else {
361                            Ok(())
362                        })
363                    } else {
364                        Poll::Pending
365                    };
366                }
367            }
368        }
369    }
370}
371
372impl<S, U> DispatcherInner<S, U>
373where
374    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
375    U: Decoder + Encoder + 'static,
376{
377    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>, nowait: bool) {
378        let mut fut = if nowait {
379            self.shared.service.call_nowait(item)
380        } else {
381            self.shared.service.call(item)
382        };
383        let inflight = self.shared.inflight.get() + 1;
384        self.shared.inflight.set(inflight);
385        if inflight == 1 {
386            self.shared.remove_flags(Flags::IDLE);
387        }
388
389        // optimize first call
390        if self.response.is_none() {
391            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
392                self.shared.handle_result(result, &self.shared.io, false);
393            } else {
394                self.response = Some(fut);
395            }
396        } else {
397            let shared = self.shared.clone();
398            spawn(async move {
399                let result = fut.await;
400                shared.handle_result(result, &shared.io, true);
401            });
402        }
403    }
404
405    fn check_error(&mut self) -> PollService<U> {
406        // check for errors
407        if let Some(err) = self.shared.error.take() {
408            log::trace!(
409                "{}: Error occured, stopping dispatcher",
410                self.shared.io.tag()
411            );
412            self.st = DispatcherState::Stop;
413
414            match err {
415                DispatcherError::Encoder(err) => {
416                    PollService::Item(DispatchItem::Stop(Reason::Encoder(err)))
417                }
418                DispatcherError::Service(err) => {
419                    self.error = Some(err);
420                    PollService::Continue
421                }
422            }
423        } else {
424            PollService::Ready
425        }
426    }
427
428    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
429        // wait until service becomes ready
430        match self.shared.service.poll_ready(cx) {
431            Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
432            // pause io read task
433            Poll::Pending => {
434                log::trace!(
435                    "{}: Service is not ready, register dispatcher",
436                    self.shared.io.tag()
437                );
438
439                // remove all timers
440                self.shared
441                    .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
442                self.shared.io.stop_timer();
443
444                match ready!(self.shared.io.poll_read_pause(cx)) {
445                    IoStatusUpdate::KeepAlive => {
446                        log::trace!(
447                            "{}: Keep-alive error, stopping dispatcher during pause",
448                            self.shared.io.tag()
449                        );
450                        self.st = DispatcherState::Stop;
451                        Poll::Ready(PollService::ItemWait(DispatchItem::Stop(
452                            Reason::KeepAliveTimeout,
453                        )))
454                    }
455                    IoStatusUpdate::PeerGone(err) => {
456                        log::trace!(
457                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
458                            self.shared.io.tag(),
459                            err
460                        );
461                        self.st = DispatcherState::Stop;
462                        Poll::Ready(PollService::ItemWait(DispatchItem::Stop(Reason::Io(
463                            err,
464                        ))))
465                    }
466                    IoStatusUpdate::WriteBackpressure => {
467                        self.st = DispatcherState::Backpressure;
468                        Poll::Ready(PollService::ItemWait(DispatchItem::Control(
469                            Control::WBackPressureEnabled,
470                        )))
471                    }
472                }
473            }
474            // handle service readiness error
475            Poll::Ready(Err(err)) => {
476                log::trace!(
477                    "{}: Service readiness check failed, stopping",
478                    self.shared.io.tag()
479                );
480                self.st = DispatcherState::Stop;
481                self.error = Some(err);
482                self.shared.insert_flags(Flags::READY_ERR);
483                Poll::Ready(PollService::Continue)
484            }
485        }
486    }
487
488    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
489        // got parsed frame
490        if decoded.item.is_some() {
491            self.read_remains = 0;
492            self.shared
493                .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
494        } else if self.shared.contains(Flags::READ_TIMEOUT) {
495            // received new data but not enough for parsing complete frame
496            self.read_remains = decoded.remains as u32;
497        } else if self.read_remains == 0 && decoded.remains == 0 {
498            // no new data, start keep-alive timer
499            if self.shared.contains(Flags::KA_ENABLED)
500                && !self.shared.contains(Flags::KA_TIMEOUT)
501            {
502                log::trace!(
503                    "{}: Start keep-alive timer {:?}",
504                    self.shared.io.tag(),
505                    self.shared.io.cfg().keepalive_timeout()
506                );
507                self.shared.insert_flags(Flags::KA_TIMEOUT);
508                self.shared
509                    .io
510                    .start_timer(self.shared.io.cfg().keepalive_timeout());
511            }
512        } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
513            // we got new data but not enough to parse single frame
514            // start read timer
515            self.shared.insert_flags(Flags::READ_TIMEOUT);
516
517            self.read_remains = decoded.remains as u32;
518            self.read_remains_prev = 0;
519            self.read_max_timeout = params.max_timeout;
520            self.shared.io.start_timer(params.timeout);
521        }
522    }
523
524    fn handle_timeout(&mut self) -> Result<(), Reason<U>> {
525        // check read timer
526        if self.shared.contains(Flags::READ_TIMEOUT) {
527            if let Some(params) = self.shared.io.cfg().frame_read_rate() {
528                let total = self.read_remains - self.read_remains_prev;
529
530                // read rate, start timer for next period
531                if total > params.rate {
532                    self.read_remains_prev = self.read_remains;
533                    self.read_remains = 0;
534
535                    if !params.max_timeout.is_zero() {
536                        self.read_max_timeout = Seconds(
537                            self.read_max_timeout.0.saturating_sub(params.timeout.0),
538                        );
539                    }
540
541                    if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
542                        log::trace!(
543                            "{}: Frame read rate {:?}, extend timer",
544                            self.shared.io.tag(),
545                            total
546                        );
547                        self.shared.io.start_timer(params.timeout);
548                        return Ok(());
549                    }
550                    log::trace!(
551                        "{}: Max payload timeout has been reached",
552                        self.shared.io.tag()
553                    );
554                }
555                Err(Reason::ReadTimeout)
556            } else {
557                Ok(())
558            }
559        } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
560            log::trace!(
561                "{}: Keep-alive error, stopping dispatcher",
562                self.shared.io.tag()
563            );
564            Err(Reason::KeepAliveTimeout)
565        } else {
566            Ok(())
567        }
568    }
569}
570
571impl<U> fmt::Debug for DispatchItem<U>
572where
573    U: Encoder + Decoder,
574    <U as Decoder>::Item: fmt::Debug,
575{
576    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
577        match *self {
578            DispatchItem::Item(ref item) => {
579                write!(fmt, "DispatchItem::Item({item:?})")
580            }
581            DispatchItem::Control(ref e) => {
582                write!(fmt, "DispatchItem::Control({e:?})")
583            }
584            DispatchItem::Stop(ref e) => {
585                write!(fmt, "DispatchItem::Stop({e:?})")
586            }
587        }
588    }
589}
590
591impl<U> fmt::Debug for Reason<U>
592where
593    U: Encoder + Decoder,
594    <U as Encoder>::Error: fmt::Debug,
595    <U as Decoder>::Error: fmt::Debug,
596{
597    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
598        match *self {
599            Reason::Io(ref err) => {
600                write!(fmt, "Reason::Io({err:?})")
601            }
602            Reason::Encoder(ref err) => {
603                write!(fmt, "Reason::Encoder({err:?})")
604            }
605            Reason::Decoder(ref err) => {
606                write!(fmt, "Reason::Decoder({err:?})")
607            }
608            Reason::KeepAliveTimeout => {
609                write!(fmt, "Reason::KeepAliveTimeout")
610            }
611            Reason::ReadTimeout => {
612                write!(fmt, "Reason::ReadTimeout")
613            }
614        }
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
621    use std::{cell::RefCell, io};
622
623    use ntex_bytes::{Bytes, BytesMut};
624    use ntex_codec::BytesCodec;
625    use ntex_io::{Flags, Io, IoConfig, IoRef, testing::IoTest};
626    use ntex_service::{ServiceCtx, cfg::SharedCfg};
627    use ntex_util::{time::Millis, time::sleep};
628    use rand::Rng;
629
630    use super::*;
631
632    pub(crate) struct State(IoRef);
633
634    impl State {
635        fn flags(&self) -> Flags {
636            self.0.flags()
637        }
638
639        fn io(&self) -> &IoRef {
640            &self.0
641        }
642
643        fn close(&self) {
644            self.0.close();
645        }
646    }
647
648    #[derive(Copy, Clone)]
649    struct BCodec(usize);
650
651    impl Encoder for BCodec {
652        type Item = Bytes;
653        type Error = io::Error;
654
655        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
656            dst.extend_from_slice(&item[..]);
657            Ok(())
658        }
659    }
660
661    impl Decoder for BCodec {
662        type Item = Bytes;
663        type Error = io::Error;
664
665        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
666            if src.len() < self.0 {
667                Ok(None)
668            } else {
669                Ok(Some(src.split_to(self.0)))
670            }
671        }
672    }
673
674    impl<S, U> Dispatcher<S, U>
675    where
676        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
677        U: Decoder + Encoder + 'static,
678    {
679        /// Construct new `Dispatcher` instance
680        pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
681            let flags = if io.cfg().keepalive_timeout().is_zero() {
682                super::Flags::empty()
683            } else {
684                super::Flags::KA_ENABLED
685            };
686
687            let inner = State(io.get_ref());
688            io.start_timer(Seconds::ONE);
689
690            let shared = Rc::new(DispatcherShared {
691                codec,
692                io: io.into(),
693                flags: Cell::new(flags),
694                error: Cell::new(None),
695                inflight: Cell::new(0),
696                service: Pipeline::new(service).bind(),
697            });
698
699            (
700                Dispatcher {
701                    inner: DispatcherInner {
702                        shared,
703                        error: None,
704                        st: DispatcherState::Processing,
705                        response: None,
706                        read_remains: 0,
707                        read_remains_prev: 0,
708                        read_max_timeout: Seconds::ZERO,
709                    },
710                },
711                inner,
712            )
713        }
714    }
715
716    #[ntex::test]
717    async fn basics() {
718        let (client, server) = IoTest::create();
719        client.remote_buffer_cap(1024);
720        client.write("GET /test HTTP/1\r\n\r\n");
721
722        let (disp, _) = Dispatcher::debug(
723            Io::from(server),
724            BytesCodec,
725            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
726                sleep(Millis(50)).await;
727                if let DispatchItem::Item(msg) = msg {
728                    Ok::<_, ()>(Some(msg))
729                } else {
730                    panic!()
731                }
732            }),
733        );
734        spawn(async move {
735            let _ = disp.await;
736        });
737
738        sleep(Millis(25)).await;
739        let buf = client.read().await.unwrap();
740        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
741
742        client.write("GET /test HTTP/1\r\n\r\n");
743        let buf = client.read().await.unwrap();
744        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
745
746        client.close().await;
747        assert!(client.is_server_dropped());
748
749        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
750    }
751
752    #[ntex::test]
753    async fn sink() {
754        let (client, server) = IoTest::create();
755        client.remote_buffer_cap(1024);
756        client.write("GET /test HTTP/1\r\n\r\n");
757
758        let (disp, st) = Dispatcher::debug(
759            Io::from(server),
760            BytesCodec,
761            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
762                if let DispatchItem::Item(msg) = msg {
763                    Ok::<_, ()>(Some(msg))
764                } else if let DispatchItem::Stop(_) = msg {
765                    Ok(None)
766                } else {
767                    panic!()
768                }
769            }),
770        );
771        spawn(async move {
772            let _ = disp.await;
773        });
774
775        let buf = client.read().await.unwrap();
776        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
777
778        assert!(
779            st.io()
780                .encode(Bytes::from_static(b"test"), &BytesCodec)
781                .is_ok()
782        );
783        let buf = client.read().await.unwrap();
784        assert_eq!(buf, Bytes::from_static(b"test"));
785
786        st.close();
787        sleep(Millis(1500)).await;
788        assert!(client.is_server_dropped());
789    }
790
791    #[ntex::test]
792    async fn err_in_service() {
793        let (client, server) = IoTest::create();
794        client.remote_buffer_cap(0);
795        client.write("GET /test HTTP/1\r\n\r\n");
796
797        let (disp, state) = Dispatcher::debug(
798            Io::from(server),
799            BytesCodec,
800            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
801                Err::<Option<Bytes>, _>(())
802            }),
803        );
804        state
805            .io()
806            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
807            .unwrap();
808        spawn(async move {
809            let _ = disp.await;
810        });
811
812        // buffer should be flushed
813        client.remote_buffer_cap(1024);
814        let buf = client.read().await.unwrap();
815        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
816
817        // write side must be closed, dispatcher waiting for read side to close
818        sleep(Millis(250)).await;
819        assert!(client.is_closed());
820
821        // close read side
822        client.close().await;
823
824        // dispatcher is closed
825        assert!(client.is_server_dropped());
826    }
827
828    #[ntex::test]
829    async fn err_in_service_ready() {
830        let (client, server) = IoTest::create();
831        client.remote_buffer_cap(0);
832        client.write("GET /test HTTP/1\r\n\r\n");
833
834        let counter = Rc::new(Cell::new(0));
835
836        struct Srv(Rc<Cell<usize>>);
837
838        impl Service<DispatchItem<BytesCodec>> for Srv {
839            type Response = Option<Response<BytesCodec>>;
840            type Error = &'static str;
841
842            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
843                self.0.set(self.0.get() + 1);
844                Err("test")
845            }
846
847            async fn call(
848                &self,
849                _: DispatchItem<BytesCodec>,
850                _: ServiceCtx<'_, Self>,
851            ) -> Result<Self::Response, Self::Error> {
852                Ok(None)
853            }
854        }
855
856        let (disp, state) =
857            Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
858        spawn(async move {
859            let res = disp.await;
860            assert_eq!(res, Err("test"));
861        });
862
863        state
864            .io()
865            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
866            .unwrap();
867
868        // buffer should be flushed
869        client.remote_buffer_cap(1024);
870        let buf = client.read().await.unwrap();
871        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
872
873        // write side must be closed, dispatcher waiting for read side to close
874        sleep(Millis(250)).await;
875        assert!(client.is_closed());
876
877        // close read side
878        client.close().await;
879        assert!(client.is_server_dropped());
880
881        // service must be checked for readiness only once
882        assert_eq!(counter.get(), 1);
883    }
884
885    #[ntex::test]
886    async fn write_backpressure() {
887        let (client, server) = IoTest::create();
888        // do not allow to write to socket
889        client.remote_buffer_cap(0);
890        client.write("GET /test HTTP/1\r\n\r\n");
891
892        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
893        let data2 = data.clone();
894
895        let io = Io::new(
896            server,
897            SharedCfg::new("TEST").add(
898                IoConfig::new()
899                    .set_read_buf(8 * 1024, 1024, 16)
900                    .set_write_buf(16 * 1024, 1024, 16),
901            ),
902        );
903
904        let (disp, state) = Dispatcher::debug(
905            io,
906            BytesCodec,
907            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
908                let data = data2.clone();
909                async move {
910                    match msg {
911                        DispatchItem::Item(_) => {
912                            data.lock().unwrap().borrow_mut().push(0);
913                            let bytes = rand::rng()
914                                .sample_iter(&rand::distr::Alphanumeric)
915                                .take(65_536)
916                                .map(char::from)
917                                .collect::<String>();
918                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
919                        }
920                        DispatchItem::Control(Control::WBackPressureEnabled) => {
921                            data.lock().unwrap().borrow_mut().push(1);
922                        }
923                        DispatchItem::Control(Control::WBackPressureDisabled) => {
924                            data.lock().unwrap().borrow_mut().push(2);
925                        }
926                        _ => (),
927                    }
928                    Ok(None)
929                }
930            }),
931        );
932
933        spawn(async move {
934            let _ = disp.await;
935        });
936
937        let buf = client.read_any();
938        assert_eq!(buf, Bytes::from_static(b""));
939        client.write("GET /test HTTP/1\r\n\r\n");
940        sleep(Millis(25)).await;
941
942        // buf must be consumed
943        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
944
945        // response message
946        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
947
948        client.remote_buffer_cap(10240);
949        sleep(Millis(50)).await;
950        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
951
952        client.remote_buffer_cap(48056);
953        sleep(Millis(50)).await;
954        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
955
956        // backpressure disabled
957        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
958    }
959
960    #[ntex::test]
961    async fn disconnect_during_read_backpressure() {
962        let (client, server) = IoTest::create();
963        client.remote_buffer_cap(0);
964
965        let (disp, state) = Dispatcher::debug(
966            Io::new(
967                server,
968                SharedCfg::new("TEST").add(
969                    IoConfig::new()
970                        .set_keepalive_timeout(Seconds::ZERO)
971                        .set_read_buf(1024, 512, 16),
972                ),
973            ),
974            BytesCodec,
975            ntex_util::services::inflight::InFlightService::new(
976                1,
977                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
978                    if let DispatchItem::Item(_) = msg {
979                        sleep(Millis(500)).await;
980                        Ok::<_, ()>(None)
981                    } else {
982                        Ok(None)
983                    }
984                }),
985            ),
986        );
987
988        let (tx, rx) = ntex::channel::oneshot::channel();
989        ntex::rt::spawn(async move {
990            let _ = disp.await;
991            let _ = tx.send(());
992        });
993
994        let bytes = rand::rng()
995            .sample_iter(&rand::distr::Alphanumeric)
996            .take(1024)
997            .map(char::from)
998            .collect::<String>();
999        client.write(bytes.clone());
1000        sleep(Millis(25)).await;
1001        client.write(bytes);
1002        sleep(Millis(25)).await;
1003
1004        // close read side
1005        state.close();
1006        let _ = rx.recv().await;
1007    }
1008
1009    #[ntex::test]
1010    async fn keepalive() {
1011        let (client, server) = IoTest::create();
1012        client.remote_buffer_cap(1024);
1013        client.write("GET /test HTTP/1\r\n\r\n");
1014
1015        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1016        let data2 = data.clone();
1017
1018        let cfg = SharedCfg::new("DBG").add(
1019            IoConfig::new()
1020                .set_disconnect_timeout(Seconds(1))
1021                .set_keepalive_timeout(Seconds(1)),
1022        );
1023
1024        let (disp, state) = Dispatcher::debug(
1025            Io::new(server, cfg),
1026            BytesCodec,
1027            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1028                let data = data2.clone();
1029                async move {
1030                    match msg {
1031                        DispatchItem::Item(bytes) => {
1032                            data.lock().unwrap().borrow_mut().push(0);
1033                            return Ok::<_, ()>(Some(bytes));
1034                        }
1035                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1036                            data.lock().unwrap().borrow_mut().push(1);
1037                        }
1038                        _ => (),
1039                    }
1040                    Ok(None)
1041                }
1042            }),
1043        );
1044        spawn(async move {
1045            let _ = disp.await;
1046        });
1047
1048        let buf = client.read().await.unwrap();
1049        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1050        sleep(Millis(2000)).await;
1051
1052        // write side must be closed, dispatcher should fail with keep-alive
1053        let flags = state.flags();
1054        assert!(flags.contains(Flags::IO_STOPPING));
1055        assert!(client.is_closed());
1056        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1057    }
1058
1059    #[ntex::test]
1060    async fn keepalive2() {
1061        let (client, server) = IoTest::create();
1062        client.remote_buffer_cap(1024);
1063
1064        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1065        let data2 = data.clone();
1066
1067        let cfg = SharedCfg::new("DBG").add(
1068            IoConfig::new()
1069                .set_keepalive_timeout(Seconds(1))
1070                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1071        );
1072
1073        let (disp, state) = Dispatcher::debug(
1074            Io::new(server, cfg),
1075            BCodec(8),
1076            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1077                let data = data2.clone();
1078                async move {
1079                    match msg {
1080                        DispatchItem::Item(bytes) => {
1081                            data.lock().unwrap().borrow_mut().push(0);
1082                            return Ok::<_, ()>(Some(bytes));
1083                        }
1084                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1085                            data.lock().unwrap().borrow_mut().push(1);
1086                        }
1087                        _ => (),
1088                    }
1089                    Ok(None)
1090                }
1091            }),
1092        );
1093        spawn(async move {
1094            let _ = disp.await;
1095        });
1096
1097        client.write("12345678");
1098        let buf = client.read().await.unwrap();
1099        assert_eq!(buf, Bytes::from_static(b"12345678"));
1100        sleep(Millis(2000)).await;
1101
1102        // write side must be closed, dispatcher should fail with keep-alive
1103        let flags = state.flags();
1104        assert!(flags.contains(Flags::IO_STOPPING));
1105        assert!(client.is_closed());
1106        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1107    }
1108
1109    /// Update keep-alive timer after receiving frame
1110    #[ntex::test]
1111    async fn keepalive3() {
1112        let (client, server) = IoTest::create();
1113        client.remote_buffer_cap(1024);
1114
1115        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1116        let data2 = data.clone();
1117
1118        let cfg = SharedCfg::new("DBG").add(
1119            IoConfig::new()
1120                .set_keepalive_timeout(Seconds(2))
1121                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1122        );
1123
1124        let (disp, _) = Dispatcher::debug(
1125            Io::new(server, cfg),
1126            BCodec(1),
1127            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1128                let data = data2.clone();
1129                async move {
1130                    match msg {
1131                        DispatchItem::Item(bytes) => {
1132                            data.lock().unwrap().borrow_mut().push(0);
1133                            return Ok::<_, ()>(Some(bytes));
1134                        }
1135                        DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1136                            data.lock().unwrap().borrow_mut().push(1);
1137                        }
1138                        _ => (),
1139                    }
1140                    Ok(None)
1141                }
1142            }),
1143        );
1144        spawn(async move {
1145            let _ = disp.await;
1146        });
1147
1148        client.write("1");
1149        let buf = client.read().await.unwrap();
1150        assert_eq!(buf, Bytes::from_static(b"1"));
1151        sleep(Millis(750)).await;
1152
1153        client.write("2");
1154        let buf = client.read().await.unwrap();
1155        assert_eq!(buf, Bytes::from_static(b"2"));
1156
1157        sleep(Millis(750)).await;
1158        client.write("3");
1159        let buf = client.read().await.unwrap();
1160        assert_eq!(buf, Bytes::from_static(b"3"));
1161
1162        sleep(Millis(750)).await;
1163        assert!(!client.is_closed());
1164        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1165    }
1166
1167    #[ntex::test]
1168    async fn read_timeout() {
1169        let (client, server) = IoTest::create();
1170        client.remote_buffer_cap(1024);
1171
1172        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1173        let data2 = data.clone();
1174
1175        let io = Io::new(
1176            server,
1177            SharedCfg::new("TEST").add(
1178                IoConfig::new()
1179                    .set_keepalive_timeout(Seconds::ZERO)
1180                    .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1181            ),
1182        );
1183
1184        let (disp, state) = Dispatcher::debug(
1185            io,
1186            BCodec(8),
1187            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1188                let data = data2.clone();
1189                async move {
1190                    match msg {
1191                        DispatchItem::Item(bytes) => {
1192                            data.lock().unwrap().borrow_mut().push(0);
1193                            return Ok::<_, ()>(Some(bytes));
1194                        }
1195                        DispatchItem::Stop(Reason::ReadTimeout) => {
1196                            data.lock().unwrap().borrow_mut().push(1);
1197                        }
1198                        _ => (),
1199                    }
1200                    Ok(None)
1201                }
1202            }),
1203        );
1204        spawn(async move {
1205            let _ = disp.await;
1206        });
1207
1208        client.write("12345678");
1209        let buf = client.read().await.unwrap();
1210        assert_eq!(buf, Bytes::from_static(b"12345678"));
1211
1212        client.write("1");
1213        sleep(Millis(1000)).await;
1214        assert!(!state.flags().contains(Flags::IO_STOPPING));
1215        client.write("23");
1216        sleep(Millis(1000)).await;
1217        assert!(!state.flags().contains(Flags::IO_STOPPING));
1218        client.write("4");
1219        sleep(Millis(2000)).await;
1220
1221        // write side must be closed, dispatcher should fail with keep-alive
1222        assert!(state.flags().contains(Flags::IO_STOPPING));
1223        assert!(client.is_closed());
1224        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1225    }
1226
1227    #[ntex::test]
1228    async fn idle_timeout() {
1229        let (client, server) = IoTest::create();
1230        client.remote_buffer_cap(1024);
1231
1232        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1233        let data2 = data.clone();
1234
1235        let io = Io::new(
1236            server,
1237            SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1238        );
1239        let ioref = io.get_ref();
1240
1241        let (disp, state) = Dispatcher::debug(
1242            io,
1243            BCodec(1),
1244            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1245                let ioref = ioref.clone();
1246                ntex::rt::spawn(async move {
1247                    sleep(Millis(500)).await;
1248                    ioref.notify_timeout();
1249                });
1250                let data = data2.clone();
1251                async move {
1252                    match msg {
1253                        DispatchItem::Item(bytes) => {
1254                            data.lock().unwrap().borrow_mut().push(0);
1255                            return Ok::<_, ()>(Some(bytes));
1256                        }
1257                        DispatchItem::Stop(Reason::ReadTimeout) => {
1258                            data.lock().unwrap().borrow_mut().push(1);
1259                        }
1260                        _ => (),
1261                    }
1262                    Ok(None)
1263                }
1264            }),
1265        );
1266        spawn(async move {
1267            let _ = disp.await;
1268        });
1269
1270        client.write("1");
1271        let buf = client.read().await.unwrap();
1272        assert_eq!(buf, Bytes::from_static(b"1"));
1273
1274        sleep(Millis(1000)).await;
1275        assert!(state.flags().contains(Flags::IO_STOPPING));
1276        assert!(client.is_closed());
1277    }
1278
1279    #[ntex::test]
1280    async fn unhandled_data() {
1281        let handled = Arc::new(AtomicBool::new(false));
1282        let handled2 = handled.clone();
1283
1284        let (client, server) = IoTest::create();
1285        client.remote_buffer_cap(1024);
1286        client.write("GET /test HTTP/1\r\n\r\n");
1287
1288        let (disp, _) = Dispatcher::debug(
1289            Io::from(server),
1290            BytesCodec,
1291            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1292                handled2.store(true, Relaxed);
1293                async move {
1294                    sleep(Millis(50)).await;
1295                    if let DispatchItem::Item(msg) = msg {
1296                        Ok::<_, ()>(Some(msg))
1297                    } else if let DispatchItem::Stop(_) = msg {
1298                        Ok::<_, ()>(None)
1299                    } else {
1300                        panic!()
1301                    }
1302                }
1303            }),
1304        );
1305        client.close().await;
1306        spawn(async move {
1307            let _ = disp.await;
1308        });
1309        sleep(Millis(50)).await;
1310
1311        assert!(handled.load(Relaxed));
1312    }
1313}