1#![allow(clippy::let_underscore_future)]
3use std::task::{Context, Poll, ready};
4use std::{cell::Cell, future::Future, pin::Pin, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
8use ntex_util::{future::Either, spawn, time::Seconds};
9
10use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
11
12type Response<U> = <U as Encoder>::Item;
13
14pin_project_lite::pin_project! {
15 pub struct Dispatcher<S, U>
18 where
19 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20 U: Encoder,
21 U: Decoder,
22 U: 'static,
23 {
24 inner: DispatcherInner<S, U>,
25 }
26}
27
28bitflags::bitflags! {
29 #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
30 struct Flags: u8 {
31 const READY_ERR = 0b0000001;
32 const IO_ERR = 0b0000010;
33 const KA_ENABLED = 0b0000100;
34 const KA_TIMEOUT = 0b0001000;
35 const READ_TIMEOUT = 0b0010000;
36 const IDLE = 0b0100000;
37 }
38}
39
40struct DispatcherInner<S, U>
41where
42 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
43 U: Encoder + Decoder + 'static,
44{
45 st: DispatcherState,
46 error: Option<S::Error>,
47 shared: Rc<DispatcherShared<S, U>>,
48 response: Option<PipelineCall<S, DispatchItem<U>>>,
49 read_remains: u32,
50 read_remains_prev: u32,
51 read_max_timeout: Seconds,
52}
53
54pub(crate) struct DispatcherShared<S, U>
55where
56 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
57 U: Encoder + Decoder,
58{
59 io: IoBoxed,
60 codec: U,
61 service: PipelineBinding<S, DispatchItem<U>>,
62 flags: Cell<Flags>,
63 error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
64 inflight: Cell<u32>,
65}
66
67#[derive(Copy, Clone, Debug)]
68enum DispatcherState {
69 Processing,
70 Backpressure,
71 Stop,
72 Shutdown,
73}
74
75#[derive(Debug)]
76enum DispatcherError<S, U> {
77 Encoder(U),
78 Service(S),
79}
80
81enum PollService<U: Encoder + Decoder> {
82 Item(DispatchItem<U>),
83 Continue,
84 Ready,
85}
86
87impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
88 fn from(err: Either<S, U>) -> Self {
89 match err {
90 Either::Left(err) => DispatcherError::Service(err),
91 Either::Right(err) => DispatcherError::Encoder(err),
92 }
93 }
94}
95
96impl<S, U> Dispatcher<S, U>
97where
98 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
99 U: Decoder + Encoder + 'static,
100{
101 pub fn new<Io, F>(io: Io, codec: U, service: F) -> Dispatcher<S, U>
103 where
104 IoBoxed: From<Io>,
105 F: IntoService<S, DispatchItem<U>>,
106 {
107 let io = IoBoxed::from(io);
108 let flags = if io.cfg().keepalive_timeout().is_zero() {
109 Flags::empty()
110 } else {
111 Flags::KA_ENABLED
112 };
113
114 let shared = Rc::new(DispatcherShared {
115 io,
116 codec,
117 flags: Cell::new(flags),
118 error: Cell::new(None),
119 inflight: Cell::new(0),
120 service: Pipeline::new(service.into_service()).bind(),
121 });
122
123 Dispatcher {
124 inner: DispatcherInner {
125 shared,
126 response: None,
127 error: None,
128 read_remains: 0,
129 read_remains_prev: 0,
130 read_max_timeout: Seconds::ZERO,
131 st: DispatcherState::Processing,
132 },
133 }
134 }
135}
136
137impl<S, U> DispatcherShared<S, U>
138where
139 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
140 U: Encoder + Decoder,
141{
142 fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoBoxed, wake: bool) {
143 match item {
144 Ok(Some(val)) => {
145 if let Err(err) = io.encode(val, &self.codec) {
146 self.error.set(Some(DispatcherError::Encoder(err)))
147 }
148 }
149 Err(err) => self.error.set(Some(DispatcherError::Service(err))),
150 Ok(None) => (),
151 }
152 let inflight = self.inflight.get() - 1;
153 self.inflight.set(inflight);
154 if inflight == 0 {
155 self.insert_flags(Flags::IDLE);
156 }
157 if wake {
158 io.wake();
159 }
160 }
161
162 fn contains(&self, f: Flags) -> bool {
163 self.flags.get().intersects(f)
164 }
165
166 fn insert_flags(&self, f: Flags) {
167 let mut flags = self.flags.get();
168 flags.insert(f);
169 self.flags.set(flags);
170 }
171
172 fn remove_flags(&self, f: Flags) -> bool {
173 let mut flags = self.flags.get();
174 if flags.intersects(f) {
175 flags.remove(f);
176 self.flags.set(flags);
177 true
178 } else {
179 false
180 }
181 }
182}
183
184impl<S, U> Future for Dispatcher<S, U>
185where
186 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
187 U: Decoder + Encoder + 'static,
188{
189 type Output = Result<(), S::Error>;
190
191 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192 let mut this = self.as_mut().project();
193 let slf = &mut this.inner;
194
195 if let Some(fut) = slf.response.as_mut() {
197 if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
198 slf.shared.handle_result(item, &slf.shared.io, false);
199 slf.response = None;
200 }
201 }
202
203 loop {
204 match slf.st {
205 DispatcherState::Processing => {
206 let item = match ready!(slf.poll_service(cx)) {
207 PollService::Ready => {
208 match slf.shared.io.poll_recv_decode(&slf.shared.codec, cx) {
210 Ok(decoded) => {
211 slf.update_timer(&decoded);
212 if let Some(el) = decoded.item {
213 DispatchItem::Item(el)
214 } else {
215 return Poll::Pending;
216 }
217 }
218 Err(RecvError::KeepAlive) => {
219 if let Err(err) = slf.handle_timeout() {
220 slf.st = DispatcherState::Stop;
221 err
222 } else {
223 continue;
224 }
225 }
226 Err(RecvError::WriteBackpressure) => {
227 slf.st = DispatcherState::Backpressure;
229 DispatchItem::WBackPressureEnabled
230 }
231 Err(RecvError::Decoder(err)) => {
232 log::trace!(
233 "{}: Decoder error, stopping dispatcher: {:?}",
234 slf.shared.io.tag(),
235 err
236 );
237 slf.st = DispatcherState::Stop;
238 DispatchItem::DecoderError(err)
239 }
240 Err(RecvError::PeerGone(err)) => {
241 log::trace!(
242 "{}: Peer is gone, stopping dispatcher: {:?}",
243 slf.shared.io.tag(),
244 err
245 );
246 slf.st = DispatcherState::Stop;
247 DispatchItem::Disconnect(err)
248 }
249 }
250 }
251 PollService::Item(item) => item,
252 PollService::Continue => continue,
253 };
254
255 slf.call_service(cx, item);
256 }
257 DispatcherState::Backpressure => {
259 match ready!(slf.poll_service(cx)) {
260 PollService::Ready => (),
261 PollService::Item(item) => slf.call_service(cx, item),
262 PollService::Continue => continue,
263 };
264
265 let item = if let Err(err) = ready!(slf.shared.io.poll_flush(cx, false))
266 {
267 slf.st = DispatcherState::Stop;
268 DispatchItem::Disconnect(Some(err))
269 } else {
270 slf.st = DispatcherState::Processing;
271 DispatchItem::WBackPressureDisabled
272 };
273 slf.call_service(cx, item);
274 }
275 DispatcherState::Stop => {
277 slf.shared.io.stop_timer();
278
279 if !slf.shared.contains(Flags::READY_ERR) {
281 if let Poll::Ready(res) = slf.shared.service.poll_ready(cx) {
282 if res.is_err() {
283 slf.shared.insert_flags(Flags::READY_ERR);
284 }
285 }
286 }
287
288 if slf.shared.inflight.get() == 0 {
289 if slf.shared.io.poll_shutdown(cx).is_ready() {
290 slf.st = DispatcherState::Shutdown;
291 continue;
292 }
293 } else if !slf.shared.contains(Flags::IO_ERR) {
294 match ready!(slf.shared.io.poll_status_update(cx)) {
295 IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
296 slf.shared.insert_flags(Flags::IO_ERR);
297 continue;
298 }
299 IoStatusUpdate::WriteBackpressure => {
300 if ready!(slf.shared.io.poll_flush(cx, true)).is_err() {
301 slf.shared.insert_flags(Flags::IO_ERR);
302 }
303 continue;
304 }
305 }
306 } else {
307 slf.shared.io.poll_dispatch(cx);
308 }
309 return Poll::Pending;
310 }
311 DispatcherState::Shutdown => {
313 return if slf.shared.service.poll_shutdown(cx).is_ready() {
314 log::trace!(
315 "{}: Service shutdown is completed, stop",
316 slf.shared.io.tag()
317 );
318
319 Poll::Ready(if let Some(err) = slf.error.take() {
320 Err(err)
321 } else {
322 Ok(())
323 })
324 } else {
325 Poll::Pending
326 };
327 }
328 }
329 }
330 }
331}
332
333impl<S, U> DispatcherInner<S, U>
334where
335 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
336 U: Decoder + Encoder + 'static,
337{
338 fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
339 let mut fut = self.shared.service.call_nowait(item);
340 let inflight = self.shared.inflight.get() + 1;
341 self.shared.inflight.set(inflight);
342 if inflight == 1 {
343 self.shared.remove_flags(Flags::IDLE);
344 }
345
346 if self.response.is_none() {
348 if let Poll::Ready(result) = Pin::new(&mut fut).poll(cx) {
349 self.shared.handle_result(result, &self.shared.io, false);
350 } else {
351 self.response = Some(fut);
352 }
353 } else {
354 let shared = self.shared.clone();
355 spawn(async move {
356 let result = fut.await;
357 shared.handle_result(result, &shared.io, true);
358 });
359 }
360 }
361
362 fn check_error(&mut self) -> PollService<U> {
363 if let Some(err) = self.shared.error.take() {
365 log::trace!(
366 "{}: Error occured, stopping dispatcher",
367 self.shared.io.tag()
368 );
369 self.st = DispatcherState::Stop;
370
371 match err {
372 DispatcherError::Encoder(err) => {
373 PollService::Item(DispatchItem::EncoderError(err))
374 }
375 DispatcherError::Service(err) => {
376 self.error = Some(err);
377 PollService::Continue
378 }
379 }
380 } else {
381 PollService::Ready
382 }
383 }
384
385 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
386 match self.shared.service.poll_ready(cx) {
388 Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()),
389 Poll::Pending => {
391 log::trace!(
392 "{}: Service is not ready, register dispatcher",
393 self.shared.io.tag()
394 );
395
396 self.shared
398 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
399 self.shared.io.stop_timer();
400
401 match ready!(self.shared.io.poll_read_pause(cx)) {
402 IoStatusUpdate::KeepAlive => {
403 log::trace!(
404 "{}: Keep-alive error, stopping dispatcher during pause",
405 self.shared.io.tag()
406 );
407 self.st = DispatcherState::Stop;
408 Poll::Ready(PollService::Item(DispatchItem::KeepAliveTimeout))
409 }
410 IoStatusUpdate::PeerGone(err) => {
411 log::trace!(
412 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
413 self.shared.io.tag(),
414 err
415 );
416 self.st = DispatcherState::Stop;
417 Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
418 }
419 IoStatusUpdate::WriteBackpressure => {
420 self.st = DispatcherState::Backpressure;
421 Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
422 }
423 }
424 }
425 Poll::Ready(Err(err)) => {
427 log::trace!(
428 "{}: Service readiness check failed, stopping",
429 self.shared.io.tag()
430 );
431 self.st = DispatcherState::Stop;
432 self.error = Some(err);
433 self.shared.insert_flags(Flags::READY_ERR);
434 Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
435 }
436 }
437 }
438
439 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
440 if decoded.item.is_some() {
442 self.read_remains = 0;
443 self.shared
444 .remove_flags(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT | Flags::IDLE);
445 } else if self.shared.contains(Flags::READ_TIMEOUT) {
446 self.read_remains = decoded.remains as u32;
448 } else if self.read_remains == 0 && decoded.remains == 0 {
449 if self.shared.contains(Flags::KA_ENABLED)
451 && !self.shared.contains(Flags::KA_TIMEOUT)
452 {
453 log::trace!(
454 "{}: Start keep-alive timer {:?}",
455 self.shared.io.tag(),
456 self.shared.io.cfg().keepalive_timeout()
457 );
458 self.shared.insert_flags(Flags::KA_TIMEOUT);
459 self.shared
460 .io
461 .start_timer(self.shared.io.cfg().keepalive_timeout());
462 }
463 } else if let Some(params) = self.shared.io.cfg().frame_read_rate() {
464 self.shared.insert_flags(Flags::READ_TIMEOUT);
467
468 self.read_remains = decoded.remains as u32;
469 self.read_remains_prev = 0;
470 self.read_max_timeout = params.max_timeout;
471 self.shared.io.start_timer(params.timeout);
472 }
473 }
474
475 fn handle_timeout(&mut self) -> Result<(), DispatchItem<U>> {
476 if self.shared.contains(Flags::READ_TIMEOUT) {
478 if let Some(params) = self.shared.io.cfg().frame_read_rate() {
479 let total = self.read_remains - self.read_remains_prev;
480
481 if total > params.rate {
483 self.read_remains_prev = self.read_remains;
484 self.read_remains = 0;
485
486 if !params.max_timeout.is_zero() {
487 self.read_max_timeout = Seconds(
488 self.read_max_timeout.0.saturating_sub(params.timeout.0),
489 );
490 }
491
492 if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
493 log::trace!(
494 "{}: Frame read rate {:?}, extend timer",
495 self.shared.io.tag(),
496 total
497 );
498 self.shared.io.start_timer(params.timeout);
499 return Ok(());
500 }
501 log::trace!(
502 "{}: Max payload timeout has been reached",
503 self.shared.io.tag()
504 );
505 }
506 Err(DispatchItem::ReadTimeout)
507 } else {
508 Ok(())
509 }
510 } else if self.shared.contains(Flags::KA_TIMEOUT | Flags::IDLE) {
511 log::trace!(
512 "{}: Keep-alive error, stopping dispatcher",
513 self.shared.io.tag()
514 );
515 Err(DispatchItem::KeepAliveTimeout)
516 } else {
517 Ok(())
518 }
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering::Relaxed};
525 use std::{cell::RefCell, io};
526
527 use ntex_bytes::{Bytes, BytesMut};
528 use ntex_codec::BytesCodec;
529 use ntex_service::{ServiceCtx, cfg::SharedCfg};
530 use ntex_util::{time::Millis, time::sleep};
531 use rand::Rng;
532
533 use super::*;
534 use crate::{Flags, Io, IoConfig, IoRef, testing::IoTest};
535
536 pub(crate) struct State(IoRef);
537
538 impl State {
539 fn flags(&self) -> Flags {
540 self.0.flags()
541 }
542
543 fn io(&self) -> &IoRef {
544 &self.0
545 }
546
547 fn close(&self) {
548 self.0.close();
549 }
550 }
551
552 #[derive(Copy, Clone)]
553 struct BCodec(usize);
554
555 impl Encoder for BCodec {
556 type Item = Bytes;
557 type Error = io::Error;
558
559 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
560 dst.extend_from_slice(&item[..]);
561 Ok(())
562 }
563 }
564
565 impl Decoder for BCodec {
566 type Item = BytesMut;
567 type Error = io::Error;
568
569 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
570 if src.len() < self.0 {
571 Ok(None)
572 } else {
573 Ok(Some(src.split_to(self.0)))
574 }
575 }
576 }
577
578 impl<S, U> Dispatcher<S, U>
579 where
580 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
581 U: Decoder + Encoder + 'static,
582 {
583 pub(crate) fn debug(io: Io, codec: U, service: S) -> (Self, State) {
585 let flags = if io.cfg().keepalive_timeout().is_zero() {
586 super::Flags::empty()
587 } else {
588 super::Flags::KA_ENABLED
589 };
590
591 let inner = State(io.get_ref());
592 io.start_timer(Seconds::ONE);
593
594 let shared = Rc::new(DispatcherShared {
595 codec,
596 io: io.into(),
597 flags: Cell::new(flags),
598 error: Cell::new(None),
599 inflight: Cell::new(0),
600 service: Pipeline::new(service).bind(),
601 });
602
603 (
604 Dispatcher {
605 inner: DispatcherInner {
606 shared,
607 error: None,
608 st: DispatcherState::Processing,
609 response: None,
610 read_remains: 0,
611 read_remains_prev: 0,
612 read_max_timeout: Seconds::ZERO,
613 },
614 },
615 inner,
616 )
617 }
618 }
619
620 #[ntex::test]
621 async fn test_basic() {
622 let (client, server) = IoTest::create();
623 client.remote_buffer_cap(1024);
624 client.write("GET /test HTTP/1\r\n\r\n");
625
626 let (disp, _) = Dispatcher::debug(
627 Io::from(server),
628 BytesCodec,
629 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
630 sleep(Millis(50)).await;
631 if let DispatchItem::Item(msg) = msg {
632 Ok::<_, ()>(Some(msg.freeze()))
633 } else {
634 panic!()
635 }
636 }),
637 );
638 spawn(async move {
639 let _ = disp.await;
640 });
641
642 sleep(Millis(25)).await;
643 let buf = client.read().await.unwrap();
644 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
645
646 client.write("GET /test HTTP/1\r\n\r\n");
647 let buf = client.read().await.unwrap();
648 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
649
650 client.close().await;
651 assert!(client.is_server_dropped());
652
653 assert!(format!("{:?}", super::Flags::KA_TIMEOUT.clone()).contains("KA_TIMEOUT"));
654 }
655
656 #[ntex::test]
657 async fn test_sink() {
658 let (client, server) = IoTest::create();
659 client.remote_buffer_cap(1024);
660 client.write("GET /test HTTP/1\r\n\r\n");
661
662 let (disp, st) = Dispatcher::debug(
663 Io::from(server),
664 BytesCodec,
665 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
666 if let DispatchItem::Item(msg) = msg {
667 Ok::<_, ()>(Some(msg.freeze()))
668 } else if let DispatchItem::Disconnect(_) = msg {
669 Ok(None)
670 } else {
671 panic!()
672 }
673 }),
674 );
675 spawn(async move {
676 let _ = disp.await;
677 });
678
679 let buf = client.read().await.unwrap();
680 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
681
682 assert!(
683 st.io()
684 .encode(Bytes::from_static(b"test"), &BytesCodec)
685 .is_ok()
686 );
687 let buf = client.read().await.unwrap();
688 assert_eq!(buf, Bytes::from_static(b"test"));
689
690 st.close();
691 sleep(Millis(1500)).await;
692 assert!(client.is_server_dropped());
693 }
694
695 #[ntex::test]
696 async fn test_err_in_service() {
697 let (client, server) = IoTest::create();
698 client.remote_buffer_cap(0);
699 client.write("GET /test HTTP/1\r\n\r\n");
700
701 let (disp, state) = Dispatcher::debug(
702 Io::from(server),
703 BytesCodec,
704 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
705 Err::<Option<Bytes>, _>(())
706 }),
707 );
708 state
709 .io()
710 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
711 .unwrap();
712 spawn(async move {
713 let _ = disp.await;
714 });
715
716 client.remote_buffer_cap(1024);
718 let buf = client.read().await.unwrap();
719 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
720
721 sleep(Millis(250)).await;
723 assert!(client.is_closed());
724
725 client.close().await;
727
728 assert!(client.is_server_dropped());
730 }
731
732 #[ntex::test]
733 async fn test_err_in_service_ready() {
734 let (client, server) = IoTest::create();
735 client.remote_buffer_cap(0);
736 client.write("GET /test HTTP/1\r\n\r\n");
737
738 let counter = Rc::new(Cell::new(0));
739
740 struct Srv(Rc<Cell<usize>>);
741
742 impl Service<DispatchItem<BytesCodec>> for Srv {
743 type Response = Option<Response<BytesCodec>>;
744 type Error = &'static str;
745
746 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
747 self.0.set(self.0.get() + 1);
748 Err("test")
749 }
750
751 async fn call(
752 &self,
753 _: DispatchItem<BytesCodec>,
754 _: ServiceCtx<'_, Self>,
755 ) -> Result<Self::Response, Self::Error> {
756 Ok(None)
757 }
758 }
759
760 let (disp, state) =
761 Dispatcher::debug(Io::from(server), BytesCodec, Srv(counter.clone()));
762 spawn(async move {
763 let res = disp.await;
764 assert_eq!(res, Err("test"));
765 });
766
767 state
768 .io()
769 .encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec)
770 .unwrap();
771
772 client.remote_buffer_cap(1024);
774 let buf = client.read().await.unwrap();
775 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
776
777 sleep(Millis(250)).await;
779 assert!(client.is_closed());
780
781 client.close().await;
783 assert!(client.is_server_dropped());
784
785 assert_eq!(counter.get(), 1);
787 }
788
789 #[ntex::test]
790 async fn test_write_backpressure() {
791 let (client, server) = IoTest::create();
792 client.remote_buffer_cap(0);
794 client.write("GET /test HTTP/1\r\n\r\n");
795
796 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
797 let data2 = data.clone();
798
799 let io = Io::new(
800 server,
801 SharedCfg::new("TEST").add(
802 IoConfig::new()
803 .set_read_buf(8 * 1024, 1024, 16)
804 .set_write_buf(16 * 1024, 1024, 16),
805 ),
806 );
807
808 let (disp, state) = Dispatcher::debug(
809 io,
810 BytesCodec,
811 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
812 let data = data2.clone();
813 async move {
814 match msg {
815 DispatchItem::Item(_) => {
816 data.lock().unwrap().borrow_mut().push(0);
817 let bytes = rand::rng()
818 .sample_iter(&rand::distr::Alphanumeric)
819 .take(65_536)
820 .map(char::from)
821 .collect::<String>();
822 return Ok::<_, ()>(Some(Bytes::from(bytes)));
823 }
824 DispatchItem::WBackPressureEnabled => {
825 data.lock().unwrap().borrow_mut().push(1);
826 }
827 DispatchItem::WBackPressureDisabled => {
828 data.lock().unwrap().borrow_mut().push(2);
829 }
830 _ => (),
831 }
832 Ok(None)
833 }
834 }),
835 );
836
837 spawn(async move {
838 let _ = disp.await;
839 });
840
841 let buf = client.read_any();
842 assert_eq!(buf, Bytes::from_static(b""));
843 client.write("GET /test HTTP/1\r\n\r\n");
844 sleep(Millis(25)).await;
845
846 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
848
849 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
851
852 client.remote_buffer_cap(10240);
853 sleep(Millis(50)).await;
854 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
855
856 client.remote_buffer_cap(48056);
857 sleep(Millis(50)).await;
858 assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 7240);
859
860 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
862 }
863
864 #[ntex::test]
865 async fn test_disconnect_during_read_backpressure() {
866 let (client, server) = IoTest::create();
867 client.remote_buffer_cap(0);
868
869 let (disp, state) = Dispatcher::debug(
870 Io::new(
871 server,
872 SharedCfg::new("TEST").add(
873 IoConfig::new()
874 .set_keepalive_timeout(Seconds::ZERO)
875 .set_read_buf(1024, 512, 16),
876 ),
877 ),
878 BytesCodec,
879 ntex_util::services::inflight::InFlightService::new(
880 1,
881 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| async move {
882 if let DispatchItem::Item(_) = msg {
883 sleep(Millis(500)).await;
884 Ok::<_, ()>(None)
885 } else {
886 Ok(None)
887 }
888 }),
889 ),
890 );
891
892 let (tx, rx) = ntex::channel::oneshot::channel();
893 ntex::rt::spawn(async move {
894 let _ = disp.await;
895 let _ = tx.send(());
896 });
897
898 let bytes = rand::rng()
899 .sample_iter(&rand::distr::Alphanumeric)
900 .take(1024)
901 .map(char::from)
902 .collect::<String>();
903 client.write(bytes.clone());
904 sleep(Millis(25)).await;
905 client.write(bytes);
906 sleep(Millis(25)).await;
907
908 state.close();
910 let _ = rx.recv().await;
911 }
912
913 #[ntex::test]
914 async fn test_keepalive() {
915 let (client, server) = IoTest::create();
916 client.remote_buffer_cap(1024);
917 client.write("GET /test HTTP/1\r\n\r\n");
918
919 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
920 let data2 = data.clone();
921
922 let cfg = SharedCfg::new("DBG").add(
923 IoConfig::new()
924 .set_disconnect_timeout(Seconds(1))
925 .set_keepalive_timeout(Seconds(1)),
926 );
927
928 let (disp, state) = Dispatcher::debug(
929 Io::new(server, cfg),
930 BytesCodec,
931 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
932 let data = data2.clone();
933 async move {
934 match msg {
935 DispatchItem::Item(bytes) => {
936 data.lock().unwrap().borrow_mut().push(0);
937 return Ok::<_, ()>(Some(bytes.freeze()));
938 }
939 DispatchItem::KeepAliveTimeout => {
940 data.lock().unwrap().borrow_mut().push(1);
941 }
942 _ => (),
943 }
944 Ok(None)
945 }
946 }),
947 );
948 spawn(async move {
949 let _ = disp.await;
950 });
951
952 let buf = client.read().await.unwrap();
953 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
954 sleep(Millis(2000)).await;
955
956 let flags = state.flags();
958 assert!(flags.contains(Flags::IO_STOPPING));
959 assert!(client.is_closed());
960 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
961 }
962
963 #[ntex::test]
964 async fn test_keepalive2() {
965 let (client, server) = IoTest::create();
966 client.remote_buffer_cap(1024);
967
968 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
969 let data2 = data.clone();
970
971 let cfg = SharedCfg::new("DBG").add(
972 IoConfig::new()
973 .set_keepalive_timeout(Seconds(1))
974 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
975 );
976
977 let (disp, state) = Dispatcher::debug(
978 Io::new(server, cfg),
979 BCodec(8),
980 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
981 let data = data2.clone();
982 async move {
983 match msg {
984 DispatchItem::Item(bytes) => {
985 data.lock().unwrap().borrow_mut().push(0);
986 return Ok::<_, ()>(Some(bytes.freeze()));
987 }
988 DispatchItem::KeepAliveTimeout => {
989 data.lock().unwrap().borrow_mut().push(1);
990 }
991 _ => (),
992 }
993 Ok(None)
994 }
995 }),
996 );
997 spawn(async move {
998 let _ = disp.await;
999 });
1000
1001 client.write("12345678");
1002 let buf = client.read().await.unwrap();
1003 assert_eq!(buf, Bytes::from_static(b"12345678"));
1004 sleep(Millis(2000)).await;
1005
1006 let flags = state.flags();
1008 assert!(flags.contains(Flags::IO_STOPPING));
1009 assert!(client.is_closed());
1010 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1011 }
1012
1013 #[ntex::test]
1015 async fn test_keepalive3() {
1016 let (client, server) = IoTest::create();
1017 client.remote_buffer_cap(1024);
1018
1019 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1020 let data2 = data.clone();
1021
1022 let cfg = SharedCfg::new("DBG").add(
1023 IoConfig::new()
1024 .set_keepalive_timeout(Seconds(2))
1025 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1026 );
1027
1028 let (disp, _) = Dispatcher::debug(
1029 Io::new(server, cfg),
1030 BCodec(1),
1031 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1032 let data = data2.clone();
1033 async move {
1034 match msg {
1035 DispatchItem::Item(bytes) => {
1036 data.lock().unwrap().borrow_mut().push(0);
1037 return Ok::<_, ()>(Some(bytes.freeze()));
1038 }
1039 DispatchItem::KeepAliveTimeout => {
1040 data.lock().unwrap().borrow_mut().push(1);
1041 }
1042 _ => (),
1043 }
1044 Ok(None)
1045 }
1046 }),
1047 );
1048 spawn(async move {
1049 let _ = disp.await;
1050 });
1051
1052 client.write("1");
1053 let buf = client.read().await.unwrap();
1054 assert_eq!(buf, Bytes::from_static(b"1"));
1055 sleep(Millis(750)).await;
1056
1057 client.write("2");
1058 let buf = client.read().await.unwrap();
1059 assert_eq!(buf, Bytes::from_static(b"2"));
1060
1061 sleep(Millis(750)).await;
1062 client.write("3");
1063 let buf = client.read().await.unwrap();
1064 assert_eq!(buf, Bytes::from_static(b"3"));
1065
1066 sleep(Millis(750)).await;
1067 assert!(!client.is_closed());
1068 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1069 }
1070
1071 #[ntex::test]
1072 async fn test_read_timeout() {
1073 let (client, server) = IoTest::create();
1074 client.remote_buffer_cap(1024);
1075
1076 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1077 let data2 = data.clone();
1078
1079 let io = Io::new(
1080 server,
1081 SharedCfg::new("TEST").add(
1082 IoConfig::new()
1083 .set_keepalive_timeout(Seconds::ZERO)
1084 .set_frame_read_rate(Seconds(1), Seconds(2), 2),
1085 ),
1086 );
1087
1088 let (disp, state) = Dispatcher::debug(
1089 io,
1090 BCodec(8),
1091 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1092 let data = data2.clone();
1093 async move {
1094 match msg {
1095 DispatchItem::Item(bytes) => {
1096 data.lock().unwrap().borrow_mut().push(0);
1097 return Ok::<_, ()>(Some(bytes.freeze()));
1098 }
1099 DispatchItem::ReadTimeout => {
1100 data.lock().unwrap().borrow_mut().push(1);
1101 }
1102 _ => (),
1103 }
1104 Ok(None)
1105 }
1106 }),
1107 );
1108 spawn(async move {
1109 let _ = disp.await;
1110 });
1111
1112 client.write("12345678");
1113 let buf = client.read().await.unwrap();
1114 assert_eq!(buf, Bytes::from_static(b"12345678"));
1115
1116 client.write("1");
1117 sleep(Millis(1000)).await;
1118 assert!(!state.flags().contains(Flags::IO_STOPPING));
1119 client.write("23");
1120 sleep(Millis(1000)).await;
1121 assert!(!state.flags().contains(Flags::IO_STOPPING));
1122 client.write("4");
1123 sleep(Millis(2000)).await;
1124
1125 assert!(state.flags().contains(Flags::IO_STOPPING));
1127 assert!(client.is_closed());
1128 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1129 }
1130
1131 #[ntex::test]
1132 async fn test_idle_timeout() {
1133 let (client, server) = IoTest::create();
1134 client.remote_buffer_cap(1024);
1135
1136 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1137 let data2 = data.clone();
1138
1139 let io = Io::new(
1140 server,
1141 SharedCfg::new("DBG").add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO)),
1142 );
1143 let ioref = io.get_ref();
1144
1145 let (disp, state) = Dispatcher::debug(
1146 io,
1147 BCodec(1),
1148 ntex_service::fn_service(move |msg: DispatchItem<BCodec>| {
1149 let ioref = ioref.clone();
1150 ntex::rt::spawn(async move {
1151 sleep(Millis(500)).await;
1152 ioref.notify_timeout();
1153 });
1154 let data = data2.clone();
1155 async move {
1156 match msg {
1157 DispatchItem::Item(bytes) => {
1158 data.lock().unwrap().borrow_mut().push(0);
1159 return Ok::<_, ()>(Some(bytes.freeze()));
1160 }
1161 DispatchItem::ReadTimeout => {
1162 data.lock().unwrap().borrow_mut().push(1);
1163 }
1164 _ => (),
1165 }
1166 Ok(None)
1167 }
1168 }),
1169 );
1170 spawn(async move {
1171 let _ = disp.await;
1172 });
1173
1174 client.write("1");
1175 let buf = client.read().await.unwrap();
1176 assert_eq!(buf, Bytes::from_static(b"1"));
1177
1178 sleep(Millis(1000)).await;
1179 assert!(state.flags().contains(Flags::IO_STOPPING));
1180 assert!(client.is_closed());
1181 }
1182
1183 #[ntex::test]
1184 async fn test_unhandled_data() {
1185 let handled = Arc::new(AtomicBool::new(false));
1186 let handled2 = handled.clone();
1187
1188 let (client, server) = IoTest::create();
1189 client.remote_buffer_cap(1024);
1190 client.write("GET /test HTTP/1\r\n\r\n");
1191
1192 let (disp, _) = Dispatcher::debug(
1193 Io::from(server),
1194 BytesCodec,
1195 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1196 handled2.store(true, Relaxed);
1197 async move {
1198 sleep(Millis(50)).await;
1199 if let DispatchItem::Item(msg) = msg {
1200 Ok::<_, ()>(Some(msg.freeze()))
1201 } else if let DispatchItem::Disconnect(_) = msg {
1202 Ok::<_, ()>(None)
1203 } else {
1204 panic!()
1205 }
1206 }
1207 }),
1208 );
1209 client.close().await;
1210 spawn(async move {
1211 let _ = disp.await;
1212 });
1213 sleep(Millis(50)).await;
1214
1215 assert!(handled.load(Relaxed));
1216 }
1217}