1use std::task::{Context, Poll, ready};
3use std::{cell::Cell, cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc};
4
5use ntex_codec::{Decoder, Encoder};
6use ntex_dispatcher::{Control, DispatchItem, Reason};
7use ntex_io::{Decoded, IoBoxed, IoRef, IoStatusUpdate, RecvError};
8use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
9use ntex_util::channel::condition::Condition;
10use ntex_util::{future::Either, future::select, spawn, task::LocalWaker, time::Seconds};
11
12type Response<U> = <U as Encoder>::Item;
13type Queue<T, E> = RefCell<VecDeque<ServiceResult<Result<T, E>>>>;
14
15pin_project_lite::pin_project! {
16 pub(crate) struct Dispatcher<S, U>
18 where
19 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
20 S: 'static,
21 U: Encoder,
22 U: Decoder,
23 U: 'static,
24 {
25 inner: DispatcherInner<S, U>
26 }
27}
28
29bitflags::bitflags! {
30 #[derive(Copy, Clone, Eq, PartialEq, Debug)]
31 struct Flags: u8 {
32 const READY_ERR = 0b0000_0001;
33 const IO_ERR = 0b0000_0010;
34 const KA_ENABLED = 0b0000_0100;
35 const KA_TIMEOUT = 0b0000_1000;
36 const READ_TIMEOUT = 0b0001_0000;
37 const READY = 0b0010_0000;
38 const READY_TASK = 0b0100_0000;
39 const RESPONSE_STOP = 0b1000_0000;
40 }
41}
42
43struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
44 io: IoBoxed,
45 flags: Flags,
46 codec: U,
47 service: PipelineBinding<S, DispatchItem<U>>,
48 st: IoDispatcherState,
49 state: Rc<DispatcherState<S, U>>,
50 read_remains: u32,
51 read_remains_prev: u32,
52 read_max_timeout: Seconds,
53 keepalive_timeout: Seconds,
54}
55
56struct DispatcherState<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
57 error: Cell<Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>>,
58 base: Cell<usize>,
59 queue: Queue<S::Response, S::Error>,
60 waker: LocalWaker,
61 stopping: Condition,
62 response: Cell<ResponseCall<S, U>>,
63 response_idx: Cell<usize>,
64}
65
66#[derive(Default)]
67enum ResponseCall<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
68 Call(PipelineCall<S, DispatchItem<U>>),
69 Canceled,
70 #[default]
71 Empty,
72}
73
74enum ServiceResult<T> {
75 Pending,
76 Ready(T),
77}
78
79impl<T> ServiceResult<T> {
80 fn take(&mut self) -> Option<T> {
81 let this = std::mem::replace(self, ServiceResult::Pending);
82 match this {
83 ServiceResult::Pending => None,
84 ServiceResult::Ready(result) => Some(result),
85 }
86 }
87}
88
89#[derive(Copy, Clone, Debug)]
90enum IoDispatcherState {
91 Processing,
92 Backpressure,
93 Stop,
94 Shutdown,
95}
96
97pub(crate) enum IoDispatcherError<S, U> {
98 Encoder(U),
99 Service(S),
100}
101
102enum PollService<U: Encoder + Decoder> {
103 Item(DispatchItem<U>),
104 ItemWait(DispatchItem<U>),
105 Continue,
106 Ready,
107}
108
109impl<S, U> From<S> for IoDispatcherError<S, U> {
110 fn from(err: S) -> Self {
111 IoDispatcherError::Service(err)
112 }
113}
114
115impl<S, U> Dispatcher<S, U>
116where
117 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
118 U: Decoder + Encoder + Clone + 'static,
119 <U as Encoder>::Item: 'static,
120{
121 pub(crate) fn new<F: IntoService<S, DispatchItem<U>>>(
123 io: IoBoxed,
124 codec: U,
125 service: F,
126 ) -> Self {
127 let state = Rc::new(DispatcherState {
128 error: Cell::new(None),
129 base: Cell::new(0),
130 queue: RefCell::new(VecDeque::new()),
131 waker: LocalWaker::default(),
132 response: Cell::new(ResponseCall::Empty),
133 response_idx: Cell::new(0),
134 stopping: Condition::new(),
135 });
136 let keepalive_timeout = io.cfg().keepalive_timeout();
137
138 Dispatcher {
139 inner: DispatcherInner {
140 io,
141 codec,
142 state,
143 keepalive_timeout,
144 flags: if keepalive_timeout.is_zero() {
145 Flags::KA_ENABLED
146 } else {
147 Flags::empty()
148 },
149 service: Pipeline::new(service.into_service()).bind(),
150 st: IoDispatcherState::Processing,
151 read_remains: 0,
152 read_remains_prev: 0,
153 read_max_timeout: Seconds::ZERO,
154 },
155 }
156 }
157
158 pub(crate) fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
164 self.inner.keepalive_timeout = timeout;
165 if timeout.is_zero() {
166 self.inner.flags.remove(Flags::KA_ENABLED);
167 } else {
168 self.inner.flags.insert(Flags::KA_ENABLED);
169 }
170 self
171 }
172}
173
174impl<S, U> DispatcherState<S, U>
175where
176 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
177 U: Encoder + Decoder,
178 <U as Encoder>::Item: 'static,
179{
180 fn handle_result(
181 &self,
182 item: Result<S::Response, S::Error>,
183 response_idx: usize,
184 io: &IoRef,
185 codec: &U,
186 stop: bool,
187 ) -> bool {
188 let mut queue = self.queue.borrow_mut();
189
190 if stop {
191 self.stopping.notify();
192
193 let resp = self.response.take();
195 if matches!(resp, ResponseCall::Call(_) | ResponseCall::Canceled) {
196 self.response.set(ResponseCall::Canceled);
197 }
198 }
199
200 let idx = response_idx.wrapping_sub(self.base.get());
201
202 if idx == 0 {
204 let _ = queue.pop_front();
205 self.base.set(self.base.get().wrapping_add(1));
206 match item {
207 Err(err) => {
208 self.error.set(Some(err.into()));
209 }
210 Ok(Some(item)) => {
211 if let Err(err) = io.encode(item, codec) {
212 self.error.set(Some(IoDispatcherError::Encoder(err)));
213 }
214 }
215 Ok(None) => (),
216 }
217
218 while let Some(item) = queue.front_mut().and_then(ServiceResult::take) {
220 let _ = queue.pop_front();
221 self.base.set(self.base.get().wrapping_add(1));
222 match item {
223 Err(err) => {
224 self.error.set(Some(err.into()));
225 }
226 Ok(Some(item)) => {
227 if let Err(err) = io.encode(item, codec) {
228 self.error.set(Some(IoDispatcherError::Encoder(err)));
229 }
230 }
231 Ok(None) => (),
232 }
233 }
234
235 queue.is_empty()
236 } else {
237 queue[idx] = ServiceResult::Ready(item);
238 false
239 }
240 }
241}
242
243impl<S, U> Future for Dispatcher<S, U>
244where
245 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
246 U: Decoder + Encoder + Clone + 'static,
247 <U as Encoder>::Item: 'static,
248{
249 type Output = Result<(), S::Error>;
250
251 #[allow(clippy::too_many_lines)]
252 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
253 let this = self.as_mut().project();
254 let inner = this.inner;
255
256 inner.state.waker.register(cx.waker());
257
258 match inner.state.response.take() {
260 ResponseCall::Call(mut fut) => {
261 if let Poll::Ready(item) = Pin::new(&mut fut).poll(cx) {
262 let stop = if inner.flags.contains(Flags::RESPONSE_STOP) {
263 inner.flags.remove(Flags::RESPONSE_STOP);
264 true
265 } else {
266 false
267 };
268 inner.state.handle_result(
269 item,
270 inner.state.response_idx.get(),
271 inner.io.as_ref(),
272 &inner.codec,
273 stop,
274 );
275 } else {
276 inner.state.response.set(ResponseCall::Call(fut));
277 }
278 }
279 ResponseCall::Canceled => {
280 inner.state.handle_result(
281 Ok(None),
282 inner.state.response_idx.get(),
283 inner.io.as_ref(),
284 &inner.codec,
285 true,
286 );
287 }
288 ResponseCall::Empty => {}
289 }
290
291 loop {
292 match inner.st {
293 IoDispatcherState::Processing => {
294 let (item, nowait, stop) = match ready!(inner.poll_service(cx)) {
295 PollService::Ready => {
296 match inner.io.poll_recv_decode(&inner.codec, cx) {
298 Ok(decoded) => {
299 inner.update_timer(&decoded);
300 if let Some(el) = decoded.item {
301 (DispatchItem::Item(el), true, false)
302 } else {
303 return Poll::Pending;
304 }
305 }
306 Err(RecvError::KeepAlive) => {
307 if let Err(err) = inner.handle_timeout() {
308 inner.stop();
309 (DispatchItem::Stop(err), true, true)
310 } else {
311 continue;
312 }
313 }
314 Err(RecvError::WriteBackpressure) => {
315 inner.st = IoDispatcherState::Backpressure;
316 (
317 DispatchItem::Control(Control::WBackPressureEnabled),
318 true,
319 false,
320 )
321 }
322 Err(RecvError::Decoder(err)) => {
323 inner.stop();
324 (DispatchItem::Stop(Reason::Decoder(err)), true, true)
325 }
326 Err(RecvError::PeerGone(err)) => {
327 inner.stop();
328 (DispatchItem::Stop(Reason::Io(err)), true, true)
329 }
330 }
331 }
332 PollService::Item(item) => (item, true, false),
333 PollService::ItemWait(item) => (item, false, false),
334 PollService::Continue => continue,
335 };
336
337 inner.call_service(cx, item, nowait, stop);
338 }
339 IoDispatcherState::Backpressure => {
341 match ready!(inner.poll_service(cx)) {
342 PollService::Ready => (),
343 PollService::Item(item) => inner.call_service(cx, item, true, false),
344 PollService::ItemWait(item) => {
345 inner.call_service(cx, item, false, false);
346 }
347 PollService::Continue => continue,
348 }
349
350 let item = if let Err(err) = ready!(inner.io.poll_flush(cx, false)) {
351 inner.stop();
352 DispatchItem::Stop(Reason::Io(Some(err)))
353 } else {
354 inner.st = IoDispatcherState::Processing;
355 DispatchItem::Control(Control::WBackPressureDisabled)
356 };
357 inner.call_service(cx, item, false, false);
358 }
359
360 IoDispatcherState::Stop => {
362 inner.io.stop_timer();
363
364 if !inner.flags.contains(Flags::READY_ERR)
366 && let Poll::Ready(res) = inner.service.poll_ready(cx)
367 && res.is_err()
368 {
369 inner.flags.insert(Flags::READY_ERR);
370 }
371
372 if inner.state.queue.borrow().is_empty() {
373 if inner.io.poll_shutdown(cx).is_ready() {
374 log::trace!("{}: io shutdown completed", inner.io.tag());
375 inner.st = IoDispatcherState::Shutdown;
376 continue;
377 }
378 } else if !inner.flags.contains(Flags::IO_ERR) {
379 match ready!(inner.io.poll_status_update(cx)) {
380 IoStatusUpdate::PeerGone(_) | IoStatusUpdate::KeepAlive => {
381 inner.flags.insert(Flags::IO_ERR);
382 continue;
383 }
384 IoStatusUpdate::WriteBackpressure => {
385 if ready!(inner.io.poll_flush(cx, true)).is_err() {
386 inner.flags.insert(Flags::IO_ERR);
387 }
388 continue;
389 }
390 }
391 } else {
392 inner.io.poll_dispatch(cx);
393 }
394 return Poll::Pending;
395 }
396 IoDispatcherState::Shutdown => {
398 return if inner.service.poll_shutdown(cx).is_ready() {
399 log::trace!("{}: Service shutdown is completed, stop", inner.io.tag());
400
401 Poll::Ready(
402 if let Some(IoDispatcherError::Service(err)) =
403 inner.state.error.take()
404 {
405 Err(err)
406 } else {
407 Ok(())
408 },
409 )
410 } else {
411 Poll::Pending
412 };
413 }
414 }
415 }
416 }
417}
418
419impl<S, U> DispatcherInner<S, U>
420where
421 S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
422 U: Decoder + Encoder + Clone + 'static,
423 <U as Encoder>::Item: 'static,
424{
425 fn stop(&mut self) {
426 self.st = IoDispatcherState::Stop;
427 }
428
429 fn call_service(
430 &mut self,
431 cx: &mut Context<'_>,
432 item: DispatchItem<U>,
433 nowait: bool,
434 stop: bool,
435 ) {
436 let mut fut = if nowait {
437 self.service.call_nowait(item)
438 } else {
439 self.service.call(item)
440 };
441 let mut queue = self.state.queue.borrow_mut();
442
443 let resp = self.state.response.take();
445 if matches!(resp, ResponseCall::Call(_) | ResponseCall::Canceled) {
446 self.state.response.set(resp);
448
449 let response_idx = self.state.base.get().wrapping_add(queue.len());
450 queue.push_back(ServiceResult::Pending);
451
452 let st = self.io.get_ref();
453 let codec = self.codec.clone();
454 let state = self.state.clone();
455
456 spawn(async move {
457 let empty_q = match select(fut, state.stopping.wait()).await {
458 Either::Left(item) => {
459 state.handle_result(item, response_idx, &st, &codec, stop)
460 }
461 Either::Right(()) => {
462 state.handle_result(Ok(None), response_idx, &st, &codec, stop)
463 }
464 };
465 if empty_q || stop {
466 st.wake();
467 }
468 });
469 } else if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) {
470 if queue.is_empty() {
472 match res {
473 Err(err) => {
474 self.state.error.set(Some(err.into()));
475 }
476 Ok(Some(item)) => {
477 if let Err(err) = self.io.encode(item, &self.codec) {
478 self.state.error.set(Some(IoDispatcherError::Encoder(err)));
479 }
480 }
481 Ok(None) => (),
482 }
483 } else {
484 if stop {
485 self.state.stopping.notify();
486 }
487 queue.push_back(ServiceResult::Ready(res));
488 self.state.response_idx.set(self.state.base.get().wrapping_add(queue.len()));
489 }
490 } else {
491 if stop {
492 self.flags.insert(Flags::RESPONSE_STOP);
493 }
494 self.state.response.set(ResponseCall::Call(fut));
495 self.state.response_idx.set(self.state.base.get().wrapping_add(queue.len()));
496 queue.push_back(ServiceResult::Pending);
497 }
498 }
499
500 fn check_error(&mut self) -> PollService<U> {
501 if let Some(err) = self.state.error.take() {
503 log::trace!("{}: Error occured, stopping dispatcher", self.io.tag());
504 self.stop();
505 match err {
506 IoDispatcherError::Encoder(err) => {
507 PollService::Item(DispatchItem::Stop(Reason::Encoder(err)))
508 }
509 IoDispatcherError::Service(err) => {
510 self.state.error.set(Some(IoDispatcherError::Service(err)));
511 PollService::Continue
512 }
513 }
514 } else {
515 PollService::Ready
516 }
517 }
518
519 fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
520 match self.service.poll_ready(cx) {
521 Poll::Ready(Ok(())) => Poll::Ready(self.check_error()),
522 Poll::Pending => {
524 log::trace!("{}: Service is not ready, pause read task", self.io.tag());
525
526 self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
528 self.io.stop_timer();
529
530 match ready!(self.io.poll_read_pause(cx)) {
531 IoStatusUpdate::KeepAlive => {
532 log::trace!(
533 "{}: Keep-alive error, stopping dispatcher during pause",
534 self.io.tag()
535 );
536 self.stop();
537 Poll::Ready(PollService::ItemWait(DispatchItem::Stop(
538 Reason::KeepAliveTimeout,
539 )))
540 }
541 IoStatusUpdate::PeerGone(err) => {
542 log::trace!(
543 "{}: Peer is gone during pause, stopping dispatcher: {:?}",
544 self.io.tag(),
545 err
546 );
547 self.stop();
548 Poll::Ready(PollService::ItemWait(DispatchItem::Stop(Reason::Io(err))))
549 }
550 IoStatusUpdate::WriteBackpressure => {
551 self.st = IoDispatcherState::Backpressure;
552 Poll::Ready(PollService::ItemWait(DispatchItem::Control(
553 Control::WBackPressureEnabled,
554 )))
555 }
556 }
557 }
558 Poll::Ready(Err(err)) => {
560 log::error!("{}: Service readiness check failed, stopping", self.io.tag());
561 self.stop();
562 self.flags.insert(Flags::READY_ERR);
563 self.state.error.set(Some(IoDispatcherError::Service(err)));
564 Poll::Ready(PollService::Continue)
565 }
566 }
567 }
568
569 fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
570 if decoded.item.is_some() {
572 self.read_remains = 0;
573 self.flags.remove(Flags::KA_TIMEOUT | Flags::READ_TIMEOUT);
574 } else if self.flags.contains(Flags::READ_TIMEOUT) {
575 self.read_remains = decoded.remains as u32;
577 } else if self.read_remains == 0 && decoded.remains == 0 {
578 if self.flags.contains(Flags::KA_ENABLED) && !self.flags.contains(Flags::KA_TIMEOUT)
580 {
581 log::trace!(
582 "{}: Start keep-alive timer {:?}",
583 self.io.tag(),
584 self.keepalive_timeout
585 );
586 self.flags.insert(Flags::KA_TIMEOUT);
587 self.io.start_timer(self.keepalive_timeout);
588 }
589 } else if let Some(params) = self.io.cfg().frame_read_rate() {
590 self.flags.insert(Flags::READ_TIMEOUT);
593
594 self.read_remains = decoded.remains as u32;
595 self.read_remains_prev = 0;
596 self.read_max_timeout = params.max_timeout;
597 self.io.start_timer(params.timeout);
598
599 log::trace!("{}: Start frame read timer {:?}", self.io.tag(), params.timeout);
600 }
601 }
602
603 fn handle_timeout(&mut self) -> Result<(), Reason<U>> {
604 if self.flags.contains(Flags::READ_TIMEOUT) {
606 if let Some(params) = self.io.cfg().frame_read_rate() {
607 let total = self.read_remains - self.read_remains_prev;
608
609 if total > params.rate {
611 self.read_remains_prev = self.read_remains;
612 self.read_remains = 0;
613
614 if !params.max_timeout.is_zero() {
615 self.read_max_timeout =
616 Seconds(self.read_max_timeout.0.saturating_sub(params.timeout.0));
617 }
618
619 if params.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
620 log::trace!(
621 "{}: Frame read rate {:?}, extend timer",
622 self.io.tag(),
623 total
624 );
625 self.io.start_timer(params.timeout);
626 return Ok(());
627 }
628 }
629 log::trace!("{}: Max payload timeout has been reached", self.io.tag());
630 return Err(Reason::ReadTimeout);
631 }
632 } else if self.flags.contains(Flags::KA_TIMEOUT) {
633 log::trace!("{}: Keep-alive error, stopping dispatcher", self.io.tag());
634 return Err(Reason::KeepAliveTimeout);
635 }
636 Ok(())
637 }
638}
639
640#[cfg(test)]
641#[allow(clippy::items_after_statements)]
642mod tests {
643 use std::sync::{Arc, Mutex, atomic::AtomicBool, atomic::Ordering};
644 use std::{cell::Cell, io};
645
646 use ntex_bytes::{Bytes, BytesMut};
647 use ntex_codec::BytesCodec;
648 use ntex_io::{self as nio, IoConfig, testing::IoTest as Io};
649 use ntex_service::{ServiceCtx, cfg::SharedCfg};
650 use ntex_util::channel::condition::Condition;
651 use ntex_util::time::{Millis, sleep};
652 use rand::Rng;
653
654 use super::*;
655
656 impl<S, U> Dispatcher<S, U>
657 where
658 S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
659 S::Error: 'static,
660 U: Decoder + Encoder + 'static,
661 <U as Encoder>::Item: 'static,
662 {
663 pub(crate) fn new_debug<F: IntoService<S, DispatchItem<U>>>(
665 io: nio::Io,
666 codec: U,
667 service: F,
668 ) -> (Self, nio::IoRef) {
669 let keepalive_timeout = io.cfg().keepalive_timeout();
670 let rio = io.get_ref();
671
672 let state = Rc::new(DispatcherState {
673 error: Cell::new(None),
674 base: Cell::new(0),
675 waker: LocalWaker::default(),
676 queue: RefCell::new(VecDeque::new()),
677 stopping: Condition::new(),
678 response: Cell::new(ResponseCall::Empty),
679 response_idx: Cell::new(0),
680 });
681
682 (
683 Dispatcher {
684 inner: DispatcherInner {
685 codec,
686 state,
687 keepalive_timeout,
688 service: Pipeline::new(service.into_service()).bind(),
689 io: IoBoxed::from(io),
690 st: IoDispatcherState::Processing,
691 flags: if keepalive_timeout.is_zero() {
692 Flags::empty()
693 } else {
694 Flags::KA_ENABLED
695 },
696 read_remains: 0,
697 read_remains_prev: 0,
698 read_max_timeout: Seconds::ZERO,
699 },
700 },
701 rio,
702 )
703 }
704 }
705
706 #[ntex::test]
707 async fn test_basic() {
708 let (client, server) = Io::create();
709 client.remote_buffer_cap(1024);
710 client.write("GET /test HTTP/1\r\n\r\n");
711
712 let (disp, _) = Dispatcher::new_debug(
713 nio::Io::new(server, SharedCfg::new("DBG")),
714 BytesCodec,
715 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
716 sleep(Millis(50)).await;
717 if let DispatchItem::Item(msg) = msg {
718 Ok::<_, ()>(Some(msg))
719 } else {
720 panic!()
721 }
722 }),
723 );
724 ntex_util::spawn(async move {
725 let _ = disp.await;
726 });
727 sleep(Millis(25)).await;
728 client.write("GET /test HTTP/1\r\n\r\n");
729
730 let buf = client.read().await.unwrap();
731 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
732
733 let buf = client.read().await.unwrap();
734 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
735
736 client.close().await;
737 assert!(client.is_server_dropped());
738 }
739
740 #[ntex::test]
741 async fn test_drop_connection() {
742 let (client, server) = Io::create();
743 client.remote_buffer_cap(1024);
744 client.write("test");
745
746 #[derive(Clone)]
747 struct OnDrop(Rc<Cell<bool>>);
748 impl Drop for OnDrop {
749 fn drop(&mut self) {
750 if Rc::strong_count(&self.0) == 2 {
751 self.0.set(true);
752 }
753 }
754 }
755 let ops = Rc::new(Cell::new(false));
756 let on_drop = OnDrop(ops.clone());
757
758 let (disp, _) = Dispatcher::new_debug(
759 nio::Io::new(server, SharedCfg::new("DBG")),
760 BytesCodec,
761 ntex_service::fn_service(async move |msg: DispatchItem<BytesCodec>| {
762 let _on_drop = on_drop.clone();
763 if let DispatchItem::Item(msg) = msg {
764 if msg == "test" {
765 sleep(Millis(500)).await;
766 }
767 Ok::<_, ()>(Some(msg))
768 } else {
769 Ok::<_, ()>(None)
770 }
771 }),
772 );
773 ntex_util::spawn(async move {
774 let _ = disp.await;
775 });
776 sleep(Millis(25)).await;
777 client.write("pl1");
778 client.close().await;
779 assert!(client.is_server_dropped());
780 assert!(ops.get());
782 }
783
784 #[ntex::test]
785 async fn test_ordering() {
786 let (client, server) = Io::create();
787 client.remote_buffer_cap(1024);
788 client.write("test");
789
790 let condition = Condition::new();
791 let waiter = condition.wait();
792
793 let (disp, _) = Dispatcher::new_debug(
794 nio::Io::new(server, SharedCfg::new("DBG")),
795 BytesCodec,
796 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
797 let waiter = waiter.clone();
798 async move {
799 waiter.await;
800 if let DispatchItem::Item(msg) = msg {
801 Ok::<_, ()>(Some(msg))
802 } else if matches!(msg, DispatchItem::Stop(Reason::Io(_))) {
803 Ok(None)
804 } else {
805 panic!()
806 }
807 }
808 }),
809 );
810 ntex_util::spawn(async move {
811 let _ = disp.await;
812 });
813 sleep(Millis(50)).await;
814
815 client.write("test");
816 sleep(Millis(50)).await;
817 client.write("test");
818 sleep(Millis(50)).await;
819 condition.notify();
820
821 let buf = client.read().await.unwrap();
822 assert_eq!(buf, Bytes::from_static(b"testtesttest"));
823
824 client.close().await;
825 assert!(client.is_server_dropped());
826 }
827
828 #[ntex::test]
831 async fn test_disconnect_ordering() {
832 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
833 enum Info {
834 Publish,
835 PublishDrop,
836 Disconnect,
837 }
838
839 struct OnDrop(Rc<RefCell<Vec<Info>>>);
840 impl Drop for OnDrop {
841 fn drop(&mut self) {
842 self.0.borrow_mut().push(Info::PublishDrop);
843 }
844 }
845
846 let condition = Condition::new();
847 let waiter = condition.wait();
848 let ops = Rc::new(RefCell::new(Vec::new()));
849 let ops2 = ops.clone();
850
851 let run_server = async || -> Io {
852 let (client, server) = Io::create();
853 client.remote_buffer_cap(1024);
854
855 let (disp, _) = Dispatcher::new_debug(
856 nio::Io::new(server, SharedCfg::new("DBG")),
857 BytesCodec,
858 ntex_service::fn_service(async move |msg: DispatchItem<BytesCodec>| {
859 if let DispatchItem::Item(msg) = msg {
860 if msg == b"1" {
861 sleep(Millis(75)).await;
862 } else {
863 ops2.borrow_mut().push(Info::Publish);
864 let on_drop = OnDrop(ops2.clone());
865 waiter.clone().await;
866 drop(on_drop);
867 }
868 Ok::<_, ()>(Some(msg))
869 } else if matches!(msg, DispatchItem::Stop(Reason::Io(_))) {
870 sleep(Millis(25)).await;
871 ops2.borrow_mut().push(Info::Disconnect);
872 Ok(None)
873 } else {
874 panic!()
875 }
876 }),
877 );
878 ntex_util::spawn(async move {
879 let _ = disp.await;
880 });
881 sleep(Millis(50)).await;
882
883 client
884 };
885 let client = run_server.clone()().await;
886
887 client.write("test");
888 sleep(Millis(50)).await;
889 client.write("test");
890 sleep(Millis(50)).await;
891 client.close().await;
892 assert!(client.is_server_dropped());
893 sleep(Millis(150)).await;
894
895 assert_eq!(
896 &[
897 Info::Publish,
898 Info::Publish,
899 Info::Disconnect,
900 Info::PublishDrop,
901 Info::PublishDrop
902 ][..],
903 &*ops.borrow()
904 );
905
906 ops.borrow_mut().clear();
908 let client = run_server().await;
909
910 client.write("1");
911 sleep(Millis(50)).await;
912
913 client.write("test");
914 sleep(Millis(50)).await;
915 client.write("test");
916 sleep(Millis(50)).await;
917 client.close().await;
918 assert!(client.is_server_dropped());
919 sleep(Millis(150)).await;
920
921 assert_eq!(
922 &[
923 Info::Publish,
924 Info::Publish,
925 Info::Disconnect,
926 Info::PublishDrop,
927 Info::PublishDrop
928 ][..],
929 &*ops.borrow()
930 );
931 }
932
933 #[ntex::test]
934 async fn test_sink() {
935 let (client, server) = Io::create();
936 client.remote_buffer_cap(1024);
937 client.write("GET /test HTTP/1\r\n\r\n");
938
939 let (disp, io) = Dispatcher::new_debug(
940 nio::Io::new(server, SharedCfg::new("DBG")),
941 BytesCodec,
942 ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
943 if let DispatchItem::Item(msg) = msg {
944 Ok::<_, ()>(Some(msg))
945 } else if let DispatchItem::Stop(Reason::Io(_)) = msg {
946 Ok(None)
947 } else {
948 panic!()
949 }
950 }),
951 );
952 ntex_util::spawn(async move {
953 let _ = disp.await;
954 });
955
956 let buf = client.read().await.unwrap();
957 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
958
959 assert!(io.encode(Bytes::from_static(b"test"), &BytesCodec).is_ok());
960 let buf = client.read().await.unwrap();
961 assert_eq!(buf, Bytes::from_static(b"test"));
962
963 io.close();
964 sleep(Millis(150)).await;
965 assert!(client.is_server_dropped());
966 }
967
968 #[ntex::test]
969 async fn test_err_in_service() {
970 let (client, server) = Io::create();
971 client.remote_buffer_cap(0);
972 client.write("GET /test HTTP/1\r\n\r\n");
973
974 let (disp, io) = Dispatcher::new_debug(
975 nio::Io::new(server, SharedCfg::new("DBG")),
976 BytesCodec,
977 ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
978 Err::<Option<Bytes>, _>(())
979 }),
980 );
981 ntex_util::spawn(async move {
982 let _ = disp.await;
983 });
984
985 io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap();
986
987 client.remote_buffer_cap(1024);
989 let buf = client.read().await.unwrap();
990 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
991
992 sleep(Millis(50)).await;
994 assert!(client.is_closed());
995
996 client.close().await;
998 assert!(client.is_server_dropped());
999 }
1000
1001 #[ntex::test]
1002 async fn test_err_in_service_ready() {
1003 struct Srv(Rc<Cell<usize>>);
1004
1005 impl Service<DispatchItem<BytesCodec>> for Srv {
1006 type Response = Option<Response<BytesCodec>>;
1007 type Error = ();
1008
1009 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), ()> {
1010 self.0.set(self.0.get() + 1);
1011 Err(())
1012 }
1013
1014 async fn call(
1015 &self,
1016 _: DispatchItem<BytesCodec>,
1017 _: ServiceCtx<'_, Self>,
1018 ) -> Result<Option<Response<BytesCodec>>, ()> {
1019 Ok(None)
1020 }
1021 }
1022
1023 let (client, server) = Io::create();
1024 client.remote_buffer_cap(0);
1025 client.write("GET /test HTTP/1\r\n\r\n");
1026
1027 let counter = Rc::new(Cell::new(0));
1028
1029 let (disp, io) = Dispatcher::new_debug(
1030 nio::Io::new(server, SharedCfg::new("DBG")),
1031 BytesCodec,
1032 Srv(counter.clone()),
1033 );
1034 io.encode(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), &BytesCodec).unwrap();
1035 ntex_util::spawn(async move {
1036 let _ = disp.await;
1037 });
1038
1039 client.remote_buffer_cap(1024);
1041 let buf = client.read().await.unwrap();
1042 assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
1043
1044 sleep(Millis(50)).await;
1046 assert!(client.is_closed());
1047
1048 client.close().await;
1050 assert!(client.is_server_dropped());
1051
1052 assert_eq!(counter.get(), 1);
1054 }
1055
1056 #[ntex::test]
1057 async fn test_write_backpressure() {
1058 let (client, server) = Io::create();
1059 client.remote_buffer_cap(0);
1061 client.write("GET /test HTTP/1\r\n\r\n");
1062
1063 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1064 let data2 = data.clone();
1065
1066 let config = SharedCfg::new("DBG").add(
1067 IoConfig::new().set_read_buf(8 * 1024, 1024, 16).set_write_buf(32 * 1024, 1024, 16),
1068 );
1069
1070 let (disp, io) = Dispatcher::new_debug(
1071 nio::Io::new(server, config),
1072 BytesCodec,
1073 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1074 let data = data2.clone();
1075 async move {
1076 match msg {
1077 DispatchItem::Item(_) => {
1078 data.lock().unwrap().borrow_mut().push(0);
1079 let bytes = rand::thread_rng()
1080 .sample_iter(&rand::distributions::Alphanumeric)
1081 .take(65_536)
1082 .map(char::from)
1083 .collect::<String>();
1084 return Ok::<_, ()>(Some(Bytes::from(bytes)));
1085 }
1086 DispatchItem::Control(Control::WBackPressureEnabled) => {
1087 data.lock().unwrap().borrow_mut().push(1);
1088 }
1089 DispatchItem::Control(Control::WBackPressureDisabled) => {
1090 data.lock().unwrap().borrow_mut().push(2);
1091 }
1092 _ => (),
1093 }
1094 Ok(None)
1095 }
1096 }),
1097 );
1098
1099 ntex_util::spawn(async move {
1100 let _ = disp.await;
1101 });
1102
1103 let buf = client.read_any();
1104 assert_eq!(buf, Bytes::from_static(b""));
1105 client.write("GET /test HTTP/1\r\n\r\n");
1106 sleep(Millis(25)).await;
1107
1108 assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
1110
1111 assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 65536);
1113
1114 client.remote_buffer_cap(10240);
1115 sleep(Millis(50)).await;
1116 assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 55296);
1117
1118 client.remote_buffer_cap(45056);
1119 sleep(Millis(50)).await;
1120 assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 10240);
1121
1122 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
1124 }
1125
1126 #[ntex::test]
1127 async fn test_shutdown_dispatcher_waker() {
1128 let (client, server) = Io::create();
1129 let server = nio::Io::new(server, SharedCfg::new("DBG"));
1130 client.remote_buffer_cap(1024);
1131
1132 let flag = Rc::new(Cell::new(true));
1133 let flag2 = flag.clone();
1134 let server_ref = server.get_ref();
1135
1136 let (disp, _io) = Dispatcher::new_debug(
1137 server,
1138 BytesCodec,
1139 ntex_service::fn_service(async move |item: DispatchItem<BytesCodec>| {
1140 let first = flag2.get();
1141 flag2.set(false);
1142 if let DispatchItem::Item(b) = item {
1143 if !first {
1144 sleep(Millis(500)).await;
1145 }
1146 Ok(Some(b))
1147 } else {
1148 server_ref.close();
1149 Ok::<_, ()>(None)
1150 }
1151 }),
1152 );
1153 let (tx, rx) = ntex_util::channel::oneshot::channel();
1154 ntex_util::spawn(async move {
1155 let _ = disp.await;
1156 let _ = tx.send(());
1157 });
1158
1159 client.write(b"msg1");
1161 sleep(Millis(25)).await;
1162
1163 client.write(b"msg2");
1165
1166 sleep(Millis(150)).await;
1168 let buf = client.read().await.unwrap();
1169 assert_eq!(buf, Bytes::from_static(b"msg1"));
1170
1171 client.close().await;
1173 let _ = rx.recv().await;
1174 }
1175
1176 #[ntex::test]
1178 async fn test_keepalive() {
1179 let (client, server) = Io::create();
1180 client.remote_buffer_cap(1024);
1181
1182 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1183 let data2 = data.clone();
1184
1185 let (disp, _) = Dispatcher::new_debug(
1186 nio::Io::new(server, SharedCfg::new("DBG")),
1187 BytesCodec,
1188 ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
1189 let data = data2.clone();
1190 async move {
1191 match msg {
1192 DispatchItem::Item(bytes) => {
1193 data.lock().unwrap().borrow_mut().push(0);
1194 return Ok::<_, ()>(Some(bytes));
1195 }
1196 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1197 data.lock().unwrap().borrow_mut().push(1);
1198 }
1199 _ => (),
1200 }
1201 Ok(None)
1202 }
1203 }),
1204 );
1205 ntex_util::spawn(async move {
1206 let _ = disp.keepalive_timeout(Seconds(2)).await;
1207 });
1208
1209 client.write("1");
1210 let buf = client.read().await.unwrap();
1211 assert_eq!(buf, Bytes::from_static(b"1"));
1212 sleep(Millis(750)).await;
1213
1214 client.write("2");
1215 let buf = client.read().await.unwrap();
1216 assert_eq!(buf, Bytes::from_static(b"2"));
1217
1218 sleep(Millis(750)).await;
1219 client.write("3");
1220 let buf = client.read().await.unwrap();
1221 assert_eq!(buf, Bytes::from_static(b"3"));
1222
1223 sleep(Millis(750)).await;
1224 assert!(!client.is_closed());
1225 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 0, 0]);
1226 }
1227
1228 #[derive(Debug, Copy, Clone)]
1229 struct BytesLenCodec(usize);
1230
1231 impl Encoder for BytesLenCodec {
1232 type Item = Bytes;
1233 type Error = io::Error;
1234
1235 #[inline]
1236 fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
1237 dst.extend_from_slice(&item[..]);
1238 Ok(())
1239 }
1240 }
1241
1242 impl Decoder for BytesLenCodec {
1243 type Item = Bytes;
1244 type Error = io::Error;
1245
1246 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
1247 if src.len() >= self.0 {
1248 Ok(Some(src.split_to(self.0)))
1249 } else {
1250 Ok(None)
1251 }
1252 }
1253 }
1254
1255 #[ntex::test]
1257 async fn test_no_keepalive_err_after_frame_timeout() {
1258 let (client, server) = Io::create();
1259 client.remote_buffer_cap(1024);
1260
1261 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1262 let data2 = data.clone();
1263
1264 let config = SharedCfg::new("BDG").add(
1265 IoConfig::new().set_keepalive_timeout(Seconds(0)).set_frame_read_rate(
1266 Seconds(1),
1267 Seconds(2),
1268 2,
1269 ),
1270 );
1271
1272 let (disp, _) = Dispatcher::new_debug(
1273 nio::Io::new(server, config),
1274 BytesLenCodec(2),
1275 ntex_service::fn_service(move |msg: DispatchItem<BytesLenCodec>| {
1276 let data = data2.clone();
1277 async move {
1278 match msg {
1279 DispatchItem::Item(bytes) => {
1280 data.lock().unwrap().borrow_mut().push(0);
1281 return Ok::<_, ()>(Some(bytes));
1282 }
1283 DispatchItem::Stop(Reason::KeepAliveTimeout) => {
1284 data.lock().unwrap().borrow_mut().push(1);
1285 }
1286 _ => (),
1287 }
1288 Ok(None)
1289 }
1290 }),
1291 );
1292 ntex_util::spawn(async move {
1293 let _ = disp.await;
1294 });
1295
1296 client.write("1");
1297 sleep(Millis(250)).await;
1298 client.write("2");
1299 let buf = client.read().await.unwrap();
1300 assert_eq!(buf, Bytes::from_static(b"12"));
1301 sleep(Millis(2000)).await;
1302
1303 assert_eq!(&data.lock().unwrap().borrow()[..], &[0]);
1304 }
1305
1306 #[ntex::test]
1307 async fn test_read_timeout() {
1308 let (client, server) = Io::create();
1309 client.remote_buffer_cap(1024);
1310
1311 let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
1312 let data2 = data.clone();
1313
1314 let config = SharedCfg::new("DBG").add(
1315 IoConfig::new().set_keepalive_timeout(Seconds::ZERO).set_frame_read_rate(
1316 Seconds(1),
1317 Seconds(2),
1318 2,
1319 ),
1320 );
1321
1322 let (disp, state) = Dispatcher::new_debug(
1323 nio::Io::new(server, config),
1324 BytesLenCodec(8),
1325 ntex_service::fn_service(move |msg: DispatchItem<BytesLenCodec>| {
1326 let data = data2.clone();
1327 async move {
1328 match msg {
1329 DispatchItem::Item(bytes) => {
1330 data.lock().unwrap().borrow_mut().push(0);
1331 return Ok::<_, ()>(Some(bytes));
1332 }
1333 DispatchItem::Stop(Reason::ReadTimeout) => {
1334 data.lock().unwrap().borrow_mut().push(1);
1335 }
1336 _ => (),
1337 }
1338 Ok(None)
1339 }
1340 }),
1341 );
1342 ntex_util::spawn(async move {
1343 let _ = disp.await;
1344 });
1345
1346 client.write("12345678");
1347 let buf = client.read().await.unwrap();
1348 assert_eq!(buf, Bytes::from_static(b"12345678"));
1349
1350 client.write("1");
1351 sleep(Millis(1000)).await;
1352 assert!(!state.flags().contains(nio::Flags::IO_STOPPING));
1353 client.write("23");
1354 sleep(Millis(1000)).await;
1355 assert!(!state.flags().contains(nio::Flags::IO_STOPPING));
1356 client.write("4");
1357 sleep(Millis(2000)).await;
1358
1359 assert!(state.flags().contains(nio::Flags::IO_STOPPING));
1361 assert!(client.is_closed());
1362 assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
1363 }
1364
1365 #[ntex::test]
1367 async fn cancel_on_stop() {
1368 #[derive(Clone)]
1369 struct OnDrop(Arc<AtomicBool>);
1370 impl Drop for OnDrop {
1371 fn drop(&mut self) {
1372 self.0.store(true, Ordering::Relaxed);
1373 }
1374 }
1375
1376 let (client, server) = Io::create();
1377 client.remote_buffer_cap(1024);
1378
1379 let data = Arc::new(AtomicBool::new(false));
1380 let data2 = OnDrop(data.clone());
1381
1382 let config = SharedCfg::new("DBG").add(
1383 IoConfig::new().set_keepalive_timeout(Seconds(0)).set_frame_read_rate(
1384 Seconds(1),
1385 Seconds(2),
1386 2,
1387 ),
1388 );
1389
1390 let (disp, _) = Dispatcher::new_debug(
1391 nio::Io::new(server, config),
1392 BytesLenCodec(2),
1393 ntex_service::fn_service(move |msg: DispatchItem<BytesLenCodec>| {
1394 let data = data2.clone();
1395 async move {
1396 if let DispatchItem::Item(bytes) = msg {
1397 sleep(Millis(99_9999)).await;
1398 drop(data);
1399 return Ok::<_, ()>(Some(bytes));
1400 }
1401 Ok(None)
1402 }
1403 }),
1404 );
1405 ntex_util::spawn(async move {
1406 let _ = disp.await;
1407 });
1408
1409 client.write("1");
1410 client.close().await;
1411 sleep(Millis(250)).await;
1412
1413 assert!(&data.load(Ordering::Relaxed));
1414 }
1415}