ntex_io/
dispatcher.rs

1//! Framed transport dispatcher
2#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14pin_project_lite::pin_project! {
15    /// Dispatcher - is a future that reads frames from bytes stream
16    /// and pass then to the service.
17    pub struct Dispatcher<S, U>
18    where
19        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20        U: Encoder,
21        U: Decoder,
22        U: 'static,
23    {
24        inner: DispatcherInner<S, U>,
25    }
26}
27
28bitflags::bitflags! {
29    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
30    struct Flags: u8  {
31        const READY_ERR     = 0b0000001;
32        const IO_ERR        = 0b0000010;
33        const KA_ENABLED    = 0b0000100;
34        const KA_TIMEOUT    = 0b0001000;
35        const READ_TIMEOUT  = 0b0010000;
36        const IDLE          = 0b0100000;
37    }
38}
39
40struct DispatcherInner<S, U>
41where
42    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
43    U: Encoder + Decoder + 'static,
44{
45    st: DispatcherState,
46    error: Option<S::Error>,
47    shared: Rc<DispatcherShared<S, U>>,
48    response: Option<PipelineCall<S, DispatchItem<U>>>,
49    read_remains: u32,
50    read_remains_prev: u32,
51    read_max_timeout: Seconds,
52}
53
54pub(crate) struct DispatcherShared<S, U>
55where
56    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
57    U: Encoder + Decoder,
58{
59    io: IoBoxed,
60    codec: U,
61    service: PipelineBinding<S, DispatchItem<U>>,
62    flags: Cell<Flags>,
63    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
64    inflight: Cell<u32>,
65}
66
67#[derive(Copy, Clone, Debug)]
68enum DispatcherState {
69    Processing,
70    Backpressure,
71    Stop,
72    Shutdown,
73}
74
75#[derive(Debug)]
76enum DispatcherError<S, U> {
77    Encoder(U),
78    Service(S),
79}
80
81enum PollService<U: Encoder + Decoder> {
82    Item(DispatchItem<U>),
83    Continue,
84    Ready,
85}
86
87impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
88    fn from(err: Either<S, U>) -> Self {
89        match err {
90            Either::Left(err) => DispatcherError::Service(err),
91            Either::Right(err) => DispatcherError::Encoder(err),
92        }
93    }
94}
95
96impl<S, U> Dispatcher<S, U>
97where
98    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
99    U: Decoder + Encoder + 'static,
100{
101    /// Construct new `Dispatcher` instance.
102    pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
103    where
104        IoBoxed: From<Io>,
105        F: IntoService<S, DispatchItem<U>>,
106    {
107        let io = IoBoxed::from(io);
108        let flags = if io.cfg().keepalive_timeout().is_zero() {
109            Flags::empty()
110        } else {
111            Flags::KA_ENABLED
112        };
113
114        let shared = Rc::new(DispatcherShared {
115            io,
116            codec,
117            flags: Cell::new(flags),
118            error: Cell::new(None),
119            inflight: Cell::new(0),
120            service: Pipeline::new(service.into_service()).bind(),
121        });
122
123        Dispatcher {
124            inner: DispatcherInner {
125                shared,
126                response: None,
127                error: None,
128                read_remains: 0,
129                read_remains_prev: 0,
130                read_max_timeout: Seconds::ZERO,
131                st: DispatcherState::Processing,
132            },
133        }
134    }
135}
136
137impl<S, U> DispatcherShared<S, U>
138where
139    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
140    U: Encoder + Decoder,
141{
142    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
143        match item {
144            Ok(Some(val)) => {
145                if let Err(err) = io.encode(val, &self.codec) {
146                    self.error.set(Some(DispatcherError::Encoder(err)))
147                }
148            }
149            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
150            Ok(None) => (),
151        }
152        let inflight = self.inflight.get() - 1;
153        self.inflight.set(inflight);
154        if inflight == 0 {
155            self.insert_flags(Flags::IDLE);
156        }
157        if wake {
158            io.wake();
159        }
160    }
161
162    fn contains(&self, f: Flags) -> bool {
163        self.flags.get().intersects(f)
164    }
165
166    fn insert_flags(&self, f: Flags) {
167        let mut flags = self.flags.get();
168        flags.insert(f);
169        self.flags.set(flags);
170    }
171
172    fn remove_flags(&self, f: Flags) -> bool {
173        let mut flags = self.flags.get();
174        if flags.intersects(f) {
175            flags.remove(f);
176            self.flags.set(flags);
177            true
178        } else {
179            false
180        }
181    }
182}
183
184impl<S, U> Future for Dispatcher<S, U>
185where
186    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
187    U: Decoder + Encoder + 'static,
188{
189    type Output = Result<(), S::Error>;
190
191    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192        let mut this = self.as_mut().project();
193        let slf = &mut this.inner;
194
195        // handle service response future
196        if let Some(fut) = slf.response.as_mut() {
197            if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
198                slf.shared.handle_result(item, &slf.shared.io, false);
199                slf.response = None;
200            }
201        }
202
203        loop {
204            match slf.st {
205                DispatcherState::Processing => {
206                    let item = match ready!(slf.poll_service(cx)) {
207                        PollService::Ready => {
208                            // decode incoming bytes if buffer is ready
209                            match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
210                                Ok(decoded) => {
211                                    slf.update_timer(&decoded);
212                                    if let Some(el) = decoded.item {
213                                        DispatchItem::Item(el)
214                                    } else {
215                                        return Poll::Pending;
216                                    }
217                                }
218                                Err(RecvError::KeepAlive) => {
219                                    if let Err(err) = slf.handle_timeout() {
220                                        slf.st = DispatcherState::Stop;
221                                        err
222                                    } else {
223                                        continue;
224                                    }
225                                }
226                                Err(RecvError::Stop) => {
227                                    log::trace!(
228                                        "{}: Dispatcher is instructed to stop",
229                                        slf.shared.io.tag()
230                                    );
231                                    slf.st = DispatcherState::Stop;
232                                    continue;
233                                }
234                                Err(RecvError::WriteBackpressure) => {
235                                    // instruct write task to notify dispatcher when data is flushed
236                                    slf.st = DispatcherState::Backpressure;
237                                    DispatchItem::WBackPressureEnabled
238                                }
239                                Err(RecvError::Decoder(err)) => {
240                                    log::trace!(
241                                        "{}: Decoder error, stopping dispatcher: {:?}",
242                                        slf.shared.io.tag(),
243                                        err
244                                    );
245                                    slf.st = DispatcherState::Stop;
246                                    DispatchItem::DecoderError(err)
247                                }
248                                Err(RecvError::PeerGone(err)) => {
249                                    log::trace!(
250                                        "{}: Peer is gone, stopping dispatcher: {:?}",
251                                        slf.shared.io.tag(),
252                                        err
253                                    );
254                                    slf.st = DispatcherState::Stop;
255                                    DispatchItem::Disconnect(err)
256                                }
257                            }
258                        }
259                        PollService::Item(item) => item,
260                        PollService::Continue => continue,
261                    };
262
263                    slf.call_service(cx, item);
264                }
265                // handle write back-pressure
266                DispatcherState::Backpressure => {
267                    match ready!(slf.poll_service(cx)) {
268                        PollService::Ready => (),
269                        PollService::Item(item) => slf.call_service(cx, item),
270                        PollService::Continue => continue,
271                    };
272
273                    let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
274                    {
275                        slf.st = DispatcherState::Stop;
276                        DispatchItem::Disconnect(Some(err))
277                    } else {
278                        slf.st = DispatcherState::Processing;
279                        DispatchItem::WBackPressureDisabled
280                    };
281                    slf.call_service(cx, item);
282                }
283                // drain service responses and shutdown io
284                DispatcherState::Stop => {
285                    slf.shared.io.stop_timer();
286
287                    // service may relay on poll_ready for response results
288                    if !slf.shared.contains(Flags::READY_ERR) {
289                        if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
290                            if res.is_err() {
291                                slf.shared.insert_flags(Flags::READY_ERR);
292                            }
293                        }
294                    }
295
296                    if slf.shared.inflight.get() == 0 {
297                        if slf.shared.io.poll_shutdown(cx).is_ready() {
298                            slf.st = DispatcherState::Shutdown;
299                            continue;
300                        }
301                    } else if !slf.shared.contains(Flags::IO_ERR) {
302                        match ready!(slf.shared.io.poll_status_update(cx)) {
303                            IoStatusUpdate::PeerGone(_)
304                            | IoStatusUpdate::Stop
305                            | IoStatusUpdate::KeepAlive => {
306                                slf.shared.insert_flags(Flags::IO_ERR);
307                                continue;
308                            }
309                            IoStatusUpdate::WriteBackpressure => {
310                                if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
311                                    slf.shared.insert_flags(Flags::IO_ERR);
312                                }
313                                continue;
314                            }
315                        }
316                    } else {
317                        slf.shared.io.poll_dispatch(cx);
318                    }
319                    return Poll::Pending;
320                }
321                // shutdown service
322                DispatcherState::Shutdown => {
323                    return if slf.shared.service.poll_shutdown(cx).is_ready() {
324                        log::trace!(
325                            "{}: Service shutdown is completed, stop",
326                            slf.shared.io.tag()
327                        );
328
329                        Poll::Ready(if let Some(err) = slf.error.take() {
330                            Err(err)
331                        } else {
332                            Ok(())
333                        })
334                    } else {
335                        Poll::Pending
336                    };
337                }
338            }
339        }
340    }
341}
342
343impl<S, U> DispatcherInner<S, U>
344where
345    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
346    U: Decoder + Encoder + 'static,
347{
348    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
349        let mut fut = self.shared.service.call_nowait(item);
350        let inflight = self.shared.inflight.get() + 1;
351        self.shared.inflight.set(inflight);
352        if inflight == 1 {
353            self.shared.remove_flags(Flags::IDLE);
354        }
355
356        // optimize first call
357        if self.response.is_none() {
358            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
359                self.shared.handle_result(result, &self.shared.io, false);
360            } else {
361                self.response = Some(fut);
362            }
363        } else {
364            let shared = self.shared.clone();
365            spawn(async move {
366                let result = fut.await;
367                shared.handle_result(result, &shared.io, true);
368            });
369        }
370    }
371
372    fn check_error(&mut self) -> PollService<U> {
373        // check for errors
374        if let Some(err) = self.shared.error.take() {
375            log::trace!(
376                "{}: Error occured, stopping dispatcher",
377                self.shared.io.tag()
378            );
379            self.st = DispatcherState::Stop;
380
381            match err {
382                DispatcherError::Encoder(err) => {
383                    PollService::Item(DispatchItem::EncoderError(err))
384                }
385                DispatcherError::Service(err) => {
386                    self.error = Some(err);
387                    PollService::Continue
388                }
389            }
390        } else {
391            PollService::Ready
392        }
393    }
394
395    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
396        // wait until service becomes ready
397        match self.shared.service.poll_ready(cx) {
398            Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
399            // pause io read task
400            Poll::Pending => {
401                log::trace!(
402                    "{}: Service is not ready, register dispatcher",
403                    self.shared.io.tag()
404                );
405
406                // remove all timers
407                self.shared
408                    .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
409                self.shared.io.stop_timer();
410
411                match ready!(self.shared.io.poll_read_pause(cx)) {
412                    IoStatusUpdate::KeepAlive => {
413                        log::trace!(
414                            "{}: Keep-alive error, stopping dispatcher during pause",
415                            self.shared.io.tag()
416                        );
417                        self.st = DispatcherState::Stop;
418                        Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
419                    }
420                    IoStatusUpdate::Stop => {
421                        log::trace!(
422                            "{}: Dispatcher is instructed to stop during pause",
423                            self.shared.io.tag()
424                        );
425                        self.st = DispatcherState::Stop;
426                        Poll::Ready(PollService::Continue)
427                    }
428                    IoStatusUpdate::PeerGone(err) => {
429                        log::trace!(
430                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
431                            self.shared.io.tag(),
432                            err
433                        );
434                        self.st = DispatcherState::Stop;
435                        Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
436                    }
437                    IoStatusUpdate::WriteBackpressure => {
438                        self.st = DispatcherState::Backpressure;
439                        Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
440                    }
441                }
442            }
443            // handle service readiness error
444            Poll::Ready(Err(err)) => {
445                log::trace!(
446                    "{}: Service readiness check failed, stopping",
447                    self.shared.io.tag()
448                );
449                self.st = DispatcherState::Stop;
450                self.error = Some(err);
451                self.shared.insert_flags(Flags::READY_ERR);
452                Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
453            }
454        }
455    }
456
457    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
458        // got parsed frame
459        if decoded.item.is_some() {
460            self.read_remains = 0;
461            self.shared
462                .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
463        } else if self.shared.contains(Flags::READ_TIMEOUT) {
464            // received new data but not enough for parsing complete frame
465            self.read_remains = decoded.remains as u32;
466        } else if self.read_remains == 0 && decoded.remains == 0 {
467            // no new data, start keep-alive timer
468            if self.shared.contains(Flags::KA_ENABLED)
469                && !self.shared.contains(Flags::KA_TIMEOUT)
470            {
471                log::trace!(
472                    "{}: Start keep-alive timer {:?}",
473                    self.shared.io.tag(),
474                    self.shared.io.cfg().keepalive_timeout()
475                );
476                self.shared.insert_flags(Flags::KA_TIMEOUT);
477                self.shared
478                    .io
479                    .start_timer(self.shared.io.cfg().keepalive_timeout());
480            }
481        } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
482            // we got new data but not enough to parse single frame
483            // start read timer
484            self.shared.insert_flags(Flags::READ_TIMEOUT);
485
486            self.read_remains = decoded.remains as u32;
487            self.read_remains_prev = 0;
488            self.read_max_timeout = params.max_timeout;
489            self.shared.io.start_timer(params.timeout);
490        }
491    }
492
493    fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
494        // check read timer
495        if self.shared.contains(Flags::READ_TIMEOUT) {
496            if let Some(params) = self.shared.io.cfg().frame_read_rate() {
497                let total = self.read_remains - self.read_remains_prev;
498
499                // read rate, start timer for next period
500                if total > params.rate {
501                    self.read_remains_prev = self.read_remains;
502                    self.read_remains = 0;
503
504                    if !params.max_timeout.is_zero() {
505                        self.read_max_timeout = Seconds(
506                            self.read_max_timeout.0.saturating_sub(params.timeout.0),
507                        );
508                    }
509
510                    if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
511                        log::trace!(
512                            "{}: Frame read rate {:?}, extend timer",
513                            self.shared.io.tag(),
514                            total
515                        );
516                        self.shared.io.start_timer(params.timeout);
517                        return Ok(());
518                    }
519                    log::trace!(
520                        "{}: Max payload timeout has been reached",
521                        self.shared.io.tag()
522                    );
523                }
524                Err(DispatchItem::ReadTimeout)
525            } else {
526                Ok(())
527            }
528        } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
529            log::trace!(
530                "{}: Keep-alive error, stopping dispatcher",
531                self.shared.io.tag()
532            );
533            Err(DispatchItem::KeepAliveTimeout)
534        } else {
535            Ok(())
536        }
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
543    use std::{cell::RefCell, io};
544
545    use ntex_bytes::{Bytes, BytesMut};
546    use ntex_codec::BytesCodec;
547    use ntex_service::{ServiceCtx, cfg::SharedCfg};
548    use ntex_util::{time::Millis, time::sleep};
549    use rand::Rng;
550
551    use super::*;
552    use crate::{Flags, Io, IoConfig, IoRef, testing::IoTest};
553
554    pub(crate) struct State(IoRef);
555
556    impl State {
557        fn flags(&self) -> Flags {
558            self.0.flags()
559        }
560
561        fn io(&self) -> &IoRef {
562            &self.0
563        }
564
565        fn close(&self) {
566            self.0.0.insert_flags(Flags::DSP_STOP);
567            self.0.0.dispatch_task.wake();
568        }
569    }
570
571    #[derive(Copy, Clone)]
572    struct BCodec(usize);
573
574    impl Encoder for BCodec {
575        type Item = Bytes;
576        type Error = io::Error;
577
578        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
579            dst.extend_from_slice(&item[..]);
580            Ok(())
581        }
582    }
583
584    impl Decoder for BCodec {
585        type Item = BytesMut;
586        type Error = io::Error;
587
588        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
589            if src.len() < self.0 {
590                Ok(None)
591            } else {
592                Ok(Some(src.split_to(self.0)))
593            }
594        }
595    }
596
597    impl<S, U> Dispatcher<S, U>
598    where
599        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
600        U: Decoder + Encoder + 'static,
601    {
602        /// Construct new `Dispatcher` instance
603        pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
604            let flags = if io.cfg().keepalive_timeout().is_zero() {
605                super::Flags::empty()
606            } else {
607                super::Flags::KA_ENABLED
608            };
609
610            let inner = State(io.get_ref());
611            io.start_timer(Seconds::ONE);
612
613            let shared = Rc::new(DispatcherShared {
614                codec,
615                io: io.into(),
616                flags: Cell::new(flags),
617                error: Cell::new(None),
618                inflight: Cell::new(0),
619                service: Pipeline::new(service).bind(),
620            });
621
622            (
623                Dispatcher {
624                    inner: DispatcherInner {
625                        shared,
626                        error: None,
627                        st: DispatcherState::Processing,
628                        response: None,
629                        read_remains: 0,
630                        read_remains_prev: 0,
631                        read_max_timeout: Seconds::ZERO,
632                    },
633                },
634                inner,
635            )
636        }
637    }
638
639    #[ntex::test]
640    async fn test_basic() {
641        let (client, server) = IoTest::create();
642        client.remote_buffer_cap(1024);
643        client.write("GET /test HTTP/1\r\n\r\n");
644
645        let (disp, _) = Dispatcher::debug(
646            Io::from(server),
647            BytesCodec,
648            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
649                sleep(Millis(50)).await;
650                if let DispatchItem::Item(msg) = msg {
651                    Ok::<_, ()>(Some(msg.freeze()))
652                } else {
653                    panic!()
654                }
655            }),
656        );
657        spawn(async move {
658            let _ = disp.await;
659        });
660
661        sleep(Millis(25)).await;
662        let buf = client.read().await.unwrap();
663        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
664
665        client.write("GET /test HTTP/1\r\n\r\n");
666        let buf = client.read().await.unwrap();
667        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
668
669        client.close().await;
670        assert!(client.is_server_dropped());
671
672        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
673    }
674
675    #[ntex::test]
676    async fn test_sink() {
677        let (client, server) = IoTest::create();
678        client.remote_buffer_cap(1024);
679        client.write("GET /test HTTP/1\r\n\r\n");
680
681        let (disp, st) = Dispatcher::debug(
682            Io::from(server),
683            BytesCodec,
684            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
685                if let DispatchItem::Item(msg) = msg {
686                    Ok::<_, ()>(Some(msg.freeze()))
687                } else {
688                    panic!()
689                }
690            }),
691        );
692        spawn(async move {
693            let _ = disp.await;
694        });
695
696        let buf = client.read().await.unwrap();
697        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
698
699        assert!(
700            st.io()
701                .encode(Bytes::from_static(b"test"), &BytesCodec)
702                .is_ok()
703        );
704        let buf = client.read().await.unwrap();
705        assert_eq!(buf, Bytes::from_static(b"test"));
706
707        st.close();
708        sleep(Millis(1500)).await;
709        assert!(client.is_server_dropped());
710    }
711
712    #[ntex::test]
713    async fn test_err_in_service() {
714        let (client, server) = IoTest::create();
715        client.remote_buffer_cap(0);
716        client.write("GET /test HTTP/1\r\n\r\n");
717
718        let (disp, state) = Dispatcher::debug(
719            Io::from(server),
720            BytesCodec,
721            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
722                Err::<Option<Bytes>, _>(())
723            }),
724        );
725        state
726            .io()
727            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
728            .unwrap();
729        spawn(async move {
730            let _ = disp.await;
731        });
732
733        // buffer should be flushed
734        client.remote_buffer_cap(1024);
735        let buf = client.read().await.unwrap();
736        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
737
738        // write side must be closed, dispatcher waiting for read side to close
739        sleep(Millis(250)).await;
740        assert!(client.is_closed());
741
742        // close read side
743        client.close().await;
744
745        // dispatcher is closed
746        assert!(client.is_server_dropped());
747    }
748
749    #[ntex::test]
750    async fn test_err_in_service_ready() {
751        let (client, server) = IoTest::create();
752        client.remote_buffer_cap(0);
753        client.write("GET /test HTTP/1\r\n\r\n");
754
755        let counter = Rc::new(Cell::new(0));
756
757        struct Srv(Rc<Cell<usize>>);
758
759        impl Service<DispatchItem<BytesCodec>> for Srv {
760            type Response = Option<Response<BytesCodec>>;
761            type Error = &'static str;
762
763            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
764                self.0.set(self.0.get() + 1);
765                Err("test")
766            }
767
768            async fn call(
769                &self,
770                _: DispatchItem<BytesCodec>,
771                _: ServiceCtx<'_, Self>,
772            ) -> Result<Self::Response, Self::Error> {
773                Ok(None)
774            }
775        }
776
777        let (disp, state) =
778            Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
779        spawn(async move {
780            let res = disp.await;
781            assert_eq!(res, Err("test"));
782        });
783
784        state
785            .io()
786            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
787            .unwrap();
788
789        // buffer should be flushed
790        client.remote_buffer_cap(1024);
791        let buf = client.read().await.unwrap();
792        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
793
794        // write side must be closed, dispatcher waiting for read side to close
795        sleep(Millis(250)).await;
796        assert!(client.is_closed());
797
798        // close read side
799        client.close().await;
800        assert!(client.is_server_dropped());
801
802        // service must be checked for readiness only once
803        assert_eq!(counter.get(), 1);
804    }
805
806    #[ntex::test]
807    async fn test_write_backpressure() {
808        let (client, server) = IoTest::create();
809        // do not allow to write to socket
810        client.remote_buffer_cap(0);
811        client.write("GET /test HTTP/1\r\n\r\n");
812
813        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
814        let data2 = data.clone();
815
816        let io = Io::new(
817            server,
818            SharedCfg::new("TEST").add(
819                IoConfig::new()
820                    .set_read_buf(8 * 1024, 1024, 16)
821                    .set_write_buf(16 * 1024, 1024, 16),
822            ),
823        );
824
825        let (disp, state) = Dispatcher::debug(
826            io,
827            BytesCodec,
828            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
829                let data = data2.clone();
830                async move {
831                    match msg {
832                        DispatchItem::Item(_) => {
833                            data.lock().unwrap().borrow_mut().push(0);
834                            let bytes = rand::rng()
835                                .sample_iter(&rand::distr::Alphanumeric)
836                                .take(65_536)
837                                .map(char::from)
838                                .collect::<String>();
839                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
840                        }
841                        DispatchItem::WBackPressureEnabled => {
842                            data.lock().unwrap().borrow_mut().push(1);
843                        }
844                        DispatchItem::WBackPressureDisabled => {
845                            data.lock().unwrap().borrow_mut().push(2);
846                        }
847                        _ => (),
848                    }
849                    Ok(None)
850                }
851            }),
852        );
853
854        spawn(async move {
855            let _ = disp.await;
856        });
857
858        let buf = client.read_any();
859        assert_eq!(buf, Bytes::from_static(b""));
860        client.write("GET /test HTTP/1\r\n\r\n");
861        sleep(Millis(25)).await;
862
863        // buf must be consumed
864        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
865
866        // response message
867        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
868
869        client.remote_buffer_cap(10240);
870        sleep(Millis(50)).await;
871        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
872
873        client.remote_buffer_cap(48056);
874        sleep(Millis(50)).await;
875        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
876
877        // backpressure disabled
878        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
879    }
880
881    #[ntex::test]
882    async fn test_disconnect_during_read_backpressure() {
883        let (client, server) = IoTest::create();
884        client.remote_buffer_cap(0);
885
886        let (disp, state) = Dispatcher::debug(
887            Io::new(
888                server,
889                SharedCfg::new("TEST").add(
890                    IoConfig::new()
891                        .set_keepalive_timeout(Seconds::ZERO)
892                        .set_read_buf(1024, 512, 16),
893                ),
894            ),
895            BytesCodec,
896            ntex_util::services::inflight::InFlightService::new(
897                1,
898                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
899                    if let DispatchItem::Item(_) = msg {
900                        sleep(Millis(500)).await;
901                        Ok::<_, ()>(None)
902                    } else {
903                        Ok(None)
904                    }
905                }),
906            ),
907        );
908
909        let (tx, rx) = ntex::channel::oneshot::channel();
910        ntex::rt::spawn(async move {
911            let _ = disp.await;
912            let _ = tx.send(());
913        });
914
915        let bytes = rand::rng()
916            .sample_iter(&rand::distr::Alphanumeric)
917            .take(1024)
918            .map(char::from)
919            .collect::<String>();
920        client.write(bytes.clone());
921        sleep(Millis(25)).await;
922        client.write(bytes);
923        sleep(Millis(25)).await;
924
925        // close read side
926        state.close();
927        let _ = rx.recv().await;
928    }
929
930    #[ntex::test]
931    async fn test_keepalive() {
932        let (client, server) = IoTest::create();
933        client.remote_buffer_cap(1024);
934        client.write("GET /test HTTP/1\r\n\r\n");
935
936        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
937        let data2 = data.clone();
938
939        let cfg = SharedCfg::new("DBG").add(
940            IoConfig::new()
941                .set_disconnect_timeout(Seconds(1))
942                .set_keepalive_timeout(Seconds(1)),
943        );
944
945        let (disp, state) = Dispatcher::debug(
946            Io::new(server, cfg),
947            BytesCodec,
948            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
949                let data = data2.clone();
950                async move {
951                    match msg {
952                        DispatchItem::Item(bytes) => {
953                            data.lock().unwrap().borrow_mut().push(0);
954                            return Ok::<_, ()>(Some(bytes.freeze()));
955                        }
956                        DispatchItem::KeepAliveTimeout => {
957                            data.lock().unwrap().borrow_mut().push(1);
958                        }
959                        _ => (),
960                    }
961                    Ok(None)
962                }
963            }),
964        );
965        spawn(async move {
966            let _ = disp.await;
967        });
968
969        let buf = client.read().await.unwrap();
970        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
971        sleep(Millis(2000)).await;
972
973        // write side must be closed, dispatcher should fail with keep-alive
974        let flags = state.flags();
975        assert!(flags.contains(Flags::IO_STOPPING));
976        assert!(client.is_closed());
977        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
978    }
979
980    #[ntex::test]
981    async fn test_keepalive2() {
982        let (client, server) = IoTest::create();
983        client.remote_buffer_cap(1024);
984
985        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
986        let data2 = data.clone();
987
988        let cfg = SharedCfg::new("DBG").add(
989            IoConfig::new()
990                .set_keepalive_timeout(Seconds(1))
991                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
992        );
993
994        let (disp, state) = Dispatcher::debug(
995            Io::new(server, cfg),
996            BCodec(8),
997            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
998                let data = data2.clone();
999                async move {
1000                    match msg {
1001                        DispatchItem::Item(bytes) => {
1002                            data.lock().unwrap().borrow_mut().push(0);
1003                            return Ok::<_, ()>(Some(bytes.freeze()));
1004                        }
1005                        DispatchItem::KeepAliveTimeout => {
1006                            data.lock().unwrap().borrow_mut().push(1);
1007                        }
1008                        _ => (),
1009                    }
1010                    Ok(None)
1011                }
1012            }),
1013        );
1014        spawn(async move {
1015            let _ = disp.await;
1016        });
1017
1018        client.write("12345678");
1019        let buf = client.read().await.unwrap();
1020        assert_eq!(buf, Bytes::from_static(b"12345678"));
1021        sleep(Millis(2000)).await;
1022
1023        // write side must be closed, dispatcher should fail with keep-alive
1024        let flags = state.flags();
1025        assert!(flags.contains(Flags::IO_STOPPING));
1026        assert!(client.is_closed());
1027        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1028    }
1029
1030    /// Update keep-alive timer after receiving frame
1031    #[ntex::test]
1032    async fn test_keepalive3() {
1033        let (client, server) = IoTest::create();
1034        client.remote_buffer_cap(1024);
1035
1036        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1037        let data2 = data.clone();
1038
1039        let cfg = SharedCfg::new("DBG").add(
1040            IoConfig::new()
1041                .set_keepalive_timeout(Seconds(2))
1042                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1043        );
1044
1045        let (disp, _) = Dispatcher::debug(
1046            Io::new(server, cfg),
1047            BCodec(1),
1048            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1049                let data = data2.clone();
1050                async move {
1051                    match msg {
1052                        DispatchItem::Item(bytes) => {
1053                            data.lock().unwrap().borrow_mut().push(0);
1054                            return Ok::<_, ()>(Some(bytes.freeze()));
1055                        }
1056                        DispatchItem::KeepAliveTimeout => {
1057                            data.lock().unwrap().borrow_mut().push(1);
1058                        }
1059                        _ => (),
1060                    }
1061                    Ok(None)
1062                }
1063            }),
1064        );
1065        spawn(async move {
1066            let _ = disp.await;
1067        });
1068
1069        client.write("1");
1070        let buf = client.read().await.unwrap();
1071        assert_eq!(buf, Bytes::from_static(b"1"));
1072        sleep(Millis(750)).await;
1073
1074        client.write("2");
1075        let buf = client.read().await.unwrap();
1076        assert_eq!(buf, Bytes::from_static(b"2"));
1077
1078        sleep(Millis(750)).await;
1079        client.write("3");
1080        let buf = client.read().await.unwrap();
1081        assert_eq!(buf, Bytes::from_static(b"3"));
1082
1083        sleep(Millis(750)).await;
1084        assert!(!client.is_closed());
1085        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1086    }
1087
1088    #[ntex::test]
1089    async fn test_read_timeout() {
1090        let (client, server) = IoTest::create();
1091        client.remote_buffer_cap(1024);
1092
1093        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1094        let data2 = data.clone();
1095
1096        let io = Io::new(
1097            server,
1098            SharedCfg::new("TEST").add(
1099                IoConfig::new()
1100                    .set_keepalive_timeout(Seconds::ZERO)
1101                    .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1102            ),
1103        );
1104
1105        let (disp, state) = Dispatcher::debug(
1106            io,
1107            BCodec(8),
1108            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1109                let data = data2.clone();
1110                async move {
1111                    match msg {
1112                        DispatchItem::Item(bytes) => {
1113                            data.lock().unwrap().borrow_mut().push(0);
1114                            return Ok::<_, ()>(Some(bytes.freeze()));
1115                        }
1116                        DispatchItem::ReadTimeout => {
1117                            data.lock().unwrap().borrow_mut().push(1);
1118                        }
1119                        _ => (),
1120                    }
1121                    Ok(None)
1122                }
1123            }),
1124        );
1125        spawn(async move {
1126            let _ = disp.await;
1127        });
1128
1129        client.write("12345678");
1130        let buf = client.read().await.unwrap();
1131        assert_eq!(buf, Bytes::from_static(b"12345678"));
1132
1133        client.write("1");
1134        sleep(Millis(1000)).await;
1135        assert!(!state.flags().contains(Flags::IO_STOPPING));
1136        client.write("23");
1137        sleep(Millis(1000)).await;
1138        assert!(!state.flags().contains(Flags::IO_STOPPING));
1139        client.write("4");
1140        sleep(Millis(2000)).await;
1141
1142        // write side must be closed, dispatcher should fail with keep-alive
1143        assert!(state.flags().contains(Flags::IO_STOPPING));
1144        assert!(client.is_closed());
1145        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1146    }
1147
1148    #[ntex::test]
1149    async fn test_idle_timeout() {
1150        let (client, server) = IoTest::create();
1151        client.remote_buffer_cap(1024);
1152
1153        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1154        let data2 = data.clone();
1155
1156        let io = Io::new(
1157            server,
1158            SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1159        );
1160        let ioref = io.get_ref();
1161
1162        let (disp, state) = Dispatcher::debug(
1163            io,
1164            BCodec(1),
1165            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1166                let ioref = ioref.clone();
1167                ntex::rt::spawn(async move {
1168                    sleep(Millis(500)).await;
1169                    ioref.notify_timeout();
1170                });
1171                let data = data2.clone();
1172                async move {
1173                    match msg {
1174                        DispatchItem::Item(bytes) => {
1175                            data.lock().unwrap().borrow_mut().push(0);
1176                            return Ok::<_, ()>(Some(bytes.freeze()));
1177                        }
1178                        DispatchItem::ReadTimeout => {
1179                            data.lock().unwrap().borrow_mut().push(1);
1180                        }
1181                        _ => (),
1182                    }
1183                    Ok(None)
1184                }
1185            }),
1186        );
1187        spawn(async move {
1188            let _ = disp.await;
1189        });
1190
1191        client.write("1");
1192        let buf = client.read().await.unwrap();
1193        assert_eq!(buf, Bytes::from_static(b"1"));
1194
1195        sleep(Millis(1000)).await;
1196        assert!(state.flags().contains(Flags::IO_STOPPING));
1197        assert!(client.is_closed());
1198    }
1199
1200    #[ntex::test]
1201    async fn test_unhandled_data() {
1202        let handled = Arc::new(AtomicBool::new(false));
1203        let handled2 = handled.clone();
1204
1205        let (client, server) = IoTest::create();
1206        client.remote_buffer_cap(1024);
1207        client.write("GET /test HTTP/1\r\n\r\n");
1208
1209        let (disp, _) = Dispatcher::debug(
1210            Io::from(server),
1211            BytesCodec,
1212            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1213                handled2.store(true, Relaxed);
1214                async move {
1215                    sleep(Millis(50)).await;
1216                    if let DispatchItem::Item(msg) = msg {
1217                        Ok::<_, ()>(Some(msg.freeze()))
1218                    } else if let DispatchItem::Disconnect(_) = msg {
1219                        Ok::<_, ()>(None)
1220                    } else {
1221                        panic!()
1222                    }
1223                }
1224            }),
1225        );
1226        client.close().await;
1227        spawn(async move {
1228            let _ = disp.await;
1229        });
1230        sleep(Millis(50)).await;
1231
1232        assert!(handled.load(Relaxed));
1233    }
1234}