1#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14pin_project_lite::pin_project! {
15 pub struct Dispatcher<S, U>
18 where
19 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20 U: Encoder,
21 U: Decoder,
22 U: 'static,
23 {
24 inner: DispatcherInner<S, U>,
25 }
26}
27
28bitflags::bitflags! {
29 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
30 struct Flags: u8 {
31 const READY_ERR = 0b0000001;
32 const IO_ERR = 0b0000010;
33 const KA_ENABLED = 0b0000100;
34 const KA_TIMEOUT = 0b0001000;
35 const READ_TIMEOUT = 0b0010000;
36 const IDLE = 0b0100000;
37 }
38}
39
40struct DispatcherInner<S, U>
41where
42 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
43 U: Encoder + Decoder + 'static,
44{
45 st: DispatcherState,
46 error: Option<S::Error>,
47 shared: Rc<DispatcherShared<S, U>>,
48 response: Option<PipelineCall<S, DispatchItem<U>>>,
49 read_remains: u32,
50 read_remains_prev: u32,
51 read_max_timeout: Seconds,
52}
53
54pub(crate) struct DispatcherShared<S, U>
55where
56 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
57 U: Encoder + Decoder,
58{
59 io: IoBoxed,
60 codec: U,
61 service: PipelineBinding<S, DispatchItem<U>>,
62 flags: Cell<Flags>,
63 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
64 inflight: Cell<u32>,
65}
66
67#[derive(Copy, Clone, Debug)]
68enum DispatcherState {
69 Processing,
70 Backpressure,
71 Stop,
72 Shutdown,
73}
74
75#[derive(Debug)]
76enum DispatcherError<S, U> {
77 Encoder(U),
78 Service(S),
79}
80
81enum PollService<U: Encoder + Decoder> {
82 Item(DispatchItem<U>),
83 Continue,
84 Ready,
85}
86
87impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
88 fn from(err: Either<S, U>) -> Self {
89 match err {
90 Either::Left(err) => DispatcherError::Service(err),
91 Either::Right(err) => DispatcherError::Encoder(err),
92 }
93 }
94}
95
96impl<S, U> Dispatcher<S, U>
97where
98 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
99 U: Decoder + Encoder + 'static,
100{
101 pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
103 where
104 IoBoxed: From<Io>,
105 F: IntoService<S, DispatchItem<U>>,
106 {
107 let io = IoBoxed::from(io);
108 let flags = if io.cfg().keepalive_timeout().is_zero() {
109 Flags::empty()
110 } else {
111 Flags::KA_ENABLED
112 };
113
114 let shared = Rc::new(DispatcherShared {
115 io,
116 codec,
117 flags: Cell::new(flags),
118 error: Cell::new(None),
119 inflight: Cell::new(0),
120 service: Pipeline::new(service.into_service()).bind(),
121 });
122
123 Dispatcher {
124 inner: DispatcherInner {
125 shared,
126 response: None,
127 error: None,
128 read_remains: 0,
129 read_remains_prev: 0,
130 read_max_timeout: Seconds::ZERO,
131 st: DispatcherState::Processing,
132 },
133 }
134 }
135}
136
137impl<S, U> DispatcherShared<S, U>
138where
139 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
140 U: Encoder + Decoder,
141{
142 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
143 match item {
144 Ok(Some(val)) => {
145 if let Err(err) = io.encode(val, &self.codec) {
146 self.error.set(Some(DispatcherError::Encoder(err)))
147 }
148 }
149 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
150 Ok(None) => (),
151 }
152 let inflight = self.inflight.get() - 1;
153 self.inflight.set(inflight);
154 if inflight == 0 {
155 self.insert_flags(Flags::IDLE);
156 }
157 if wake {
158 io.wake();
159 }
160 }
161
162 fn contains(&self, f: Flags) -> bool {
163 self.flags.get().intersects(f)
164 }
165
166 fn insert_flags(&self, f: Flags) {
167 let mut flags = self.flags.get();
168 flags.insert(f);
169 self.flags.set(flags);
170 }
171
172 fn remove_flags(&self, f: Flags) -> bool {
173 let mut flags = self.flags.get();
174 if flags.intersects(f) {
175 flags.remove(f);
176 self.flags.set(flags);
177 true
178 } else {
179 false
180 }
181 }
182}
183
184impl<S, U> Future for Dispatcher<S, U>
185where
186 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
187 U: Decoder + Encoder + 'static,
188{
189 type Output = Result<(), S::Error>;
190
191 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192 let mut this = self.as_mut().project();
193 let slf = &mut this.inner;
194
195 if let Some(fut) = slf.response.as_mut() {
197 if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
198 slf.shared.handle_result(item, &slf.shared.io, false);
199 slf.response = None;
200 }
201 }
202
203 loop {
204 match slf.st {
205 DispatcherState::Processing => {
206 let item = match ready!(slf.poll_service(cx)) {
207 PollService::Ready => {
208 match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
210 Ok(decoded) => {
211 slf.update_timer(&decoded);
212 if let Some(el) = decoded.item {
213 DispatchItem::Item(el)
214 } else {
215 return Poll::Pending;
216 }
217 }
218 Err(RecvError::KeepAlive) => {
219 if let Err(err) = slf.handle_timeout() {
220 slf.st = DispatcherState::Stop;
221 err
222 } else {
223 continue;
224 }
225 }
226 Err(RecvError::Stop) => {
227 log::trace!(
228 "{}: Dispatcher is instructed to stop",
229 slf.shared.io.tag()
230 );
231 slf.st = DispatcherState::Stop;
232 continue;
233 }
234 Err(RecvError::WriteBackpressure) => {
235 slf.st = DispatcherState::Backpressure;
237 DispatchItem::WBackPressureEnabled
238 }
239 Err(RecvError::Decoder(err)) => {
240 log::trace!(
241 "{}: Decoder error, stopping dispatcher: {:?}",
242 slf.shared.io.tag(),
243 err
244 );
245 slf.st = DispatcherState::Stop;
246 DispatchItem::DecoderError(err)
247 }
248 Err(RecvError::PeerGone(err)) => {
249 log::trace!(
250 "{}: Peer is gone, stopping dispatcher: {:?}",
251 slf.shared.io.tag(),
252 err
253 );
254 slf.st = DispatcherState::Stop;
255 DispatchItem::Disconnect(err)
256 }
257 }
258 }
259 PollService::Item(item) => item,
260 PollService::Continue => continue,
261 };
262
263 slf.call_service(cx, item);
264 }
265 DispatcherState::Backpressure => {
267 match ready!(slf.poll_service(cx)) {
268 PollService::Ready => (),
269 PollService::Item(item) => slf.call_service(cx, item),
270 PollService::Continue => continue,
271 };
272
273 let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
274 {
275 slf.st = DispatcherState::Stop;
276 DispatchItem::Disconnect(Some(err))
277 } else {
278 slf.st = DispatcherState::Processing;
279 DispatchItem::WBackPressureDisabled
280 };
281 slf.call_service(cx, item);
282 }
283 DispatcherState::Stop => {
285 slf.shared.io.stop_timer();
286
287 if !slf.shared.contains(Flags::READY_ERR) {
289 if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
290 if res.is_err() {
291 slf.shared.insert_flags(Flags::READY_ERR);
292 }
293 }
294 }
295
296 if slf.shared.inflight.get() == 0 {
297 if slf.shared.io.poll_shutdown(cx).is_ready() {
298 slf.st = DispatcherState::Shutdown;
299 continue;
300 }
301 } else if !slf.shared.contains(Flags::IO_ERR) {
302 match ready!(slf.shared.io.poll_status_update(cx)) {
303 IoStatusUpdate::PeerGone(_)
304 | IoStatusUpdate::Stop
305 | IoStatusUpdate::KeepAlive => {
306 slf.shared.insert_flags(Flags::IO_ERR);
307 continue;
308 }
309 IoStatusUpdate::WriteBackpressure => {
310 if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
311 slf.shared.insert_flags(Flags::IO_ERR);
312 }
313 continue;
314 }
315 }
316 } else {
317 slf.shared.io.poll_dispatch(cx);
318 }
319 return Poll::Pending;
320 }
321 DispatcherState::Shutdown => {
323 return if slf.shared.service.poll_shutdown(cx).is_ready() {
324 log::trace!(
325 "{}: Service shutdown is completed, stop",
326 slf.shared.io.tag()
327 );
328
329 Poll::Ready(if let Some(err) = slf.error.take() {
330 Err(err)
331 } else {
332 Ok(())
333 })
334 } else {
335 Poll::Pending
336 };
337 }
338 }
339 }
340 }
341}
342
343impl<S, U> DispatcherInner<S, U>
344where
345 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
346 U: Decoder + Encoder + 'static,
347{
348 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
349 let mut fut = self.shared.service.call_nowait(item);
350 let inflight = self.shared.inflight.get() + 1;
351 self.shared.inflight.set(inflight);
352 if inflight == 1 {
353 self.shared.remove_flags(Flags::IDLE);
354 }
355
356 if self.response.is_none() {
358 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
359 self.shared.handle_result(result, &self.shared.io, false);
360 } else {
361 self.response = Some(fut);
362 }
363 } else {
364 let shared = self.shared.clone();
365 spawn(async move {
366 let result = fut.await;
367 shared.handle_result(result, &shared.io, true);
368 });
369 }
370 }
371
372 fn check_error(&mut self) -> PollService<U> {
373 if let Some(err) = self.shared.error.take() {
375 log::trace!(
376 "{}: Error occured, stopping dispatcher",
377 self.shared.io.tag()
378 );
379 self.st = DispatcherState::Stop;
380
381 match err {
382 DispatcherError::Encoder(err) => {
383 PollService::Item(DispatchItem::EncoderError(err))
384 }
385 DispatcherError::Service(err) => {
386 self.error = Some(err);
387 PollService::Continue
388 }
389 }
390 } else {
391 PollService::Ready
392 }
393 }
394
395 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
396 match self.shared.service.poll_ready(cx) {
398 Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
399 Poll::Pending => {
401 log::trace!(
402 "{}: Service is not ready, register dispatcher",
403 self.shared.io.tag()
404 );
405
406 self.shared
408 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
409 self.shared.io.stop_timer();
410
411 match ready!(self.shared.io.poll_read_pause(cx)) {
412 IoStatusUpdate::KeepAlive => {
413 log::trace!(
414 "{}: Keep-alive error, stopping dispatcher during pause",
415 self.shared.io.tag()
416 );
417 self.st = DispatcherState::Stop;
418 Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
419 }
420 IoStatusUpdate::Stop => {
421 log::trace!(
422 "{}: Dispatcher is instructed to stop during pause",
423 self.shared.io.tag()
424 );
425 self.st = DispatcherState::Stop;
426 Poll::Ready(PollService::Continue)
427 }
428 IoStatusUpdate::PeerGone(err) => {
429 log::trace!(
430 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
431 self.shared.io.tag(),
432 err
433 );
434 self.st = DispatcherState::Stop;
435 Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
436 }
437 IoStatusUpdate::WriteBackpressure => {
438 self.st = DispatcherState::Backpressure;
439 Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
440 }
441 }
442 }
443 Poll::Ready(Err(err)) => {
445 log::trace!(
446 "{}: Service readiness check failed, stopping",
447 self.shared.io.tag()
448 );
449 self.st = DispatcherState::Stop;
450 self.error = Some(err);
451 self.shared.insert_flags(Flags::READY_ERR);
452 Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
453 }
454 }
455 }
456
457 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
458 if decoded.item.is_some() {
460 self.read_remains = 0;
461 self.shared
462 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
463 } else if self.shared.contains(Flags::READ_TIMEOUT) {
464 self.read_remains = decoded.remains as u32;
466 } else if self.read_remains == 0 && decoded.remains == 0 {
467 if self.shared.contains(Flags::KA_ENABLED)
469 && !self.shared.contains(Flags::KA_TIMEOUT)
470 {
471 log::trace!(
472 "{}: Start keep-alive timer {:?}",
473 self.shared.io.tag(),
474 self.shared.io.cfg().keepalive_timeout()
475 );
476 self.shared.insert_flags(Flags::KA_TIMEOUT);
477 self.shared
478 .io
479 .start_timer(self.shared.io.cfg().keepalive_timeout());
480 }
481 } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
482 self.shared.insert_flags(Flags::READ_TIMEOUT);
485
486 self.read_remains = decoded.remains as u32;
487 self.read_remains_prev = 0;
488 self.read_max_timeout = params.max_timeout;
489 self.shared.io.start_timer(params.timeout);
490 }
491 }
492
493 fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
494 if self.shared.contains(Flags::READ_TIMEOUT) {
496 if let Some(params) = self.shared.io.cfg().frame_read_rate() {
497 let total = self.read_remains - self.read_remains_prev;
498
499 if total > params.rate {
501 self.read_remains_prev = self.read_remains;
502 self.read_remains = 0;
503
504 if !params.max_timeout.is_zero() {
505 self.read_max_timeout = Seconds(
506 self.read_max_timeout.0.saturating_sub(params.timeout.0),
507 );
508 }
509
510 if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
511 log::trace!(
512 "{}: Frame read rate {:?}, extend timer",
513 self.shared.io.tag(),
514 total
515 );
516 self.shared.io.start_timer(params.timeout);
517 return Ok(());
518 }
519 log::trace!(
520 "{}: Max payload timeout has been reached",
521 self.shared.io.tag()
522 );
523 }
524 Err(DispatchItem::ReadTimeout)
525 } else {
526 Ok(())
527 }
528 } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
529 log::trace!(
530 "{}: Keep-alive error, stopping dispatcher",
531 self.shared.io.tag()
532 );
533 Err(DispatchItem::KeepAliveTimeout)
534 } else {
535 Ok(())
536 }
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
543 use std::{cell::RefCell, io};
544
545 use ntex_bytes::{Bytes, BytesMut};
546 use ntex_codec::BytesCodec;
547 use ntex_service::{ServiceCtx, cfg::SharedCfg};
548 use ntex_util::{time::Millis, time::sleep};
549 use rand::Rng;
550
551 use super::*;
552 use crate::{Flags, Io, IoConfig, IoRef, testing::IoTest};
553
554 pub(crate) struct State(IoRef);
555
556 impl State {
557 fn flags(&self) -> Flags {
558 self.0.flags()
559 }
560
561 fn io(&self) -> &IoRef {
562 &self.0
563 }
564
565 fn close(&self) {
566 self.0.0.insert_flags(Flags::DSP_STOP);
567 self.0.0.dispatch_task.wake();
568 }
569 }
570
571 #[derive(Copy, Clone)]
572 struct BCodec(usize);
573
574 impl Encoder for BCodec {
575 type Item = Bytes;
576 type Error = io::Error;
577
578 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
579 dst.extend_from_slice(&item[..]);
580 Ok(())
581 }
582 }
583
584 impl Decoder for BCodec {
585 type Item = BytesMut;
586 type Error = io::Error;
587
588 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
589 if src.len() < self.0 {
590 Ok(None)
591 } else {
592 Ok(Some(src.split_to(self.0)))
593 }
594 }
595 }
596
597 impl<S, U> Dispatcher<S, U>
598 where
599 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
600 U: Decoder + Encoder + 'static,
601 {
602 pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
604 let flags = if io.cfg().keepalive_timeout().is_zero() {
605 super::Flags::empty()
606 } else {
607 super::Flags::KA_ENABLED
608 };
609
610 let inner = State(io.get_ref());
611 io.start_timer(Seconds::ONE);
612
613 let shared = Rc::new(DispatcherShared {
614 codec,
615 io: io.into(),
616 flags: Cell::new(flags),
617 error: Cell::new(None),
618 inflight: Cell::new(0),
619 service: Pipeline::new(service).bind(),
620 });
621
622 (
623 Dispatcher {
624 inner: DispatcherInner {
625 shared,
626 error: None,
627 st: DispatcherState::Processing,
628 response: None,
629 read_remains: 0,
630 read_remains_prev: 0,
631 read_max_timeout: Seconds::ZERO,
632 },
633 },
634 inner,
635 )
636 }
637 }
638
639 #[ntex::test]
640 async fn test_basic() {
641 let (client, server) = IoTest::create();
642 client.remote_buffer_cap(1024);
643 client.write("GET /test HTTP/1\r\n\r\n");
644
645 let (disp, _) = Dispatcher::debug(
646 Io::from(server),
647 BytesCodec,
648 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
649 sleep(Millis(50)).await;
650 if let DispatchItem::Item(msg) = msg {
651 Ok::<_, ()>(Some(msg.freeze()))
652 } else {
653 panic!()
654 }
655 }),
656 );
657 spawn(async move {
658 let _ = disp.await;
659 });
660
661 sleep(Millis(25)).await;
662 let buf = client.read().await.unwrap();
663 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
664
665 client.write("GET /test HTTP/1\r\n\r\n");
666 let buf = client.read().await.unwrap();
667 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
668
669 client.close().await;
670 assert!(client.is_server_dropped());
671
672 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
673 }
674
675 #[ntex::test]
676 async fn test_sink() {
677 let (client, server) = IoTest::create();
678 client.remote_buffer_cap(1024);
679 client.write("GET /test HTTP/1\r\n\r\n");
680
681 let (disp, st) = Dispatcher::debug(
682 Io::from(server),
683 BytesCodec,
684 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
685 if let DispatchItem::Item(msg) = msg {
686 Ok::<_, ()>(Some(msg.freeze()))
687 } else {
688 panic!()
689 }
690 }),
691 );
692 spawn(async move {
693 let _ = disp.await;
694 });
695
696 let buf = client.read().await.unwrap();
697 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
698
699 assert!(
700 st.io()
701 .encode(Bytes::from_static(b"test"), &BytesCodec)
702 .is_ok()
703 );
704 let buf = client.read().await.unwrap();
705 assert_eq!(buf, Bytes::from_static(b"test"));
706
707 st.close();
708 sleep(Millis(1500)).await;
709 assert!(client.is_server_dropped());
710 }
711
712 #[ntex::test]
713 async fn test_err_in_service() {
714 let (client, server) = IoTest::create();
715 client.remote_buffer_cap(0);
716 client.write("GET /test HTTP/1\r\n\r\n");
717
718 let (disp, state) = Dispatcher::debug(
719 Io::from(server),
720 BytesCodec,
721 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
722 Err::<Option<Bytes>, _>(())
723 }),
724 );
725 state
726 .io()
727 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
728 .unwrap();
729 spawn(async move {
730 let _ = disp.await;
731 });
732
733 client.remote_buffer_cap(1024);
735 let buf = client.read().await.unwrap();
736 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
737
738 sleep(Millis(250)).await;
740 assert!(client.is_closed());
741
742 client.close().await;
744
745 assert!(client.is_server_dropped());
747 }
748
749 #[ntex::test]
750 async fn test_err_in_service_ready() {
751 let (client, server) = IoTest::create();
752 client.remote_buffer_cap(0);
753 client.write("GET /test HTTP/1\r\n\r\n");
754
755 let counter = Rc::new(Cell::new(0));
756
757 struct Srv(Rc<Cell<usize>>);
758
759 impl Service<DispatchItem<BytesCodec>> for Srv {
760 type Response = Option<Response<BytesCodec>>;
761 type Error = &'static str;
762
763 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
764 self.0.set(self.0.get() + 1);
765 Err("test")
766 }
767
768 async fn call(
769 &self,
770 _: DispatchItem<BytesCodec>,
771 _: ServiceCtx<'_, Self>,
772 ) -> Result<Self::Response, Self::Error> {
773 Ok(None)
774 }
775 }
776
777 let (disp, state) =
778 Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
779 spawn(async move {
780 let res = disp.await;
781 assert_eq!(res, Err("test"));
782 });
783
784 state
785 .io()
786 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
787 .unwrap();
788
789 client.remote_buffer_cap(1024);
791 let buf = client.read().await.unwrap();
792 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
793
794 sleep(Millis(250)).await;
796 assert!(client.is_closed());
797
798 client.close().await;
800 assert!(client.is_server_dropped());
801
802 assert_eq!(counter.get(), 1);
804 }
805
806 #[ntex::test]
807 async fn test_write_backpressure() {
808 let (client, server) = IoTest::create();
809 client.remote_buffer_cap(0);
811 client.write("GET /test HTTP/1\r\n\r\n");
812
813 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
814 let data2 = data.clone();
815
816 let io = Io::new(
817 server,
818 SharedCfg::new("TEST").add(
819 IoConfig::new()
820 .set_read_buf(8 * 1024, 1024, 16)
821 .set_write_buf(16 * 1024, 1024, 16),
822 ),
823 );
824
825 let (disp, state) = Dispatcher::debug(
826 io,
827 BytesCodec,
828 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
829 let data = data2.clone();
830 async move {
831 match msg {
832 DispatchItem::Item(_) => {
833 data.lock().unwrap().borrow_mut().push(0);
834 let bytes = rand::rng()
835 .sample_iter(&rand::distr::Alphanumeric)
836 .take(65_536)
837 .map(char::from)
838 .collect::<String>();
839 return Ok::<_, ()>(Some(Bytes::from(bytes)));
840 }
841 DispatchItem::WBackPressureEnabled => {
842 data.lock().unwrap().borrow_mut().push(1);
843 }
844 DispatchItem::WBackPressureDisabled => {
845 data.lock().unwrap().borrow_mut().push(2);
846 }
847 _ => (),
848 }
849 Ok(None)
850 }
851 }),
852 );
853
854 spawn(async move {
855 let _ = disp.await;
856 });
857
858 let buf = client.read_any();
859 assert_eq!(buf, Bytes::from_static(b""));
860 client.write("GET /test HTTP/1\r\n\r\n");
861 sleep(Millis(25)).await;
862
863 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
865
866 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
868
869 client.remote_buffer_cap(10240);
870 sleep(Millis(50)).await;
871 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
872
873 client.remote_buffer_cap(48056);
874 sleep(Millis(50)).await;
875 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
876
877 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
879 }
880
881 #[ntex::test]
882 async fn test_disconnect_during_read_backpressure() {
883 let (client, server) = IoTest::create();
884 client.remote_buffer_cap(0);
885
886 let (disp, state) = Dispatcher::debug(
887 Io::new(
888 server,
889 SharedCfg::new("TEST").add(
890 IoConfig::new()
891 .set_keepalive_timeout(Seconds::ZERO)
892 .set_read_buf(1024, 512, 16),
893 ),
894 ),
895 BytesCodec,
896 ntex_util::services::inflight::InFlightService::new(
897 1,
898 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
899 if let DispatchItem::Item(_) = msg {
900 sleep(Millis(500)).await;
901 Ok::<_, ()>(None)
902 } else {
903 Ok(None)
904 }
905 }),
906 ),
907 );
908
909 let (tx, rx) = ntex::channel::oneshot::channel();
910 ntex::rt::spawn(async move {
911 let _ = disp.await;
912 let _ = tx.send(());
913 });
914
915 let bytes = rand::rng()
916 .sample_iter(&rand::distr::Alphanumeric)
917 .take(1024)
918 .map(char::from)
919 .collect::<String>();
920 client.write(bytes.clone());
921 sleep(Millis(25)).await;
922 client.write(bytes);
923 sleep(Millis(25)).await;
924
925 state.close();
927 let _ = rx.recv().await;
928 }
929
930 #[ntex::test]
931 async fn test_keepalive() {
932 let (client, server) = IoTest::create();
933 client.remote_buffer_cap(1024);
934 client.write("GET /test HTTP/1\r\n\r\n");
935
936 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
937 let data2 = data.clone();
938
939 let cfg = SharedCfg::new("DBG").add(
940 IoConfig::new()
941 .set_disconnect_timeout(Seconds(1))
942 .set_keepalive_timeout(Seconds(1)),
943 );
944
945 let (disp, state) = Dispatcher::debug(
946 Io::new(server, cfg),
947 BytesCodec,
948 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
949 let data = data2.clone();
950 async move {
951 match msg {
952 DispatchItem::Item(bytes) => {
953 data.lock().unwrap().borrow_mut().push(0);
954 return Ok::<_, ()>(Some(bytes.freeze()));
955 }
956 DispatchItem::KeepAliveTimeout => {
957 data.lock().unwrap().borrow_mut().push(1);
958 }
959 _ => (),
960 }
961 Ok(None)
962 }
963 }),
964 );
965 spawn(async move {
966 let _ = disp.await;
967 });
968
969 let buf = client.read().await.unwrap();
970 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
971 sleep(Millis(2000)).await;
972
973 let flags = state.flags();
975 assert!(flags.contains(Flags::IO_STOPPING));
976 assert!(client.is_closed());
977 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
978 }
979
980 #[ntex::test]
981 async fn test_keepalive2() {
982 let (client, server) = IoTest::create();
983 client.remote_buffer_cap(1024);
984
985 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
986 let data2 = data.clone();
987
988 let cfg = SharedCfg::new("DBG").add(
989 IoConfig::new()
990 .set_keepalive_timeout(Seconds(1))
991 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
992 );
993
994 let (disp, state) = Dispatcher::debug(
995 Io::new(server, cfg),
996 BCodec(8),
997 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
998 let data = data2.clone();
999 async move {
1000 match msg {
1001 DispatchItem::Item(bytes) => {
1002 data.lock().unwrap().borrow_mut().push(0);
1003 return Ok::<_, ()>(Some(bytes.freeze()));
1004 }
1005 DispatchItem::KeepAliveTimeout => {
1006 data.lock().unwrap().borrow_mut().push(1);
1007 }
1008 _ => (),
1009 }
1010 Ok(None)
1011 }
1012 }),
1013 );
1014 spawn(async move {
1015 let _ = disp.await;
1016 });
1017
1018 client.write("12345678");
1019 let buf = client.read().await.unwrap();
1020 assert_eq!(buf, Bytes::from_static(b"12345678"));
1021 sleep(Millis(2000)).await;
1022
1023 let flags = state.flags();
1025 assert!(flags.contains(Flags::IO_STOPPING));
1026 assert!(client.is_closed());
1027 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1028 }
1029
1030 #[ntex::test]
1032 async fn test_keepalive3() {
1033 let (client, server) = IoTest::create();
1034 client.remote_buffer_cap(1024);
1035
1036 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1037 let data2 = data.clone();
1038
1039 let cfg = SharedCfg::new("DBG").add(
1040 IoConfig::new()
1041 .set_keepalive_timeout(Seconds(2))
1042 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1043 );
1044
1045 let (disp, _) = Dispatcher::debug(
1046 Io::new(server, cfg),
1047 BCodec(1),
1048 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1049 let data = data2.clone();
1050 async move {
1051 match msg {
1052 DispatchItem::Item(bytes) => {
1053 data.lock().unwrap().borrow_mut().push(0);
1054 return Ok::<_, ()>(Some(bytes.freeze()));
1055 }
1056 DispatchItem::KeepAliveTimeout => {
1057 data.lock().unwrap().borrow_mut().push(1);
1058 }
1059 _ => (),
1060 }
1061 Ok(None)
1062 }
1063 }),
1064 );
1065 spawn(async move {
1066 let _ = disp.await;
1067 });
1068
1069 client.write("1");
1070 let buf = client.read().await.unwrap();
1071 assert_eq!(buf, Bytes::from_static(b"1"));
1072 sleep(Millis(750)).await;
1073
1074 client.write("2");
1075 let buf = client.read().await.unwrap();
1076 assert_eq!(buf, Bytes::from_static(b"2"));
1077
1078 sleep(Millis(750)).await;
1079 client.write("3");
1080 let buf = client.read().await.unwrap();
1081 assert_eq!(buf, Bytes::from_static(b"3"));
1082
1083 sleep(Millis(750)).await;
1084 assert!(!client.is_closed());
1085 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1086 }
1087
1088 #[ntex::test]
1089 async fn test_read_timeout() {
1090 let (client, server) = IoTest::create();
1091 client.remote_buffer_cap(1024);
1092
1093 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1094 let data2 = data.clone();
1095
1096 let io = Io::new(
1097 server,
1098 SharedCfg::new("TEST").add(
1099 IoConfig::new()
1100 .set_keepalive_timeout(Seconds::ZERO)
1101 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1102 ),
1103 );
1104
1105 let (disp, state) = Dispatcher::debug(
1106 io,
1107 BCodec(8),
1108 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1109 let data = data2.clone();
1110 async move {
1111 match msg {
1112 DispatchItem::Item(bytes) => {
1113 data.lock().unwrap().borrow_mut().push(0);
1114 return Ok::<_, ()>(Some(bytes.freeze()));
1115 }
1116 DispatchItem::ReadTimeout => {
1117 data.lock().unwrap().borrow_mut().push(1);
1118 }
1119 _ => (),
1120 }
1121 Ok(None)
1122 }
1123 }),
1124 );
1125 spawn(async move {
1126 let _ = disp.await;
1127 });
1128
1129 client.write("12345678");
1130 let buf = client.read().await.unwrap();
1131 assert_eq!(buf, Bytes::from_static(b"12345678"));
1132
1133 client.write("1");
1134 sleep(Millis(1000)).await;
1135 assert!(!state.flags().contains(Flags::IO_STOPPING));
1136 client.write("23");
1137 sleep(Millis(1000)).await;
1138 assert!(!state.flags().contains(Flags::IO_STOPPING));
1139 client.write("4");
1140 sleep(Millis(2000)).await;
1141
1142 assert!(state.flags().contains(Flags::IO_STOPPING));
1144 assert!(client.is_closed());
1145 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1146 }
1147
1148 #[ntex::test]
1149 async fn test_idle_timeout() {
1150 let (client, server) = IoTest::create();
1151 client.remote_buffer_cap(1024);
1152
1153 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1154 let data2 = data.clone();
1155
1156 let io = Io::new(
1157 server,
1158 SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1159 );
1160 let ioref = io.get_ref();
1161
1162 let (disp, state) = Dispatcher::debug(
1163 io,
1164 BCodec(1),
1165 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1166 let ioref = ioref.clone();
1167 ntex::rt::spawn(async move {
1168 sleep(Millis(500)).await;
1169 ioref.notify_timeout();
1170 });
1171 let data = data2.clone();
1172 async move {
1173 match msg {
1174 DispatchItem::Item(bytes) => {
1175 data.lock().unwrap().borrow_mut().push(0);
1176 return Ok::<_, ()>(Some(bytes.freeze()));
1177 }
1178 DispatchItem::ReadTimeout => {
1179 data.lock().unwrap().borrow_mut().push(1);
1180 }
1181 _ => (),
1182 }
1183 Ok(None)
1184 }
1185 }),
1186 );
1187 spawn(async move {
1188 let _ = disp.await;
1189 });
1190
1191 client.write("1");
1192 let buf = client.read().await.unwrap();
1193 assert_eq!(buf, Bytes::from_static(b"1"));
1194
1195 sleep(Millis(1000)).await;
1196 assert!(state.flags().contains(Flags::IO_STOPPING));
1197 assert!(client.is_closed());
1198 }
1199
1200 #[ntex::test]
1201 async fn test_unhandled_data() {
1202 let handled = Arc::new(AtomicBool::new(false));
1203 let handled2 = handled.clone();
1204
1205 let (client, server) = IoTest::create();
1206 client.remote_buffer_cap(1024);
1207 client.write("GET /test HTTP/1\r\n\r\n");
1208
1209 let (disp, _) = Dispatcher::debug(
1210 Io::from(server),
1211 BytesCodec,
1212 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1213 handled2.store(true, Relaxed);
1214 async move {
1215 sleep(Millis(50)).await;
1216 if let DispatchItem::Item(msg) = msg {
1217 Ok::<_, ()>(Some(msg.freeze()))
1218 } else if let DispatchItem::Disconnect(_) = msg {
1219 Ok::<_, ()>(None)
1220 } else {
1221 panic!()
1222 }
1223 }
1224 }),
1225 );
1226 client.close().await;
1227 spawn(async move {
1228 let _ = disp.await;
1229 });
1230 sleep(Millis(50)).await;
1231
1232 assert!(handled.load(Relaxed));
1233 }
1234}