1#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14pin_project_lite::pin_project! {
15 pub struct Dispatcher<S, U>
18 where
19 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20 U: Encoder,
21 U: Decoder,
22 U: 'static,
23 {
24 inner: DispatcherInner<S, U>,
25 }
26}
27
28bitflags::bitflags! {
29 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
30 struct Flags: u8 {
31 const READY_ERR = 0b0000001;
32 const IO_ERR = 0b0000010;
33 const KA_ENABLED = 0b0000100;
34 const KA_TIMEOUT = 0b0001000;
35 const READ_TIMEOUT = 0b0010000;
36 const IDLE = 0b0100000;
37 }
38}
39
40struct DispatcherInner<S, U>
41where
42 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
43 U: Encoder + Decoder + 'static,
44{
45 st: DispatcherState,
46 error: Option<S::Error>,
47 shared: Rc<DispatcherShared<S, U>>,
48 response: Option<PipelineCall<S, DispatchItem<U>>>,
49 read_remains: u32,
50 read_remains_prev: u32,
51 read_max_timeout: Seconds,
52}
53
54pub(crate) struct DispatcherShared<S, U>
55where
56 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
57 U: Encoder + Decoder,
58{
59 io: IoBoxed,
60 codec: U,
61 service: PipelineBinding<S, DispatchItem<U>>,
62 flags: Cell<Flags>,
63 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
64 inflight: Cell<u32>,
65}
66
67#[derive(Copy, Clone, Debug)]
68enum DispatcherState {
69 Processing,
70 Backpressure,
71 Stop,
72 Shutdown,
73}
74
75#[derive(Debug)]
76enum DispatcherError<S, U> {
77 Encoder(U),
78 Service(S),
79}
80
81enum PollService<U: Encoder + Decoder> {
82 Item(DispatchItem<U>),
83 ItemWait(DispatchItem<U>),
84 Continue,
85 Ready,
86}
87
88impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
89 fn from(err: Either<S, U>) -> Self {
90 match err {
91 Either::Left(err) => DispatcherError::Service(err),
92 Either::Right(err) => DispatcherError::Encoder(err),
93 }
94 }
95}
96
97impl<S, U> Dispatcher<S, U>
98where
99 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
100 U: Decoder + Encoder + 'static,
101{
102 pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
104 where
105 IoBoxed: From<Io>,
106 F: IntoService<S, DispatchItem<U>>,
107 {
108 let io = IoBoxed::from(io);
109 let flags = if io.cfg().keepalive_timeout().is_zero() {
110 Flags::empty()
111 } else {
112 Flags::KA_ENABLED
113 };
114
115 let shared = Rc::new(DispatcherShared {
116 io,
117 codec,
118 flags: Cell::new(flags),
119 error: Cell::new(None),
120 inflight: Cell::new(0),
121 service: Pipeline::new(service.into_service()).bind(),
122 });
123
124 Dispatcher {
125 inner: DispatcherInner {
126 shared,
127 response: None,
128 error: None,
129 read_remains: 0,
130 read_remains_prev: 0,
131 read_max_timeout: Seconds::ZERO,
132 st: DispatcherState::Processing,
133 },
134 }
135 }
136}
137
138impl<S, U> DispatcherShared<S, U>
139where
140 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
141 U: Encoder + Decoder,
142{
143 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
144 match item {
145 Ok(Some(val)) => {
146 if let Err(err) = io.encode(val, &self.codec) {
147 self.error.set(Some(DispatcherError::Encoder(err)))
148 }
149 }
150 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
151 Ok(None) => (),
152 }
153 let inflight = self.inflight.get() - 1;
154 self.inflight.set(inflight);
155 if inflight == 0 {
156 self.insert_flags(Flags::IDLE);
157 }
158 if wake {
159 io.wake();
160 }
161 }
162
163 fn contains(&self, f: Flags) -> bool {
164 self.flags.get().intersects(f)
165 }
166
167 fn insert_flags(&self, f: Flags) {
168 let mut flags = self.flags.get();
169 flags.insert(f);
170 self.flags.set(flags);
171 }
172
173 fn remove_flags(&self, f: Flags) -> bool {
174 let mut flags = self.flags.get();
175 if flags.intersects(f) {
176 flags.remove(f);
177 self.flags.set(flags);
178 true
179 } else {
180 false
181 }
182 }
183}
184
185impl<S, U> Future for Dispatcher<S, U>
186where
187 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
188 U: Decoder + Encoder + 'static,
189{
190 type Output = Result<(), S::Error>;
191
192 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193 let mut this = self.as_mut().project();
194 let slf = &mut this.inner;
195
196 if let Some(fut) = slf.response.as_mut() {
198 if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
199 slf.shared.handle_result(item, &slf.shared.io, false);
200 slf.response = None;
201 }
202 }
203
204 loop {
205 match slf.st {
206 DispatcherState::Processing => {
207 let (item, nowait) = match ready!(slf.poll_service(cx)) {
208 PollService::Ready => {
209 match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
211 Ok(decoded) => {
212 slf.update_timer(&decoded);
213 if let Some(el) = decoded.item {
214 (DispatchItem::Item(el), true)
215 } else {
216 return Poll::Pending;
217 }
218 }
219 Err(RecvError::KeepAlive) => {
220 if let Err(err) = slf.handle_timeout() {
221 slf.st = DispatcherState::Stop;
222 (err, true)
223 } else {
224 continue;
225 }
226 }
227 Err(RecvError::WriteBackpressure) => {
228 slf.st = DispatcherState::Backpressure;
230 (DispatchItem::WBackPressureEnabled, true)
231 }
232 Err(RecvError::Decoder(err)) => {
233 log::trace!(
234 "{}: Decoder error, stopping dispatcher: {:?}",
235 slf.shared.io.tag(),
236 err
237 );
238 slf.st = DispatcherState::Stop;
239 (DispatchItem::DecoderError(err), true)
240 }
241 Err(RecvError::PeerGone(err)) => {
242 log::trace!(
243 "{}: Peer is gone, stopping dispatcher: {:?}",
244 slf.shared.io.tag(),
245 err
246 );
247 slf.st = DispatcherState::Stop;
248 (DispatchItem::Disconnect(err), true)
249 }
250 }
251 }
252 PollService::Item(item) => (item, true),
253 PollService::ItemWait(item) => (item, false),
254 PollService::Continue => continue,
255 };
256
257 slf.call_service(cx, item, nowait);
258 }
259 DispatcherState::Backpressure => {
261 match ready!(slf.poll_service(cx)) {
262 PollService::Ready => (),
263 PollService::Item(item) => slf.call_service(cx, item, true),
264 PollService::ItemWait(item) => slf.call_service(cx, item, false),
265 PollService::Continue => continue,
266 };
267
268 let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
269 {
270 slf.st = DispatcherState::Stop;
271 DispatchItem::Disconnect(Some(err))
272 } else {
273 slf.st = DispatcherState::Processing;
274 DispatchItem::WBackPressureDisabled
275 };
276 slf.call_service(cx, item, false);
277 }
278 DispatcherState::Stop => {
280 slf.shared.io.stop_timer();
281
282 if !slf.shared.contains(Flags::READY_ERR) {
284 if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
285 if res.is_err() {
286 slf.shared.insert_flags(Flags::READY_ERR);
287 }
288 }
289 }
290
291 if slf.shared.inflight.get() == 0 {
292 if slf.shared.io.poll_shutdown(cx).is_ready() {
293 slf.st = DispatcherState::Shutdown;
294 continue;
295 }
296 } else if !slf.shared.contains(Flags::IO_ERR) {
297 match ready!(slf.shared.io.poll_status_update(cx)) {
298 IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
299 slf.shared.insert_flags(Flags::IO_ERR);
300 continue;
301 }
302 IoStatusUpdate::WriteBackpressure => {
303 if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
304 slf.shared.insert_flags(Flags::IO_ERR);
305 }
306 continue;
307 }
308 }
309 } else {
310 slf.shared.io.poll_dispatch(cx);
311 }
312 return Poll::Pending;
313 }
314 DispatcherState::Shutdown => {
316 return if slf.shared.service.poll_shutdown(cx).is_ready() {
317 log::trace!(
318 "{}: Service shutdown is completed, stop",
319 slf.shared.io.tag()
320 );
321
322 Poll::Ready(if let Some(err) = slf.error.take() {
323 Err(err)
324 } else {
325 Ok(())
326 })
327 } else {
328 Poll::Pending
329 };
330 }
331 }
332 }
333 }
334}
335
336impl<S, U> DispatcherInner<S, U>
337where
338 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
339 U: Decoder + Encoder + 'static,
340{
341 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>, nowait: bool) {
342 let mut fut = if nowait {
343 self.shared.service.call_nowait(item)
344 } else {
345 self.shared.service.call(item)
346 };
347 let inflight = self.shared.inflight.get() + 1;
348 self.shared.inflight.set(inflight);
349 if inflight == 1 {
350 self.shared.remove_flags(Flags::IDLE);
351 }
352
353 if self.response.is_none() {
355 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
356 self.shared.handle_result(result, &self.shared.io, false);
357 } else {
358 self.response = Some(fut);
359 }
360 } else {
361 let shared = self.shared.clone();
362 spawn(async move {
363 let result = fut.await;
364 shared.handle_result(result, &shared.io, true);
365 });
366 }
367 }
368
369 fn check_error(&mut self) -> PollService<U> {
370 if let Some(err) = self.shared.error.take() {
372 log::trace!(
373 "{}: Error occured, stopping dispatcher",
374 self.shared.io.tag()
375 );
376 self.st = DispatcherState::Stop;
377
378 match err {
379 DispatcherError::Encoder(err) => {
380 PollService::Item(DispatchItem::EncoderError(err))
381 }
382 DispatcherError::Service(err) => {
383 self.error = Some(err);
384 PollService::Continue
385 }
386 }
387 } else {
388 PollService::Ready
389 }
390 }
391
392 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
393 match self.shared.service.poll_ready(cx) {
395 Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
396 Poll::Pending => {
398 log::trace!(
399 "{}: Service is not ready, register dispatcher",
400 self.shared.io.tag()
401 );
402
403 self.shared
405 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
406 self.shared.io.stop_timer();
407
408 match ready!(self.shared.io.poll_read_pause(cx)) {
409 IoStatusUpdate::KeepAlive => {
410 log::trace!(
411 "{}: Keep-alive error, stopping dispatcher during pause",
412 self.shared.io.tag()
413 );
414 self.st = DispatcherState::Stop;
415 Poll::Ready(PollService::ItemWait(DispatchItem::KeepAliveTimeout))
416 }
417 IoStatusUpdate::PeerGone(err) => {
418 log::trace!(
419 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
420 self.shared.io.tag(),
421 err
422 );
423 self.st = DispatcherState::Stop;
424 Poll::Ready(PollService::ItemWait(DispatchItem::Disconnect(err)))
425 }
426 IoStatusUpdate::WriteBackpressure => {
427 self.st = DispatcherState::Backpressure;
428 Poll::Ready(PollService::ItemWait(
429 DispatchItem::WBackPressureEnabled,
430 ))
431 }
432 }
433 }
434 Poll::Ready(Err(err)) => {
436 log::trace!(
437 "{}: Service readiness check failed, stopping",
438 self.shared.io.tag()
439 );
440 self.st = DispatcherState::Stop;
441 self.error = Some(err);
442 self.shared.insert_flags(Flags::READY_ERR);
443 Poll::Ready(PollService::Continue)
444 }
445 }
446 }
447
448 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
449 if decoded.item.is_some() {
451 self.read_remains = 0;
452 self.shared
453 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
454 } else if self.shared.contains(Flags::READ_TIMEOUT) {
455 self.read_remains = decoded.remains as u32;
457 } else if self.read_remains == 0 && decoded.remains == 0 {
458 if self.shared.contains(Flags::KA_ENABLED)
460 && !self.shared.contains(Flags::KA_TIMEOUT)
461 {
462 log::trace!(
463 "{}: Start keep-alive timer {:?}",
464 self.shared.io.tag(),
465 self.shared.io.cfg().keepalive_timeout()
466 );
467 self.shared.insert_flags(Flags::KA_TIMEOUT);
468 self.shared
469 .io
470 .start_timer(self.shared.io.cfg().keepalive_timeout());
471 }
472 } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
473 self.shared.insert_flags(Flags::READ_TIMEOUT);
476
477 self.read_remains = decoded.remains as u32;
478 self.read_remains_prev = 0;
479 self.read_max_timeout = params.max_timeout;
480 self.shared.io.start_timer(params.timeout);
481 }
482 }
483
484 fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
485 if self.shared.contains(Flags::READ_TIMEOUT) {
487 if let Some(params) = self.shared.io.cfg().frame_read_rate() {
488 let total = self.read_remains - self.read_remains_prev;
489
490 if total > params.rate {
492 self.read_remains_prev = self.read_remains;
493 self.read_remains = 0;
494
495 if !params.max_timeout.is_zero() {
496 self.read_max_timeout = Seconds(
497 self.read_max_timeout.0.saturating_sub(params.timeout.0),
498 );
499 }
500
501 if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
502 log::trace!(
503 "{}: Frame read rate {:?}, extend timer",
504 self.shared.io.tag(),
505 total
506 );
507 self.shared.io.start_timer(params.timeout);
508 return Ok(());
509 }
510 log::trace!(
511 "{}: Max payload timeout has been reached",
512 self.shared.io.tag()
513 );
514 }
515 Err(DispatchItem::ReadTimeout)
516 } else {
517 Ok(())
518 }
519 } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
520 log::trace!(
521 "{}: Keep-alive error, stopping dispatcher",
522 self.shared.io.tag()
523 );
524 Err(DispatchItem::KeepAliveTimeout)
525 } else {
526 Ok(())
527 }
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
534 use std::{cell::RefCell, io};
535
536 use ntex_bytes::{Bytes, BytesMut};
537 use ntex_codec::BytesCodec;
538 use ntex_service::{ServiceCtx, cfg::SharedCfg};
539 use ntex_util::{time::Millis, time::sleep};
540 use rand::Rng;
541
542 use super::*;
543 use crate::{Flags, Io, IoConfig, IoRef, testing::IoTest};
544
545 pub(crate) struct State(IoRef);
546
547 impl State {
548 fn flags(&self) -> Flags {
549 self.0.flags()
550 }
551
552 fn io(&self) -> &IoRef {
553 &self.0
554 }
555
556 fn close(&self) {
557 self.0.close();
558 }
559 }
560
561 #[derive(Copy, Clone)]
562 struct BCodec(usize);
563
564 impl Encoder for BCodec {
565 type Item = Bytes;
566 type Error = io::Error;
567
568 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
569 dst.extend_from_slice(&item[..]);
570 Ok(())
571 }
572 }
573
574 impl Decoder for BCodec {
575 type Item = BytesMut;
576 type Error = io::Error;
577
578 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
579 if src.len() < self.0 {
580 Ok(None)
581 } else {
582 Ok(Some(src.split_to(self.0)))
583 }
584 }
585 }
586
587 impl<S, U> Dispatcher<S, U>
588 where
589 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
590 U: Decoder + Encoder + 'static,
591 {
592 pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
594 let flags = if io.cfg().keepalive_timeout().is_zero() {
595 super::Flags::empty()
596 } else {
597 super::Flags::KA_ENABLED
598 };
599
600 let inner = State(io.get_ref());
601 io.start_timer(Seconds::ONE);
602
603 let shared = Rc::new(DispatcherShared {
604 codec,
605 io: io.into(),
606 flags: Cell::new(flags),
607 error: Cell::new(None),
608 inflight: Cell::new(0),
609 service: Pipeline::new(service).bind(),
610 });
611
612 (
613 Dispatcher {
614 inner: DispatcherInner {
615 shared,
616 error: None,
617 st: DispatcherState::Processing,
618 response: None,
619 read_remains: 0,
620 read_remains_prev: 0,
621 read_max_timeout: Seconds::ZERO,
622 },
623 },
624 inner,
625 )
626 }
627 }
628
629 #[ntex::test]
630 async fn basics() {
631 let (client, server) = IoTest::create();
632 client.remote_buffer_cap(1024);
633 client.write("GET /test HTTP/1\r\n\r\n");
634
635 let (disp, _) = Dispatcher::debug(
636 Io::from(server),
637 BytesCodec,
638 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
639 sleep(Millis(50)).await;
640 if let DispatchItem::Item(msg) = msg {
641 Ok::<_, ()>(Some(msg.freeze()))
642 } else {
643 panic!()
644 }
645 }),
646 );
647 spawn(async move {
648 let _ = disp.await;
649 });
650
651 sleep(Millis(25)).await;
652 let buf = client.read().await.unwrap();
653 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
654
655 client.write("GET /test HTTP/1\r\n\r\n");
656 let buf = client.read().await.unwrap();
657 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
658
659 client.close().await;
660 assert!(client.is_server_dropped());
661
662 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
663 }
664
665 #[ntex::test]
666 async fn sink() {
667 let (client, server) = IoTest::create();
668 client.remote_buffer_cap(1024);
669 client.write("GET /test HTTP/1\r\n\r\n");
670
671 let (disp, st) = Dispatcher::debug(
672 Io::from(server),
673 BytesCodec,
674 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
675 if let DispatchItem::Item(msg) = msg {
676 Ok::<_, ()>(Some(msg.freeze()))
677 } else if let DispatchItem::Disconnect(_) = msg {
678 Ok(None)
679 } else {
680 panic!()
681 }
682 }),
683 );
684 spawn(async move {
685 let _ = disp.await;
686 });
687
688 let buf = client.read().await.unwrap();
689 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
690
691 assert!(
692 st.io()
693 .encode(Bytes::from_static(b"test"), &BytesCodec)
694 .is_ok()
695 );
696 let buf = client.read().await.unwrap();
697 assert_eq!(buf, Bytes::from_static(b"test"));
698
699 st.close();
700 sleep(Millis(1500)).await;
701 assert!(client.is_server_dropped());
702 }
703
704 #[ntex::test]
705 async fn err_in_service() {
706 let (client, server) = IoTest::create();
707 client.remote_buffer_cap(0);
708 client.write("GET /test HTTP/1\r\n\r\n");
709
710 let (disp, state) = Dispatcher::debug(
711 Io::from(server),
712 BytesCodec,
713 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
714 Err::<Option<Bytes>, _>(())
715 }),
716 );
717 state
718 .io()
719 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
720 .unwrap();
721 spawn(async move {
722 let _ = disp.await;
723 });
724
725 client.remote_buffer_cap(1024);
727 let buf = client.read().await.unwrap();
728 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
729
730 sleep(Millis(250)).await;
732 assert!(client.is_closed());
733
734 client.close().await;
736
737 assert!(client.is_server_dropped());
739 }
740
741 #[ntex::test]
742 async fn err_in_service_ready() {
743 let (client, server) = IoTest::create();
744 client.remote_buffer_cap(0);
745 client.write("GET /test HTTP/1\r\n\r\n");
746
747 let counter = Rc::new(Cell::new(0));
748
749 struct Srv(Rc<Cell<usize>>);
750
751 impl Service<DispatchItem<BytesCodec>> for Srv {
752 type Response = Option<Response<BytesCodec>>;
753 type Error = &'static str;
754
755 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
756 self.0.set(self.0.get() + 1);
757 Err("test")
758 }
759
760 async fn call(
761 &self,
762 _: DispatchItem<BytesCodec>,
763 _: ServiceCtx<'_, Self>,
764 ) -> Result<Self::Response, Self::Error> {
765 Ok(None)
766 }
767 }
768
769 let (disp, state) =
770 Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
771 spawn(async move {
772 let res = disp.await;
773 assert_eq!(res, Err("test"));
774 });
775
776 state
777 .io()
778 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
779 .unwrap();
780
781 client.remote_buffer_cap(1024);
783 let buf = client.read().await.unwrap();
784 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
785
786 sleep(Millis(250)).await;
788 assert!(client.is_closed());
789
790 client.close().await;
792 assert!(client.is_server_dropped());
793
794 assert_eq!(counter.get(), 1);
796 }
797
798 #[ntex::test]
799 async fn write_backpressure() {
800 let (client, server) = IoTest::create();
801 client.remote_buffer_cap(0);
803 client.write("GET /test HTTP/1\r\n\r\n");
804
805 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
806 let data2 = data.clone();
807
808 let io = Io::new(
809 server,
810 SharedCfg::new("TEST").add(
811 IoConfig::new()
812 .set_read_buf(8 * 1024, 1024, 16)
813 .set_write_buf(16 * 1024, 1024, 16),
814 ),
815 );
816
817 let (disp, state) = Dispatcher::debug(
818 io,
819 BytesCodec,
820 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
821 let data = data2.clone();
822 async move {
823 match msg {
824 DispatchItem::Item(_) => {
825 data.lock().unwrap().borrow_mut().push(0);
826 let bytes = rand::rng()
827 .sample_iter(&rand::distr::Alphanumeric)
828 .take(65_536)
829 .map(char::from)
830 .collect::<String>();
831 return Ok::<_, ()>(Some(Bytes::from(bytes)));
832 }
833 DispatchItem::WBackPressureEnabled => {
834 data.lock().unwrap().borrow_mut().push(1);
835 }
836 DispatchItem::WBackPressureDisabled => {
837 data.lock().unwrap().borrow_mut().push(2);
838 }
839 _ => (),
840 }
841 Ok(None)
842 }
843 }),
844 );
845
846 spawn(async move {
847 let _ = disp.await;
848 });
849
850 let buf = client.read_any();
851 assert_eq!(buf, Bytes::from_static(b""));
852 client.write("GET /test HTTP/1\r\n\r\n");
853 sleep(Millis(25)).await;
854
855 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
857
858 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
860
861 client.remote_buffer_cap(10240);
862 sleep(Millis(50)).await;
863 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
864
865 client.remote_buffer_cap(48056);
866 sleep(Millis(50)).await;
867 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
868
869 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
871 }
872
873 #[ntex::test]
874 async fn disconnect_during_read_backpressure() {
875 let (client, server) = IoTest::create();
876 client.remote_buffer_cap(0);
877
878 let (disp, state) = Dispatcher::debug(
879 Io::new(
880 server,
881 SharedCfg::new("TEST").add(
882 IoConfig::new()
883 .set_keepalive_timeout(Seconds::ZERO)
884 .set_read_buf(1024, 512, 16),
885 ),
886 ),
887 BytesCodec,
888 ntex_util::services::inflight::InFlightService::new(
889 1,
890 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
891 if let DispatchItem::Item(_) = msg {
892 sleep(Millis(500)).await;
893 Ok::<_, ()>(None)
894 } else {
895 Ok(None)
896 }
897 }),
898 ),
899 );
900
901 let (tx, rx) = ntex::channel::oneshot::channel();
902 ntex::rt::spawn(async move {
903 let _ = disp.await;
904 let _ = tx.send(());
905 });
906
907 let bytes = rand::rng()
908 .sample_iter(&rand::distr::Alphanumeric)
909 .take(1024)
910 .map(char::from)
911 .collect::<String>();
912 client.write(bytes.clone());
913 sleep(Millis(25)).await;
914 client.write(bytes);
915 sleep(Millis(25)).await;
916
917 state.close();
919 let _ = rx.recv().await;
920 }
921
922 #[ntex::test]
923 async fn keepalive() {
924 let (client, server) = IoTest::create();
925 client.remote_buffer_cap(1024);
926 client.write("GET /test HTTP/1\r\n\r\n");
927
928 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
929 let data2 = data.clone();
930
931 let cfg = SharedCfg::new("DBG").add(
932 IoConfig::new()
933 .set_disconnect_timeout(Seconds(1))
934 .set_keepalive_timeout(Seconds(1)),
935 );
936
937 let (disp, state) = Dispatcher::debug(
938 Io::new(server, cfg),
939 BytesCodec,
940 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
941 let data = data2.clone();
942 async move {
943 match msg {
944 DispatchItem::Item(bytes) => {
945 data.lock().unwrap().borrow_mut().push(0);
946 return Ok::<_, ()>(Some(bytes.freeze()));
947 }
948 DispatchItem::KeepAliveTimeout => {
949 data.lock().unwrap().borrow_mut().push(1);
950 }
951 _ => (),
952 }
953 Ok(None)
954 }
955 }),
956 );
957 spawn(async move {
958 let _ = disp.await;
959 });
960
961 let buf = client.read().await.unwrap();
962 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
963 sleep(Millis(2000)).await;
964
965 let flags = state.flags();
967 assert!(flags.contains(Flags::IO_STOPPING));
968 assert!(client.is_closed());
969 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
970 }
971
972 #[ntex::test]
973 async fn keepalive2() {
974 let (client, server) = IoTest::create();
975 client.remote_buffer_cap(1024);
976
977 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
978 let data2 = data.clone();
979
980 let cfg = SharedCfg::new("DBG").add(
981 IoConfig::new()
982 .set_keepalive_timeout(Seconds(1))
983 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
984 );
985
986 let (disp, state) = Dispatcher::debug(
987 Io::new(server, cfg),
988 BCodec(8),
989 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
990 let data = data2.clone();
991 async move {
992 match msg {
993 DispatchItem::Item(bytes) => {
994 data.lock().unwrap().borrow_mut().push(0);
995 return Ok::<_, ()>(Some(bytes.freeze()));
996 }
997 DispatchItem::KeepAliveTimeout => {
998 data.lock().unwrap().borrow_mut().push(1);
999 }
1000 _ => (),
1001 }
1002 Ok(None)
1003 }
1004 }),
1005 );
1006 spawn(async move {
1007 let _ = disp.await;
1008 });
1009
1010 client.write("12345678");
1011 let buf = client.read().await.unwrap();
1012 assert_eq!(buf, Bytes::from_static(b"12345678"));
1013 sleep(Millis(2000)).await;
1014
1015 let flags = state.flags();
1017 assert!(flags.contains(Flags::IO_STOPPING));
1018 assert!(client.is_closed());
1019 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1020 }
1021
1022 #[ntex::test]
1024 async fn keepalive3() {
1025 let (client, server) = IoTest::create();
1026 client.remote_buffer_cap(1024);
1027
1028 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1029 let data2 = data.clone();
1030
1031 let cfg = SharedCfg::new("DBG").add(
1032 IoConfig::new()
1033 .set_keepalive_timeout(Seconds(2))
1034 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1035 );
1036
1037 let (disp, _) = Dispatcher::debug(
1038 Io::new(server, cfg),
1039 BCodec(1),
1040 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1041 let data = data2.clone();
1042 async move {
1043 match msg {
1044 DispatchItem::Item(bytes) => {
1045 data.lock().unwrap().borrow_mut().push(0);
1046 return Ok::<_, ()>(Some(bytes.freeze()));
1047 }
1048 DispatchItem::KeepAliveTimeout => {
1049 data.lock().unwrap().borrow_mut().push(1);
1050 }
1051 _ => (),
1052 }
1053 Ok(None)
1054 }
1055 }),
1056 );
1057 spawn(async move {
1058 let _ = disp.await;
1059 });
1060
1061 client.write("1");
1062 let buf = client.read().await.unwrap();
1063 assert_eq!(buf, Bytes::from_static(b"1"));
1064 sleep(Millis(750)).await;
1065
1066 client.write("2");
1067 let buf = client.read().await.unwrap();
1068 assert_eq!(buf, Bytes::from_static(b"2"));
1069
1070 sleep(Millis(750)).await;
1071 client.write("3");
1072 let buf = client.read().await.unwrap();
1073 assert_eq!(buf, Bytes::from_static(b"3"));
1074
1075 sleep(Millis(750)).await;
1076 assert!(!client.is_closed());
1077 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1078 }
1079
1080 #[ntex::test]
1081 async fn read_timeout() {
1082 let (client, server) = IoTest::create();
1083 client.remote_buffer_cap(1024);
1084
1085 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1086 let data2 = data.clone();
1087
1088 let io = Io::new(
1089 server,
1090 SharedCfg::new("TEST").add(
1091 IoConfig::new()
1092 .set_keepalive_timeout(Seconds::ZERO)
1093 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1094 ),
1095 );
1096
1097 let (disp, state) = Dispatcher::debug(
1098 io,
1099 BCodec(8),
1100 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1101 let data = data2.clone();
1102 async move {
1103 match msg {
1104 DispatchItem::Item(bytes) => {
1105 data.lock().unwrap().borrow_mut().push(0);
1106 return Ok::<_, ()>(Some(bytes.freeze()));
1107 }
1108 DispatchItem::ReadTimeout => {
1109 data.lock().unwrap().borrow_mut().push(1);
1110 }
1111 _ => (),
1112 }
1113 Ok(None)
1114 }
1115 }),
1116 );
1117 spawn(async move {
1118 let _ = disp.await;
1119 });
1120
1121 client.write("12345678");
1122 let buf = client.read().await.unwrap();
1123 assert_eq!(buf, Bytes::from_static(b"12345678"));
1124
1125 client.write("1");
1126 sleep(Millis(1000)).await;
1127 assert!(!state.flags().contains(Flags::IO_STOPPING));
1128 client.write("23");
1129 sleep(Millis(1000)).await;
1130 assert!(!state.flags().contains(Flags::IO_STOPPING));
1131 client.write("4");
1132 sleep(Millis(2000)).await;
1133
1134 assert!(state.flags().contains(Flags::IO_STOPPING));
1136 assert!(client.is_closed());
1137 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1138 }
1139
1140 #[ntex::test]
1141 async fn idle_timeout() {
1142 let (client, server) = IoTest::create();
1143 client.remote_buffer_cap(1024);
1144
1145 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1146 let data2 = data.clone();
1147
1148 let io = Io::new(
1149 server,
1150 SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1151 );
1152 let ioref = io.get_ref();
1153
1154 let (disp, state) = Dispatcher::debug(
1155 io,
1156 BCodec(1),
1157 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1158 let ioref = ioref.clone();
1159 ntex::rt::spawn(async move {
1160 sleep(Millis(500)).await;
1161 ioref.notify_timeout();
1162 });
1163 let data = data2.clone();
1164 async move {
1165 match msg {
1166 DispatchItem::Item(bytes) => {
1167 data.lock().unwrap().borrow_mut().push(0);
1168 return Ok::<_, ()>(Some(bytes.freeze()));
1169 }
1170 DispatchItem::ReadTimeout => {
1171 data.lock().unwrap().borrow_mut().push(1);
1172 }
1173 _ => (),
1174 }
1175 Ok(None)
1176 }
1177 }),
1178 );
1179 spawn(async move {
1180 let _ = disp.await;
1181 });
1182
1183 client.write("1");
1184 let buf = client.read().await.unwrap();
1185 assert_eq!(buf, Bytes::from_static(b"1"));
1186
1187 sleep(Millis(1000)).await;
1188 assert!(state.flags().contains(Flags::IO_STOPPING));
1189 assert!(client.is_closed());
1190 }
1191
1192 #[ntex::test]
1193 async fn unhandled_data() {
1194 let handled = Arc::new(AtomicBool::new(false));
1195 let handled2 = handled.clone();
1196
1197 let (client, server) = IoTest::create();
1198 client.remote_buffer_cap(1024);
1199 client.write("GET /test HTTP/1\r\n\r\n");
1200
1201 let (disp, _) = Dispatcher::debug(
1202 Io::from(server),
1203 BytesCodec,
1204 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1205 handled2.store(true, Relaxed);
1206 async move {
1207 sleep(Millis(50)).await;
1208 if let DispatchItem::Item(msg) = msg {
1209 Ok::<_, ()>(Some(msg.freeze()))
1210 } else if let DispatchItem::Disconnect(_) = msg {
1211 Ok::<_, ()>(None)
1212 } else {
1213 panic!()
1214 }
1215 }
1216 }),
1217 );
1218 client.close().await;
1219 spawn(async move {
1220 let _ = disp.await;
1221 });
1222 sleep(Millis(50)).await;
1223
1224 assert!(handled.load(Relaxed));
1225 }
1226}