1#![allow(clippy::let_underscore_future)]
3use std::task::{ready, Context, Poll};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14#[derive(Clone, Debug)]
15pub struct DispatcherConfig(Rc<DispatcherConfigInner>);
17
18#[derive(Debug)]
19struct DispatcherConfigInner {
20 keepalive_timeout: Cell<Seconds>,
21 disconnect_timeout: Cell<Seconds>,
22 frame_read_enabled: Cell<bool>,
23 frame_read_rate: Cell<u16>,
24 frame_read_timeout: Cell<Seconds>,
25 frame_read_max_timeout: Cell<Seconds>,
26}
27
28impl Default for DispatcherConfig {
29 fn default() -> Self {
30 DispatcherConfig(Rc::new(DispatcherConfigInner {
31 keepalive_timeout: Cell::new(Seconds(30)),
32 disconnect_timeout: Cell::new(Seconds(1)),
33 frame_read_rate: Cell::new(0),
34 frame_read_enabled: Cell::new(false),
35 frame_read_timeout: Cell::new(Seconds::ZERO),
36 frame_read_max_timeout: Cell::new(Seconds::ZERO),
37 }))
38 }
39}
40
41impl DispatcherConfig {
42 #[inline]
43 pub fn keepalive_timeout(&self) -> Seconds {
45 self.0.keepalive_timeout.get()
46 }
47
48 #[inline]
49 pub fn disconnect_timeout(&self) -> Seconds {
51 self.0.disconnect_timeout.get()
52 }
53
54 #[inline]
55 pub fn frame_read_rate(&self) -> Option<(Seconds, Seconds, u16)> {
57 if self.0.frame_read_enabled.get() {
58 Some((
59 self.0.frame_read_timeout.get(),
60 self.0.frame_read_max_timeout.get(),
61 self.0.frame_read_rate.get(),
62 ))
63 } else {
64 None
65 }
66 }
67
68 pub fn set_keepalive_timeout(&self, timeout: Seconds) -> &Self {
74 self.0.keepalive_timeout.set(timeout);
75 self
76 }
77
78 pub fn set_disconnect_timeout(&self, timeout: Seconds) -> &Self {
87 self.0.disconnect_timeout.set(timeout);
88 self
89 }
90
91 pub fn set_frame_read_rate(
99 &self,
100 timeout: Seconds,
101 max_timeout: Seconds,
102 rate: u16,
103 ) -> &Self {
104 self.0.frame_read_enabled.set(!timeout.is_zero());
105 self.0.frame_read_timeout.set(timeout);
106 self.0.frame_read_max_timeout.set(max_timeout);
107 self.0.frame_read_rate.set(rate);
108 self
109 }
110}
111
112pin_project_lite::pin_project! {
113 pub struct Dispatcher<S, U>
116 where
117 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
118 U: Encoder,
119 U: Decoder,
120 U: 'static,
121 {
122 inner: DispatcherInner<S, U>,
123 }
124}
125
126bitflags::bitflags! {
127 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
128 struct Flags: u8 {
129 const READY_ERR = 0b0000001;
130 const IO_ERR = 0b0000010;
131 const KA_ENABLED = 0b0000100;
132 const KA_TIMEOUT = 0b0001000;
133 const READ_TIMEOUT = 0b0010000;
134 }
135}
136
137struct DispatcherInner<S, U>
138where
139 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
140 U: Encoder + Decoder + 'static,
141{
142 st: DispatcherState,
143 error: Option<S::Error>,
144 flags: Flags,
145 shared: Rc<DispatcherShared<S, U>>,
146 response: Option<PipelineCall<S, DispatchItem<U>>>,
147 cfg: DispatcherConfig,
148 read_remains: u32,
149 read_remains_prev: u32,
150 read_max_timeout: Seconds,
151}
152
153pub(crate) struct DispatcherShared<S, U>
154where
155 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
156 U: Encoder + Decoder,
157{
158 io: IoBoxed,
159 codec: U,
160 service: PipelineBinding<S, DispatchItem<U>>,
161 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
162 inflight: Cell<u32>,
163}
164
165#[derive(Copy, Clone, Debug)]
166enum DispatcherState {
167 Processing,
168 Backpressure,
169 Stop,
170 Shutdown,
171}
172
173#[derive(Debug)]
174enum DispatcherError<S, U> {
175 Encoder(U),
176 Service(S),
177}
178
179enum PollService<U: Encoder + Decoder> {
180 Item(DispatchItem<U>),
181 Continue,
182 Ready,
183}
184
185impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
186 fn from(err: Either<S, U>) -> Self {
187 match err {
188 Either::Left(err) => DispatcherError::Service(err),
189 Either::Right(err) => DispatcherError::Encoder(err),
190 }
191 }
192}
193
194impl<S, U> Dispatcher<S, U>
195where
196 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
197 U: Decoder + Encoder + 'static,
198{
199 pub fn new<Io, F>(
201 io: Io,
202 codec: U,
203 service: F,
204 cfg: &DispatcherConfig,
205 ) -> Dispatcher<S, U>
206 where
207 IoBoxed: From<Io>,
208 F: IntoService<S, DispatchItem<U>>,
209 {
210 let io = IoBoxed::from(io);
211 io.set_disconnect_timeout(cfg.disconnect_timeout());
212
213 let flags = if cfg.keepalive_timeout().is_zero() {
214 Flags::empty()
215 } else {
216 Flags::KA_ENABLED
217 };
218
219 let shared = Rc::new(DispatcherShared {
220 io,
221 codec,
222 error: Cell::new(None),
223 inflight: Cell::new(0),
224 service: Pipeline::new(service.into_service()).bind(),
225 });
226
227 Dispatcher {
228 inner: DispatcherInner {
229 shared,
230 flags,
231 cfg: cfg.clone(),
232 response: None,
233 error: None,
234 read_remains: 0,
235 read_remains_prev: 0,
236 read_max_timeout: Seconds::ZERO,
237 st: DispatcherState::Processing,
238 },
239 }
240 }
241}
242
243impl<S, U> DispatcherShared<S, U>
244where
245 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
246 U: Encoder + Decoder,
247{
248 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
249 match item {
250 Ok(Some(val)) => {
251 if let Err(err) = io.encode(val, &self.codec) {
252 self.error.set(Some(DispatcherError::Encoder(err)))
253 }
254 }
255 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
256 Ok(None) => (),
257 }
258 self.inflight.set(self.inflight.get() - 1);
259 if wake {
260 io.wake();
261 }
262 }
263}
264
265impl<S, U> Future for Dispatcher<S, U>
266where
267 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
268 U: Decoder + Encoder + 'static,
269{
270 type Output = Result<(), S::Error>;
271
272 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273 let mut this = self.as_mut().project();
274 let slf = &mut this.inner;
275
276 if let Some(fut) = slf.response.as_mut() {
278 if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
279 slf.shared.handle_result(item, &slf.shared.io, false);
280 slf.response = None;
281 }
282 }
283
284 loop {
285 match slf.st {
286 DispatcherState::Processing => {
287 let item = match ready!(slf.poll_service(cx)) {
288 PollService::Ready => {
289 match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
291 Ok(decoded) => {
292 slf.update_timer(&decoded);
293 if let Some(el) = decoded.item {
294 DispatchItem::Item(el)
295 } else {
296 return Poll::Pending;
297 }
298 }
299 Err(RecvError::KeepAlive) => {
300 if let Err(err) = slf.handle_timeout() {
301 slf.st = DispatcherState::Stop;
302 err
303 } else {
304 continue;
305 }
306 }
307 Err(RecvError::Stop) => {
308 log::trace!(
309 "{}: Dispatcher is instructed to stop",
310 slf.shared.io.tag()
311 );
312 slf.st = DispatcherState::Stop;
313 continue;
314 }
315 Err(RecvError::WriteBackpressure) => {
316 slf.st = DispatcherState::Backpressure;
318 DispatchItem::WBackPressureEnabled
319 }
320 Err(RecvError::Decoder(err)) => {
321 log::trace!(
322 "{}: Decoder error, stopping dispatcher: {:?}",
323 slf.shared.io.tag(),
324 err
325 );
326 slf.st = DispatcherState::Stop;
327 DispatchItem::DecoderError(err)
328 }
329 Err(RecvError::PeerGone(err)) => {
330 log::trace!(
331 "{}: Peer is gone, stopping dispatcher: {:?}",
332 slf.shared.io.tag(),
333 err
334 );
335 slf.st = DispatcherState::Stop;
336 DispatchItem::Disconnect(err)
337 }
338 }
339 }
340 PollService::Item(item) => item,
341 PollService::Continue => continue,
342 };
343
344 slf.call_service(cx, item);
345 }
346 DispatcherState::Backpressure => {
348 match ready!(slf.poll_service(cx)) {
349 PollService::Ready => (),
350 PollService::Item(item) => slf.call_service(cx, item),
351 PollService::Continue => continue,
352 };
353
354 let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
355 {
356 slf.st = DispatcherState::Stop;
357 DispatchItem::Disconnect(Some(err))
358 } else {
359 slf.st = DispatcherState::Processing;
360 DispatchItem::WBackPressureDisabled
361 };
362 slf.call_service(cx, item);
363 }
364 DispatcherState::Stop => {
366 slf.shared.io.stop_timer();
367
368 if !slf.flags.contains(Flags::READY_ERR) {
370 if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
371 if res.is_err() {
372 slf.flags.insert(Flags::READY_ERR);
373 }
374 }
375 }
376
377 if slf.shared.inflight.get() == 0 {
378 if slf.shared.io.poll_shutdown(cx).is_ready() {
379 slf.st = DispatcherState::Shutdown;
380 continue;
381 }
382 } else if !slf.flags.contains(Flags::IO_ERR) {
383 match ready!(slf.shared.io.poll_status_update(cx)) {
384 IoStatusUpdate::PeerGone(_)
385 | IoStatusUpdate::Stop
386 | IoStatusUpdate::KeepAlive => {
387 slf.flags.insert(Flags::IO_ERR);
388 continue;
389 }
390 IoStatusUpdate::WriteBackpressure => {
391 if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
392 slf.flags.insert(Flags::IO_ERR);
393 }
394 continue;
395 }
396 }
397 } else {
398 slf.shared.io.poll_dispatch(cx);
399 }
400 return Poll::Pending;
401 }
402 DispatcherState::Shutdown => {
404 return if slf.shared.service.poll_shutdown(cx).is_ready() {
405 log::trace!(
406 "{}: Service shutdown is completed, stop",
407 slf.shared.io.tag()
408 );
409
410 Poll::Ready(if let Some(err) = slf.error.take() {
411 Err(err)
412 } else {
413 Ok(())
414 })
415 } else {
416 Poll::Pending
417 };
418 }
419 }
420 }
421 }
422}
423
424impl<S, U> DispatcherInner<S, U>
425where
426 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
427 U: Decoder + Encoder + 'static,
428{
429 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
430 let mut fut = self.shared.service.call_nowait(item);
431 self.shared.inflight.set(self.shared.inflight.get() + 1);
432
433 if self.response.is_none() {
435 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
436 self.shared.handle_result(result, &self.shared.io, false);
437 } else {
438 self.response = Some(fut);
439 }
440 } else {
441 let shared = self.shared.clone();
442 spawn(async move {
443 let result = fut.await;
444 shared.handle_result(result, &shared.io, true);
445 });
446 }
447 }
448
449 fn check_error(&mut self) -> PollService<U> {
450 if let Some(err) = self.shared.error.take() {
452 log::trace!(
453 "{}: Error occured, stopping dispatcher",
454 self.shared.io.tag()
455 );
456 self.st = DispatcherState::Stop;
457
458 match err {
459 DispatcherError::Encoder(err) => {
460 PollService::Item(DispatchItem::EncoderError(err))
461 }
462 DispatcherError::Service(err) => {
463 self.error = Some(err);
464 PollService::Continue
465 }
466 }
467 } else {
468 PollService::Ready
469 }
470 }
471
472 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
473 match self.shared.service.poll_ready(cx) {
475 Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
476 Poll::Pending => {
478 log::trace!(
479 "{}: Service is not ready, register dispatcher",
480 self.shared.io.tag()
481 );
482
483 self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
485 self.shared.io.stop_timer();
486
487 match ready!(self.shared.io.poll_read_pause(cx)) {
488 IoStatusUpdate::KeepAlive => {
489 log::trace!(
490 "{}: Keep-alive error, stopping dispatcher during pause",
491 self.shared.io.tag()
492 );
493 self.st = DispatcherState::Stop;
494 Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
495 }
496 IoStatusUpdate::Stop => {
497 log::trace!(
498 "{}: Dispatcher is instructed to stop during pause",
499 self.shared.io.tag()
500 );
501 self.st = DispatcherState::Stop;
502 Poll::Ready(PollService::Continue)
503 }
504 IoStatusUpdate::PeerGone(err) => {
505 log::trace!(
506 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
507 self.shared.io.tag(),
508 err
509 );
510 self.st = DispatcherState::Stop;
511 Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
512 }
513 IoStatusUpdate::WriteBackpressure => {
514 self.st = DispatcherState::Backpressure;
515 Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
516 }
517 }
518 }
519 Poll::Ready(Err(err)) => {
521 log::trace!(
522 "{}: Service readiness check failed, stopping",
523 self.shared.io.tag()
524 );
525 self.st = DispatcherState::Stop;
526 self.error = Some(err);
527 self.flags.insert(Flags::READY_ERR);
528 Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
529 }
530 }
531 }
532
533 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
534 if decoded.item.is_some() {
536 self.read_remains = 0;
537 self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
538 } else if self.flags.contains(Flags::READ_TIMEOUT) {
539 self.read_remains = decoded.remains as u32;
541 } else if self.read_remains == 0 && decoded.remains == 0 {
542 if self.flags.contains(Flags::KA_ENABLED)
544 && !self.flags.contains(Flags::KA_TIMEOUT)
545 {
546 log::trace!(
547 "{}: Start keep-alive timer {:?}",
548 self.shared.io.tag(),
549 self.cfg.keepalive_timeout()
550 );
551 self.flags.insert(Flags::KA_TIMEOUT);
552 self.shared.io.start_timer(self.cfg.keepalive_timeout());
553 }
554 } else if let Some((timeout, max, _)) = self.cfg.frame_read_rate() {
555 self.flags.insert(Flags::READ_TIMEOUT);
558
559 self.read_remains = decoded.remains as u32;
560 self.read_remains_prev = 0;
561 self.read_max_timeout = max;
562 self.shared.io.start_timer(timeout);
563 }
564 }
565
566 fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
567 if self.flags.contains(Flags::READ_TIMEOUT) {
569 if let Some((timeout, max, rate)) = self.cfg.frame_read_rate() {
570 let total = (self.read_remains - self.read_remains_prev)
571 .try_into()
572 .unwrap_or(u16::MAX);
573
574 if total > rate {
576 self.read_remains_prev = self.read_remains;
577 self.read_remains = 0;
578
579 if !max.is_zero() {
580 self.read_max_timeout =
581 Seconds(self.read_max_timeout.0.saturating_sub(timeout.0));
582 }
583
584 if max.is_zero() || !self.read_max_timeout.is_zero() {
585 log::trace!(
586 "{}: Frame read rate {:?}, extend timer",
587 self.shared.io.tag(),
588 total
589 );
590 self.shared.io.start_timer(timeout);
591 return Ok(());
592 }
593 log::trace!(
594 "{}: Max payload timeout has been reached",
595 self.shared.io.tag()
596 );
597 }
598 Err(DispatchItem::ReadTimeout)
599 } else {
600 Ok(())
601 }
602 } else if self.flags.contains(Flags::KA_TIMEOUT) {
603 log::trace!(
604 "{}: Keep-alive error, stopping dispatcher",
605 self.shared.io.tag()
606 );
607 Err(DispatchItem::KeepAliveTimeout)
608 } else {
609 Ok(())
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
617 use std::{cell::RefCell, io};
618
619 use ntex_bytes::{Bytes, BytesMut, PoolId, PoolRef};
620 use ntex_codec::BytesCodec;
621 use ntex_service::ServiceCtx;
622 use ntex_util::{time::sleep, time::Millis};
623 use rand::Rng;
624
625 use super::*;
626 use crate::{testing::IoTest, Flags, Io, IoRef, IoStream};
627
628 pub(crate) struct State(IoRef);
629
630 impl State {
631 fn flags(&self) -> Flags {
632 self.0.flags()
633 }
634
635 fn io(&self) -> &IoRef {
636 &self.0
637 }
638
639 fn close(&self) {
640 self.0 .0.insert_flags(Flags::DSP_STOP);
641 self.0 .0.dispatch_task.wake();
642 }
643
644 fn set_memory_pool(&self, pool: PoolRef) {
645 self.0 .0.pool.set(pool);
646 }
647 }
648
649 #[derive(Copy, Clone)]
650 struct BCodec(usize);
651
652 impl Encoder for BCodec {
653 type Item = Bytes;
654 type Error = io::Error;
655
656 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
657 dst.extend_from_slice(&item[..]);
658 Ok(())
659 }
660 }
661
662 impl Decoder for BCodec {
663 type Item = BytesMut;
664 type Error = io::Error;
665
666 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
667 if src.len() < self.0 {
668 Ok(None)
669 } else {
670 Ok(Some(src.split_to(self.0)))
671 }
672 }
673 }
674
675 impl<S, U> Dispatcher<S, U>
676 where
677 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
678 U: Decoder + Encoder + 'static,
679 {
680 pub(crate) fn debug<T: IoStream>(io: T, codec: U, service: S) -> (Self, State) {
682 let cfg = DispatcherConfig::default()
683 .set_keepalive_timeout(Seconds(1))
684 .clone();
685 Self::debug_cfg(io, codec, service, cfg)
686 }
687
688 pub(crate) fn debug_cfg<T: IoStream>(
690 io: T,
691 codec: U,
692 service: S,
693 cfg: DispatcherConfig,
694 ) -> (Self, State) {
695 let state = Io::new(io);
696 state.set_disconnect_timeout(cfg.disconnect_timeout());
697 state.set_tag("DBG");
698
699 let flags = if cfg.keepalive_timeout().is_zero() {
700 super::Flags::empty()
701 } else {
702 super::Flags::KA_ENABLED
703 };
704
705 let inner = State(state.get_ref());
706 state.start_timer(Seconds::ONE);
707
708 let shared = Rc::new(DispatcherShared {
709 codec,
710 io: state.into(),
711 error: Cell::new(None),
712 inflight: Cell::new(0),
713 service: Pipeline::new(service).bind(),
714 });
715
716 (
717 Dispatcher {
718 inner: DispatcherInner {
719 error: None,
720 st: DispatcherState::Processing,
721 response: None,
722 read_remains: 0,
723 read_remains_prev: 0,
724 read_max_timeout: Seconds::ZERO,
725 shared,
726 cfg,
727 flags,
728 },
729 },
730 inner,
731 )
732 }
733 }
734
735 #[ntex::test]
736 async fn test_basic() {
737 let (client, server) = IoTest::create();
738 client.remote_buffer_cap(1024);
739 client.write("GET /test HTTP/1\r\n\r\n");
740
741 let (disp, _) = Dispatcher::debug(
742 server,
743 BytesCodec,
744 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
745 sleep(Millis(50)).await;
746 if let DispatchItem::Item(msg) = msg {
747 Ok::<_, ()>(Some(msg.freeze()))
748 } else {
749 panic!()
750 }
751 }),
752 );
753 spawn(async move {
754 let _ = disp.await;
755 });
756
757 sleep(Millis(25)).await;
758 let buf = client.read().await.unwrap();
759 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
760
761 client.write("GET /test HTTP/1\r\n\r\n");
762 let buf = client.read().await.unwrap();
763 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
764
765 client.close().await;
766 assert!(client.is_server_dropped());
767
768 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
769 }
770
771 #[ntex::test]
772 async fn test_sink() {
773 let (client, server) = IoTest::create();
774 client.remote_buffer_cap(1024);
775 client.write("GET /test HTTP/1\r\n\r\n");
776
777 let (disp, st) = Dispatcher::debug(
778 server,
779 BytesCodec,
780 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
781 if let DispatchItem::Item(msg) = msg {
782 Ok::<_, ()>(Some(msg.freeze()))
783 } else {
784 panic!()
785 }
786 }),
787 );
788 spawn(async move {
789 let _ = disp.await;
790 });
791
792 let buf = client.read().await.unwrap();
793 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
794
795 assert!(st
796 .io()
797 .encode(Bytes::from_static(b"test"), &BytesCodec)
798 .is_ok());
799 let buf = client.read().await.unwrap();
800 assert_eq!(buf, Bytes::from_static(b"test"));
801
802 st.close();
803 sleep(Millis(1500)).await;
804 assert!(client.is_server_dropped());
805 }
806
807 #[ntex::test]
808 async fn test_err_in_service() {
809 let (client, server) = IoTest::create();
810 client.remote_buffer_cap(0);
811 client.write("GET /test HTTP/1\r\n\r\n");
812
813 let (disp, state) = Dispatcher::debug(
814 server,
815 BytesCodec,
816 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
817 Err::<Option<Bytes>, _>(())
818 }),
819 );
820 state
821 .io()
822 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
823 .unwrap();
824 spawn(async move {
825 let _ = disp.await;
826 });
827
828 client.remote_buffer_cap(1024);
830 let buf = client.read().await.unwrap();
831 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
832
833 sleep(Millis(250)).await;
835 assert!(client.is_closed());
836
837 client.close().await;
839
840 assert!(client.is_server_dropped());
842 }
843
844 #[ntex::test]
845 async fn test_err_in_service_ready() {
846 let (client, server) = IoTest::create();
847 client.remote_buffer_cap(0);
848 client.write("GET /test HTTP/1\r\n\r\n");
849
850 let counter = Rc::new(Cell::new(0));
851
852 struct Srv(Rc<Cell<usize>>);
853
854 impl Service<DispatchItem<BytesCodec>> for Srv {
855 type Response = Option<Response<BytesCodec>>;
856 type Error = &'static str;
857
858 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
859 self.0.set(self.0.get() + 1);
860 Err("test")
861 }
862
863 async fn call(
864 &self,
865 _: DispatchItem<BytesCodec>,
866 _: ServiceCtx<'_, Self>,
867 ) -> Result<Self::Response, Self::Error> {
868 Ok(None)
869 }
870 }
871
872 let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
873 spawn(async move {
874 let res = disp.await;
875 assert_eq!(res, Err("test"));
876 });
877
878 state
879 .io()
880 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
881 .unwrap();
882
883 client.remote_buffer_cap(1024);
885 let buf = client.read().await.unwrap();
886 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
887
888 sleep(Millis(250)).await;
890 assert!(client.is_closed());
891
892 client.close().await;
894 assert!(client.is_server_dropped());
895
896 assert_eq!(counter.get(), 1);
898 }
899
900 #[ntex::test]
901 async fn test_write_backpressure() {
902 let (client, server) = IoTest::create();
903 client.remote_buffer_cap(0);
905 client.write("GET /test HTTP/1\r\n\r\n");
906
907 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
908 let data2 = data.clone();
909
910 let (disp, state) = Dispatcher::debug(
911 server,
912 BytesCodec,
913 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
914 let data = data2.clone();
915 async move {
916 match msg {
917 DispatchItem::Item(_) => {
918 data.lock().unwrap().borrow_mut().push(0);
919 let bytes = rand::rng()
920 .sample_iter(&rand::distr::Alphanumeric)
921 .take(65_536)
922 .map(char::from)
923 .collect::<String>();
924 return Ok::<_, ()>(Some(Bytes::from(bytes)));
925 }
926 DispatchItem::WBackPressureEnabled => {
927 data.lock().unwrap().borrow_mut().push(1);
928 }
929 DispatchItem::WBackPressureDisabled => {
930 data.lock().unwrap().borrow_mut().push(2);
931 }
932 _ => (),
933 }
934 Ok(None)
935 }
936 }),
937 );
938 let pool = PoolId::P10.pool_ref();
939 pool.set_read_params(8 * 1024, 1024);
940 pool.set_write_params(16 * 1024, 1024);
941 state.set_memory_pool(pool);
942
943 spawn(async move {
944 let _ = disp.await;
945 });
946
947 let buf = client.read_any();
948 assert_eq!(buf, Bytes::from_static(b""));
949 client.write("GET /test HTTP/1\r\n\r\n");
950 sleep(Millis(25)).await;
951
952 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
954
955 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
957
958 client.remote_buffer_cap(10240);
959 sleep(Millis(50)).await;
960 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
961
962 client.remote_buffer_cap(45056);
963 sleep(Millis(50)).await;
964 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 10240);
965
966 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
968 }
969
970 #[ntex::test]
971 async fn test_disconnect_during_read_backpressure() {
972 let (client, server) = IoTest::create();
973 client.remote_buffer_cap(0);
974
975 let (disp, state) = Dispatcher::debug(
976 server,
977 BytesCodec,
978 ntex_util::services::inflight::InFlightService::new(
979 1,
980 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
981 if let DispatchItem::Item(_) = msg {
982 sleep(Millis(500)).await;
983 Ok::<_, ()>(None)
984 } else {
985 Ok(None)
986 }
987 }),
988 ),
989 );
990 disp.inner.cfg.set_keepalive_timeout(Seconds::ZERO);
991 let pool = PoolId::P10.pool_ref();
992 pool.set_read_params(1024, 512);
993 state.set_memory_pool(pool);
994
995 let (tx, rx) = ntex::channel::oneshot::channel();
996 ntex::rt::spawn(async move {
997 let _ = disp.await;
998 let _ = tx.send(());
999 });
1000
1001 let bytes = rand::rng()
1002 .sample_iter(&rand::distr::Alphanumeric)
1003 .take(1024)
1004 .map(char::from)
1005 .collect::<String>();
1006 client.write(bytes.clone());
1007 sleep(Millis(25)).await;
1008 client.write(bytes);
1009 sleep(Millis(25)).await;
1010
1011 state.close();
1013 let _ = rx.recv().await;
1014 }
1015
1016 #[ntex::test]
1017 async fn test_keepalive() {
1018 let (client, server) = IoTest::create();
1019 client.remote_buffer_cap(1024);
1020 client.write("GET /test HTTP/1\r\n\r\n");
1021
1022 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1023 let data2 = data.clone();
1024
1025 let cfg = DispatcherConfig::default()
1026 .set_disconnect_timeout(Seconds(1))
1027 .set_keepalive_timeout(Seconds(1))
1028 .clone();
1029
1030 let (disp, state) = Dispatcher::debug_cfg(
1031 server,
1032 BytesCodec,
1033 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1034 let data = data2.clone();
1035 async move {
1036 match msg {
1037 DispatchItem::Item(bytes) => {
1038 data.lock().unwrap().borrow_mut().push(0);
1039 return Ok::<_, ()>(Some(bytes.freeze()));
1040 }
1041 DispatchItem::KeepAliveTimeout => {
1042 data.lock().unwrap().borrow_mut().push(1);
1043 }
1044 _ => (),
1045 }
1046 Ok(None)
1047 }
1048 }),
1049 cfg,
1050 );
1051 spawn(async move {
1052 let _ = disp.await;
1053 });
1054
1055 let buf = client.read().await.unwrap();
1056 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1057 sleep(Millis(2000)).await;
1058
1059 let flags = state.flags();
1061 assert!(flags.contains(Flags::IO_STOPPING));
1062 assert!(client.is_closed());
1063 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1064 }
1065
1066 #[ntex::test]
1067 async fn test_keepalive2() {
1068 let (client, server) = IoTest::create();
1069 client.remote_buffer_cap(1024);
1070
1071 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1072 let data2 = data.clone();
1073
1074 let cfg = DispatcherConfig::default()
1075 .set_keepalive_timeout(Seconds(1))
1076 .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1077 .clone();
1078
1079 let (disp, state) = Dispatcher::debug_cfg(
1080 server,
1081 BCodec(8),
1082 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1083 let data = data2.clone();
1084 async move {
1085 match msg {
1086 DispatchItem::Item(bytes) => {
1087 data.lock().unwrap().borrow_mut().push(0);
1088 return Ok::<_, ()>(Some(bytes.freeze()));
1089 }
1090 DispatchItem::KeepAliveTimeout => {
1091 data.lock().unwrap().borrow_mut().push(1);
1092 }
1093 _ => (),
1094 }
1095 Ok(None)
1096 }
1097 }),
1098 cfg,
1099 );
1100 spawn(async move {
1101 let _ = disp.await;
1102 });
1103
1104 client.write("12345678");
1105 let buf = client.read().await.unwrap();
1106 assert_eq!(buf, Bytes::from_static(b"12345678"));
1107 sleep(Millis(2000)).await;
1108
1109 let flags = state.flags();
1111 assert!(flags.contains(Flags::IO_STOPPING));
1112 assert!(client.is_closed());
1113 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1114 }
1115
1116 #[ntex::test]
1118 async fn test_keepalive3() {
1119 let (client, server) = IoTest::create();
1120 client.remote_buffer_cap(1024);
1121
1122 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1123 let data2 = data.clone();
1124
1125 let cfg = DispatcherConfig::default()
1126 .set_keepalive_timeout(Seconds(2))
1127 .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1128 .clone();
1129
1130 let (disp, _) = Dispatcher::debug_cfg(
1131 server,
1132 BCodec(1),
1133 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1134 let data = data2.clone();
1135 async move {
1136 match msg {
1137 DispatchItem::Item(bytes) => {
1138 data.lock().unwrap().borrow_mut().push(0);
1139 return Ok::<_, ()>(Some(bytes.freeze()));
1140 }
1141 DispatchItem::KeepAliveTimeout => {
1142 data.lock().unwrap().borrow_mut().push(1);
1143 }
1144 _ => (),
1145 }
1146 Ok(None)
1147 }
1148 }),
1149 cfg,
1150 );
1151 spawn(async move {
1152 let _ = disp.await;
1153 });
1154
1155 client.write("1");
1156 let buf = client.read().await.unwrap();
1157 assert_eq!(buf, Bytes::from_static(b"1"));
1158 sleep(Millis(750)).await;
1159
1160 client.write("2");
1161 let buf = client.read().await.unwrap();
1162 assert_eq!(buf, Bytes::from_static(b"2"));
1163
1164 sleep(Millis(750)).await;
1165 client.write("3");
1166 let buf = client.read().await.unwrap();
1167 assert_eq!(buf, Bytes::from_static(b"3"));
1168
1169 sleep(Millis(750)).await;
1170 assert!(!client.is_closed());
1171 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1172 }
1173
1174 #[ntex::test]
1175 async fn test_read_timeout() {
1176 let (client, server) = IoTest::create();
1177 client.remote_buffer_cap(1024);
1178
1179 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1180 let data2 = data.clone();
1181
1182 let (disp, state) = Dispatcher::debug(
1183 server,
1184 BCodec(8),
1185 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1186 let data = data2.clone();
1187 async move {
1188 match msg {
1189 DispatchItem::Item(bytes) => {
1190 data.lock().unwrap().borrow_mut().push(0);
1191 return Ok::<_, ()>(Some(bytes.freeze()));
1192 }
1193 DispatchItem::ReadTimeout => {
1194 data.lock().unwrap().borrow_mut().push(1);
1195 }
1196 _ => (),
1197 }
1198 Ok(None)
1199 }
1200 }),
1201 );
1202 spawn(async move {
1203 disp.inner
1204 .cfg
1205 .set_keepalive_timeout(Seconds::ZERO)
1206 .set_frame_read_rate(Seconds(1), Seconds(2), 2);
1207 let _ = disp.await;
1208 });
1209
1210 client.write("12345678");
1211 let buf = client.read().await.unwrap();
1212 assert_eq!(buf, Bytes::from_static(b"12345678"));
1213
1214 client.write("1");
1215 sleep(Millis(1000)).await;
1216 assert!(!state.flags().contains(Flags::IO_STOPPING));
1217 client.write("23");
1218 sleep(Millis(1000)).await;
1219 assert!(!state.flags().contains(Flags::IO_STOPPING));
1220 client.write("4");
1221 sleep(Millis(2000)).await;
1222
1223 assert!(state.flags().contains(Flags::IO_STOPPING));
1225 assert!(client.is_closed());
1226 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1227 }
1228
1229 #[ntex::test]
1230 async fn test_unhandled_data() {
1231 let handled = Arc::new(AtomicBool::new(false));
1232 let handled2 = handled.clone();
1233
1234 let (client, server) = IoTest::create();
1235 client.remote_buffer_cap(1024);
1236 client.write("GET /test HTTP/1\r\n\r\n");
1237
1238 let (disp, _) = Dispatcher::debug(
1239 server,
1240 BytesCodec,
1241 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1242 handled2.store(true, Relaxed);
1243 async move {
1244 sleep(Millis(50)).await;
1245 if let DispatchItem::Item(msg) = msg {
1246 Ok::<_, ()>(Some(msg.freeze()))
1247 } else if let DispatchItem::Disconnect(_) = msg {
1248 Ok::<_, ()>(None)
1249 } else {
1250 panic!()
1251 }
1252 }
1253 }),
1254 );
1255 client.close().await;
1256 spawn(async move {
1257 let _ = disp.await;
1258 });
1259 sleep(Millis(50)).await;
1260
1261 assert!(handled.load(Relaxed));
1262 }
1263}