1#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, fmt, future::Future, io, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_io::{Decoded, IoBoxed, IoStatusUpdate, RecvError};
8use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
9use ntex_util::{future::Either, spawn, time::Seconds};
10
11type Response<U> = <U as Encoder>::Item;
12
13pub enum DispatchItem<U: Encoder + Decoder> {
15 Item(<U as Decoder>::Item),
17 Control(Control),
19 Stop(Reason<U>),
21}
22
23#[derive(Copy, Clone, Debug, PartialEq, Eq)]
24pub enum Control {
26 WBackPressureEnabled,
28 WBackPressureDisabled,
30}
31
32pub enum Reason<U: Encoder + Decoder> {
34 Io(Option<io::Error>),
36 Encoder(<U as Encoder>::Error),
38 Decoder(<U as Decoder>::Error),
40 KeepAliveTimeout,
42 ReadTimeout,
44}
45
46pin_project_lite::pin_project! {
47 pub struct Dispatcher<S, U>
50 where
51 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
52 U: Encoder,
53 U: Decoder,
54 U: 'static,
55 {
56 inner: DispatcherInner<S, U>,
57 }
58}
59
60bitflags::bitflags! {
61 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
62 struct Flags: u8 {
63 const READY_ERR = 0b0000001;
64 const IO_ERR = 0b0000010;
65 const KA_ENABLED = 0b0000100;
66 const KA_TIMEOUT = 0b0001000;
67 const READ_TIMEOUT = 0b0010000;
68 const IDLE = 0b0100000;
69 }
70}
71
72struct DispatcherInner<S, U>
73where
74 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
75 U: Encoder + Decoder + 'static,
76{
77 st: DispatcherState,
78 error: Option<S::Error>,
79 shared: Rc<DispatcherShared<S, U>>,
80 response: Option<PipelineCall<S, DispatchItem<U>>>,
81 read_remains: u32,
82 read_remains_prev: u32,
83 read_max_timeout: Seconds,
84}
85
86pub(crate) struct DispatcherShared<S, U>
87where
88 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
89 U: Encoder + Decoder,
90{
91 io: IoBoxed,
92 codec: U,
93 service: PipelineBinding<S, DispatchItem<U>>,
94 flags: Cell<Flags>,
95 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
96 inflight: Cell<u32>,
97}
98
99#[derive(Copy, Clone, Debug)]
100enum DispatcherState {
101 Processing,
102 Backpressure,
103 Stop,
104 Shutdown,
105}
106
107#[derive(Debug)]
108enum DispatcherError<S, U> {
109 Encoder(U),
110 Service(S),
111}
112
113enum PollService<U: Encoder + Decoder> {
114 Item(DispatchItem<U>),
115 ItemWait(DispatchItem<U>),
116 Continue,
117 Ready,
118}
119
120impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
121 fn from(err: Either<S, U>) -> Self {
122 match err {
123 Either::Left(err) => DispatcherError::Service(err),
124 Either::Right(err) => DispatcherError::Encoder(err),
125 }
126 }
127}
128
129impl<S, U> Dispatcher<S, U>
130where
131 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
132 U: Decoder + Encoder + 'static,
133{
134 pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
136 where
137 IoBoxed: From<Io>,
138 F: IntoService<S, DispatchItem<U>>,
139 {
140 let io = IoBoxed::from(io);
141 let flags = if io.cfg().keepalive_timeout().is_zero() {
142 Flags::empty()
143 } else {
144 Flags::KA_ENABLED
145 };
146
147 let shared = Rc::new(DispatcherShared {
148 io,
149 codec,
150 flags: Cell::new(flags),
151 error: Cell::new(None),
152 inflight: Cell::new(0),
153 service: Pipeline::new(service.into_service()).bind(),
154 });
155
156 Dispatcher {
157 inner: DispatcherInner {
158 shared,
159 response: None,
160 error: None,
161 read_remains: 0,
162 read_remains_prev: 0,
163 read_max_timeout: Seconds::ZERO,
164 st: DispatcherState::Processing,
165 },
166 }
167 }
168}
169
170impl<S, U> DispatcherShared<S, U>
171where
172 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
173 U: Encoder + Decoder,
174{
175 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
176 match item {
177 Ok(Some(val)) => {
178 if let Err(err) = io.encode(val, &self.codec) {
179 self.error.set(Some(DispatcherError::Encoder(err)))
180 }
181 }
182 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
183 Ok(None) => (),
184 }
185 let inflight = self.inflight.get() - 1;
186 self.inflight.set(inflight);
187 if inflight == 0 {
188 self.insert_flags(Flags::IDLE);
189 }
190 if wake {
191 io.wake();
192 }
193 }
194
195 fn contains(&self, f: Flags) -> bool {
196 self.flags.get().intersects(f)
197 }
198
199 fn insert_flags(&self, f: Flags) {
200 let mut flags = self.flags.get();
201 flags.insert(f);
202 self.flags.set(flags);
203 }
204
205 fn remove_flags(&self, f: Flags) -> bool {
206 let mut flags = self.flags.get();
207 if flags.intersects(f) {
208 flags.remove(f);
209 self.flags.set(flags);
210 true
211 } else {
212 false
213 }
214 }
215}
216
217impl<S, U> Future for Dispatcher<S, U>
218where
219 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
220 U: Decoder + Encoder + 'static,
221{
222 type Output = Result<(), S::Error>;
223
224 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225 let mut this = self.as_mut().project();
226 let slf = &mut this.inner;
227
228 if let Some(fut) = slf.response.as_mut()
230 && let Poll::Ready(item) = Pin::new(fut).poll(cx)
231 {
232 slf.shared.handle_result(item, &slf.shared.io, false);
233 slf.response = None;
234 }
235
236 loop {
237 match slf.st {
238 DispatcherState::Processing => {
239 let (item, nowait) = match ready!(slf.poll_service(cx)) {
240 PollService::Ready => {
241 match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
243 Ok(decoded) => {
244 slf.update_timer(&decoded);
245 if let Some(el) = decoded.item {
246 (DispatchItem::Item(el), true)
247 } else {
248 return Poll::Pending;
249 }
250 }
251 Err(RecvError::KeepAlive) => {
252 if let Err(ctl) = slf.handle_timeout() {
253 slf.st = DispatcherState::Stop;
254 (DispatchItem::Stop(ctl), true)
255 } else {
256 continue;
257 }
258 }
259 Err(RecvError::WriteBackpressure) => {
260 slf.st = DispatcherState::Backpressure;
262 (
263 DispatchItem::Control(
264 Control::WBackPressureEnabled,
265 ),
266 true,
267 )
268 }
269 Err(RecvError::Decoder(err)) => {
270 log::trace!(
271 "{}: Decoder error, stopping dispatcher: {:?}",
272 slf.shared.io.tag(),
273 err
274 );
275 slf.st = DispatcherState::Stop;
276 (DispatchItem::Stop(Reason::Decoder(err)), true)
277 }
278 Err(RecvError::PeerGone(err)) => {
279 log::trace!(
280 "{}: Peer is gone, stopping dispatcher: {:?}",
281 slf.shared.io.tag(),
282 err
283 );
284 slf.st = DispatcherState::Stop;
285 (DispatchItem::Stop(Reason::Io(err)), true)
286 }
287 }
288 }
289 PollService::Item(item) => (item, true),
290 PollService::ItemWait(item) => (item, false),
291 PollService::Continue => continue,
292 };
293
294 slf.call_service(cx, item, nowait);
295 }
296 DispatcherState::Backpressure => {
298 match ready!(slf.poll_service(cx)) {
299 PollService::Ready => (),
300 PollService::Item(item) => slf.call_service(cx, item, true),
301 PollService::ItemWait(item) => slf.call_service(cx, item, false),
302 PollService::Continue => continue,
303 };
304
305 let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
306 {
307 slf.st = DispatcherState::Stop;
308 DispatchItem::Stop(Reason::Io(Some(err)))
309 } else {
310 slf.st = DispatcherState::Processing;
311 DispatchItem::Control(Control::WBackPressureDisabled)
312 };
313 slf.call_service(cx, item, false);
314 }
315 DispatcherState::Stop => {
317 slf.shared.io.stop_timer();
318
319 if !slf.shared.contains(Flags::READY_ERR)
321 && let Poll::Ready(res) = slf.shared.service.poll_ready(cx)
322 && res.is_err()
323 {
324 slf.shared.insert_flags(Flags::READY_ERR);
325 }
326
327 if slf.shared.inflight.get() == 0 {
328 if slf.shared.io.poll_shutdown(cx).is_ready() {
329 slf.st = DispatcherState::Shutdown;
330 continue;
331 }
332 } else if !slf.shared.contains(Flags::IO_ERR) {
333 match ready!(slf.shared.io.poll_status_update(cx)) {
334 IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
335 slf.shared.insert_flags(Flags::IO_ERR);
336 continue;
337 }
338 IoStatusUpdate::WriteBackpressure => {
339 if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
340 slf.shared.insert_flags(Flags::IO_ERR);
341 }
342 continue;
343 }
344 }
345 } else {
346 slf.shared.io.poll_dispatch(cx);
347 }
348 return Poll::Pending;
349 }
350 DispatcherState::Shutdown => {
352 return if slf.shared.service.poll_shutdown(cx).is_ready() {
353 log::trace!(
354 "{}: Service shutdown is completed, stop",
355 slf.shared.io.tag()
356 );
357
358 Poll::Ready(if let Some(err) = slf.error.take() {
359 Err(err)
360 } else {
361 Ok(())
362 })
363 } else {
364 Poll::Pending
365 };
366 }
367 }
368 }
369 }
370}
371
372impl<S, U> DispatcherInner<S, U>
373where
374 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
375 U: Decoder + Encoder + 'static,
376{
377 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>, nowait: bool) {
378 let mut fut = if nowait {
379 self.shared.service.call_nowait(item)
380 } else {
381 self.shared.service.call(item)
382 };
383 let inflight = self.shared.inflight.get() + 1;
384 self.shared.inflight.set(inflight);
385 if inflight == 1 {
386 self.shared.remove_flags(Flags::IDLE);
387 }
388
389 if self.response.is_none() {
391 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
392 self.shared.handle_result(result, &self.shared.io, false);
393 } else {
394 self.response = Some(fut);
395 }
396 } else {
397 let shared = self.shared.clone();
398 spawn(async move {
399 let result = fut.await;
400 shared.handle_result(result, &shared.io, true);
401 });
402 }
403 }
404
405 fn check_error(&mut self) -> PollService<U> {
406 if let Some(err) = self.shared.error.take() {
408 log::trace!(
409 "{}: Error occured, stopping dispatcher",
410 self.shared.io.tag()
411 );
412 self.st = DispatcherState::Stop;
413
414 match err {
415 DispatcherError::Encoder(err) => {
416 PollService::Item(DispatchItem::Stop(Reason::Encoder(err)))
417 }
418 DispatcherError::Service(err) => {
419 self.error = Some(err);
420 PollService::Continue
421 }
422 }
423 } else {
424 PollService::Ready
425 }
426 }
427
428 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
429 match self.shared.service.poll_ready(cx) {
431 Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
432 Poll::Pending => {
434 log::trace!(
435 "{}: Service is not ready, register dispatcher",
436 self.shared.io.tag()
437 );
438
439 self.shared
441 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
442 self.shared.io.stop_timer();
443
444 match ready!(self.shared.io.poll_read_pause(cx)) {
445 IoStatusUpdate::KeepAlive => {
446 log::trace!(
447 "{}: Keep-alive error, stopping dispatcher during pause",
448 self.shared.io.tag()
449 );
450 self.st = DispatcherState::Stop;
451 Poll::Ready(PollService::ItemWait(DispatchItem::Stop(
452 Reason::KeepAliveTimeout,
453 )))
454 }
455 IoStatusUpdate::PeerGone(err) => {
456 log::trace!(
457 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
458 self.shared.io.tag(),
459 err
460 );
461 self.st = DispatcherState::Stop;
462 Poll::Ready(PollService::ItemWait(DispatchItem::Stop(Reason::Io(
463 err,
464 ))))
465 }
466 IoStatusUpdate::WriteBackpressure => {
467 self.st = DispatcherState::Backpressure;
468 Poll::Ready(PollService::ItemWait(DispatchItem::Control(
469 Control::WBackPressureEnabled,
470 )))
471 }
472 }
473 }
474 Poll::Ready(Err(err)) => {
476 log::trace!(
477 "{}: Service readiness check failed, stopping",
478 self.shared.io.tag()
479 );
480 self.st = DispatcherState::Stop;
481 self.error = Some(err);
482 self.shared.insert_flags(Flags::READY_ERR);
483 Poll::Ready(PollService::Continue)
484 }
485 }
486 }
487
488 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
489 if decoded.item.is_some() {
491 self.read_remains = 0;
492 self.shared
493 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
494 } else if self.shared.contains(Flags::READ_TIMEOUT) {
495 self.read_remains = decoded.remains as u32;
497 } else if self.read_remains == 0 && decoded.remains == 0 {
498 if self.shared.contains(Flags::KA_ENABLED)
500 && !self.shared.contains(Flags::KA_TIMEOUT)
501 {
502 log::trace!(
503 "{}: Start keep-alive timer {:?}",
504 self.shared.io.tag(),
505 self.shared.io.cfg().keepalive_timeout()
506 );
507 self.shared.insert_flags(Flags::KA_TIMEOUT);
508 self.shared
509 .io
510 .start_timer(self.shared.io.cfg().keepalive_timeout());
511 }
512 } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
513 self.shared.insert_flags(Flags::READ_TIMEOUT);
516
517 self.read_remains = decoded.remains as u32;
518 self.read_remains_prev = 0;
519 self.read_max_timeout = params.max_timeout;
520 self.shared.io.start_timer(params.timeout);
521 }
522 }
523
524 fn handle_timeout(&mut self) -> Result<(), Reason<U>> {
525 if self.shared.contains(Flags::READ_TIMEOUT) {
527 if let Some(params) = self.shared.io.cfg().frame_read_rate() {
528 let total = self.read_remains - self.read_remains_prev;
529
530 if total > params.rate {
532 self.read_remains_prev = self.read_remains;
533 self.read_remains = 0;
534
535 if !params.max_timeout.is_zero() {
536 self.read_max_timeout = Seconds(
537 self.read_max_timeout.0.saturating_sub(params.timeout.0),
538 );
539 }
540
541 if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
542 log::trace!(
543 "{}: Frame read rate {:?}, extend timer",
544 self.shared.io.tag(),
545 total
546 );
547 self.shared.io.start_timer(params.timeout);
548 return Ok(());
549 }
550 log::trace!(
551 "{}: Max payload timeout has been reached",
552 self.shared.io.tag()
553 );
554 }
555 Err(Reason::ReadTimeout)
556 } else {
557 Ok(())
558 }
559 } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
560 log::trace!(
561 "{}: Keep-alive error, stopping dispatcher",
562 self.shared.io.tag()
563 );
564 Err(Reason::KeepAliveTimeout)
565 } else {
566 Ok(())
567 }
568 }
569}
570
571impl<U> fmt::Debug for DispatchItem<U>
572where
573 U: Encoder + Decoder,
574 <U as Decoder>::Item: fmt::Debug,
575{
576 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
577 match *self {
578 DispatchItem::Item(ref item) => {
579 write!(fmt, "DispatchItem::Item({item:?})")
580 }
581 DispatchItem::Control(ref e) => {
582 write!(fmt, "DispatchItem::Control({e:?})")
583 }
584 DispatchItem::Stop(ref e) => {
585 write!(fmt, "DispatchItem::Stop({e:?})")
586 }
587 }
588 }
589}
590
591impl<U> fmt::Debug for Reason<U>
592where
593 U: Encoder + Decoder,
594 <U as Encoder>::Error: fmt::Debug,
595 <U as Decoder>::Error: fmt::Debug,
596{
597 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
598 match *self {
599 Reason::Io(ref err) => {
600 write!(fmt, "Reason::Io({err:?})")
601 }
602 Reason::Encoder(ref err) => {
603 write!(fmt, "Reason::Encoder({err:?})")
604 }
605 Reason::Decoder(ref err) => {
606 write!(fmt, "Reason::Decoder({err:?})")
607 }
608 Reason::KeepAliveTimeout => {
609 write!(fmt, "Reason::KeepAliveTimeout")
610 }
611 Reason::ReadTimeout => {
612 write!(fmt, "Reason::ReadTimeout")
613 }
614 }
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
621 use std::{cell::RefCell, io};
622
623 use ntex_bytes::{Bytes, BytesMut};
624 use ntex_codec::BytesCodec;
625 use ntex_io::{Flags, Io, IoConfig, IoRef, testing::IoTest};
626 use ntex_service::{ServiceCtx, cfg::SharedCfg};
627 use ntex_util::{time::Millis, time::sleep};
628 use rand::Rng;
629
630 use super::*;
631
632 pub(crate) struct State(IoRef);
633
634 impl State {
635 fn flags(&self) -> Flags {
636 self.0.flags()
637 }
638
639 fn io(&self) -> &IoRef {
640 &self.0
641 }
642
643 fn close(&self) {
644 self.0.close();
645 }
646 }
647
648 #[derive(Copy, Clone)]
649 struct BCodec(usize);
650
651 impl Encoder for BCodec {
652 type Item = Bytes;
653 type Error = io::Error;
654
655 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
656 dst.extend_from_slice(&item[..]);
657 Ok(())
658 }
659 }
660
661 impl Decoder for BCodec {
662 type Item = Bytes;
663 type Error = io::Error;
664
665 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
666 if src.len() < self.0 {
667 Ok(None)
668 } else {
669 Ok(Some(src.split_to(self.0)))
670 }
671 }
672 }
673
674 impl<S, U> Dispatcher<S, U>
675 where
676 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
677 U: Decoder + Encoder + 'static,
678 {
679 pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
681 let flags = if io.cfg().keepalive_timeout().is_zero() {
682 super::Flags::empty()
683 } else {
684 super::Flags::KA_ENABLED
685 };
686
687 let inner = State(io.get_ref());
688 io.start_timer(Seconds::ONE);
689
690 let shared = Rc::new(DispatcherShared {
691 codec,
692 io: io.into(),
693 flags: Cell::new(flags),
694 error: Cell::new(None),
695 inflight: Cell::new(0),
696 service: Pipeline::new(service).bind(),
697 });
698
699 (
700 Dispatcher {
701 inner: DispatcherInner {
702 shared,
703 error: None,
704 st: DispatcherState::Processing,
705 response: None,
706 read_remains: 0,
707 read_remains_prev: 0,
708 read_max_timeout: Seconds::ZERO,
709 },
710 },
711 inner,
712 )
713 }
714 }
715
716 #[ntex::test]
717 async fn basics() {
718 let (client, server) = IoTest::create();
719 client.remote_buffer_cap(1024);
720 client.write("GET /test HTTP/1\r\n\r\n");
721
722 let (disp, _) = Dispatcher::debug(
723 Io::from(server),
724 BytesCodec,
725 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
726 sleep(Millis(50)).await;
727 if let DispatchItem::Item(msg) = msg {
728 Ok::<_, ()>(Some(msg))
729 } else {
730 panic!()
731 }
732 }),
733 );
734 spawn(async move {
735 let _ = disp.await;
736 });
737
738 sleep(Millis(25)).await;
739 let buf = client.read().await.unwrap();
740 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
741
742 client.write("GET /test HTTP/1\r\n\r\n");
743 let buf = client.read().await.unwrap();
744 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
745
746 client.close().await;
747 assert!(client.is_server_dropped());
748
749 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
750 }
751
752 #[ntex::test]
753 async fn sink() {
754 let (client, server) = IoTest::create();
755 client.remote_buffer_cap(1024);
756 client.write("GET /test HTTP/1\r\n\r\n");
757
758 let (disp, st) = Dispatcher::debug(
759 Io::from(server),
760 BytesCodec,
761 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
762 if let DispatchItem::Item(msg) = msg {
763 Ok::<_, ()>(Some(msg))
764 } else if let DispatchItem::Stop(_) = msg {
765 Ok(None)
766 } else {
767 panic!()
768 }
769 }),
770 );
771 spawn(async move {
772 let _ = disp.await;
773 });
774
775 let buf = client.read().await.unwrap();
776 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
777
778 assert!(
779 st.io()
780 .encode(Bytes::from_static(b"test"), &BytesCodec)
781 .is_ok()
782 );
783 let buf = client.read().await.unwrap();
784 assert_eq!(buf, Bytes::from_static(b"test"));
785
786 st.close();
787 sleep(Millis(1500)).await;
788 assert!(client.is_server_dropped());
789 }
790
791 #[ntex::test]
792 async fn err_in_service() {
793 let (client, server) = IoTest::create();
794 client.remote_buffer_cap(0);
795 client.write("GET /test HTTP/1\r\n\r\n");
796
797 let (disp, state) = Dispatcher::debug(
798 Io::from(server),
799 BytesCodec,
800 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
801 Err::<Option<Bytes>, _>(())
802 }),
803 );
804 state
805 .io()
806 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
807 .unwrap();
808 spawn(async move {
809 let _ = disp.await;
810 });
811
812 client.remote_buffer_cap(1024);
814 let buf = client.read().await.unwrap();
815 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
816
817 sleep(Millis(250)).await;
819 assert!(client.is_closed());
820
821 client.close().await;
823
824 assert!(client.is_server_dropped());
826 }
827
828 #[ntex::test]
829 async fn err_in_service_ready() {
830 let (client, server) = IoTest::create();
831 client.remote_buffer_cap(0);
832 client.write("GET /test HTTP/1\r\n\r\n");
833
834 let counter = Rc::new(Cell::new(0));
835
836 struct Srv(Rc<Cell<usize>>);
837
838 impl Service<DispatchItem<BytesCodec>> for Srv {
839 type Response = Option<Response<BytesCodec>>;
840 type Error = &'static str;
841
842 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
843 self.0.set(self.0.get() + 1);
844 Err("test")
845 }
846
847 async fn call(
848 &self,
849 _: DispatchItem<BytesCodec>,
850 _: ServiceCtx<'_, Self>,
851 ) -> Result<Self::Response, Self::Error> {
852 Ok(None)
853 }
854 }
855
856 let (disp, state) =
857 Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
858 spawn(async move {
859 let res = disp.await;
860 assert_eq!(res, Err("test"));
861 });
862
863 state
864 .io()
865 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
866 .unwrap();
867
868 client.remote_buffer_cap(1024);
870 let buf = client.read().await.unwrap();
871 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
872
873 sleep(Millis(250)).await;
875 assert!(client.is_closed());
876
877 client.close().await;
879 assert!(client.is_server_dropped());
880
881 assert_eq!(counter.get(), 1);
883 }
884
885 #[ntex::test]
886 async fn write_backpressure() {
887 let (client, server) = IoTest::create();
888 client.remote_buffer_cap(0);
890 client.write("GET /test HTTP/1\r\n\r\n");
891
892 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
893 let data2 = data.clone();
894
895 let io = Io::new(
896 server,
897 SharedCfg::new("TEST").add(
898 IoConfig::new()
899 .set_read_buf(8 * 1024, 1024, 16)
900 .set_write_buf(16 * 1024, 1024, 16),
901 ),
902 );
903
904 let (disp, state) = Dispatcher::debug(
905 io,
906 BytesCodec,
907 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
908 let data = data2.clone();
909 async move {
910 match msg {
911 DispatchItem::Item(_) => {
912 data.lock().unwrap().borrow_mut().push(0);
913 let bytes = rand::rng()
914 .sample_iter(&rand::distr::Alphanumeric)
915 .take(65_536)
916 .map(char::from)
917 .collect::<String>();
918 return Ok::<_, ()>(Some(Bytes::from(bytes)));
919 }
920 DispatchItem::Control(Control::WBackPressureEnabled) => {
921 data.lock().unwrap().borrow_mut().push(1);
922 }
923 DispatchItem::Control(Control::WBackPressureDisabled) => {
924 data.lock().unwrap().borrow_mut().push(2);
925 }
926 _ => (),
927 }
928 Ok(None)
929 }
930 }),
931 );
932
933 spawn(async move {
934 let _ = disp.await;
935 });
936
937 let buf = client.read_any();
938 assert_eq!(buf, Bytes::from_static(b""));
939 client.write("GET /test HTTP/1\r\n\r\n");
940 sleep(Millis(25)).await;
941
942 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
944
945 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
947
948 client.remote_buffer_cap(10240);
949 sleep(Millis(50)).await;
950 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
951
952 client.remote_buffer_cap(48056);
953 sleep(Millis(50)).await;
954 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
955
956 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
958 }
959
960 #[ntex::test]
961 async fn disconnect_during_read_backpressure() {
962 let (client, server) = IoTest::create();
963 client.remote_buffer_cap(0);
964
965 let (disp, state) = Dispatcher::debug(
966 Io::new(
967 server,
968 SharedCfg::new("TEST").add(
969 IoConfig::new()
970 .set_keepalive_timeout(Seconds::ZERO)
971 .set_read_buf(1024, 512, 16),
972 ),
973 ),
974 BytesCodec,
975 ntex_util::services::inflight::InFlightService::new(
976 1,
977 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
978 if let DispatchItem::Item(_) = msg {
979 sleep(Millis(500)).await;
980 Ok::<_, ()>(None)
981 } else {
982 Ok(None)
983 }
984 }),
985 ),
986 );
987
988 let (tx, rx) = ntex::channel::oneshot::channel();
989 ntex::rt::spawn(async move {
990 let _ = disp.await;
991 let _ = tx.send(());
992 });
993
994 let bytes = rand::rng()
995 .sample_iter(&rand::distr::Alphanumeric)
996 .take(1024)
997 .map(char::from)
998 .collect::<String>();
999 client.write(bytes.clone());
1000 sleep(Millis(25)).await;
1001 client.write(bytes);
1002 sleep(Millis(25)).await;
1003
1004 state.close();
1006 let _ = rx.recv().await;
1007 }
1008
1009 #[ntex::test]
1010 async fn keepalive() {
1011 let (client, server) = IoTest::create();
1012 client.remote_buffer_cap(1024);
1013 client.write("GET /test HTTP/1\r\n\r\n");
1014
1015 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1016 let data2 = data.clone();
1017
1018 let cfg = SharedCfg::new("DBG").add(
1019 IoConfig::new()
1020 .set_disconnect_timeout(Seconds(1))
1021 .set_keepalive_timeout(Seconds(1)),
1022 );
1023
1024 let (disp, state) = Dispatcher::debug(
1025 Io::new(server, cfg),
1026 BytesCodec,
1027 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1028 let data = data2.clone();
1029 async move {
1030 match msg {
1031 DispatchItem::Item(bytes) => {
1032 data.lock().unwrap().borrow_mut().push(0);
1033 return Ok::<_, ()>(Some(bytes));
1034 }
1035 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1036 data.lock().unwrap().borrow_mut().push(1);
1037 }
1038 _ => (),
1039 }
1040 Ok(None)
1041 }
1042 }),
1043 );
1044 spawn(async move {
1045 let _ = disp.await;
1046 });
1047
1048 let buf = client.read().await.unwrap();
1049 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1050 sleep(Millis(2000)).await;
1051
1052 let flags = state.flags();
1054 assert!(flags.contains(Flags::IO_STOPPING));
1055 assert!(client.is_closed());
1056 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1057 }
1058
1059 #[ntex::test]
1060 async fn keepalive2() {
1061 let (client, server) = IoTest::create();
1062 client.remote_buffer_cap(1024);
1063
1064 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1065 let data2 = data.clone();
1066
1067 let cfg = SharedCfg::new("DBG").add(
1068 IoConfig::new()
1069 .set_keepalive_timeout(Seconds(1))
1070 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1071 );
1072
1073 let (disp, state) = Dispatcher::debug(
1074 Io::new(server, cfg),
1075 BCodec(8),
1076 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1077 let data = data2.clone();
1078 async move {
1079 match msg {
1080 DispatchItem::Item(bytes) => {
1081 data.lock().unwrap().borrow_mut().push(0);
1082 return Ok::<_, ()>(Some(bytes));
1083 }
1084 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1085 data.lock().unwrap().borrow_mut().push(1);
1086 }
1087 _ => (),
1088 }
1089 Ok(None)
1090 }
1091 }),
1092 );
1093 spawn(async move {
1094 let _ = disp.await;
1095 });
1096
1097 client.write("12345678");
1098 let buf = client.read().await.unwrap();
1099 assert_eq!(buf, Bytes::from_static(b"12345678"));
1100 sleep(Millis(2000)).await;
1101
1102 let flags = state.flags();
1104 assert!(flags.contains(Flags::IO_STOPPING));
1105 assert!(client.is_closed());
1106 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1107 }
1108
1109 #[ntex::test]
1111 async fn keepalive3() {
1112 let (client, server) = IoTest::create();
1113 client.remote_buffer_cap(1024);
1114
1115 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1116 let data2 = data.clone();
1117
1118 let cfg = SharedCfg::new("DBG").add(
1119 IoConfig::new()
1120 .set_keepalive_timeout(Seconds(2))
1121 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1122 );
1123
1124 let (disp, _) = Dispatcher::debug(
1125 Io::new(server, cfg),
1126 BCodec(1),
1127 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1128 let data = data2.clone();
1129 async move {
1130 match msg {
1131 DispatchItem::Item(bytes) => {
1132 data.lock().unwrap().borrow_mut().push(0);
1133 return Ok::<_, ()>(Some(bytes));
1134 }
1135 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1136 data.lock().unwrap().borrow_mut().push(1);
1137 }
1138 _ => (),
1139 }
1140 Ok(None)
1141 }
1142 }),
1143 );
1144 spawn(async move {
1145 let _ = disp.await;
1146 });
1147
1148 client.write("1");
1149 let buf = client.read().await.unwrap();
1150 assert_eq!(buf, Bytes::from_static(b"1"));
1151 sleep(Millis(750)).await;
1152
1153 client.write("2");
1154 let buf = client.read().await.unwrap();
1155 assert_eq!(buf, Bytes::from_static(b"2"));
1156
1157 sleep(Millis(750)).await;
1158 client.write("3");
1159 let buf = client.read().await.unwrap();
1160 assert_eq!(buf, Bytes::from_static(b"3"));
1161
1162 sleep(Millis(750)).await;
1163 assert!(!client.is_closed());
1164 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1165 }
1166
1167 #[ntex::test]
1168 async fn read_timeout() {
1169 let (client, server) = IoTest::create();
1170 client.remote_buffer_cap(1024);
1171
1172 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1173 let data2 = data.clone();
1174
1175 let io = Io::new(
1176 server,
1177 SharedCfg::new("TEST").add(
1178 IoConfig::new()
1179 .set_keepalive_timeout(Seconds::ZERO)
1180 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1181 ),
1182 );
1183
1184 let (disp, state) = Dispatcher::debug(
1185 io,
1186 BCodec(8),
1187 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1188 let data = data2.clone();
1189 async move {
1190 match msg {
1191 DispatchItem::Item(bytes) => {
1192 data.lock().unwrap().borrow_mut().push(0);
1193 return Ok::<_, ()>(Some(bytes));
1194 }
1195 DispatchItem::Stop(Reason::ReadTimeout) => {
1196 data.lock().unwrap().borrow_mut().push(1);
1197 }
1198 _ => (),
1199 }
1200 Ok(None)
1201 }
1202 }),
1203 );
1204 spawn(async move {
1205 let _ = disp.await;
1206 });
1207
1208 client.write("12345678");
1209 let buf = client.read().await.unwrap();
1210 assert_eq!(buf, Bytes::from_static(b"12345678"));
1211
1212 client.write("1");
1213 sleep(Millis(1000)).await;
1214 assert!(!state.flags().contains(Flags::IO_STOPPING));
1215 client.write("23");
1216 sleep(Millis(1000)).await;
1217 assert!(!state.flags().contains(Flags::IO_STOPPING));
1218 client.write("4");
1219 sleep(Millis(2000)).await;
1220
1221 assert!(state.flags().contains(Flags::IO_STOPPING));
1223 assert!(client.is_closed());
1224 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1225 }
1226
1227 #[ntex::test]
1228 async fn idle_timeout() {
1229 let (client, server) = IoTest::create();
1230 client.remote_buffer_cap(1024);
1231
1232 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1233 let data2 = data.clone();
1234
1235 let io = Io::new(
1236 server,
1237 SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1238 );
1239 let ioref = io.get_ref();
1240
1241 let (disp, state) = Dispatcher::debug(
1242 io,
1243 BCodec(1),
1244 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1245 let ioref = ioref.clone();
1246 ntex::rt::spawn(async move {
1247 sleep(Millis(500)).await;
1248 ioref.notify_timeout();
1249 });
1250 let data = data2.clone();
1251 async move {
1252 match msg {
1253 DispatchItem::Item(bytes) => {
1254 data.lock().unwrap().borrow_mut().push(0);
1255 return Ok::<_, ()>(Some(bytes));
1256 }
1257 DispatchItem::Stop(Reason::ReadTimeout) => {
1258 data.lock().unwrap().borrow_mut().push(1);
1259 }
1260 _ => (),
1261 }
1262 Ok(None)
1263 }
1264 }),
1265 );
1266 spawn(async move {
1267 let _ = disp.await;
1268 });
1269
1270 client.write("1");
1271 let buf = client.read().await.unwrap();
1272 assert_eq!(buf, Bytes::from_static(b"1"));
1273
1274 sleep(Millis(1000)).await;
1275 assert!(state.flags().contains(Flags::IO_STOPPING));
1276 assert!(client.is_closed());
1277 }
1278
1279 #[ntex::test]
1280 async fn unhandled_data() {
1281 let handled = Arc::new(AtomicBool::new(false));
1282 let handled2 = handled.clone();
1283
1284 let (client, server) = IoTest::create();
1285 client.remote_buffer_cap(1024);
1286 client.write("GET /test HTTP/1\r\n\r\n");
1287
1288 let (disp, _) = Dispatcher::debug(
1289 Io::from(server),
1290 BytesCodec,
1291 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1292 handled2.store(true, Relaxed);
1293 async move {
1294 sleep(Millis(50)).await;
1295 if let DispatchItem::Item(msg) = msg {
1296 Ok::<_, ()>(Some(msg))
1297 } else if let DispatchItem::Stop(_) = msg {
1298 Ok::<_, ()>(None)
1299 } else {
1300 panic!()
1301 }
1302 }
1303 }),
1304 );
1305 client.close().await;
1306 spawn(async move {
1307 let _ = disp.await;
1308 });
1309 sleep(Millis(50)).await;
1310
1311 assert!(handled.load(Relaxed));
1312 }
1313}