ntex_io/
dispatcher.rs

1//! Framed transport dispatcher
2#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14pin_project_lite::pin_project! {
15    /// Dispatcher - is a future that reads frames from bytes stream
16    /// and pass then to the service.
17    pub struct Dispatcher<S, U>
18    where
19        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20        U: Encoder,
21        U: Decoder,
22        U: 'static,
23    {
24        inner: DispatcherInner<S, U>,
25    }
26}
27
28bitflags::bitflags! {
29    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
30    struct Flags: u8  {
31        const READY_ERR     = 0b0000001;
32        const IO_ERR        = 0b0000010;
33        const KA_ENABLED    = 0b0000100;
34        const KA_TIMEOUT    = 0b0001000;
35        const READ_TIMEOUT  = 0b0010000;
36        const IDLE          = 0b0100000;
37    }
38}
39
40struct DispatcherInner<S, U>
41where
42    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
43    U: Encoder + Decoder + 'static,
44{
45    st: DispatcherState,
46    error: Option<S::Error>,
47    shared: Rc<DispatcherShared<S, U>>,
48    response: Option<PipelineCall<S, DispatchItem<U>>>,
49    read_remains: u32,
50    read_remains_prev: u32,
51    read_max_timeout: Seconds,
52}
53
54pub(crate) struct DispatcherShared<S, U>
55where
56    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
57    U: Encoder + Decoder,
58{
59    io: IoBoxed,
60    codec: U,
61    service: PipelineBinding<S, DispatchItem<U>>,
62    flags: Cell<Flags>,
63    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
64    inflight: Cell<u32>,
65}
66
67#[derive(Copy, Clone, Debug)]
68enum DispatcherState {
69    Processing,
70    Backpressure,
71    Stop,
72    Shutdown,
73}
74
75#[derive(Debug)]
76enum DispatcherError<S, U> {
77    Encoder(U),
78    Service(S),
79}
80
81enum PollService<U: Encoder + Decoder> {
82    Item(DispatchItem<U>),
83    Continue,
84    Ready,
85}
86
87impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
88    fn from(err: Either<S, U>) -> Self {
89        match err {
90            Either::Left(err) => DispatcherError::Service(err),
91            Either::Right(err) => DispatcherError::Encoder(err),
92        }
93    }
94}
95
96impl<S, U> Dispatcher<S, U>
97where
98    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
99    U: Decoder + Encoder + 'static,
100{
101    /// Construct new `Dispatcher` instance.
102    pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
103    where
104        IoBoxed: From<Io>,
105        F: IntoService<S, DispatchItem<U>>,
106    {
107        let io = IoBoxed::from(io);
108        let flags = if io.cfg().keepalive_timeout().is_zero() {
109            Flags::empty()
110        } else {
111            Flags::KA_ENABLED
112        };
113
114        let shared = Rc::new(DispatcherShared {
115            io,
116            codec,
117            flags: Cell::new(flags),
118            error: Cell::new(None),
119            inflight: Cell::new(0),
120            service: Pipeline::new(service.into_service()).bind(),
121        });
122
123        Dispatcher {
124            inner: DispatcherInner {
125                shared,
126                response: None,
127                error: None,
128                read_remains: 0,
129                read_remains_prev: 0,
130                read_max_timeout: Seconds::ZERO,
131                st: DispatcherState::Processing,
132            },
133        }
134    }
135}
136
137impl<S, U> DispatcherShared<S, U>
138where
139    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
140    U: Encoder + Decoder,
141{
142    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
143        match item {
144            Ok(Some(val)) => {
145                if let Err(err) = io.encode(val, &self.codec) {
146                    self.error.set(Some(DispatcherError::Encoder(err)))
147                }
148            }
149            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
150            Ok(None) => (),
151        }
152        let inflight = self.inflight.get() - 1;
153        self.inflight.set(inflight);
154        if inflight == 0 {
155            self.insert_flags(Flags::IDLE);
156        }
157        if wake {
158            io.wake();
159        }
160    }
161
162    fn contains(&self, f: Flags) -> bool {
163        self.flags.get().intersects(f)
164    }
165
166    fn insert_flags(&self, f: Flags) {
167        let mut flags = self.flags.get();
168        flags.insert(f);
169        self.flags.set(flags);
170    }
171
172    fn remove_flags(&self, f: Flags) -> bool {
173        let mut flags = self.flags.get();
174        if flags.intersects(f) {
175            flags.remove(f);
176            self.flags.set(flags);
177            true
178        } else {
179            false
180        }
181    }
182}
183
184impl<S, U> Future for Dispatcher<S, U>
185where
186    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
187    U: Decoder + Encoder + 'static,
188{
189    type Output = Result<(), S::Error>;
190
191    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192        let mut this = self.as_mut().project();
193        let slf = &mut this.inner;
194
195        // handle service response future
196        if let Some(fut) = slf.response.as_mut() {
197            if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
198                slf.shared.handle_result(item, &slf.shared.io, false);
199                slf.response = None;
200            }
201        }
202
203        loop {
204            match slf.st {
205                DispatcherState::Processing => {
206                    let item = match ready!(slf.poll_service(cx)) {
207                        PollService::Ready => {
208                            // decode incoming bytes if buffer is ready
209                            match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
210                                Ok(decoded) => {
211                                    slf.update_timer(&decoded);
212                                    if let Some(el) = decoded.item {
213                                        DispatchItem::Item(el)
214                                    } else {
215                                        return Poll::Pending;
216                                    }
217                                }
218                                Err(RecvError::KeepAlive) => {
219                                    if let Err(err) = slf.handle_timeout() {
220                                        slf.st = DispatcherState::Stop;
221                                        err
222                                    } else {
223                                        continue;
224                                    }
225                                }
226                                Err(RecvError::WriteBackpressure) => {
227                                    // instruct write task to notify dispatcher when data is flushed
228                                    slf.st = DispatcherState::Backpressure;
229                                    DispatchItem::WBackPressureEnabled
230                                }
231                                Err(RecvError::Decoder(err)) => {
232                                    log::trace!(
233                                        "{}: Decoder error, stopping dispatcher: {:?}",
234                                        slf.shared.io.tag(),
235                                        err
236                                    );
237                                    slf.st = DispatcherState::Stop;
238                                    DispatchItem::DecoderError(err)
239                                }
240                                Err(RecvError::PeerGone(err)) => {
241                                    log::trace!(
242                                        "{}: Peer is gone, stopping dispatcher: {:?}",
243                                        slf.shared.io.tag(),
244                                        err
245                                    );
246                                    slf.st = DispatcherState::Stop;
247                                    DispatchItem::Disconnect(err)
248                                }
249                            }
250                        }
251                        PollService::Item(item) => item,
252                        PollService::Continue => continue,
253                    };
254
255                    slf.call_service(cx, item);
256                }
257                // handle write back-pressure
258                DispatcherState::Backpressure => {
259                    match ready!(slf.poll_service(cx)) {
260                        PollService::Ready => (),
261                        PollService::Item(item) => slf.call_service(cx, item),
262                        PollService::Continue => continue,
263                    };
264
265                    let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
266                    {
267                        slf.st = DispatcherState::Stop;
268                        DispatchItem::Disconnect(Some(err))
269                    } else {
270                        slf.st = DispatcherState::Processing;
271                        DispatchItem::WBackPressureDisabled
272                    };
273                    slf.call_service(cx, item);
274                }
275                // drain service responses and shutdown io
276                DispatcherState::Stop => {
277                    slf.shared.io.stop_timer();
278
279                    // service may relay on poll_ready for response results
280                    if !slf.shared.contains(Flags::READY_ERR) {
281                        if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
282                            if res.is_err() {
283                                slf.shared.insert_flags(Flags::READY_ERR);
284                            }
285                        }
286                    }
287
288                    if slf.shared.inflight.get() == 0 {
289                        if slf.shared.io.poll_shutdown(cx).is_ready() {
290                            slf.st = DispatcherState::Shutdown;
291                            continue;
292                        }
293                    } else if !slf.shared.contains(Flags::IO_ERR) {
294                        match ready!(slf.shared.io.poll_status_update(cx)) {
295                            IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
296                                slf.shared.insert_flags(Flags::IO_ERR);
297                                continue;
298                            }
299                            IoStatusUpdate::WriteBackpressure => {
300                                if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
301                                    slf.shared.insert_flags(Flags::IO_ERR);
302                                }
303                                continue;
304                            }
305                        }
306                    } else {
307                        slf.shared.io.poll_dispatch(cx);
308                    }
309                    return Poll::Pending;
310                }
311                // shutdown service
312                DispatcherState::Shutdown => {
313                    return if slf.shared.service.poll_shutdown(cx).is_ready() {
314                        log::trace!(
315                            "{}: Service shutdown is completed, stop",
316                            slf.shared.io.tag()
317                        );
318
319                        Poll::Ready(if let Some(err) = slf.error.take() {
320                            Err(err)
321                        } else {
322                            Ok(())
323                        })
324                    } else {
325                        Poll::Pending
326                    };
327                }
328            }
329        }
330    }
331}
332
333impl<S, U> DispatcherInner<S, U>
334where
335    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
336    U: Decoder + Encoder + 'static,
337{
338    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
339        let mut fut = self.shared.service.call_nowait(item);
340        let inflight = self.shared.inflight.get() + 1;
341        self.shared.inflight.set(inflight);
342        if inflight == 1 {
343            self.shared.remove_flags(Flags::IDLE);
344        }
345
346        // optimize first call
347        if self.response.is_none() {
348            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
349                self.shared.handle_result(result, &self.shared.io, false);
350            } else {
351                self.response = Some(fut);
352            }
353        } else {
354            let shared = self.shared.clone();
355            spawn(async move {
356                let result = fut.await;
357                shared.handle_result(result, &shared.io, true);
358            });
359        }
360    }
361
362    fn check_error(&mut self) -> PollService<U> {
363        // check for errors
364        if let Some(err) = self.shared.error.take() {
365            log::trace!(
366                "{}: Error occured, stopping dispatcher",
367                self.shared.io.tag()
368            );
369            self.st = DispatcherState::Stop;
370
371            match err {
372                DispatcherError::Encoder(err) => {
373                    PollService::Item(DispatchItem::EncoderError(err))
374                }
375                DispatcherError::Service(err) => {
376                    self.error = Some(err);
377                    PollService::Continue
378                }
379            }
380        } else {
381            PollService::Ready
382        }
383    }
384
385    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
386        // wait until service becomes ready
387        match self.shared.service.poll_ready(cx) {
388            Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
389            // pause io read task
390            Poll::Pending => {
391                log::trace!(
392                    "{}: Service is not ready, register dispatcher",
393                    self.shared.io.tag()
394                );
395
396                // remove all timers
397                self.shared
398                    .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
399                self.shared.io.stop_timer();
400
401                match ready!(self.shared.io.poll_read_pause(cx)) {
402                    IoStatusUpdate::KeepAlive => {
403                        log::trace!(
404                            "{}: Keep-alive error, stopping dispatcher during pause",
405                            self.shared.io.tag()
406                        );
407                        self.st = DispatcherState::Stop;
408                        Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
409                    }
410                    IoStatusUpdate::PeerGone(err) => {
411                        log::trace!(
412                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
413                            self.shared.io.tag(),
414                            err
415                        );
416                        self.st = DispatcherState::Stop;
417                        Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
418                    }
419                    IoStatusUpdate::WriteBackpressure => {
420                        self.st = DispatcherState::Backpressure;
421                        Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
422                    }
423                }
424            }
425            // handle service readiness error
426            Poll::Ready(Err(err)) => {
427                log::trace!(
428                    "{}: Service readiness check failed, stopping",
429                    self.shared.io.tag()
430                );
431                self.st = DispatcherState::Stop;
432                self.error = Some(err);
433                self.shared.insert_flags(Flags::READY_ERR);
434                Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
435            }
436        }
437    }
438
439    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
440        // got parsed frame
441        if decoded.item.is_some() {
442            self.read_remains = 0;
443            self.shared
444                .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
445        } else if self.shared.contains(Flags::READ_TIMEOUT) {
446            // received new data but not enough for parsing complete frame
447            self.read_remains = decoded.remains as u32;
448        } else if self.read_remains == 0 && decoded.remains == 0 {
449            // no new data, start keep-alive timer
450            if self.shared.contains(Flags::KA_ENABLED)
451                && !self.shared.contains(Flags::KA_TIMEOUT)
452            {
453                log::trace!(
454                    "{}: Start keep-alive timer {:?}",
455                    self.shared.io.tag(),
456                    self.shared.io.cfg().keepalive_timeout()
457                );
458                self.shared.insert_flags(Flags::KA_TIMEOUT);
459                self.shared
460                    .io
461                    .start_timer(self.shared.io.cfg().keepalive_timeout());
462            }
463        } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
464            // we got new data but not enough to parse single frame
465            // start read timer
466            self.shared.insert_flags(Flags::READ_TIMEOUT);
467
468            self.read_remains = decoded.remains as u32;
469            self.read_remains_prev = 0;
470            self.read_max_timeout = params.max_timeout;
471            self.shared.io.start_timer(params.timeout);
472        }
473    }
474
475    fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
476        // check read timer
477        if self.shared.contains(Flags::READ_TIMEOUT) {
478            if let Some(params) = self.shared.io.cfg().frame_read_rate() {
479                let total = self.read_remains - self.read_remains_prev;
480
481                // read rate, start timer for next period
482                if total > params.rate {
483                    self.read_remains_prev = self.read_remains;
484                    self.read_remains = 0;
485
486                    if !params.max_timeout.is_zero() {
487                        self.read_max_timeout = Seconds(
488                            self.read_max_timeout.0.saturating_sub(params.timeout.0),
489                        );
490                    }
491
492                    if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
493                        log::trace!(
494                            "{}: Frame read rate {:?}, extend timer",
495                            self.shared.io.tag(),
496                            total
497                        );
498                        self.shared.io.start_timer(params.timeout);
499                        return Ok(());
500                    }
501                    log::trace!(
502                        "{}: Max payload timeout has been reached",
503                        self.shared.io.tag()
504                    );
505                }
506                Err(DispatchItem::ReadTimeout)
507            } else {
508                Ok(())
509            }
510        } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
511            log::trace!(
512                "{}: Keep-alive error, stopping dispatcher",
513                self.shared.io.tag()
514            );
515            Err(DispatchItem::KeepAliveTimeout)
516        } else {
517            Ok(())
518        }
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
525    use std::{cell::RefCell, io};
526
527    use ntex_bytes::{Bytes, BytesMut};
528    use ntex_codec::BytesCodec;
529    use ntex_service::{ServiceCtx, cfg::SharedCfg};
530    use ntex_util::{time::Millis, time::sleep};
531    use rand::Rng;
532
533    use super::*;
534    use crate::{Flags, Io, IoConfig, IoRef, testing::IoTest};
535
536    pub(crate) struct State(IoRef);
537
538    impl State {
539        fn flags(&self) -> Flags {
540            self.0.flags()
541        }
542
543        fn io(&self) -> &IoRef {
544            &self.0
545        }
546
547        fn close(&self) {
548            self.0.close();
549        }
550    }
551
552    #[derive(Copy, Clone)]
553    struct BCodec(usize);
554
555    impl Encoder for BCodec {
556        type Item = Bytes;
557        type Error = io::Error;
558
559        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
560            dst.extend_from_slice(&item[..]);
561            Ok(())
562        }
563    }
564
565    impl Decoder for BCodec {
566        type Item = BytesMut;
567        type Error = io::Error;
568
569        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
570            if src.len() < self.0 {
571                Ok(None)
572            } else {
573                Ok(Some(src.split_to(self.0)))
574            }
575        }
576    }
577
578    impl<S, U> Dispatcher<S, U>
579    where
580        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
581        U: Decoder + Encoder + 'static,
582    {
583        /// Construct new `Dispatcher` instance
584        pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
585            let flags = if io.cfg().keepalive_timeout().is_zero() {
586                super::Flags::empty()
587            } else {
588                super::Flags::KA_ENABLED
589            };
590
591            let inner = State(io.get_ref());
592            io.start_timer(Seconds::ONE);
593
594            let shared = Rc::new(DispatcherShared {
595                codec,
596                io: io.into(),
597                flags: Cell::new(flags),
598                error: Cell::new(None),
599                inflight: Cell::new(0),
600                service: Pipeline::new(service).bind(),
601            });
602
603            (
604                Dispatcher {
605                    inner: DispatcherInner {
606                        shared,
607                        error: None,
608                        st: DispatcherState::Processing,
609                        response: None,
610                        read_remains: 0,
611                        read_remains_prev: 0,
612                        read_max_timeout: Seconds::ZERO,
613                    },
614                },
615                inner,
616            )
617        }
618    }
619
620    #[ntex::test]
621    async fn test_basic() {
622        let (client, server) = IoTest::create();
623        client.remote_buffer_cap(1024);
624        client.write("GET /test HTTP/1\r\n\r\n");
625
626        let (disp, _) = Dispatcher::debug(
627            Io::from(server),
628            BytesCodec,
629            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
630                sleep(Millis(50)).await;
631                if let DispatchItem::Item(msg) = msg {
632                    Ok::<_, ()>(Some(msg.freeze()))
633                } else {
634                    panic!()
635                }
636            }),
637        );
638        spawn(async move {
639            let _ = disp.await;
640        });
641
642        sleep(Millis(25)).await;
643        let buf = client.read().await.unwrap();
644        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
645
646        client.write("GET /test HTTP/1\r\n\r\n");
647        let buf = client.read().await.unwrap();
648        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
649
650        client.close().await;
651        assert!(client.is_server_dropped());
652
653        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
654    }
655
656    #[ntex::test]
657    async fn test_sink() {
658        let (client, server) = IoTest::create();
659        client.remote_buffer_cap(1024);
660        client.write("GET /test HTTP/1\r\n\r\n");
661
662        let (disp, st) = Dispatcher::debug(
663            Io::from(server),
664            BytesCodec,
665            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
666                if let DispatchItem::Item(msg) = msg {
667                    Ok::<_, ()>(Some(msg.freeze()))
668                } else if let DispatchItem::Disconnect(_) = msg {
669                    Ok(None)
670                } else {
671                    panic!()
672                }
673            }),
674        );
675        spawn(async move {
676            let _ = disp.await;
677        });
678
679        let buf = client.read().await.unwrap();
680        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
681
682        assert!(
683            st.io()
684                .encode(Bytes::from_static(b"test"), &BytesCodec)
685                .is_ok()
686        );
687        let buf = client.read().await.unwrap();
688        assert_eq!(buf, Bytes::from_static(b"test"));
689
690        st.close();
691        sleep(Millis(1500)).await;
692        assert!(client.is_server_dropped());
693    }
694
695    #[ntex::test]
696    async fn test_err_in_service() {
697        let (client, server) = IoTest::create();
698        client.remote_buffer_cap(0);
699        client.write("GET /test HTTP/1\r\n\r\n");
700
701        let (disp, state) = Dispatcher::debug(
702            Io::from(server),
703            BytesCodec,
704            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
705                Err::<Option<Bytes>, _>(())
706            }),
707        );
708        state
709            .io()
710            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
711            .unwrap();
712        spawn(async move {
713            let _ = disp.await;
714        });
715
716        // buffer should be flushed
717        client.remote_buffer_cap(1024);
718        let buf = client.read().await.unwrap();
719        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
720
721        // write side must be closed, dispatcher waiting for read side to close
722        sleep(Millis(250)).await;
723        assert!(client.is_closed());
724
725        // close read side
726        client.close().await;
727
728        // dispatcher is closed
729        assert!(client.is_server_dropped());
730    }
731
732    #[ntex::test]
733    async fn test_err_in_service_ready() {
734        let (client, server) = IoTest::create();
735        client.remote_buffer_cap(0);
736        client.write("GET /test HTTP/1\r\n\r\n");
737
738        let counter = Rc::new(Cell::new(0));
739
740        struct Srv(Rc<Cell<usize>>);
741
742        impl Service<DispatchItem<BytesCodec>> for Srv {
743            type Response = Option<Response<BytesCodec>>;
744            type Error = &'static str;
745
746            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
747                self.0.set(self.0.get() + 1);
748                Err("test")
749            }
750
751            async fn call(
752                &self,
753                _: DispatchItem<BytesCodec>,
754                _: ServiceCtx<'_, Self>,
755            ) -> Result<Self::Response, Self::Error> {
756                Ok(None)
757            }
758        }
759
760        let (disp, state) =
761            Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
762        spawn(async move {
763            let res = disp.await;
764            assert_eq!(res, Err("test"));
765        });
766
767        state
768            .io()
769            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
770            .unwrap();
771
772        // buffer should be flushed
773        client.remote_buffer_cap(1024);
774        let buf = client.read().await.unwrap();
775        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
776
777        // write side must be closed, dispatcher waiting for read side to close
778        sleep(Millis(250)).await;
779        assert!(client.is_closed());
780
781        // close read side
782        client.close().await;
783        assert!(client.is_server_dropped());
784
785        // service must be checked for readiness only once
786        assert_eq!(counter.get(), 1);
787    }
788
789    #[ntex::test]
790    async fn test_write_backpressure() {
791        let (client, server) = IoTest::create();
792        // do not allow to write to socket
793        client.remote_buffer_cap(0);
794        client.write("GET /test HTTP/1\r\n\r\n");
795
796        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
797        let data2 = data.clone();
798
799        let io = Io::new(
800            server,
801            SharedCfg::new("TEST").add(
802                IoConfig::new()
803                    .set_read_buf(8 * 1024, 1024, 16)
804                    .set_write_buf(16 * 1024, 1024, 16),
805            ),
806        );
807
808        let (disp, state) = Dispatcher::debug(
809            io,
810            BytesCodec,
811            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
812                let data = data2.clone();
813                async move {
814                    match msg {
815                        DispatchItem::Item(_) => {
816                            data.lock().unwrap().borrow_mut().push(0);
817                            let bytes = rand::rng()
818                                .sample_iter(&rand::distr::Alphanumeric)
819                                .take(65_536)
820                                .map(char::from)
821                                .collect::<String>();
822                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
823                        }
824                        DispatchItem::WBackPressureEnabled => {
825                            data.lock().unwrap().borrow_mut().push(1);
826                        }
827                        DispatchItem::WBackPressureDisabled => {
828                            data.lock().unwrap().borrow_mut().push(2);
829                        }
830                        _ => (),
831                    }
832                    Ok(None)
833                }
834            }),
835        );
836
837        spawn(async move {
838            let _ = disp.await;
839        });
840
841        let buf = client.read_any();
842        assert_eq!(buf, Bytes::from_static(b""));
843        client.write("GET /test HTTP/1\r\n\r\n");
844        sleep(Millis(25)).await;
845
846        // buf must be consumed
847        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
848
849        // response message
850        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
851
852        client.remote_buffer_cap(10240);
853        sleep(Millis(50)).await;
854        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
855
856        client.remote_buffer_cap(48056);
857        sleep(Millis(50)).await;
858        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
859
860        // backpressure disabled
861        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
862    }
863
864    #[ntex::test]
865    async fn test_disconnect_during_read_backpressure() {
866        let (client, server) = IoTest::create();
867        client.remote_buffer_cap(0);
868
869        let (disp, state) = Dispatcher::debug(
870            Io::new(
871                server,
872                SharedCfg::new("TEST").add(
873                    IoConfig::new()
874                        .set_keepalive_timeout(Seconds::ZERO)
875                        .set_read_buf(1024, 512, 16),
876                ),
877            ),
878            BytesCodec,
879            ntex_util::services::inflight::InFlightService::new(
880                1,
881                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
882                    if let DispatchItem::Item(_) = msg {
883                        sleep(Millis(500)).await;
884                        Ok::<_, ()>(None)
885                    } else {
886                        Ok(None)
887                    }
888                }),
889            ),
890        );
891
892        let (tx, rx) = ntex::channel::oneshot::channel();
893        ntex::rt::spawn(async move {
894            let _ = disp.await;
895            let _ = tx.send(());
896        });
897
898        let bytes = rand::rng()
899            .sample_iter(&rand::distr::Alphanumeric)
900            .take(1024)
901            .map(char::from)
902            .collect::<String>();
903        client.write(bytes.clone());
904        sleep(Millis(25)).await;
905        client.write(bytes);
906        sleep(Millis(25)).await;
907
908        // close read side
909        state.close();
910        let _ = rx.recv().await;
911    }
912
913    #[ntex::test]
914    async fn test_keepalive() {
915        let (client, server) = IoTest::create();
916        client.remote_buffer_cap(1024);
917        client.write("GET /test HTTP/1\r\n\r\n");
918
919        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
920        let data2 = data.clone();
921
922        let cfg = SharedCfg::new("DBG").add(
923            IoConfig::new()
924                .set_disconnect_timeout(Seconds(1))
925                .set_keepalive_timeout(Seconds(1)),
926        );
927
928        let (disp, state) = Dispatcher::debug(
929            Io::new(server, cfg),
930            BytesCodec,
931            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
932                let data = data2.clone();
933                async move {
934                    match msg {
935                        DispatchItem::Item(bytes) => {
936                            data.lock().unwrap().borrow_mut().push(0);
937                            return Ok::<_, ()>(Some(bytes.freeze()));
938                        }
939                        DispatchItem::KeepAliveTimeout => {
940                            data.lock().unwrap().borrow_mut().push(1);
941                        }
942                        _ => (),
943                    }
944                    Ok(None)
945                }
946            }),
947        );
948        spawn(async move {
949            let _ = disp.await;
950        });
951
952        let buf = client.read().await.unwrap();
953        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
954        sleep(Millis(2000)).await;
955
956        // write side must be closed, dispatcher should fail with keep-alive
957        let flags = state.flags();
958        assert!(flags.contains(Flags::IO_STOPPING));
959        assert!(client.is_closed());
960        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
961    }
962
963    #[ntex::test]
964    async fn test_keepalive2() {
965        let (client, server) = IoTest::create();
966        client.remote_buffer_cap(1024);
967
968        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
969        let data2 = data.clone();
970
971        let cfg = SharedCfg::new("DBG").add(
972            IoConfig::new()
973                .set_keepalive_timeout(Seconds(1))
974                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
975        );
976
977        let (disp, state) = Dispatcher::debug(
978            Io::new(server, cfg),
979            BCodec(8),
980            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
981                let data = data2.clone();
982                async move {
983                    match msg {
984                        DispatchItem::Item(bytes) => {
985                            data.lock().unwrap().borrow_mut().push(0);
986                            return Ok::<_, ()>(Some(bytes.freeze()));
987                        }
988                        DispatchItem::KeepAliveTimeout => {
989                            data.lock().unwrap().borrow_mut().push(1);
990                        }
991                        _ => (),
992                    }
993                    Ok(None)
994                }
995            }),
996        );
997        spawn(async move {
998            let _ = disp.await;
999        });
1000
1001        client.write("12345678");
1002        let buf = client.read().await.unwrap();
1003        assert_eq!(buf, Bytes::from_static(b"12345678"));
1004        sleep(Millis(2000)).await;
1005
1006        // write side must be closed, dispatcher should fail with keep-alive
1007        let flags = state.flags();
1008        assert!(flags.contains(Flags::IO_STOPPING));
1009        assert!(client.is_closed());
1010        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1011    }
1012
1013    /// Update keep-alive timer after receiving frame
1014    #[ntex::test]
1015    async fn test_keepalive3() {
1016        let (client, server) = IoTest::create();
1017        client.remote_buffer_cap(1024);
1018
1019        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1020        let data2 = data.clone();
1021
1022        let cfg = SharedCfg::new("DBG").add(
1023            IoConfig::new()
1024                .set_keepalive_timeout(Seconds(2))
1025                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1026        );
1027
1028        let (disp, _) = Dispatcher::debug(
1029            Io::new(server, cfg),
1030            BCodec(1),
1031            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1032                let data = data2.clone();
1033                async move {
1034                    match msg {
1035                        DispatchItem::Item(bytes) => {
1036                            data.lock().unwrap().borrow_mut().push(0);
1037                            return Ok::<_, ()>(Some(bytes.freeze()));
1038                        }
1039                        DispatchItem::KeepAliveTimeout => {
1040                            data.lock().unwrap().borrow_mut().push(1);
1041                        }
1042                        _ => (),
1043                    }
1044                    Ok(None)
1045                }
1046            }),
1047        );
1048        spawn(async move {
1049            let _ = disp.await;
1050        });
1051
1052        client.write("1");
1053        let buf = client.read().await.unwrap();
1054        assert_eq!(buf, Bytes::from_static(b"1"));
1055        sleep(Millis(750)).await;
1056
1057        client.write("2");
1058        let buf = client.read().await.unwrap();
1059        assert_eq!(buf, Bytes::from_static(b"2"));
1060
1061        sleep(Millis(750)).await;
1062        client.write("3");
1063        let buf = client.read().await.unwrap();
1064        assert_eq!(buf, Bytes::from_static(b"3"));
1065
1066        sleep(Millis(750)).await;
1067        assert!(!client.is_closed());
1068        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1069    }
1070
1071    #[ntex::test]
1072    async fn test_read_timeout() {
1073        let (client, server) = IoTest::create();
1074        client.remote_buffer_cap(1024);
1075
1076        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1077        let data2 = data.clone();
1078
1079        let io = Io::new(
1080            server,
1081            SharedCfg::new("TEST").add(
1082                IoConfig::new()
1083                    .set_keepalive_timeout(Seconds::ZERO)
1084                    .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1085            ),
1086        );
1087
1088        let (disp, state) = Dispatcher::debug(
1089            io,
1090            BCodec(8),
1091            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1092                let data = data2.clone();
1093                async move {
1094                    match msg {
1095                        DispatchItem::Item(bytes) => {
1096                            data.lock().unwrap().borrow_mut().push(0);
1097                            return Ok::<_, ()>(Some(bytes.freeze()));
1098                        }
1099                        DispatchItem::ReadTimeout => {
1100                            data.lock().unwrap().borrow_mut().push(1);
1101                        }
1102                        _ => (),
1103                    }
1104                    Ok(None)
1105                }
1106            }),
1107        );
1108        spawn(async move {
1109            let _ = disp.await;
1110        });
1111
1112        client.write("12345678");
1113        let buf = client.read().await.unwrap();
1114        assert_eq!(buf, Bytes::from_static(b"12345678"));
1115
1116        client.write("1");
1117        sleep(Millis(1000)).await;
1118        assert!(!state.flags().contains(Flags::IO_STOPPING));
1119        client.write("23");
1120        sleep(Millis(1000)).await;
1121        assert!(!state.flags().contains(Flags::IO_STOPPING));
1122        client.write("4");
1123        sleep(Millis(2000)).await;
1124
1125        // write side must be closed, dispatcher should fail with keep-alive
1126        assert!(state.flags().contains(Flags::IO_STOPPING));
1127        assert!(client.is_closed());
1128        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1129    }
1130
1131    #[ntex::test]
1132    async fn test_idle_timeout() {
1133        let (client, server) = IoTest::create();
1134        client.remote_buffer_cap(1024);
1135
1136        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1137        let data2 = data.clone();
1138
1139        let io = Io::new(
1140            server,
1141            SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1142        );
1143        let ioref = io.get_ref();
1144
1145        let (disp, state) = Dispatcher::debug(
1146            io,
1147            BCodec(1),
1148            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1149                let ioref = ioref.clone();
1150                ntex::rt::spawn(async move {
1151                    sleep(Millis(500)).await;
1152                    ioref.notify_timeout();
1153                });
1154                let data = data2.clone();
1155                async move {
1156                    match msg {
1157                        DispatchItem::Item(bytes) => {
1158                            data.lock().unwrap().borrow_mut().push(0);
1159                            return Ok::<_, ()>(Some(bytes.freeze()));
1160                        }
1161                        DispatchItem::ReadTimeout => {
1162                            data.lock().unwrap().borrow_mut().push(1);
1163                        }
1164                        _ => (),
1165                    }
1166                    Ok(None)
1167                }
1168            }),
1169        );
1170        spawn(async move {
1171            let _ = disp.await;
1172        });
1173
1174        client.write("1");
1175        let buf = client.read().await.unwrap();
1176        assert_eq!(buf, Bytes::from_static(b"1"));
1177
1178        sleep(Millis(1000)).await;
1179        assert!(state.flags().contains(Flags::IO_STOPPING));
1180        assert!(client.is_closed());
1181    }
1182
1183    #[ntex::test]
1184    async fn test_unhandled_data() {
1185        let handled = Arc::new(AtomicBool::new(false));
1186        let handled2 = handled.clone();
1187
1188        let (client, server) = IoTest::create();
1189        client.remote_buffer_cap(1024);
1190        client.write("GET /test HTTP/1\r\n\r\n");
1191
1192        let (disp, _) = Dispatcher::debug(
1193            Io::from(server),
1194            BytesCodec,
1195            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1196                handled2.store(true, Relaxed);
1197                async move {
1198                    sleep(Millis(50)).await;
1199                    if let DispatchItem::Item(msg) = msg {
1200                        Ok::<_, ()>(Some(msg.freeze()))
1201                    } else if let DispatchItem::Disconnect(_) = msg {
1202                        Ok::<_, ()>(None)
1203                    } else {
1204                        panic!()
1205                    }
1206                }
1207            }),
1208        );
1209        client.close().await;
1210        spawn(async move {
1211            let _ = disp.await;
1212        });
1213        sleep(Millis(50)).await;
1214
1215        assert!(handled.load(Relaxed));
1216    }
1217}