ntex_io/
dispatcher.rs

1//! Framed transport dispatcher
2#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14pin_project_lite::pin_project! {
15    /// Dispatcher - is a future that reads frames from bytes stream
16    /// and pass then to the service.
17    pub struct Dispatcher<S, U>
18    where
19        S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20        U: Encoder,
21        U: Decoder,
22        U: 'static,
23    {
24        inner: DispatcherInner<S, U>,
25    }
26}
27
28bitflags::bitflags! {
29    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
30    struct Flags: u8  {
31        const READY_ERR     = 0b0000001;
32        const IO_ERR        = 0b0000010;
33        const KA_ENABLED    = 0b0000100;
34        const KA_TIMEOUT    = 0b0001000;
35        const READ_TIMEOUT  = 0b0010000;
36        const IDLE          = 0b0100000;
37    }
38}
39
40struct DispatcherInner<S, U>
41where
42    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
43    U: Encoder + Decoder + 'static,
44{
45    st: DispatcherState,
46    error: Option<S::Error>,
47    shared: Rc<DispatcherShared<S, U>>,
48    response: Option<PipelineCall<S, DispatchItem<U>>>,
49    read_remains: u32,
50    read_remains_prev: u32,
51    read_max_timeout: Seconds,
52}
53
54pub(crate) struct DispatcherShared<S, U>
55where
56    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
57    U: Encoder + Decoder,
58{
59    io: IoBoxed,
60    codec: U,
61    service: PipelineBinding<S, DispatchItem<U>>,
62    flags: Cell<Flags>,
63    error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
64    inflight: Cell<u32>,
65}
66
67#[derive(Copy, Clone, Debug)]
68enum DispatcherState {
69    Processing,
70    Backpressure,
71    Stop,
72    Shutdown,
73}
74
75#[derive(Debug)]
76enum DispatcherError<S, U> {
77    Encoder(U),
78    Service(S),
79}
80
81enum PollService<U: Encoder + Decoder> {
82    Item(DispatchItem<U>),
83    ItemWait(DispatchItem<U>),
84    Continue,
85    Ready,
86}
87
88impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
89    fn from(err: Either<S, U>) -> Self {
90        match err {
91            Either::Left(err) => DispatcherError::Service(err),
92            Either::Right(err) => DispatcherError::Encoder(err),
93        }
94    }
95}
96
97impl<S, U> Dispatcher<S, U>
98where
99    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
100    U: Decoder + Encoder + 'static,
101{
102    /// Construct new `Dispatcher` instance.
103    pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
104    where
105        IoBoxed: From<Io>,
106        F: IntoService<S, DispatchItem<U>>,
107    {
108        let io = IoBoxed::from(io);
109        let flags = if io.cfg().keepalive_timeout().is_zero() {
110            Flags::empty()
111        } else {
112            Flags::KA_ENABLED
113        };
114
115        let shared = Rc::new(DispatcherShared {
116            io,
117            codec,
118            flags: Cell::new(flags),
119            error: Cell::new(None),
120            inflight: Cell::new(0),
121            service: Pipeline::new(service.into_service()).bind(),
122        });
123
124        Dispatcher {
125            inner: DispatcherInner {
126                shared,
127                response: None,
128                error: None,
129                read_remains: 0,
130                read_remains_prev: 0,
131                read_max_timeout: Seconds::ZERO,
132                st: DispatcherState::Processing,
133            },
134        }
135    }
136}
137
138impl<S, U> DispatcherShared<S, U>
139where
140    S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
141    U: Encoder + Decoder,
142{
143    fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
144        match item {
145            Ok(Some(val)) => {
146                if let Err(err) = io.encode(val, &self.codec) {
147                    self.error.set(Some(DispatcherError::Encoder(err)))
148                }
149            }
150            Err(err) => self.error.set(Some(DispatcherError::Service(err))),
151            Ok(None) => (),
152        }
153        let inflight = self.inflight.get() - 1;
154        self.inflight.set(inflight);
155        if inflight == 0 {
156            self.insert_flags(Flags::IDLE);
157        }
158        if wake {
159            io.wake();
160        }
161    }
162
163    fn contains(&self, f: Flags) -> bool {
164        self.flags.get().intersects(f)
165    }
166
167    fn insert_flags(&self, f: Flags) {
168        let mut flags = self.flags.get();
169        flags.insert(f);
170        self.flags.set(flags);
171    }
172
173    fn remove_flags(&self, f: Flags) -> bool {
174        let mut flags = self.flags.get();
175        if flags.intersects(f) {
176            flags.remove(f);
177            self.flags.set(flags);
178            true
179        } else {
180            false
181        }
182    }
183}
184
185impl<S, U> Future for Dispatcher<S, U>
186where
187    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
188    U: Decoder + Encoder + 'static,
189{
190    type Output = Result<(), S::Error>;
191
192    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193        let mut this = self.as_mut().project();
194        let slf = &mut this.inner;
195
196        // handle service response future
197        if let Some(fut) = slf.response.as_mut() {
198            if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
199                slf.shared.handle_result(item, &slf.shared.io, false);
200                slf.response = None;
201            }
202        }
203
204        loop {
205            match slf.st {
206                DispatcherState::Processing => {
207                    let (item, nowait) = match ready!(slf.poll_service(cx)) {
208                        PollService::Ready => {
209                            // decode incoming bytes if buffer is ready
210                            match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
211                                Ok(decoded) => {
212                                    slf.update_timer(&decoded);
213                                    if let Some(el) = decoded.item {
214                                        (DispatchItem::Item(el), true)
215                                    } else {
216                                        return Poll::Pending;
217                                    }
218                                }
219                                Err(RecvError::KeepAlive) => {
220                                    if let Err(err) = slf.handle_timeout() {
221                                        slf.st = DispatcherState::Stop;
222                                        (err, true)
223                                    } else {
224                                        continue;
225                                    }
226                                }
227                                Err(RecvError::WriteBackpressure) => {
228                                    // instruct write task to notify dispatcher when data is flushed
229                                    slf.st = DispatcherState::Backpressure;
230                                    (DispatchItem::WBackPressureEnabled, true)
231                                }
232                                Err(RecvError::Decoder(err)) => {
233                                    log::trace!(
234                                        "{}: Decoder error, stopping dispatcher: {:?}",
235                                        slf.shared.io.tag(),
236                                        err
237                                    );
238                                    slf.st = DispatcherState::Stop;
239                                    (DispatchItem::DecoderError(err), true)
240                                }
241                                Err(RecvError::PeerGone(err)) => {
242                                    log::trace!(
243                                        "{}: Peer is gone, stopping dispatcher: {:?}",
244                                        slf.shared.io.tag(),
245                                        err
246                                    );
247                                    slf.st = DispatcherState::Stop;
248                                    (DispatchItem::Disconnect(err), true)
249                                }
250                            }
251                        }
252                        PollService::Item(item) => (item, true),
253                        PollService::ItemWait(item) => (item, false),
254                        PollService::Continue => continue,
255                    };
256
257                    slf.call_service(cx, item, nowait);
258                }
259                // handle write back-pressure
260                DispatcherState::Backpressure => {
261                    match ready!(slf.poll_service(cx)) {
262                        PollService::Ready => (),
263                        PollService::Item(item) => slf.call_service(cx, item, true),
264                        PollService::ItemWait(item) => slf.call_service(cx, item, false),
265                        PollService::Continue => continue,
266                    };
267
268                    let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
269                    {
270                        slf.st = DispatcherState::Stop;
271                        DispatchItem::Disconnect(Some(err))
272                    } else {
273                        slf.st = DispatcherState::Processing;
274                        DispatchItem::WBackPressureDisabled
275                    };
276                    slf.call_service(cx, item, false);
277                }
278                // drain service responses and shutdown io
279                DispatcherState::Stop => {
280                    slf.shared.io.stop_timer();
281
282                    // service may relay on poll_ready for response results
283                    if !slf.shared.contains(Flags::READY_ERR) {
284                        if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
285                            if res.is_err() {
286                                slf.shared.insert_flags(Flags::READY_ERR);
287                            }
288                        }
289                    }
290
291                    if slf.shared.inflight.get() == 0 {
292                        if slf.shared.io.poll_shutdown(cx).is_ready() {
293                            slf.st = DispatcherState::Shutdown;
294                            continue;
295                        }
296                    } else if !slf.shared.contains(Flags::IO_ERR) {
297                        match ready!(slf.shared.io.poll_status_update(cx)) {
298                            IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
299                                slf.shared.insert_flags(Flags::IO_ERR);
300                                continue;
301                            }
302                            IoStatusUpdate::WriteBackpressure => {
303                                if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
304                                    slf.shared.insert_flags(Flags::IO_ERR);
305                                }
306                                continue;
307                            }
308                        }
309                    } else {
310                        slf.shared.io.poll_dispatch(cx);
311                    }
312                    return Poll::Pending;
313                }
314                // shutdown service
315                DispatcherState::Shutdown => {
316                    return if slf.shared.service.poll_shutdown(cx).is_ready() {
317                        log::trace!(
318                            "{}: Service shutdown is completed, stop",
319                            slf.shared.io.tag()
320                        );
321
322                        Poll::Ready(if let Some(err) = slf.error.take() {
323                            Err(err)
324                        } else {
325                            Ok(())
326                        })
327                    } else {
328                        Poll::Pending
329                    };
330                }
331            }
332        }
333    }
334}
335
336impl<S, U> DispatcherInner<S, U>
337where
338    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
339    U: Decoder + Encoder + 'static,
340{
341    fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>, nowait: bool) {
342        let mut fut = if nowait {
343            self.shared.service.call_nowait(item)
344        } else {
345            self.shared.service.call(item)
346        };
347        let inflight = self.shared.inflight.get() + 1;
348        self.shared.inflight.set(inflight);
349        if inflight == 1 {
350            self.shared.remove_flags(Flags::IDLE);
351        }
352
353        // optimize first call
354        if self.response.is_none() {
355            if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
356                self.shared.handle_result(result, &self.shared.io, false);
357            } else {
358                self.response = Some(fut);
359            }
360        } else {
361            let shared = self.shared.clone();
362            spawn(async move {
363                let result = fut.await;
364                shared.handle_result(result, &shared.io, true);
365            });
366        }
367    }
368
369    fn check_error(&mut self) -> PollService<U> {
370        // check for errors
371        if let Some(err) = self.shared.error.take() {
372            log::trace!(
373                "{}: Error occured, stopping dispatcher",
374                self.shared.io.tag()
375            );
376            self.st = DispatcherState::Stop;
377
378            match err {
379                DispatcherError::Encoder(err) => {
380                    PollService::Item(DispatchItem::EncoderError(err))
381                }
382                DispatcherError::Service(err) => {
383                    self.error = Some(err);
384                    PollService::Continue
385                }
386            }
387        } else {
388            PollService::Ready
389        }
390    }
391
392    fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
393        // wait until service becomes ready
394        match self.shared.service.poll_ready(cx) {
395            Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
396            // pause io read task
397            Poll::Pending => {
398                log::trace!(
399                    "{}: Service is not ready, register dispatcher",
400                    self.shared.io.tag()
401                );
402
403                // remove all timers
404                self.shared
405                    .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
406                self.shared.io.stop_timer();
407
408                match ready!(self.shared.io.poll_read_pause(cx)) {
409                    IoStatusUpdate::KeepAlive => {
410                        log::trace!(
411                            "{}: Keep-alive error, stopping dispatcher during pause",
412                            self.shared.io.tag()
413                        );
414                        self.st = DispatcherState::Stop;
415                        Poll::Ready(PollService::ItemWait(DispatchItem::KeepAliveTimeout))
416                    }
417                    IoStatusUpdate::PeerGone(err) => {
418                        log::trace!(
419                            "{}: Peer is gone during pause, stopping dispatcher: {:?}",
420                            self.shared.io.tag(),
421                            err
422                        );
423                        self.st = DispatcherState::Stop;
424                        Poll::Ready(PollService::ItemWait(DispatchItem::Disconnect(err)))
425                    }
426                    IoStatusUpdate::WriteBackpressure => {
427                        self.st = DispatcherState::Backpressure;
428                        Poll::Ready(PollService::ItemWait(
429                            DispatchItem::WBackPressureEnabled,
430                        ))
431                    }
432                }
433            }
434            // handle service readiness error
435            Poll::Ready(Err(err)) => {
436                log::trace!(
437                    "{}: Service readiness check failed, stopping",
438                    self.shared.io.tag()
439                );
440                self.st = DispatcherState::Stop;
441                self.error = Some(err);
442                self.shared.insert_flags(Flags::READY_ERR);
443                Poll::Ready(PollService::Continue)
444            }
445        }
446    }
447
448    fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
449        // got parsed frame
450        if decoded.item.is_some() {
451            self.read_remains = 0;
452            self.shared
453                .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
454        } else if self.shared.contains(Flags::READ_TIMEOUT) {
455            // received new data but not enough for parsing complete frame
456            self.read_remains = decoded.remains as u32;
457        } else if self.read_remains == 0 && decoded.remains == 0 {
458            // no new data, start keep-alive timer
459            if self.shared.contains(Flags::KA_ENABLED)
460                && !self.shared.contains(Flags::KA_TIMEOUT)
461            {
462                log::trace!(
463                    "{}: Start keep-alive timer {:?}",
464                    self.shared.io.tag(),
465                    self.shared.io.cfg().keepalive_timeout()
466                );
467                self.shared.insert_flags(Flags::KA_TIMEOUT);
468                self.shared
469                    .io
470                    .start_timer(self.shared.io.cfg().keepalive_timeout());
471            }
472        } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
473            // we got new data but not enough to parse single frame
474            // start read timer
475            self.shared.insert_flags(Flags::READ_TIMEOUT);
476
477            self.read_remains = decoded.remains as u32;
478            self.read_remains_prev = 0;
479            self.read_max_timeout = params.max_timeout;
480            self.shared.io.start_timer(params.timeout);
481        }
482    }
483
484    fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
485        // check read timer
486        if self.shared.contains(Flags::READ_TIMEOUT) {
487            if let Some(params) = self.shared.io.cfg().frame_read_rate() {
488                let total = self.read_remains - self.read_remains_prev;
489
490                // read rate, start timer for next period
491                if total > params.rate {
492                    self.read_remains_prev = self.read_remains;
493                    self.read_remains = 0;
494
495                    if !params.max_timeout.is_zero() {
496                        self.read_max_timeout = Seconds(
497                            self.read_max_timeout.0.saturating_sub(params.timeout.0),
498                        );
499                    }
500
501                    if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
502                        log::trace!(
503                            "{}: Frame read rate {:?}, extend timer",
504                            self.shared.io.tag(),
505                            total
506                        );
507                        self.shared.io.start_timer(params.timeout);
508                        return Ok(());
509                    }
510                    log::trace!(
511                        "{}: Max payload timeout has been reached",
512                        self.shared.io.tag()
513                    );
514                }
515                Err(DispatchItem::ReadTimeout)
516            } else {
517                Ok(())
518            }
519        } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
520            log::trace!(
521                "{}: Keep-alive error, stopping dispatcher",
522                self.shared.io.tag()
523            );
524            Err(DispatchItem::KeepAliveTimeout)
525        } else {
526            Ok(())
527        }
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
534    use std::{cell::RefCell, io};
535
536    use ntex_bytes::{Bytes, BytesMut};
537    use ntex_codec::BytesCodec;
538    use ntex_service::{ServiceCtx, cfg::SharedCfg};
539    use ntex_util::{time::Millis, time::sleep};
540    use rand::Rng;
541
542    use super::*;
543    use crate::{Flags, Io, IoConfig, IoRef, testing::IoTest};
544
545    pub(crate) struct State(IoRef);
546
547    impl State {
548        fn flags(&self) -> Flags {
549            self.0.flags()
550        }
551
552        fn io(&self) -> &IoRef {
553            &self.0
554        }
555
556        fn close(&self) {
557            self.0.close();
558        }
559    }
560
561    #[derive(Copy, Clone)]
562    struct BCodec(usize);
563
564    impl Encoder for BCodec {
565        type Item = Bytes;
566        type Error = io::Error;
567
568        fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
569            dst.extend_from_slice(&item[..]);
570            Ok(())
571        }
572    }
573
574    impl Decoder for BCodec {
575        type Item = BytesMut;
576        type Error = io::Error;
577
578        fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
579            if src.len() < self.0 {
580                Ok(None)
581            } else {
582                Ok(Some(src.split_to(self.0)))
583            }
584        }
585    }
586
587    impl<S, U> Dispatcher<S, U>
588    where
589        S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
590        U: Decoder + Encoder + 'static,
591    {
592        /// Construct new `Dispatcher` instance
593        pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
594            let flags = if io.cfg().keepalive_timeout().is_zero() {
595                super::Flags::empty()
596            } else {
597                super::Flags::KA_ENABLED
598            };
599
600            let inner = State(io.get_ref());
601            io.start_timer(Seconds::ONE);
602
603            let shared = Rc::new(DispatcherShared {
604                codec,
605                io: io.into(),
606                flags: Cell::new(flags),
607                error: Cell::new(None),
608                inflight: Cell::new(0),
609                service: Pipeline::new(service).bind(),
610            });
611
612            (
613                Dispatcher {
614                    inner: DispatcherInner {
615                        shared,
616                        error: None,
617                        st: DispatcherState::Processing,
618                        response: None,
619                        read_remains: 0,
620                        read_remains_prev: 0,
621                        read_max_timeout: Seconds::ZERO,
622                    },
623                },
624                inner,
625            )
626        }
627    }
628
629    #[ntex::test]
630    async fn basics() {
631        let (client, server) = IoTest::create();
632        client.remote_buffer_cap(1024);
633        client.write("GET /test HTTP/1\r\n\r\n");
634
635        let (disp, _) = Dispatcher::debug(
636            Io::from(server),
637            BytesCodec,
638            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
639                sleep(Millis(50)).await;
640                if let DispatchItem::Item(msg) = msg {
641                    Ok::<_, ()>(Some(msg.freeze()))
642                } else {
643                    panic!()
644                }
645            }),
646        );
647        spawn(async move {
648            let _ = disp.await;
649        });
650
651        sleep(Millis(25)).await;
652        let buf = client.read().await.unwrap();
653        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
654
655        client.write("GET /test HTTP/1\r\n\r\n");
656        let buf = client.read().await.unwrap();
657        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
658
659        client.close().await;
660        assert!(client.is_server_dropped());
661
662        assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
663    }
664
665    #[ntex::test]
666    async fn sink() {
667        let (client, server) = IoTest::create();
668        client.remote_buffer_cap(1024);
669        client.write("GET /test HTTP/1\r\n\r\n");
670
671        let (disp, st) = Dispatcher::debug(
672            Io::from(server),
673            BytesCodec,
674            ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
675                if let DispatchItem::Item(msg) = msg {
676                    Ok::<_, ()>(Some(msg.freeze()))
677                } else if let DispatchItem::Disconnect(_) = msg {
678                    Ok(None)
679                } else {
680                    panic!()
681                }
682            }),
683        );
684        spawn(async move {
685            let _ = disp.await;
686        });
687
688        let buf = client.read().await.unwrap();
689        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
690
691        assert!(
692            st.io()
693                .encode(Bytes::from_static(b"test"), &BytesCodec)
694                .is_ok()
695        );
696        let buf = client.read().await.unwrap();
697        assert_eq!(buf, Bytes::from_static(b"test"));
698
699        st.close();
700        sleep(Millis(1500)).await;
701        assert!(client.is_server_dropped());
702    }
703
704    #[ntex::test]
705    async fn err_in_service() {
706        let (client, server) = IoTest::create();
707        client.remote_buffer_cap(0);
708        client.write("GET /test HTTP/1\r\n\r\n");
709
710        let (disp, state) = Dispatcher::debug(
711            Io::from(server),
712            BytesCodec,
713            ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
714                Err::<Option<Bytes>, _>(())
715            }),
716        );
717        state
718            .io()
719            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
720            .unwrap();
721        spawn(async move {
722            let _ = disp.await;
723        });
724
725        // buffer should be flushed
726        client.remote_buffer_cap(1024);
727        let buf = client.read().await.unwrap();
728        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
729
730        // write side must be closed, dispatcher waiting for read side to close
731        sleep(Millis(250)).await;
732        assert!(client.is_closed());
733
734        // close read side
735        client.close().await;
736
737        // dispatcher is closed
738        assert!(client.is_server_dropped());
739    }
740
741    #[ntex::test]
742    async fn err_in_service_ready() {
743        let (client, server) = IoTest::create();
744        client.remote_buffer_cap(0);
745        client.write("GET /test HTTP/1\r\n\r\n");
746
747        let counter = Rc::new(Cell::new(0));
748
749        struct Srv(Rc<Cell<usize>>);
750
751        impl Service<DispatchItem<BytesCodec>> for Srv {
752            type Response = Option<Response<BytesCodec>>;
753            type Error = &'static str;
754
755            async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
756                self.0.set(self.0.get() + 1);
757                Err("test")
758            }
759
760            async fn call(
761                &self,
762                _: DispatchItem<BytesCodec>,
763                _: ServiceCtx<'_, Self>,
764            ) -> Result<Self::Response, Self::Error> {
765                Ok(None)
766            }
767        }
768
769        let (disp, state) =
770            Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
771        spawn(async move {
772            let res = disp.await;
773            assert_eq!(res, Err("test"));
774        });
775
776        state
777            .io()
778            .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
779            .unwrap();
780
781        // buffer should be flushed
782        client.remote_buffer_cap(1024);
783        let buf = client.read().await.unwrap();
784        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
785
786        // write side must be closed, dispatcher waiting for read side to close
787        sleep(Millis(250)).await;
788        assert!(client.is_closed());
789
790        // close read side
791        client.close().await;
792        assert!(client.is_server_dropped());
793
794        // service must be checked for readiness only once
795        assert_eq!(counter.get(), 1);
796    }
797
798    #[ntex::test]
799    async fn write_backpressure() {
800        let (client, server) = IoTest::create();
801        // do not allow to write to socket
802        client.remote_buffer_cap(0);
803        client.write("GET /test HTTP/1\r\n\r\n");
804
805        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
806        let data2 = data.clone();
807
808        let io = Io::new(
809            server,
810            SharedCfg::new("TEST").add(
811                IoConfig::new()
812                    .set_read_buf(8 * 1024, 1024, 16)
813                    .set_write_buf(16 * 1024, 1024, 16),
814            ),
815        );
816
817        let (disp, state) = Dispatcher::debug(
818            io,
819            BytesCodec,
820            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
821                let data = data2.clone();
822                async move {
823                    match msg {
824                        DispatchItem::Item(_) => {
825                            data.lock().unwrap().borrow_mut().push(0);
826                            let bytes = rand::rng()
827                                .sample_iter(&rand::distr::Alphanumeric)
828                                .take(65_536)
829                                .map(char::from)
830                                .collect::<String>();
831                            return Ok::<_, ()>(Some(Bytes::from(bytes)));
832                        }
833                        DispatchItem::WBackPressureEnabled => {
834                            data.lock().unwrap().borrow_mut().push(1);
835                        }
836                        DispatchItem::WBackPressureDisabled => {
837                            data.lock().unwrap().borrow_mut().push(2);
838                        }
839                        _ => (),
840                    }
841                    Ok(None)
842                }
843            }),
844        );
845
846        spawn(async move {
847            let _ = disp.await;
848        });
849
850        let buf = client.read_any();
851        assert_eq!(buf, Bytes::from_static(b""));
852        client.write("GET /test HTTP/1\r\n\r\n");
853        sleep(Millis(25)).await;
854
855        // buf must be consumed
856        assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
857
858        // response message
859        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
860
861        client.remote_buffer_cap(10240);
862        sleep(Millis(50)).await;
863        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
864
865        client.remote_buffer_cap(48056);
866        sleep(Millis(50)).await;
867        assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
868
869        // backpressure disabled
870        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
871    }
872
873    #[ntex::test]
874    async fn disconnect_during_read_backpressure() {
875        let (client, server) = IoTest::create();
876        client.remote_buffer_cap(0);
877
878        let (disp, state) = Dispatcher::debug(
879            Io::new(
880                server,
881                SharedCfg::new("TEST").add(
882                    IoConfig::new()
883                        .set_keepalive_timeout(Seconds::ZERO)
884                        .set_read_buf(1024, 512, 16),
885                ),
886            ),
887            BytesCodec,
888            ntex_util::services::inflight::InFlightService::new(
889                1,
890                ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
891                    if let DispatchItem::Item(_) = msg {
892                        sleep(Millis(500)).await;
893                        Ok::<_, ()>(None)
894                    } else {
895                        Ok(None)
896                    }
897                }),
898            ),
899        );
900
901        let (tx, rx) = ntex::channel::oneshot::channel();
902        ntex::rt::spawn(async move {
903            let _ = disp.await;
904            let _ = tx.send(());
905        });
906
907        let bytes = rand::rng()
908            .sample_iter(&rand::distr::Alphanumeric)
909            .take(1024)
910            .map(char::from)
911            .collect::<String>();
912        client.write(bytes.clone());
913        sleep(Millis(25)).await;
914        client.write(bytes);
915        sleep(Millis(25)).await;
916
917        // close read side
918        state.close();
919        let _ = rx.recv().await;
920    }
921
922    #[ntex::test]
923    async fn keepalive() {
924        let (client, server) = IoTest::create();
925        client.remote_buffer_cap(1024);
926        client.write("GET /test HTTP/1\r\n\r\n");
927
928        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
929        let data2 = data.clone();
930
931        let cfg = SharedCfg::new("DBG").add(
932            IoConfig::new()
933                .set_disconnect_timeout(Seconds(1))
934                .set_keepalive_timeout(Seconds(1)),
935        );
936
937        let (disp, state) = Dispatcher::debug(
938            Io::new(server, cfg),
939            BytesCodec,
940            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
941                let data = data2.clone();
942                async move {
943                    match msg {
944                        DispatchItem::Item(bytes) => {
945                            data.lock().unwrap().borrow_mut().push(0);
946                            return Ok::<_, ()>(Some(bytes.freeze()));
947                        }
948                        DispatchItem::KeepAliveTimeout => {
949                            data.lock().unwrap().borrow_mut().push(1);
950                        }
951                        _ => (),
952                    }
953                    Ok(None)
954                }
955            }),
956        );
957        spawn(async move {
958            let _ = disp.await;
959        });
960
961        let buf = client.read().await.unwrap();
962        assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
963        sleep(Millis(2000)).await;
964
965        // write side must be closed, dispatcher should fail with keep-alive
966        let flags = state.flags();
967        assert!(flags.contains(Flags::IO_STOPPING));
968        assert!(client.is_closed());
969        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
970    }
971
972    #[ntex::test]
973    async fn keepalive2() {
974        let (client, server) = IoTest::create();
975        client.remote_buffer_cap(1024);
976
977        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
978        let data2 = data.clone();
979
980        let cfg = SharedCfg::new("DBG").add(
981            IoConfig::new()
982                .set_keepalive_timeout(Seconds(1))
983                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
984        );
985
986        let (disp, state) = Dispatcher::debug(
987            Io::new(server, cfg),
988            BCodec(8),
989            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
990                let data = data2.clone();
991                async move {
992                    match msg {
993                        DispatchItem::Item(bytes) => {
994                            data.lock().unwrap().borrow_mut().push(0);
995                            return Ok::<_, ()>(Some(bytes.freeze()));
996                        }
997                        DispatchItem::KeepAliveTimeout => {
998                            data.lock().unwrap().borrow_mut().push(1);
999                        }
1000                        _ => (),
1001                    }
1002                    Ok(None)
1003                }
1004            }),
1005        );
1006        spawn(async move {
1007            let _ = disp.await;
1008        });
1009
1010        client.write("12345678");
1011        let buf = client.read().await.unwrap();
1012        assert_eq!(buf, Bytes::from_static(b"12345678"));
1013        sleep(Millis(2000)).await;
1014
1015        // write side must be closed, dispatcher should fail with keep-alive
1016        let flags = state.flags();
1017        assert!(flags.contains(Flags::IO_STOPPING));
1018        assert!(client.is_closed());
1019        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1020    }
1021
1022    /// Update keep-alive timer after receiving frame
1023    #[ntex::test]
1024    async fn keepalive3() {
1025        let (client, server) = IoTest::create();
1026        client.remote_buffer_cap(1024);
1027
1028        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1029        let data2 = data.clone();
1030
1031        let cfg = SharedCfg::new("DBG").add(
1032            IoConfig::new()
1033                .set_keepalive_timeout(Seconds(2))
1034                .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1035        );
1036
1037        let (disp, _) = Dispatcher::debug(
1038            Io::new(server, cfg),
1039            BCodec(1),
1040            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1041                let data = data2.clone();
1042                async move {
1043                    match msg {
1044                        DispatchItem::Item(bytes) => {
1045                            data.lock().unwrap().borrow_mut().push(0);
1046                            return Ok::<_, ()>(Some(bytes.freeze()));
1047                        }
1048                        DispatchItem::KeepAliveTimeout => {
1049                            data.lock().unwrap().borrow_mut().push(1);
1050                        }
1051                        _ => (),
1052                    }
1053                    Ok(None)
1054                }
1055            }),
1056        );
1057        spawn(async move {
1058            let _ = disp.await;
1059        });
1060
1061        client.write("1");
1062        let buf = client.read().await.unwrap();
1063        assert_eq!(buf, Bytes::from_static(b"1"));
1064        sleep(Millis(750)).await;
1065
1066        client.write("2");
1067        let buf = client.read().await.unwrap();
1068        assert_eq!(buf, Bytes::from_static(b"2"));
1069
1070        sleep(Millis(750)).await;
1071        client.write("3");
1072        let buf = client.read().await.unwrap();
1073        assert_eq!(buf, Bytes::from_static(b"3"));
1074
1075        sleep(Millis(750)).await;
1076        assert!(!client.is_closed());
1077        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1078    }
1079
1080    #[ntex::test]
1081    async fn read_timeout() {
1082        let (client, server) = IoTest::create();
1083        client.remote_buffer_cap(1024);
1084
1085        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1086        let data2 = data.clone();
1087
1088        let io = Io::new(
1089            server,
1090            SharedCfg::new("TEST").add(
1091                IoConfig::new()
1092                    .set_keepalive_timeout(Seconds::ZERO)
1093                    .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1094            ),
1095        );
1096
1097        let (disp, state) = Dispatcher::debug(
1098            io,
1099            BCodec(8),
1100            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1101                let data = data2.clone();
1102                async move {
1103                    match msg {
1104                        DispatchItem::Item(bytes) => {
1105                            data.lock().unwrap().borrow_mut().push(0);
1106                            return Ok::<_, ()>(Some(bytes.freeze()));
1107                        }
1108                        DispatchItem::ReadTimeout => {
1109                            data.lock().unwrap().borrow_mut().push(1);
1110                        }
1111                        _ => (),
1112                    }
1113                    Ok(None)
1114                }
1115            }),
1116        );
1117        spawn(async move {
1118            let _ = disp.await;
1119        });
1120
1121        client.write("12345678");
1122        let buf = client.read().await.unwrap();
1123        assert_eq!(buf, Bytes::from_static(b"12345678"));
1124
1125        client.write("1");
1126        sleep(Millis(1000)).await;
1127        assert!(!state.flags().contains(Flags::IO_STOPPING));
1128        client.write("23");
1129        sleep(Millis(1000)).await;
1130        assert!(!state.flags().contains(Flags::IO_STOPPING));
1131        client.write("4");
1132        sleep(Millis(2000)).await;
1133
1134        // write side must be closed, dispatcher should fail with keep-alive
1135        assert!(state.flags().contains(Flags::IO_STOPPING));
1136        assert!(client.is_closed());
1137        assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1138    }
1139
1140    #[ntex::test]
1141    async fn idle_timeout() {
1142        let (client, server) = IoTest::create();
1143        client.remote_buffer_cap(1024);
1144
1145        let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1146        let data2 = data.clone();
1147
1148        let io = Io::new(
1149            server,
1150            SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1151        );
1152        let ioref = io.get_ref();
1153
1154        let (disp, state) = Dispatcher::debug(
1155            io,
1156            BCodec(1),
1157            ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1158                let ioref = ioref.clone();
1159                ntex::rt::spawn(async move {
1160                    sleep(Millis(500)).await;
1161                    ioref.notify_timeout();
1162                });
1163                let data = data2.clone();
1164                async move {
1165                    match msg {
1166                        DispatchItem::Item(bytes) => {
1167                            data.lock().unwrap().borrow_mut().push(0);
1168                            return Ok::<_, ()>(Some(bytes.freeze()));
1169                        }
1170                        DispatchItem::ReadTimeout => {
1171                            data.lock().unwrap().borrow_mut().push(1);
1172                        }
1173                        _ => (),
1174                    }
1175                    Ok(None)
1176                }
1177            }),
1178        );
1179        spawn(async move {
1180            let _ = disp.await;
1181        });
1182
1183        client.write("1");
1184        let buf = client.read().await.unwrap();
1185        assert_eq!(buf, Bytes::from_static(b"1"));
1186
1187        sleep(Millis(1000)).await;
1188        assert!(state.flags().contains(Flags::IO_STOPPING));
1189        assert!(client.is_closed());
1190    }
1191
1192    #[ntex::test]
1193    async fn unhandled_data() {
1194        let handled = Arc::new(AtomicBool::new(false));
1195        let handled2 = handled.clone();
1196
1197        let (client, server) = IoTest::create();
1198        client.remote_buffer_cap(1024);
1199        client.write("GET /test HTTP/1\r\n\r\n");
1200
1201        let (disp, _) = Dispatcher::debug(
1202            Io::from(server),
1203            BytesCodec,
1204            ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1205                handled2.store(true, Relaxed);
1206                async move {
1207                    sleep(Millis(50)).await;
1208                    if let DispatchItem::Item(msg) = msg {
1209                        Ok::<_, ()>(Some(msg.freeze()))
1210                    } else if let DispatchItem::Disconnect(_) = msg {
1211                        Ok::<_, ()>(None)
1212                    } else {
1213                        panic!()
1214                    }
1215                }
1216            }),
1217        );
1218        client.close().await;
1219        spawn(async move {
1220            let _ = disp.await;
1221        });
1222        sleep(Millis(50)).await;
1223
1224        assert!(handled.load(Relaxed));
1225    }
1226}