1#![deny(clippy::pedantic)]
3#![allow(clippy::cast_possible_truncation)]
4use std::task::{Context, Poll, ready};
5use std::{cell::Cell, fmt, future::Future, io, pin::Pin, rc::Rc};
6
7use ntex_codec::{Decoder, Encoder};
8use ntex_io::{Decoded, IoBoxed, IoStatusUpdate, RecvError};
9use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
10use ntex_util::{future::Either, spawn, time::Seconds};
11
12type Response<U> = <U as Encoder>::Item;
13
14pub enum DispatchItem<U: Encoder + Decoder> {
16 Item(<U as Decoder>::Item),
18 Control(Control),
20 Stop(Reason<U>),
22}
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub enum Control {
27 WBackPressureEnabled,
29 WBackPressureDisabled,
31}
32
33pub enum Reason<U: Encoder + Decoder> {
35 Io(Option<io::Error>),
37 Encoder(<U as Encoder>::Error),
39 Decoder(<U as Decoder>::Error),
41 KeepAliveTimeout,
43 ReadTimeout,
45}
46
47pin_project_lite::pin_project! {
48 pub struct Dispatcher<S, U>
51 where
52 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
53 U: Encoder,
54 U: Decoder,
55 U: 'static,
56 {
57 inner: DispatcherInner<S, U>,
58 }
59}
60
61bitflags::bitflags! {
62 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
63 struct Flags: u8 {
64 const READY_ERR = 0b000_0001;
65 const IO_ERR = 0b000_0010;
66 const KA_ENABLED = 0b000_0100;
67 const KA_TIMEOUT = 0b000_1000;
68 const READ_TIMEOUT = 0b001_0000;
69 const IDLE = 0b010_0000;
70 }
71}
72
73struct DispatcherInner<S, U>
74where
75 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
76 U: Encoder + Decoder + 'static,
77{
78 st: DispatcherState,
79 error: Option<S::Error>,
80 shared: Rc<DispatcherShared<S, U>>,
81 response: Option<PipelineCall<S, DispatchItem<U>>>,
82 read_remains: u32,
83 read_remains_prev: u32,
84 read_max_timeout: Seconds,
85}
86
87pub(crate) struct DispatcherShared<S, U>
88where
89 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
90 U: Encoder + Decoder,
91{
92 io: IoBoxed,
93 codec: U,
94 service: PipelineBinding<S, DispatchItem<U>>,
95 flags: Cell<Flags>,
96 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
97 inflight: Cell<u32>,
98}
99
100#[derive(Copy, Clone, Debug)]
101enum DispatcherState {
102 Processing,
103 Backpressure,
104 Stop,
105 Shutdown,
106}
107
108#[derive(Debug)]
109enum DispatcherError<S, U> {
110 Encoder(U),
111 Service(S),
112}
113
114enum PollService<U: Encoder + Decoder> {
115 Item(DispatchItem<U>),
116 ItemWait(DispatchItem<U>),
117 Continue,
118 Ready,
119}
120
121impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
122 fn from(err: Either<S, U>) -> Self {
123 match err {
124 Either::Left(err) => DispatcherError::Service(err),
125 Either::Right(err) => DispatcherError::Encoder(err),
126 }
127 }
128}
129
130impl<S, U> Dispatcher<S, U>
131where
132 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
133 U: Decoder + Encoder + 'static,
134{
135 pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
137 where
138 IoBoxed: From<Io>,
139 F: IntoService<S, DispatchItem<U>>,
140 {
141 let io = IoBoxed::from(io);
142 let flags = if io.cfg().keepalive_timeout().is_zero() {
143 Flags::empty()
144 } else {
145 Flags::KA_ENABLED
146 };
147
148 let shared = Rc::new(DispatcherShared {
149 io,
150 codec,
151 flags: Cell::new(flags),
152 error: Cell::new(None),
153 inflight: Cell::new(0),
154 service: Pipeline::new(service.into_service()).bind(),
155 });
156
157 Dispatcher {
158 inner: DispatcherInner {
159 shared,
160 response: None,
161 error: None,
162 read_remains: 0,
163 read_remains_prev: 0,
164 read_max_timeout: Seconds::ZERO,
165 st: DispatcherState::Processing,
166 },
167 }
168 }
169}
170
171impl<S, U> DispatcherShared<S, U>
172where
173 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
174 U: Encoder + Decoder,
175{
176 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
177 match item {
178 Ok(Some(val)) => {
179 if let Err(err) = io.encode(val, &self.codec) {
180 self.error.set(Some(DispatcherError::Encoder(err)));
181 }
182 }
183 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
184 Ok(None) => (),
185 }
186 let inflight = self.inflight.get() - 1;
187 self.inflight.set(inflight);
188 if inflight == 0 {
189 self.insert_flags(Flags::IDLE);
190 }
191 if wake {
192 io.wake();
193 }
194 }
195
196 fn contains(&self, f: Flags) -> bool {
197 self.flags.get().intersects(f)
198 }
199
200 fn insert_flags(&self, f: Flags) {
201 let mut flags = self.flags.get();
202 flags.insert(f);
203 self.flags.set(flags);
204 }
205
206 fn remove_flags(&self, f: Flags) -> bool {
207 let mut flags = self.flags.get();
208 if flags.intersects(f) {
209 flags.remove(f);
210 self.flags.set(flags);
211 true
212 } else {
213 false
214 }
215 }
216}
217
218impl<S, U> Future for Dispatcher<S, U>
219where
220 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
221 U: Decoder + Encoder + 'static,
222{
223 type Output = Result<(), S::Error>;
224
225 #[allow(clippy::too_many_lines)]
226 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227 let this = self.as_mut().project();
228 let inner = this.inner;
229
230 if let Some(fut) = inner.response.as_mut()
232 && let Poll::Ready(item) = Pin::new(fut).poll(cx)
233 {
234 inner.shared.handle_result(item, &inner.shared.io, false);
235 inner.response = None;
236 }
237
238 loop {
239 match inner.st {
240 DispatcherState::Processing => {
241 let (item, nowait) = match ready!(inner.poll_service(cx)) {
242 PollService::Ready => {
243 match inner.shared.io.poll_recv_decode(&inner.shared.codec, cx)
245 {
246 Ok(decoded) => {
247 inner.update_timer(&decoded);
248 if let Some(el) = decoded.item {
249 (DispatchItem::Item(el), true)
250 } else {
251 return Poll::Pending;
252 }
253 }
254 Err(RecvError::KeepAlive) => {
255 if let Err(ctl) = inner.handle_timeout() {
256 inner.st = DispatcherState::Stop;
257 (DispatchItem::Stop(ctl), true)
258 } else {
259 continue;
260 }
261 }
262 Err(RecvError::WriteBackpressure) => {
263 inner.st = DispatcherState::Backpressure;
265 (
266 DispatchItem::Control(
267 Control::WBackPressureEnabled,
268 ),
269 true,
270 )
271 }
272 Err(RecvError::Decoder(err)) => {
273 log::trace!(
274 "{}: Decoder error, stopping dispatcher: {:?}",
275 inner.shared.io.tag(),
276 err
277 );
278 inner.st = DispatcherState::Stop;
279 (DispatchItem::Stop(Reason::Decoder(err)), true)
280 }
281 Err(RecvError::PeerGone(err)) => {
282 log::trace!(
283 "{}: Peer is gone, stopping dispatcher: {:?}",
284 inner.shared.io.tag(),
285 err
286 );
287 inner.st = DispatcherState::Stop;
288 (DispatchItem::Stop(Reason::Io(err)), true)
289 }
290 }
291 }
292 PollService::Item(item) => (item, true),
293 PollService::ItemWait(item) => (item, false),
294 PollService::Continue => continue,
295 };
296
297 inner.call_service(cx, item, nowait);
298 }
299 DispatcherState::Backpressure => {
301 match ready!(inner.poll_service(cx)) {
302 PollService::Ready => (),
303 PollService::Item(item) => inner.call_service(cx, item, true),
304 PollService::ItemWait(item) => inner.call_service(cx, item, false),
305 PollService::Continue => continue,
306 }
307
308 let item =
309 if let Err(err) = ready!(inner.shared.io.poll_flush(cx, false)) {
310 inner.st = DispatcherState::Stop;
311 DispatchItem::Stop(Reason::Io(Some(err)))
312 } else {
313 inner.st = DispatcherState::Processing;
314 DispatchItem::Control(Control::WBackPressureDisabled)
315 };
316 inner.call_service(cx, item, false);
317 }
318 DispatcherState::Stop => {
320 inner.shared.io.stop_timer();
321
322 if !inner.shared.contains(Flags::READY_ERR)
324 && let Poll::Ready(res) = inner.shared.service.poll_ready(cx)
325 && res.is_err()
326 {
327 inner.shared.insert_flags(Flags::READY_ERR);
328 }
329
330 if inner.shared.inflight.get() == 0 {
331 if inner.shared.io.poll_shutdown(cx).is_ready() {
332 inner.st = DispatcherState::Shutdown;
333 continue;
334 }
335 } else if !inner.shared.contains(Flags::IO_ERR) {
336 match ready!(inner.shared.io.poll_status_update(cx)) {
337 IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
338 inner.shared.insert_flags(Flags::IO_ERR);
339 continue;
340 }
341 IoStatusUpdate::WriteBackpressure => {
342 if ready!(inner.shared.io.poll_flush(cx, true)).is_err() {
343 inner.shared.insert_flags(Flags::IO_ERR);
344 }
345 continue;
346 }
347 }
348 } else {
349 inner.shared.io.poll_dispatch(cx);
350 }
351 return Poll::Pending;
352 }
353 DispatcherState::Shutdown => {
355 return if inner.shared.service.poll_shutdown(cx).is_ready() {
356 log::trace!(
357 "{}: Service shutdown is completed, stop",
358 inner.shared.io.tag()
359 );
360
361 Poll::Ready(if let Some(err) = inner.error.take() {
362 Err(err)
363 } else {
364 Ok(())
365 })
366 } else {
367 Poll::Pending
368 };
369 }
370 }
371 }
372 }
373}
374
375impl<S, U> DispatcherInner<S, U>
376where
377 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
378 U: Decoder + Encoder + 'static,
379{
380 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>, nowait: bool) {
381 let mut fut = if nowait {
382 self.shared.service.call_nowait(item)
383 } else {
384 self.shared.service.call(item)
385 };
386 let inflight = self.shared.inflight.get() + 1;
387 self.shared.inflight.set(inflight);
388 if inflight == 1 {
389 self.shared.remove_flags(Flags::IDLE);
390 }
391
392 if self.response.is_none() {
394 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
395 self.shared.handle_result(result, &self.shared.io, false);
396 } else {
397 self.response = Some(fut);
398 }
399 } else {
400 let shared = self.shared.clone();
401 spawn(async move {
402 let result = fut.await;
403 shared.handle_result(result, &shared.io, true);
404 });
405 }
406 }
407
408 fn check_error(&mut self) -> PollService<U> {
409 if let Some(err) = self.shared.error.take() {
411 log::trace!(
412 "{}: Error occured, stopping dispatcher",
413 self.shared.io.tag()
414 );
415 self.st = DispatcherState::Stop;
416
417 match err {
418 DispatcherError::Encoder(err) => {
419 PollService::Item(DispatchItem::Stop(Reason::Encoder(err)))
420 }
421 DispatcherError::Service(err) => {
422 self.error = Some(err);
423 PollService::Continue
424 }
425 }
426 } else {
427 PollService::Ready
428 }
429 }
430
431 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
432 match self.shared.service.poll_ready(cx) {
434 Poll::Ready(Ok(())) => Poll::Ready(self.check_error()),
435 Poll::Pending => {
437 log::trace!(
438 "{}: Service is not ready, register dispatcher",
439 self.shared.io.tag()
440 );
441
442 self.shared
444 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
445 self.shared.io.stop_timer();
446
447 match ready!(self.shared.io.poll_read_pause(cx)) {
448 IoStatusUpdate::KeepAlive => {
449 if self.shared.contains(Flags::KA_ENABLED) {
450 log::trace!(
451 "{}: Keep-alive error, stopping dispatcher during pause",
452 self.shared.io.tag()
453 );
454 self.st = DispatcherState::Stop;
455 Poll::Ready(PollService::ItemWait(DispatchItem::Stop(
456 Reason::KeepAliveTimeout,
457 )))
458 } else {
459 Poll::Ready(PollService::Continue)
461 }
462 }
463 IoStatusUpdate::PeerGone(err) => {
464 log::trace!(
465 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
466 self.shared.io.tag(),
467 err
468 );
469 self.st = DispatcherState::Stop;
470 Poll::Ready(PollService::ItemWait(DispatchItem::Stop(Reason::Io(
471 err,
472 ))))
473 }
474 IoStatusUpdate::WriteBackpressure => {
475 self.st = DispatcherState::Backpressure;
476 Poll::Ready(PollService::ItemWait(DispatchItem::Control(
477 Control::WBackPressureEnabled,
478 )))
479 }
480 }
481 }
482 Poll::Ready(Err(err)) => {
484 log::trace!(
485 "{}: Service readiness check failed, stopping",
486 self.shared.io.tag()
487 );
488 self.st = DispatcherState::Stop;
489 self.error = Some(err);
490 self.shared.insert_flags(Flags::READY_ERR);
491 Poll::Ready(PollService::Continue)
492 }
493 }
494 }
495
496 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
497 if decoded.item.is_some() {
499 self.read_remains = 0;
500 self.shared
501 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
502 } else if self.shared.contains(Flags::READ_TIMEOUT) {
503 self.read_remains = decoded.remains as u32;
505 } else if self.read_remains == 0 && decoded.remains == 0 {
506 if self.shared.contains(Flags::KA_ENABLED)
508 && !self.shared.contains(Flags::KA_TIMEOUT)
509 {
510 log::trace!(
511 "{}: Start keep-alive timer {:?}",
512 self.shared.io.tag(),
513 self.shared.io.cfg().keepalive_timeout()
514 );
515 self.shared.insert_flags(Flags::KA_TIMEOUT);
516 self.shared
517 .io
518 .start_timer(self.shared.io.cfg().keepalive_timeout());
519 }
520 } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
521 self.shared.insert_flags(Flags::READ_TIMEOUT);
524
525 self.read_remains = decoded.remains as u32;
526 self.read_remains_prev = 0;
527 self.read_max_timeout = params.max_timeout;
528 self.shared.io.start_timer(params.timeout);
529 }
530 }
531
532 fn handle_timeout(&mut self) -> Result<(), Reason<U>> {
533 if self.shared.contains(Flags::READ_TIMEOUT) {
535 if let Some(params) = self.shared.io.cfg().frame_read_rate() {
536 let total = self.read_remains - self.read_remains_prev;
537
538 if total > params.rate {
540 self.read_remains_prev = self.read_remains;
541 self.read_remains = 0;
542
543 if !params.max_timeout.is_zero() {
544 self.read_max_timeout = Seconds(
545 self.read_max_timeout.0.saturating_sub(params.timeout.0),
546 );
547 }
548
549 if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
550 log::trace!(
551 "{}: Frame read rate {:?}, extend timer",
552 self.shared.io.tag(),
553 total
554 );
555 self.shared.io.start_timer(params.timeout);
556 return Ok(());
557 }
558 log::trace!(
559 "{}: Max payload timeout has been reached",
560 self.shared.io.tag()
561 );
562 }
563 Err(Reason::ReadTimeout)
564 } else {
565 Ok(())
566 }
567 } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
568 log::trace!(
569 "{}: Keep-alive error, stopping dispatcher",
570 self.shared.io.tag()
571 );
572 Err(Reason::KeepAliveTimeout)
573 } else {
574 Ok(())
575 }
576 }
577}
578
579impl<U> fmt::Debug for DispatchItem<U>
580where
581 U: Encoder + Decoder,
582 <U as Decoder>::Item: fmt::Debug,
583{
584 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
585 match *self {
586 DispatchItem::Item(ref item) => {
587 write!(fmt, "DispatchItem::Item({item:?})")
588 }
589 DispatchItem::Control(ref e) => {
590 write!(fmt, "DispatchItem::Control({e:?})")
591 }
592 DispatchItem::Stop(ref e) => {
593 write!(fmt, "DispatchItem::Stop({e:?})")
594 }
595 }
596 }
597}
598
599impl<U> fmt::Debug for Reason<U>
600where
601 U: Encoder + Decoder,
602 <U as Encoder>::Error: fmt::Debug,
603 <U as Decoder>::Error: fmt::Debug,
604{
605 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
606 match *self {
607 Reason::Io(ref err) => {
608 write!(fmt, "Reason::Io({err:?})")
609 }
610 Reason::Encoder(ref err) => {
611 write!(fmt, "Reason::Encoder({err:?})")
612 }
613 Reason::Decoder(ref err) => {
614 write!(fmt, "Reason::Decoder({err:?})")
615 }
616 Reason::KeepAliveTimeout => {
617 write!(fmt, "Reason::KeepAliveTimeout")
618 }
619 Reason::ReadTimeout => {
620 write!(fmt, "Reason::ReadTimeout")
621 }
622 }
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
629 use std::{cell::RefCell, io};
630
631 use ntex_bytes::{Bytes, BytesMut};
632 use ntex_codec::BytesCodec;
633 use ntex_io::{Flags, Io, IoConfig, IoRef, testing::IoTest};
634 use ntex_service::{ServiceCtx, cfg::SharedCfg};
635 use ntex_util::{time::Millis, time::sleep};
636 use rand::Rng;
637
638 use super::*;
639
640 pub(crate) struct State(IoRef);
641
642 impl State {
643 fn flags(&self) -> Flags {
644 self.0.flags()
645 }
646
647 fn io(&self) -> &IoRef {
648 &self.0
649 }
650
651 fn close(&self) {
652 self.0.close();
653 }
654 }
655
656 #[derive(Copy, Clone)]
657 struct BCodec(usize);
658
659 impl Encoder for BCodec {
660 type Item = Bytes;
661 type Error = io::Error;
662
663 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
664 dst.extend_from_slice(&item[..]);
665 Ok(())
666 }
667 }
668
669 impl Decoder for BCodec {
670 type Item = Bytes;
671 type Error = io::Error;
672
673 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
674 if src.len() < self.0 {
675 Ok(None)
676 } else {
677 Ok(Some(src.split_to(self.0)))
678 }
679 }
680 }
681
682 impl<S, U> Dispatcher<S, U>
683 where
684 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
685 U: Decoder + Encoder + 'static,
686 {
687 pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
689 let flags = if io.cfg().keepalive_timeout().is_zero() {
690 super::Flags::empty()
691 } else {
692 super::Flags::KA_ENABLED
693 };
694
695 let inner = State(io.get_ref());
696 io.start_timer(Seconds::ONE);
697
698 let shared = Rc::new(DispatcherShared {
699 codec,
700 io: io.into(),
701 flags: Cell::new(flags),
702 error: Cell::new(None),
703 inflight: Cell::new(0),
704 service: Pipeline::new(service).bind(),
705 });
706
707 (
708 Dispatcher {
709 inner: DispatcherInner {
710 shared,
711 error: None,
712 st: DispatcherState::Processing,
713 response: None,
714 read_remains: 0,
715 read_remains_prev: 0,
716 read_max_timeout: Seconds::ZERO,
717 },
718 },
719 inner,
720 )
721 }
722 }
723
724 #[ntex::test]
725 async fn basics() {
726 let (client, server) = IoTest::create();
727 client.remote_buffer_cap(1024);
728 client.write("GET /test HTTP/1\r\n\r\n");
729
730 let (disp, _) = Dispatcher::debug(
731 Io::from(server),
732 BytesCodec,
733 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
734 sleep(Millis(50)).await;
735 if let DispatchItem::Item(msg) = msg {
736 Ok::<_, ()>(Some(msg))
737 } else {
738 panic!()
739 }
740 }),
741 );
742 spawn(async move {
743 let _ = disp.await;
744 });
745
746 sleep(Millis(25)).await;
747 let buf = client.read().await.unwrap();
748 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
749
750 client.write("GET /test HTTP/1\r\n\r\n");
751 let buf = client.read().await.unwrap();
752 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
753
754 client.close().await;
755 assert!(client.is_server_dropped());
756
757 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
758 }
759
760 #[ntex::test]
761 async fn sink() {
762 let (client, server) = IoTest::create();
763 client.remote_buffer_cap(1024);
764 client.write("GET /test HTTP/1\r\n\r\n");
765
766 let (disp, st) = Dispatcher::debug(
767 Io::from(server),
768 BytesCodec,
769 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
770 if let DispatchItem::Item(msg) = msg {
771 Ok::<_, ()>(Some(msg))
772 } else if let DispatchItem::Stop(_) = msg {
773 Ok(None)
774 } else {
775 panic!()
776 }
777 }),
778 );
779 spawn(async move {
780 let _ = disp.await;
781 });
782
783 let buf = client.read().await.unwrap();
784 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
785
786 assert!(
787 st.io()
788 .encode(Bytes::from_static(b"test"), &BytesCodec)
789 .is_ok()
790 );
791 let buf = client.read().await.unwrap();
792 assert_eq!(buf, Bytes::from_static(b"test"));
793
794 st.close();
795 sleep(Millis(1500)).await;
796 assert!(client.is_server_dropped());
797 }
798
799 #[ntex::test]
800 async fn err_in_service() {
801 let (client, server) = IoTest::create();
802 client.remote_buffer_cap(0);
803 client.write("GET /test HTTP/1\r\n\r\n");
804
805 let (disp, state) = Dispatcher::debug(
806 Io::from(server),
807 BytesCodec,
808 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
809 Err::<Option<Bytes>, _>(())
810 }),
811 );
812 state
813 .io()
814 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
815 .unwrap();
816 spawn(async move {
817 let _ = disp.await;
818 });
819
820 client.remote_buffer_cap(1024);
822 let buf = client.read().await.unwrap();
823 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
824
825 sleep(Millis(250)).await;
827 assert!(client.is_closed());
828
829 client.close().await;
831
832 assert!(client.is_server_dropped());
834 }
835
836 #[ntex::test]
837 #[allow(clippy::items_after_statements)]
838 async fn err_in_service_ready() {
839 let (client, server) = IoTest::create();
840 client.remote_buffer_cap(0);
841 client.write("GET /test HTTP/1\r\n\r\n");
842
843 let counter = Rc::new(Cell::new(0));
844
845 struct Srv(Rc<Cell<usize>>);
846
847 impl Service<DispatchItem<BytesCodec>> for Srv {
848 type Response = Option<Response<BytesCodec>>;
849 type Error = &'static str;
850
851 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
852 self.0.set(self.0.get() + 1);
853 Err("test")
854 }
855
856 async fn call(
857 &self,
858 _: DispatchItem<BytesCodec>,
859 _: ServiceCtx<'_, Self>,
860 ) -> Result<Self::Response, Self::Error> {
861 Ok(None)
862 }
863 }
864
865 let (disp, state) =
866 Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
867 spawn(async move {
868 let res = disp.await;
869 assert_eq!(res, Err("test"));
870 });
871
872 state
873 .io()
874 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
875 .unwrap();
876
877 client.remote_buffer_cap(1024);
879 let buf = client.read().await.unwrap();
880 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
881
882 sleep(Millis(250)).await;
884 assert!(client.is_closed());
885
886 client.close().await;
888 assert!(client.is_server_dropped());
889
890 assert_eq!(counter.get(), 1);
892 }
893
894 #[ntex::test]
895 async fn write_backpressure() {
896 let (client, server) = IoTest::create();
897 client.remote_buffer_cap(0);
899 client.write("GET /test HTTP/1\r\n\r\n");
900
901 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
902 let data2 = data.clone();
903
904 let io = Io::new(
905 server,
906 SharedCfg::new("TEST").add(
907 IoConfig::new()
908 .set_read_buf(8 * 1024, 1024, 16)
909 .set_write_buf(16 * 1024, 1024, 16),
910 ),
911 );
912
913 let (disp, state) = Dispatcher::debug(
914 io,
915 BytesCodec,
916 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
917 let data = data2.clone();
918 async move {
919 match msg {
920 DispatchItem::Item(_) => {
921 data.lock().unwrap().borrow_mut().push(0);
922 let bytes = rand::rng()
923 .sample_iter(&rand::distr::Alphanumeric)
924 .take(65_536)
925 .map(char::from)
926 .collect::<String>();
927 return Ok::<_, ()>(Some(Bytes::from(bytes)));
928 }
929 DispatchItem::Control(Control::WBackPressureEnabled) => {
930 data.lock().unwrap().borrow_mut().push(1);
931 }
932 DispatchItem::Control(Control::WBackPressureDisabled) => {
933 data.lock().unwrap().borrow_mut().push(2);
934 }
935 _ => (),
936 }
937 Ok(None)
938 }
939 }),
940 );
941
942 spawn(async move {
943 let _ = disp.await;
944 });
945
946 let buf = client.read_any();
947 assert_eq!(buf, Bytes::from_static(b""));
948 client.write("GET /test HTTP/1\r\n\r\n");
949 sleep(Millis(25)).await;
950
951 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
953
954 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
956
957 client.remote_buffer_cap(10240);
958 sleep(Millis(50)).await;
959 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
960
961 client.remote_buffer_cap(48056);
962 sleep(Millis(50)).await;
963 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
964
965 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
967 }
968
969 #[ntex::test]
970 async fn disconnect_during_read_backpressure() {
971 let (client, server) = IoTest::create();
972 client.remote_buffer_cap(0);
973
974 let (disp, state) = Dispatcher::debug(
975 Io::new(
976 server,
977 SharedCfg::new("TEST").add(
978 IoConfig::new()
979 .set_keepalive_timeout(Seconds::ZERO)
980 .set_read_buf(1024, 512, 16),
981 ),
982 ),
983 BytesCodec,
984 ntex_util::services::inflight::InFlightService::new(
985 1,
986 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
987 if let DispatchItem::Item(_) = msg {
988 sleep(Millis(500)).await;
989 Ok::<_, ()>(None)
990 } else {
991 Ok(None)
992 }
993 }),
994 ),
995 );
996
997 let (tx, rx) = ntex::channel::oneshot::channel();
998 ntex::rt::spawn(async move {
999 let _ = disp.await;
1000 let _ = tx.send(());
1001 });
1002
1003 let bytes = rand::rng()
1004 .sample_iter(&rand::distr::Alphanumeric)
1005 .take(1024)
1006 .map(char::from)
1007 .collect::<String>();
1008 client.write(bytes.clone());
1009 sleep(Millis(25)).await;
1010 client.write(bytes);
1011 sleep(Millis(25)).await;
1012
1013 state.close();
1015 let _ = rx.recv().await;
1016 }
1017
1018 #[ntex::test]
1019 async fn keepalive() {
1020 let (client, server) = IoTest::create();
1021 client.remote_buffer_cap(1024);
1022 client.write("GET /test HTTP/1\r\n\r\n");
1023
1024 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1025 let data2 = data.clone();
1026
1027 let cfg = SharedCfg::new("DBG").add(
1028 IoConfig::new()
1029 .set_disconnect_timeout(Seconds(1))
1030 .set_keepalive_timeout(Seconds(1)),
1031 );
1032
1033 let (disp, state) = Dispatcher::debug(
1034 Io::new(server, cfg),
1035 BytesCodec,
1036 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1037 let data = data2.clone();
1038 async move {
1039 match msg {
1040 DispatchItem::Item(bytes) => {
1041 data.lock().unwrap().borrow_mut().push(0);
1042 return Ok::<_, ()>(Some(bytes));
1043 }
1044 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1045 data.lock().unwrap().borrow_mut().push(1);
1046 }
1047 _ => (),
1048 }
1049 Ok(None)
1050 }
1051 }),
1052 );
1053 spawn(async move {
1054 let _ = disp.await;
1055 });
1056
1057 let buf = client.read().await.unwrap();
1058 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1059 sleep(Millis(2000)).await;
1060
1061 let flags = state.flags();
1063 assert!(flags.contains(Flags::IO_STOPPING));
1064 assert!(client.is_closed());
1065 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1066 }
1067
1068 #[ntex::test]
1069 async fn keepalive2() {
1070 let (client, server) = IoTest::create();
1071 client.remote_buffer_cap(1024);
1072
1073 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1074 let data2 = data.clone();
1075
1076 let cfg = SharedCfg::new("DBG").add(
1077 IoConfig::new()
1078 .set_keepalive_timeout(Seconds(1))
1079 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1080 );
1081
1082 let (disp, state) = Dispatcher::debug(
1083 Io::new(server, cfg),
1084 BCodec(8),
1085 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1086 let data = data2.clone();
1087 async move {
1088 match msg {
1089 DispatchItem::Item(bytes) => {
1090 data.lock().unwrap().borrow_mut().push(0);
1091 return Ok::<_, ()>(Some(bytes));
1092 }
1093 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1094 data.lock().unwrap().borrow_mut().push(1);
1095 }
1096 _ => (),
1097 }
1098 Ok(None)
1099 }
1100 }),
1101 );
1102 spawn(async move {
1103 let _ = disp.await;
1104 });
1105
1106 client.write("12345678");
1107 let buf = client.read().await.unwrap();
1108 assert_eq!(buf, Bytes::from_static(b"12345678"));
1109 sleep(Millis(2000)).await;
1110
1111 let flags = state.flags();
1113 assert!(flags.contains(Flags::IO_STOPPING));
1114 assert!(client.is_closed());
1115 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1116 }
1117
1118 #[ntex::test]
1120 async fn keepalive3() {
1121 let (client, server) = IoTest::create();
1122 client.remote_buffer_cap(1024);
1123
1124 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1125 let data2 = data.clone();
1126
1127 let cfg = SharedCfg::new("DBG").add(
1128 IoConfig::new()
1129 .set_keepalive_timeout(Seconds(2))
1130 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1131 );
1132
1133 let (disp, _) = Dispatcher::debug(
1134 Io::new(server, cfg),
1135 BCodec(1),
1136 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1137 let data = data2.clone();
1138 async move {
1139 match msg {
1140 DispatchItem::Item(bytes) => {
1141 data.lock().unwrap().borrow_mut().push(0);
1142 return Ok::<_, ()>(Some(bytes));
1143 }
1144 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1145 data.lock().unwrap().borrow_mut().push(1);
1146 }
1147 _ => (),
1148 }
1149 Ok(None)
1150 }
1151 }),
1152 );
1153 spawn(async move {
1154 let _ = disp.await;
1155 });
1156
1157 client.write("1");
1158 let buf = client.read().await.unwrap();
1159 assert_eq!(buf, Bytes::from_static(b"1"));
1160 sleep(Millis(750)).await;
1161
1162 client.write("2");
1163 let buf = client.read().await.unwrap();
1164 assert_eq!(buf, Bytes::from_static(b"2"));
1165
1166 sleep(Millis(750)).await;
1167 client.write("3");
1168 let buf = client.read().await.unwrap();
1169 assert_eq!(buf, Bytes::from_static(b"3"));
1170
1171 sleep(Millis(750)).await;
1172 assert!(!client.is_closed());
1173 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1174 }
1175
1176 #[ntex::test]
1177 async fn read_timeout() {
1178 let (client, server) = IoTest::create();
1179 client.remote_buffer_cap(1024);
1180
1181 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1182 let data2 = data.clone();
1183
1184 let io = Io::new(
1185 server,
1186 SharedCfg::new("TEST").add(
1187 IoConfig::new()
1188 .set_keepalive_timeout(Seconds::ZERO)
1189 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1190 ),
1191 );
1192
1193 let (disp, state) = Dispatcher::debug(
1194 io,
1195 BCodec(8),
1196 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1197 let data = data2.clone();
1198 async move {
1199 match msg {
1200 DispatchItem::Item(bytes) => {
1201 data.lock().unwrap().borrow_mut().push(0);
1202 return Ok::<_, ()>(Some(bytes));
1203 }
1204 DispatchItem::Stop(Reason::ReadTimeout) => {
1205 data.lock().unwrap().borrow_mut().push(1);
1206 }
1207 _ => (),
1208 }
1209 Ok(None)
1210 }
1211 }),
1212 );
1213 spawn(async move {
1214 let _ = disp.await;
1215 });
1216
1217 client.write("12345678");
1218 let buf = client.read().await.unwrap();
1219 assert_eq!(buf, Bytes::from_static(b"12345678"));
1220
1221 client.write("1");
1222 sleep(Millis(1000)).await;
1223 assert!(!state.flags().contains(Flags::IO_STOPPING));
1224 client.write("23");
1225 sleep(Millis(1000)).await;
1226 assert!(!state.flags().contains(Flags::IO_STOPPING));
1227 client.write("4");
1228 sleep(Millis(2000)).await;
1229
1230 assert!(state.flags().contains(Flags::IO_STOPPING));
1232 assert!(client.is_closed());
1233 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1234 }
1235
1236 #[ntex::test]
1237 async fn idle_timeout() {
1238 let (client, server) = IoTest::create();
1239 client.remote_buffer_cap(1024);
1240
1241 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1242 let data2 = data.clone();
1243
1244 let io = Io::new(
1245 server,
1246 SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1247 );
1248 let ioref = io.get_ref();
1249
1250 let (disp, state) = Dispatcher::debug(
1251 io,
1252 BCodec(1),
1253 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1254 let ioref = ioref.clone();
1255 ntex::rt::spawn(async move {
1256 sleep(Millis(500)).await;
1257 ioref.notify_timeout();
1258 });
1259 let data = data2.clone();
1260 async move {
1261 match msg {
1262 DispatchItem::Item(bytes) => {
1263 data.lock().unwrap().borrow_mut().push(0);
1264 return Ok::<_, ()>(Some(bytes));
1265 }
1266 DispatchItem::Stop(Reason::ReadTimeout) => {
1267 data.lock().unwrap().borrow_mut().push(1);
1268 }
1269 _ => (),
1270 }
1271 Ok(None)
1272 }
1273 }),
1274 );
1275 spawn(async move {
1276 let _ = disp.await;
1277 });
1278
1279 client.write("1");
1280 let buf = client.read().await.unwrap();
1281 assert_eq!(buf, Bytes::from_static(b"1"));
1282
1283 sleep(Millis(1000)).await;
1284 assert!(state.flags().contains(Flags::IO_STOPPING));
1285 assert!(client.is_closed());
1286 }
1287
1288 #[ntex::test]
1289 async fn unhandled_data() {
1290 let handled = Arc::new(AtomicBool::new(false));
1291 let handled2 = handled.clone();
1292
1293 let (client, server) = IoTest::create();
1294 client.remote_buffer_cap(1024);
1295 client.write("GET /test HTTP/1\r\n\r\n");
1296
1297 let (disp, _) = Dispatcher::debug(
1298 Io::from(server),
1299 BytesCodec,
1300 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1301 handled2.store(true, Relaxed);
1302 async move {
1303 sleep(Millis(50)).await;
1304 if let DispatchItem::Item(msg) = msg {
1305 Ok::<_, ()>(Some(msg))
1306 } else if let DispatchItem::Stop(_) = msg {
1307 Ok::<_, ()>(None)
1308 } else {
1309 panic!()
1310 }
1311 }
1312 }),
1313 );
1314 client.close().await;
1315 spawn(async move {
1316 let _ = disp.await;
1317 });
1318 sleep(Millis(50)).await;
1319
1320 assert!(handled.load(Relaxed));
1321 }
1322}