1#![allow(clippy::let_underscore_future)]
3use std::task::{ready, Context, Poll};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14#[derive(Clone, Debug)]
15pub struct DispatcherConfig(Rc<DispatcherConfigInner>);
17
18#[derive(Debug)]
19struct DispatcherConfigInner {
20 keepalive_timeout: Cell<Seconds>,
21 disconnect_timeout: Cell<Seconds>,
22 frame_read_enabled: Cell<bool>,
23 frame_read_rate: Cell<u16>,
24 frame_read_timeout: Cell<Seconds>,
25 frame_read_max_timeout: Cell<Seconds>,
26}
27
28impl Default for DispatcherConfig {
29 fn default() -> Self {
30 DispatcherConfig(Rc::new(DispatcherConfigInner {
31 keepalive_timeout: Cell::new(Seconds(30)),
32 disconnect_timeout: Cell::new(Seconds(1)),
33 frame_read_rate: Cell::new(0),
34 frame_read_enabled: Cell::new(false),
35 frame_read_timeout: Cell::new(Seconds::ZERO),
36 frame_read_max_timeout: Cell::new(Seconds::ZERO),
37 }))
38 }
39}
40
41impl DispatcherConfig {
42 #[inline]
43 pub fn keepalive_timeout(&self) -> Seconds {
45 self.0.keepalive_timeout.get()
46 }
47
48 #[inline]
49 pub fn disconnect_timeout(&self) -> Seconds {
51 self.0.disconnect_timeout.get()
52 }
53
54 #[inline]
55 pub fn frame_read_rate(&self) -> Option<(Seconds, Seconds, u16)> {
57 if self.0.frame_read_enabled.get() {
58 Some((
59 self.0.frame_read_timeout.get(),
60 self.0.frame_read_max_timeout.get(),
61 self.0.frame_read_rate.get(),
62 ))
63 } else {
64 None
65 }
66 }
67
68 pub fn set_keepalive_timeout(&self, timeout: Seconds) -> &Self {
74 self.0.keepalive_timeout.set(timeout);
75 self
76 }
77
78 pub fn set_disconnect_timeout(&self, timeout: Seconds) -> &Self {
87 self.0.disconnect_timeout.set(timeout);
88 self
89 }
90
91 pub fn set_frame_read_rate(
99 &self,
100 timeout: Seconds,
101 max_timeout: Seconds,
102 rate: u16,
103 ) -> &Self {
104 self.0.frame_read_enabled.set(!timeout.is_zero());
105 self.0.frame_read_timeout.set(timeout);
106 self.0.frame_read_max_timeout.set(max_timeout);
107 self.0.frame_read_rate.set(rate);
108 self
109 }
110}
111
112pin_project_lite::pin_project! {
113 pub struct Dispatcher<S, U>
116 where
117 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
118 U: Encoder,
119 U: Decoder,
120 U: 'static,
121 {
122 inner: DispatcherInner<S, U>,
123 }
124}
125
126bitflags::bitflags! {
127 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
128 struct Flags: u8 {
129 const READY_ERR = 0b0000001;
130 const IO_ERR = 0b0000010;
131 const KA_ENABLED = 0b0000100;
132 const KA_TIMEOUT = 0b0001000;
133 const READ_TIMEOUT = 0b0010000;
134 const IDLE = 0b0100000;
135 }
136}
137
138struct DispatcherInner<S, U>
139where
140 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
141 U: Encoder + Decoder + 'static,
142{
143 st: DispatcherState,
144 error: Option<S::Error>,
145 shared: Rc<DispatcherShared<S, U>>,
146 response: Option<PipelineCall<S, DispatchItem<U>>>,
147 cfg: DispatcherConfig,
148 read_remains: u32,
149 read_remains_prev: u32,
150 read_max_timeout: Seconds,
151}
152
153pub(crate) struct DispatcherShared<S, U>
154where
155 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
156 U: Encoder + Decoder,
157{
158 io: IoBoxed,
159 codec: U,
160 service: PipelineBinding<S, DispatchItem<U>>,
161 flags: Cell<Flags>,
162 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
163 inflight: Cell<u32>,
164}
165
166#[derive(Copy, Clone, Debug)]
167enum DispatcherState {
168 Processing,
169 Backpressure,
170 Stop,
171 Shutdown,
172}
173
174#[derive(Debug)]
175enum DispatcherError<S, U> {
176 Encoder(U),
177 Service(S),
178}
179
180enum PollService<U: Encoder + Decoder> {
181 Item(DispatchItem<U>),
182 Continue,
183 Ready,
184}
185
186impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
187 fn from(err: Either<S, U>) -> Self {
188 match err {
189 Either::Left(err) => DispatcherError::Service(err),
190 Either::Right(err) => DispatcherError::Encoder(err),
191 }
192 }
193}
194
195impl<S, U> Dispatcher<S, U>
196where
197 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
198 U: Decoder + Encoder + 'static,
199{
200 pub fn new<Io, F>(
202 io: Io,
203 codec: U,
204 service: F,
205 cfg: &DispatcherConfig,
206 ) -> Dispatcher<S, U>
207 where
208 IoBoxed: From<Io>,
209 F: IntoService<S, DispatchItem<U>>,
210 {
211 let io = IoBoxed::from(io);
212 io.set_disconnect_timeout(cfg.disconnect_timeout());
213
214 let flags = if cfg.keepalive_timeout().is_zero() {
215 Flags::empty()
216 } else {
217 Flags::KA_ENABLED
218 };
219
220 let shared = Rc::new(DispatcherShared {
221 io,
222 codec,
223 flags: Cell::new(flags),
224 error: Cell::new(None),
225 inflight: Cell::new(0),
226 service: Pipeline::new(service.into_service()).bind(),
227 });
228
229 Dispatcher {
230 inner: DispatcherInner {
231 shared,
232 cfg: cfg.clone(),
233 response: None,
234 error: None,
235 read_remains: 0,
236 read_remains_prev: 0,
237 read_max_timeout: Seconds::ZERO,
238 st: DispatcherState::Processing,
239 },
240 }
241 }
242}
243
244impl<S, U> DispatcherShared<S, U>
245where
246 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
247 U: Encoder + Decoder,
248{
249 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
250 match item {
251 Ok(Some(val)) => {
252 if let Err(err) = io.encode(val, &self.codec) {
253 self.error.set(Some(DispatcherError::Encoder(err)))
254 }
255 }
256 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
257 Ok(None) => (),
258 }
259 let inflight = self.inflight.get() - 1;
260 self.inflight.set(inflight);
261 if inflight == 0 {
262 self.insert_flags(Flags::IDLE);
263 }
264 if wake {
265 io.wake();
266 }
267 }
268
269 fn contains(&self, f: Flags) -> bool {
270 self.flags.get().intersects(f)
271 }
272
273 fn insert_flags(&self, f: Flags) {
274 let mut flags = self.flags.get();
275 flags.insert(f);
276 self.flags.set(flags);
277 }
278
279 fn remove_flags(&self, f: Flags) -> bool {
280 let mut flags = self.flags.get();
281 if flags.intersects(f) {
282 flags.remove(f);
283 self.flags.set(flags);
284 true
285 } else {
286 false
287 }
288 }
289}
290
291impl<S, U> Future for Dispatcher<S, U>
292where
293 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
294 U: Decoder + Encoder + 'static,
295{
296 type Output = Result<(), S::Error>;
297
298 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299 let mut this = self.as_mut().project();
300 let slf = &mut this.inner;
301
302 if let Some(fut) = slf.response.as_mut() {
304 if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
305 slf.shared.handle_result(item, &slf.shared.io, false);
306 slf.response = None;
307 }
308 }
309
310 loop {
311 match slf.st {
312 DispatcherState::Processing => {
313 let item = match ready!(slf.poll_service(cx)) {
314 PollService::Ready => {
315 match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
317 Ok(decoded) => {
318 slf.update_timer(&decoded);
319 if let Some(el) = decoded.item {
320 DispatchItem::Item(el)
321 } else {
322 return Poll::Pending;
323 }
324 }
325 Err(RecvError::KeepAlive) => {
326 if let Err(err) = slf.handle_timeout() {
327 slf.st = DispatcherState::Stop;
328 err
329 } else {
330 continue;
331 }
332 }
333 Err(RecvError::Stop) => {
334 log::trace!(
335 "{}: Dispatcher is instructed to stop",
336 slf.shared.io.tag()
337 );
338 slf.st = DispatcherState::Stop;
339 continue;
340 }
341 Err(RecvError::WriteBackpressure) => {
342 slf.st = DispatcherState::Backpressure;
344 DispatchItem::WBackPressureEnabled
345 }
346 Err(RecvError::Decoder(err)) => {
347 log::trace!(
348 "{}: Decoder error, stopping dispatcher: {:?}",
349 slf.shared.io.tag(),
350 err
351 );
352 slf.st = DispatcherState::Stop;
353 DispatchItem::DecoderError(err)
354 }
355 Err(RecvError::PeerGone(err)) => {
356 log::trace!(
357 "{}: Peer is gone, stopping dispatcher: {:?}",
358 slf.shared.io.tag(),
359 err
360 );
361 slf.st = DispatcherState::Stop;
362 DispatchItem::Disconnect(err)
363 }
364 }
365 }
366 PollService::Item(item) => item,
367 PollService::Continue => continue,
368 };
369
370 slf.call_service(cx, item);
371 }
372 DispatcherState::Backpressure => {
374 match ready!(slf.poll_service(cx)) {
375 PollService::Ready => (),
376 PollService::Item(item) => slf.call_service(cx, item),
377 PollService::Continue => continue,
378 };
379
380 let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
381 {
382 slf.st = DispatcherState::Stop;
383 DispatchItem::Disconnect(Some(err))
384 } else {
385 slf.st = DispatcherState::Processing;
386 DispatchItem::WBackPressureDisabled
387 };
388 slf.call_service(cx, item);
389 }
390 DispatcherState::Stop => {
392 slf.shared.io.stop_timer();
393
394 if !slf.shared.contains(Flags::READY_ERR) {
396 if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
397 if res.is_err() {
398 slf.shared.insert_flags(Flags::READY_ERR);
399 }
400 }
401 }
402
403 if slf.shared.inflight.get() == 0 {
404 if slf.shared.io.poll_shutdown(cx).is_ready() {
405 slf.st = DispatcherState::Shutdown;
406 continue;
407 }
408 } else if !slf.shared.contains(Flags::IO_ERR) {
409 match ready!(slf.shared.io.poll_status_update(cx)) {
410 IoStatusUpdate::PeerGone(_)
411 | IoStatusUpdate::Stop
412 | IoStatusUpdate::KeepAlive => {
413 slf.shared.insert_flags(Flags::IO_ERR);
414 continue;
415 }
416 IoStatusUpdate::WriteBackpressure => {
417 if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
418 slf.shared.insert_flags(Flags::IO_ERR);
419 }
420 continue;
421 }
422 }
423 } else {
424 slf.shared.io.poll_dispatch(cx);
425 }
426 return Poll::Pending;
427 }
428 DispatcherState::Shutdown => {
430 return if slf.shared.service.poll_shutdown(cx).is_ready() {
431 log::trace!(
432 "{}: Service shutdown is completed, stop",
433 slf.shared.io.tag()
434 );
435
436 Poll::Ready(if let Some(err) = slf.error.take() {
437 Err(err)
438 } else {
439 Ok(())
440 })
441 } else {
442 Poll::Pending
443 };
444 }
445 }
446 }
447 }
448}
449
450impl<S, U> DispatcherInner<S, U>
451where
452 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
453 U: Decoder + Encoder + 'static,
454{
455 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
456 let mut fut = self.shared.service.call_nowait(item);
457 let inflight = self.shared.inflight.get() + 1;
458 self.shared.inflight.set(inflight);
459 if inflight == 1 {
460 self.shared.remove_flags(Flags::IDLE);
461 }
462
463 if self.response.is_none() {
465 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
466 self.shared.handle_result(result, &self.shared.io, false);
467 } else {
468 self.response = Some(fut);
469 }
470 } else {
471 let shared = self.shared.clone();
472 spawn(async move {
473 let result = fut.await;
474 shared.handle_result(result, &shared.io, true);
475 });
476 }
477 }
478
479 fn check_error(&mut self) -> PollService<U> {
480 if let Some(err) = self.shared.error.take() {
482 log::trace!(
483 "{}: Error occured, stopping dispatcher",
484 self.shared.io.tag()
485 );
486 self.st = DispatcherState::Stop;
487
488 match err {
489 DispatcherError::Encoder(err) => {
490 PollService::Item(DispatchItem::EncoderError(err))
491 }
492 DispatcherError::Service(err) => {
493 self.error = Some(err);
494 PollService::Continue
495 }
496 }
497 } else {
498 PollService::Ready
499 }
500 }
501
502 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
503 match self.shared.service.poll_ready(cx) {
505 Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
506 Poll::Pending => {
508 log::trace!(
509 "{}: Service is not ready, register dispatcher",
510 self.shared.io.tag()
511 );
512
513 self.shared
515 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
516 self.shared.io.stop_timer();
517
518 match ready!(self.shared.io.poll_read_pause(cx)) {
519 IoStatusUpdate::KeepAlive => {
520 log::trace!(
521 "{}: Keep-alive error, stopping dispatcher during pause",
522 self.shared.io.tag()
523 );
524 self.st = DispatcherState::Stop;
525 Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
526 }
527 IoStatusUpdate::Stop => {
528 log::trace!(
529 "{}: Dispatcher is instructed to stop during pause",
530 self.shared.io.tag()
531 );
532 self.st = DispatcherState::Stop;
533 Poll::Ready(PollService::Continue)
534 }
535 IoStatusUpdate::PeerGone(err) => {
536 log::trace!(
537 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
538 self.shared.io.tag(),
539 err
540 );
541 self.st = DispatcherState::Stop;
542 Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
543 }
544 IoStatusUpdate::WriteBackpressure => {
545 self.st = DispatcherState::Backpressure;
546 Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
547 }
548 }
549 }
550 Poll::Ready(Err(err)) => {
552 log::trace!(
553 "{}: Service readiness check failed, stopping",
554 self.shared.io.tag()
555 );
556 self.st = DispatcherState::Stop;
557 self.error = Some(err);
558 self.shared.insert_flags(Flags::READY_ERR);
559 Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
560 }
561 }
562 }
563
564 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
565 if decoded.item.is_some() {
567 self.read_remains = 0;
568 self.shared
569 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
570 } else if self.shared.contains(Flags::READ_TIMEOUT) {
571 self.read_remains = decoded.remains as u32;
573 } else if self.read_remains == 0 && decoded.remains == 0 {
574 if self.shared.contains(Flags::KA_ENABLED)
576 && !self.shared.contains(Flags::KA_TIMEOUT)
577 {
578 log::trace!(
579 "{}: Start keep-alive timer {:?}",
580 self.shared.io.tag(),
581 self.cfg.keepalive_timeout()
582 );
583 self.shared.insert_flags(Flags::KA_TIMEOUT);
584 self.shared.io.start_timer(self.cfg.keepalive_timeout());
585 }
586 } else if let Some((timeout, max, _)) = self.cfg.frame_read_rate() {
587 self.shared.insert_flags(Flags::READ_TIMEOUT);
590
591 self.read_remains = decoded.remains as u32;
592 self.read_remains_prev = 0;
593 self.read_max_timeout = max;
594 self.shared.io.start_timer(timeout);
595 }
596 }
597
598 fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
599 if self.shared.contains(Flags::READ_TIMEOUT) {
601 if let Some((timeout, max, rate)) = self.cfg.frame_read_rate() {
602 let total = (self.read_remains - self.read_remains_prev)
603 .try_into()
604 .unwrap_or(u16::MAX);
605
606 if total > rate {
608 self.read_remains_prev = self.read_remains;
609 self.read_remains = 0;
610
611 if !max.is_zero() {
612 self.read_max_timeout =
613 Seconds(self.read_max_timeout.0.saturating_sub(timeout.0));
614 }
615
616 if max.is_zero() || !self.read_max_timeout.is_zero() {
617 log::trace!(
618 "{}: Frame read rate {:?}, extend timer",
619 self.shared.io.tag(),
620 total
621 );
622 self.shared.io.start_timer(timeout);
623 return Ok(());
624 }
625 log::trace!(
626 "{}: Max payload timeout has been reached",
627 self.shared.io.tag()
628 );
629 }
630 Err(DispatchItem::ReadTimeout)
631 } else {
632 Ok(())
633 }
634 } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
635 log::trace!(
636 "{}: Keep-alive error, stopping dispatcher",
637 self.shared.io.tag()
638 );
639 Err(DispatchItem::KeepAliveTimeout)
640 } else {
641 Ok(())
642 }
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
649 use std::{cell::RefCell, io};
650
651 use ntex_bytes::{Bytes, BytesMut, PoolId, PoolRef};
652 use ntex_codec::BytesCodec;
653 use ntex_service::ServiceCtx;
654 use ntex_util::{time::sleep, time::Millis};
655 use rand::Rng;
656
657 use super::*;
658 use crate::{testing::IoTest, Flags, Io, IoRef};
659
660 pub(crate) struct State(IoRef);
661
662 impl State {
663 fn flags(&self) -> Flags {
664 self.0.flags()
665 }
666
667 fn io(&self) -> &IoRef {
668 &self.0
669 }
670
671 fn close(&self) {
672 self.0 .0.insert_flags(Flags::DSP_STOP);
673 self.0 .0.dispatch_task.wake();
674 }
675
676 fn set_memory_pool(&self, pool: PoolRef) {
677 self.0 .0.pool.set(pool);
678 }
679 }
680
681 #[derive(Copy, Clone)]
682 struct BCodec(usize);
683
684 impl Encoder for BCodec {
685 type Item = Bytes;
686 type Error = io::Error;
687
688 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
689 dst.extend_from_slice(&item[..]);
690 Ok(())
691 }
692 }
693
694 impl Decoder for BCodec {
695 type Item = BytesMut;
696 type Error = io::Error;
697
698 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
699 if src.len() < self.0 {
700 Ok(None)
701 } else {
702 Ok(Some(src.split_to(self.0)))
703 }
704 }
705 }
706
707 impl<S, U> Dispatcher<S, U>
708 where
709 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
710 U: Decoder + Encoder + 'static,
711 {
712 pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
714 let cfg = DispatcherConfig::default()
715 .set_keepalive_timeout(Seconds(1))
716 .clone();
717 Self::debug_cfg(io, codec, service, cfg)
718 }
719
720 pub(crate) fn debug_cfg(
722 state: Io,
723 codec: U,
724 service: S,
725 cfg: DispatcherConfig,
726 ) -> (Self, State) {
727 state.set_disconnect_timeout(cfg.disconnect_timeout());
728 state.set_tag("DBG");
729
730 let flags = if cfg.keepalive_timeout().is_zero() {
731 super::Flags::empty()
732 } else {
733 super::Flags::KA_ENABLED
734 };
735
736 let inner = State(state.get_ref());
737 state.start_timer(Seconds::ONE);
738
739 let shared = Rc::new(DispatcherShared {
740 codec,
741 io: state.into(),
742 flags: Cell::new(flags),
743 error: Cell::new(None),
744 inflight: Cell::new(0),
745 service: Pipeline::new(service).bind(),
746 });
747
748 (
749 Dispatcher {
750 inner: DispatcherInner {
751 cfg,
752 shared,
753 error: None,
754 st: DispatcherState::Processing,
755 response: None,
756 read_remains: 0,
757 read_remains_prev: 0,
758 read_max_timeout: Seconds::ZERO,
759 },
760 },
761 inner,
762 )
763 }
764 }
765
766 #[ntex::test]
767 async fn test_basic() {
768 let (client, server) = IoTest::create();
769 client.remote_buffer_cap(1024);
770 client.write("GET /test HTTP/1\r\n\r\n");
771
772 let (disp, _) = Dispatcher::debug(
773 Io::new(server),
774 BytesCodec,
775 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
776 sleep(Millis(50)).await;
777 if let DispatchItem::Item(msg) = msg {
778 Ok::<_, ()>(Some(msg.freeze()))
779 } else {
780 panic!()
781 }
782 }),
783 );
784 spawn(async move {
785 let _ = disp.await;
786 });
787
788 sleep(Millis(25)).await;
789 let buf = client.read().await.unwrap();
790 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
791
792 client.write("GET /test HTTP/1\r\n\r\n");
793 let buf = client.read().await.unwrap();
794 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
795
796 client.close().await;
797 assert!(client.is_server_dropped());
798
799 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
800 }
801
802 #[ntex::test]
803 async fn test_sink() {
804 let (client, server) = IoTest::create();
805 client.remote_buffer_cap(1024);
806 client.write("GET /test HTTP/1\r\n\r\n");
807
808 let (disp, st) = Dispatcher::debug(
809 Io::new(server),
810 BytesCodec,
811 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
812 if let DispatchItem::Item(msg) = msg {
813 Ok::<_, ()>(Some(msg.freeze()))
814 } else {
815 panic!()
816 }
817 }),
818 );
819 spawn(async move {
820 let _ = disp.await;
821 });
822
823 let buf = client.read().await.unwrap();
824 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
825
826 assert!(st
827 .io()
828 .encode(Bytes::from_static(b"test"), &BytesCodec)
829 .is_ok());
830 let buf = client.read().await.unwrap();
831 assert_eq!(buf, Bytes::from_static(b"test"));
832
833 st.close();
834 sleep(Millis(1500)).await;
835 assert!(client.is_server_dropped());
836 }
837
838 #[ntex::test]
839 async fn test_err_in_service() {
840 let (client, server) = IoTest::create();
841 client.remote_buffer_cap(0);
842 client.write("GET /test HTTP/1\r\n\r\n");
843
844 let (disp, state) = Dispatcher::debug(
845 Io::new(server),
846 BytesCodec,
847 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
848 Err::<Option<Bytes>, _>(())
849 }),
850 );
851 state
852 .io()
853 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
854 .unwrap();
855 spawn(async move {
856 let _ = disp.await;
857 });
858
859 client.remote_buffer_cap(1024);
861 let buf = client.read().await.unwrap();
862 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
863
864 sleep(Millis(250)).await;
866 assert!(client.is_closed());
867
868 client.close().await;
870
871 assert!(client.is_server_dropped());
873 }
874
875 #[ntex::test]
876 async fn test_err_in_service_ready() {
877 let (client, server) = IoTest::create();
878 client.remote_buffer_cap(0);
879 client.write("GET /test HTTP/1\r\n\r\n");
880
881 let counter = Rc::new(Cell::new(0));
882
883 struct Srv(Rc<Cell<usize>>);
884
885 impl Service<DispatchItem<BytesCodec>> for Srv {
886 type Response = Option<Response<BytesCodec>>;
887 type Error = &'static str;
888
889 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
890 self.0.set(self.0.get() + 1);
891 Err("test")
892 }
893
894 async fn call(
895 &self,
896 _: DispatchItem<BytesCodec>,
897 _: ServiceCtx<'_, Self>,
898 ) -> Result<Self::Response, Self::Error> {
899 Ok(None)
900 }
901 }
902
903 let (disp, state) =
904 Dispatcher::debug(Io::new(server), BytesCodec, Srv(counter.clone()));
905 spawn(async move {
906 let res = disp.await;
907 assert_eq!(res, Err("test"));
908 });
909
910 state
911 .io()
912 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
913 .unwrap();
914
915 client.remote_buffer_cap(1024);
917 let buf = client.read().await.unwrap();
918 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
919
920 sleep(Millis(250)).await;
922 assert!(client.is_closed());
923
924 client.close().await;
926 assert!(client.is_server_dropped());
927
928 assert_eq!(counter.get(), 1);
930 }
931
932 #[ntex::test]
933 async fn test_write_backpressure() {
934 let (client, server) = IoTest::create();
935 client.remote_buffer_cap(0);
937 client.write("GET /test HTTP/1\r\n\r\n");
938
939 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
940 let data2 = data.clone();
941
942 let (disp, state) = Dispatcher::debug(
943 Io::new(server),
944 BytesCodec,
945 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
946 let data = data2.clone();
947 async move {
948 match msg {
949 DispatchItem::Item(_) => {
950 data.lock().unwrap().borrow_mut().push(0);
951 let bytes = rand::rng()
952 .sample_iter(&rand::distr::Alphanumeric)
953 .take(65_536)
954 .map(char::from)
955 .collect::<String>();
956 return Ok::<_, ()>(Some(Bytes::from(bytes)));
957 }
958 DispatchItem::WBackPressureEnabled => {
959 data.lock().unwrap().borrow_mut().push(1);
960 }
961 DispatchItem::WBackPressureDisabled => {
962 data.lock().unwrap().borrow_mut().push(2);
963 }
964 _ => (),
965 }
966 Ok(None)
967 }
968 }),
969 );
970 let pool = PoolId::P10.pool_ref();
971 pool.set_read_params(8 * 1024, 1024);
972 pool.set_write_params(16 * 1024, 1024);
973 state.set_memory_pool(pool);
974
975 spawn(async move {
976 let _ = disp.await;
977 });
978
979 let buf = client.read_any();
980 assert_eq!(buf, Bytes::from_static(b""));
981 client.write("GET /test HTTP/1\r\n\r\n");
982 sleep(Millis(25)).await;
983
984 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
986
987 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
989
990 client.remote_buffer_cap(10240);
991 sleep(Millis(50)).await;
992 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
993
994 client.remote_buffer_cap(45056);
995 sleep(Millis(50)).await;
996 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 10240);
997
998 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
1000 }
1001
1002 #[ntex::test]
1003 async fn test_disconnect_during_read_backpressure() {
1004 let (client, server) = IoTest::create();
1005 client.remote_buffer_cap(0);
1006
1007 let (disp, state) = Dispatcher::debug(
1008 Io::new(server),
1009 BytesCodec,
1010 ntex_util::services::inflight::InFlightService::new(
1011 1,
1012 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
1013 if let DispatchItem::Item(_) = msg {
1014 sleep(Millis(500)).await;
1015 Ok::<_, ()>(None)
1016 } else {
1017 Ok(None)
1018 }
1019 }),
1020 ),
1021 );
1022 disp.inner.cfg.set_keepalive_timeout(Seconds::ZERO);
1023 let pool = PoolId::P10.pool_ref();
1024 pool.set_read_params(1024, 512);
1025 state.set_memory_pool(pool);
1026
1027 let (tx, rx) = ntex::channel::oneshot::channel();
1028 ntex::rt::spawn(async move {
1029 let _ = disp.await;
1030 let _ = tx.send(());
1031 });
1032
1033 let bytes = rand::rng()
1034 .sample_iter(&rand::distr::Alphanumeric)
1035 .take(1024)
1036 .map(char::from)
1037 .collect::<String>();
1038 client.write(bytes.clone());
1039 sleep(Millis(25)).await;
1040 client.write(bytes);
1041 sleep(Millis(25)).await;
1042
1043 state.close();
1045 let _ = rx.recv().await;
1046 }
1047
1048 #[ntex::test]
1049 async fn test_keepalive() {
1050 let (client, server) = IoTest::create();
1051 client.remote_buffer_cap(1024);
1052 client.write("GET /test HTTP/1\r\n\r\n");
1053
1054 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1055 let data2 = data.clone();
1056
1057 let cfg = DispatcherConfig::default()
1058 .set_disconnect_timeout(Seconds(1))
1059 .set_keepalive_timeout(Seconds(1))
1060 .clone();
1061
1062 let (disp, state) = Dispatcher::debug_cfg(
1063 Io::new(server),
1064 BytesCodec,
1065 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1066 let data = data2.clone();
1067 async move {
1068 match msg {
1069 DispatchItem::Item(bytes) => {
1070 data.lock().unwrap().borrow_mut().push(0);
1071 return Ok::<_, ()>(Some(bytes.freeze()));
1072 }
1073 DispatchItem::KeepAliveTimeout => {
1074 data.lock().unwrap().borrow_mut().push(1);
1075 }
1076 _ => (),
1077 }
1078 Ok(None)
1079 }
1080 }),
1081 cfg,
1082 );
1083 spawn(async move {
1084 let _ = disp.await;
1085 });
1086
1087 let buf = client.read().await.unwrap();
1088 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1089 sleep(Millis(2000)).await;
1090
1091 let flags = state.flags();
1093 assert!(flags.contains(Flags::IO_STOPPING));
1094 assert!(client.is_closed());
1095 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1096 }
1097
1098 #[ntex::test]
1099 async fn test_keepalive2() {
1100 let (client, server) = IoTest::create();
1101 client.remote_buffer_cap(1024);
1102
1103 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1104 let data2 = data.clone();
1105
1106 let cfg = DispatcherConfig::default()
1107 .set_keepalive_timeout(Seconds(1))
1108 .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1109 .clone();
1110
1111 let (disp, state) = Dispatcher::debug_cfg(
1112 Io::new(server),
1113 BCodec(8),
1114 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1115 let data = data2.clone();
1116 async move {
1117 match msg {
1118 DispatchItem::Item(bytes) => {
1119 data.lock().unwrap().borrow_mut().push(0);
1120 return Ok::<_, ()>(Some(bytes.freeze()));
1121 }
1122 DispatchItem::KeepAliveTimeout => {
1123 data.lock().unwrap().borrow_mut().push(1);
1124 }
1125 _ => (),
1126 }
1127 Ok(None)
1128 }
1129 }),
1130 cfg,
1131 );
1132 spawn(async move {
1133 let _ = disp.await;
1134 });
1135
1136 client.write("12345678");
1137 let buf = client.read().await.unwrap();
1138 assert_eq!(buf, Bytes::from_static(b"12345678"));
1139 sleep(Millis(2000)).await;
1140
1141 let flags = state.flags();
1143 assert!(flags.contains(Flags::IO_STOPPING));
1144 assert!(client.is_closed());
1145 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1146 }
1147
1148 #[ntex::test]
1150 async fn test_keepalive3() {
1151 let (client, server) = IoTest::create();
1152 client.remote_buffer_cap(1024);
1153
1154 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1155 let data2 = data.clone();
1156
1157 let cfg = DispatcherConfig::default()
1158 .set_keepalive_timeout(Seconds(2))
1159 .set_frame_read_rate(Seconds(1), Seconds(2), 2)
1160 .clone();
1161
1162 let (disp, _) = Dispatcher::debug_cfg(
1163 Io::new(server),
1164 BCodec(1),
1165 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1166 let data = data2.clone();
1167 async move {
1168 match msg {
1169 DispatchItem::Item(bytes) => {
1170 data.lock().unwrap().borrow_mut().push(0);
1171 return Ok::<_, ()>(Some(bytes.freeze()));
1172 }
1173 DispatchItem::KeepAliveTimeout => {
1174 data.lock().unwrap().borrow_mut().push(1);
1175 }
1176 _ => (),
1177 }
1178 Ok(None)
1179 }
1180 }),
1181 cfg,
1182 );
1183 spawn(async move {
1184 let _ = disp.await;
1185 });
1186
1187 client.write("1");
1188 let buf = client.read().await.unwrap();
1189 assert_eq!(buf, Bytes::from_static(b"1"));
1190 sleep(Millis(750)).await;
1191
1192 client.write("2");
1193 let buf = client.read().await.unwrap();
1194 assert_eq!(buf, Bytes::from_static(b"2"));
1195
1196 sleep(Millis(750)).await;
1197 client.write("3");
1198 let buf = client.read().await.unwrap();
1199 assert_eq!(buf, Bytes::from_static(b"3"));
1200
1201 sleep(Millis(750)).await;
1202 assert!(!client.is_closed());
1203 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1204 }
1205
1206 #[ntex::test]
1207 async fn test_read_timeout() {
1208 let (client, server) = IoTest::create();
1209 client.remote_buffer_cap(1024);
1210
1211 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1212 let data2 = data.clone();
1213
1214 let (disp, state) = Dispatcher::debug(
1215 Io::new(server),
1216 BCodec(8),
1217 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1218 let data = data2.clone();
1219 async move {
1220 match msg {
1221 DispatchItem::Item(bytes) => {
1222 data.lock().unwrap().borrow_mut().push(0);
1223 return Ok::<_, ()>(Some(bytes.freeze()));
1224 }
1225 DispatchItem::ReadTimeout => {
1226 data.lock().unwrap().borrow_mut().push(1);
1227 }
1228 _ => (),
1229 }
1230 Ok(None)
1231 }
1232 }),
1233 );
1234 spawn(async move {
1235 disp.inner
1236 .cfg
1237 .set_keepalive_timeout(Seconds::ZERO)
1238 .set_frame_read_rate(Seconds(1), Seconds(2), 2);
1239 let _ = disp.await;
1240 });
1241
1242 client.write("12345678");
1243 let buf = client.read().await.unwrap();
1244 assert_eq!(buf, Bytes::from_static(b"12345678"));
1245
1246 client.write("1");
1247 sleep(Millis(1000)).await;
1248 assert!(!state.flags().contains(Flags::IO_STOPPING));
1249 client.write("23");
1250 sleep(Millis(1000)).await;
1251 assert!(!state.flags().contains(Flags::IO_STOPPING));
1252 client.write("4");
1253 sleep(Millis(2000)).await;
1254
1255 assert!(state.flags().contains(Flags::IO_STOPPING));
1257 assert!(client.is_closed());
1258 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1259 }
1260
1261 #[ntex::test]
1262 async fn test_idle_timeout() {
1263 let (client, server) = IoTest::create();
1264 client.remote_buffer_cap(1024);
1265
1266 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1267 let data2 = data.clone();
1268
1269 let io = Io::new(server);
1270 let ioref = io.get_ref();
1271
1272 let (disp, state) = Dispatcher::debug_cfg(
1273 io,
1274 BCodec(1),
1275 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1276 let ioref = ioref.clone();
1277 ntex::rt::spawn(async move {
1278 sleep(Millis(500)).await;
1279 ioref.notify_timeout();
1280 });
1281 let data = data2.clone();
1282 async move {
1283 match msg {
1284 DispatchItem::Item(bytes) => {
1285 data.lock().unwrap().borrow_mut().push(0);
1286 return Ok::<_, ()>(Some(bytes.freeze()));
1287 }
1288 DispatchItem::ReadTimeout => {
1289 data.lock().unwrap().borrow_mut().push(1);
1290 }
1291 _ => (),
1292 }
1293 Ok(None)
1294 }
1295 }),
1296 DispatcherConfig::default()
1297 .set_keepalive_timeout(Seconds::ZERO)
1298 .clone(),
1299 );
1300 spawn(async move {
1301 let _ = disp.await;
1302 });
1303
1304 client.write("1");
1305 let buf = client.read().await.unwrap();
1306 assert_eq!(buf, Bytes::from_static(b"1"));
1307
1308 sleep(Millis(1000)).await;
1309 assert!(state.flags().contains(Flags::IO_STOPPING));
1310 assert!(client.is_closed());
1311 }
1312
1313 #[ntex::test]
1314 async fn test_unhandled_data() {
1315 let handled = Arc::new(AtomicBool::new(false));
1316 let handled2 = handled.clone();
1317
1318 let (client, server) = IoTest::create();
1319 client.remote_buffer_cap(1024);
1320 client.write("GET /test HTTP/1\r\n\r\n");
1321
1322 let (disp, _) = Dispatcher::debug(
1323 Io::new(server),
1324 BytesCodec,
1325 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1326 handled2.store(true, Relaxed);
1327 async move {
1328 sleep(Millis(50)).await;
1329 if let DispatchItem::Item(msg) = msg {
1330 Ok::<_, ()>(Some(msg.freeze()))
1331 } else if let DispatchItem::Disconnect(_) = msg {
1332 Ok::<_, ()>(None)
1333 } else {
1334 panic!()
1335 }
1336 }
1337 }),
1338 );
1339 client.close().await;
1340 spawn(async move {
1341 let _ = disp.await;
1342 });
1343 sleep(Millis(50)).await;
1344
1345 assert!(handled.load(Relaxed));
1346 }
1347}