1use std::cell::{Cell, UnsafeCell};
2use std::future::{Future, poll_fn};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::cfg::SharedCfg;
8use ntex_util::{future::Either, task::LocalWaker};
9
10use crate::buf::Stack;
11use crate::cfg::{BufConfig, IoConfig};
12use crate::filter::{Base, Filter, Layer, NullFilter};
13use crate::flags::Flags;
14use crate::seal::{IoBoxed, Sealed};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoContext, IoStatusUpdate, IoStream, RecvError};
17
18pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25 filter: FilterPtr,
26 pub(super) cfg: Cell<&'static IoConfig>,
27 pub(super) flags: Cell<Flags>,
28 pub(super) error: Cell<Option<io::Error>>,
29 pub(super) read_task: LocalWaker,
30 pub(super) write_task: LocalWaker,
31 pub(super) dispatch_task: LocalWaker,
32 pub(super) buffer: Stack,
33 pub(super) handle: Cell<Option<Box<dyn Handle>>>,
34 pub(super) timeout: Cell<TimerHandle>,
35 #[allow(clippy::box_collection)]
36 pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
37}
38
39impl IoState {
40 pub(super) fn filter(&self) -> &dyn Filter {
41 self.filter.filter.get()
42 }
43
44 pub(super) fn insert_flags(&self, f: Flags) {
45 let mut flags = self.flags.get();
46 flags.insert(f);
47 self.flags.set(flags);
48 }
49
50 pub(super) fn remove_flags(&self, f: Flags) -> bool {
51 let mut flags = self.flags.get();
52 if flags.intersects(f) {
53 flags.remove(f);
54 self.flags.set(flags);
55 true
56 } else {
57 false
58 }
59 }
60
61 pub(super) fn notify_timeout(&self) {
62 let mut flags = self.flags.get();
63 if !flags.contains(Flags::DSP_TIMEOUT) {
64 flags.insert(Flags::DSP_TIMEOUT);
65 self.flags.set(flags);
66 self.dispatch_task.wake();
67 log::trace!("{}: Timer, notify dispatcher", self.cfg.get().tag());
68 }
69 }
70
71 pub(super) fn notify_disconnect(&self) {
72 if let Some(on_disconnect) = self.on_disconnect.take() {
73 for item in on_disconnect.into_iter() {
74 item.wake();
75 }
76 }
77 }
78
79 pub(super) fn error(&self) -> Option<io::Error> {
81 if let Some(err) = self.error.take() {
82 self.error
83 .set(Some(io::Error::new(err.kind(), format!("{err}"))));
84 Some(err)
85 } else {
86 None
87 }
88 }
89
90 pub(super) fn error_or_disconnected(&self) -> io::Error {
92 self.error()
93 .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
94 }
95
96 pub(super) fn io_stopped(&self, err: Option<io::Error>) {
97 if !self.flags.get().is_stopped() {
98 log::trace!(
99 "{}: {} Io error {:?} flags: {:?}",
100 self.cfg.get().tag(),
101 self as *const _ as usize,
102 err,
103 self.flags.get()
104 );
105
106 if err.is_some() {
107 self.error.set(err);
108 }
109 self.read_task.wake();
110 self.write_task.wake();
111 self.notify_disconnect();
112 self.handle.take();
113 self.insert_flags(
114 Flags::IO_STOPPED
115 | Flags::IO_STOPPING
116 | Flags::IO_STOPPING_FILTERS
117 | Flags::BUF_R_READY,
118 );
119 if !self.dispatch_task.wake_checked() {
120 log::trace!(
121 "{}: {} Dispatcher is not registered, flags: {:?}",
122 self.cfg.get().tag(),
123 self as *const _ as usize,
124 self.flags.get()
125 );
126 }
127 }
128 }
129
130 pub(super) fn init_shutdown(&self) {
132 if !self
133 .flags
134 .get()
135 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
136 {
137 log::trace!(
138 "{}: Initiate io shutdown {:?}",
139 self.cfg.get().tag(),
140 self.flags.get()
141 );
142 self.insert_flags(Flags::IO_STOPPING_FILTERS);
143 self.read_task.wake();
144 }
145 }
146
147 #[inline]
148 pub(super) fn read_buf(&self) -> &BufConfig {
149 self.cfg.get().read_buf()
150 }
151
152 #[inline]
153 pub(super) fn write_buf(&self) -> &BufConfig {
154 self.cfg.get().write_buf()
155 }
156}
157
158impl Eq for IoState {}
159
160impl PartialEq for IoState {
161 #[inline]
162 fn eq(&self, other: &Self) -> bool {
163 ptr::eq(self, other)
164 }
165}
166
167impl hash::Hash for IoState {
168 #[inline]
169 fn hash<H: hash::Hasher>(&self, state: &mut H) {
170 (self as *const _ as usize).hash(state);
171 }
172}
173
174impl fmt::Debug for IoState {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 let err = self.error.take();
177 let res = f
178 .debug_struct("IoState")
179 .field("flags", &self.flags)
180 .field("filter", &self.filter.is_set())
181 .field("timeout", &self.timeout)
182 .field("error", &err)
183 .field("buffer", &self.buffer)
184 .field("cfg", &self.cfg)
185 .finish();
186 self.error.set(err);
187 res
188 }
189}
190
191impl Io {
192 #[inline]
193 pub fn new<I: IoStream, T: Into<SharedCfg>>(io: I, cfg: T) -> Self {
195 let inner = Rc::new(IoState {
196 cfg: Cell::new(cfg.into().get::<IoConfig>().into_static()),
197 filter: FilterPtr::null(),
198 flags: Cell::new(Flags::WR_PAUSED),
199 error: Cell::new(None),
200 dispatch_task: LocalWaker::new(),
201 read_task: LocalWaker::new(),
202 write_task: LocalWaker::new(),
203 buffer: Stack::new(),
204 handle: Cell::new(None),
205 timeout: Cell::new(TimerHandle::default()),
206 on_disconnect: Cell::new(None),
207 });
208 inner.filter.update(Base::new(IoRef(inner.clone())));
209
210 let io_ref = IoRef(inner);
211
212 let hnd = io.start(IoContext::new(&io_ref));
214 io_ref.0.handle.set(hnd);
215
216 Io(UnsafeCell::new(io_ref), marker::PhantomData)
217 }
218}
219
220impl<I: IoStream> From<I> for Io {
221 #[inline]
222 fn from(io: I) -> Io {
223 Io::new(io, SharedCfg::default())
224 }
225}
226
227impl<F> Io<F> {
228 #[inline]
229 pub fn take(&self) -> Self {
233 Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
234 }
235
236 fn take_io_ref(&self) -> IoRef {
237 let inner = Rc::new(IoState {
238 cfg: Cell::new(SharedCfg::default().get::<IoConfig>().into_static()),
239 filter: FilterPtr::null(),
240 flags: Cell::new(
241 Flags::DSP_STOP
242 | Flags::IO_STOPPED
243 | Flags::IO_STOPPING
244 | Flags::IO_STOPPING_FILTERS,
245 ),
246 error: Cell::new(None),
247 dispatch_task: LocalWaker::new(),
248 read_task: LocalWaker::new(),
249 write_task: LocalWaker::new(),
250 buffer: Stack::new(),
251 handle: Cell::new(None),
252 timeout: Cell::new(TimerHandle::default()),
253 on_disconnect: Cell::new(None),
254 });
255 unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
256 }
257
258 #[inline]
259 #[doc(hidden)]
260 pub fn flags(&self) -> Flags {
262 self.st().flags.get()
263 }
264
265 #[inline]
266 pub fn get_ref(&self) -> IoRef {
268 self.io_ref().clone()
269 }
270
271 fn st(&self) -> &IoState {
272 unsafe { &(*self.0.get()).0 }
273 }
274
275 fn io_ref(&self) -> &IoRef {
276 unsafe { &*self.0.get() }
277 }
278
279 #[inline]
280 pub fn set_config<T: Into<SharedCfg>>(&self, cfg: T) {
282 self.st()
283 .cfg
284 .set(cfg.into().get::<IoConfig>().into_static());
285 }
286}
287
288impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
289 #[inline]
290 pub fn filter(&self) -> &F {
292 &self.st().filter.filter::<Layer<F, T>>().0
293 }
294}
295
296impl<F: Filter> Io<F> {
297 #[inline]
298 pub fn seal(self) -> Io<Sealed> {
300 let state = self.take_io_ref();
301 state.0.filter.seal::<F>();
302
303 Io(UnsafeCell::new(state), marker::PhantomData)
304 }
305
306 #[inline]
307 pub fn boxed(self) -> IoBoxed {
309 self.seal().into()
310 }
311
312 #[inline]
313 pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
315 where
316 U: FilterLayer,
317 {
318 let state = self.take_io_ref();
319
320 unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
325 .buffer
326 .add_layer();
327
328 state.0.filter.add_filter::<F, U>(nf);
330
331 Io(UnsafeCell::new(state), marker::PhantomData)
332 }
333
334 pub fn map_filter<U, R>(self, f: U) -> Io<R>
336 where
337 U: FnOnce(F) -> R,
338 R: Filter,
339 {
340 let state = self.take_io_ref();
341 state.0.filter.map_filter::<F, U, R>(f);
342
343 Io(UnsafeCell::new(state), marker::PhantomData)
344 }
345}
346
347impl<F> Io<F> {
348 #[inline]
349 pub async fn recv<U>(
351 &self,
352 codec: &U,
353 ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
354 where
355 U: Decoder,
356 {
357 loop {
358 return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
359 Ok(item) => Ok(Some(item)),
360 Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
361 io::ErrorKind::TimedOut,
362 "Timeout",
363 ))),
364 Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
365 io::ErrorKind::UnexpectedEof,
366 "Dispatcher stopped",
367 ))),
368 Err(RecvError::WriteBackpressure) => {
369 poll_fn(|cx| self.poll_flush(cx, false))
370 .await
371 .map_err(Either::Right)?;
372 continue;
373 }
374 Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
375 Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
376 Err(RecvError::PeerGone(None)) => Ok(None),
377 };
378 }
379 }
380
381 #[inline]
382 pub async fn read_ready(&self) -> io::Result<Option<()>> {
384 poll_fn(|cx| self.poll_read_ready(cx)).await
385 }
386
387 #[inline]
388 pub async fn read_notify(&self) -> io::Result<Option<()>> {
390 poll_fn(|cx| self.poll_read_notify(cx)).await
391 }
392
393 #[inline]
394 pub fn pause(&self) {
396 let st = self.st();
397 if !st.flags.get().contains(Flags::RD_PAUSED) {
398 st.read_task.wake();
399 st.insert_flags(Flags::RD_PAUSED);
400 }
401 }
402
403 #[inline]
404 pub async fn send<U>(
406 &self,
407 item: U::Item,
408 codec: &U,
409 ) -> Result<(), Either<U::Error, io::Error>>
410 where
411 U: Encoder,
412 {
413 self.encode(item, codec).map_err(Either::Left)?;
414
415 poll_fn(|cx| self.poll_flush(cx, true))
416 .await
417 .map_err(Either::Right)?;
418
419 Ok(())
420 }
421
422 #[inline]
423 pub async fn flush(&self, full: bool) -> io::Result<()> {
427 poll_fn(|cx| self.poll_flush(cx, full)).await
428 }
429
430 #[inline]
431 pub async fn shutdown(&self) -> io::Result<()> {
433 poll_fn(|cx| self.poll_shutdown(cx)).await
434 }
435
436 #[inline]
437 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
451 let st = self.st();
452 let mut flags = st.flags.get();
453
454 if flags.is_stopped() {
455 Poll::Ready(Err(st.error_or_disconnected()))
456 } else {
457 st.dispatch_task.register(cx.waker());
458
459 let ready = flags.is_read_buf_ready();
460 if flags.cannot_read() {
461 flags.cleanup_read_flags();
462 st.read_task.wake();
463 st.flags.set(flags);
464 if ready {
465 Poll::Ready(Ok(Some(())))
466 } else {
467 Poll::Pending
468 }
469 } else if ready {
470 flags.remove(Flags::BUF_R_READY);
471 st.flags.set(flags);
472 Poll::Ready(Ok(Some(())))
473 } else {
474 Poll::Pending
475 }
476 }
477 }
478
479 #[inline]
480 pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
482 let ready = self.poll_read_ready(cx);
483
484 if ready.is_pending() {
485 let st = self.st();
486 if st.remove_flags(Flags::RD_NOTIFY) {
487 Poll::Ready(Ok(Some(())))
488 } else {
489 st.insert_flags(Flags::RD_NOTIFY);
490 Poll::Pending
491 }
492 } else {
493 ready
494 }
495 }
496
497 #[inline]
498 pub fn poll_recv<U>(
503 &self,
504 codec: &U,
505 cx: &mut Context<'_>,
506 ) -> Poll<Result<U::Item, RecvError<U>>>
507 where
508 U: Decoder,
509 {
510 let decoded = self.poll_recv_decode(codec, cx)?;
511
512 if let Some(item) = decoded.item {
513 Poll::Ready(Ok(item))
514 } else {
515 Poll::Pending
516 }
517 }
518
519 #[doc(hidden)]
520 #[inline]
521 pub fn poll_recv_decode<U>(
526 &self,
527 codec: &U,
528 cx: &mut Context<'_>,
529 ) -> Result<Decoded<U::Item>, RecvError<U>>
530 where
531 U: Decoder,
532 {
533 let decoded = self
534 .decode_item(codec)
535 .map_err(|err| RecvError::Decoder(err))?;
536
537 if decoded.item.is_some() {
538 Ok(decoded)
539 } else {
540 let st = self.st();
541 let flags = st.flags.get();
542 if flags.is_stopped() {
543 Err(RecvError::PeerGone(st.error()))
544 } else if flags.contains(Flags::DSP_STOP) {
545 st.remove_flags(Flags::DSP_STOP);
546 Err(RecvError::Stop)
547 } else if flags.contains(Flags::DSP_TIMEOUT) {
548 st.remove_flags(Flags::DSP_TIMEOUT);
549 Err(RecvError::KeepAlive)
550 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
551 Err(RecvError::WriteBackpressure)
552 } else {
553 match self.poll_read_ready(cx) {
554 Poll::Pending | Poll::Ready(Ok(Some(()))) => {
555 if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
556 log::trace!(
557 "{}: Not enough data to decode next frame",
558 self.tag()
559 );
560 }
561 Ok(decoded)
562 }
563 Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
564 Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
565 }
566 }
567 }
568 }
569
570 #[inline]
571 pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
577 let st = self.st();
578 let flags = self.flags();
579
580 let len = st.buffer.write_destination_size();
581 if len > 0 {
582 if full {
583 st.insert_flags(Flags::BUF_W_MUST_FLUSH);
584 st.dispatch_task.register(cx.waker());
585 return if flags.is_stopped() {
586 Poll::Ready(Err(st.error_or_disconnected()))
587 } else {
588 Poll::Pending
589 };
590 } else if len >= st.write_buf().half {
591 st.insert_flags(Flags::BUF_W_BACKPRESSURE);
592 st.dispatch_task.register(cx.waker());
593 return if flags.is_stopped() {
594 Poll::Ready(Err(st.error_or_disconnected()))
595 } else {
596 Poll::Pending
597 };
598 }
599 }
600 if flags.is_stopped() {
601 Poll::Ready(Err(st.error_or_disconnected()))
602 } else {
603 st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
604 Poll::Ready(Ok(()))
605 }
606 }
607
608 #[inline]
609 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
611 let st = self.st();
612 let flags = st.flags.get();
613
614 if flags.is_stopped() {
615 if let Some(err) = st.error() {
616 Poll::Ready(Err(err))
617 } else {
618 Poll::Ready(Ok(()))
619 }
620 } else {
621 if !flags.contains(Flags::IO_STOPPING_FILTERS) {
622 st.init_shutdown();
623 }
624
625 st.read_task.wake();
626 st.write_task.wake();
627 st.dispatch_task.register(cx.waker());
628 Poll::Pending
629 }
630 }
631
632 #[inline]
633 pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
637 self.pause();
638 let result = self.poll_status_update(cx);
639 if !result.is_pending() {
640 self.st().dispatch_task.register(cx.waker());
641 }
642 result
643 }
644
645 #[inline]
646 pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
648 let st = self.st();
649 let flags = st.flags.get();
650 if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
651 Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
652 } else if flags.contains(Flags::DSP_STOP) {
653 st.remove_flags(Flags::DSP_STOP);
654 Poll::Ready(IoStatusUpdate::Stop)
655 } else if flags.contains(Flags::DSP_TIMEOUT) {
656 st.remove_flags(Flags::DSP_TIMEOUT);
657 Poll::Ready(IoStatusUpdate::KeepAlive)
658 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
659 Poll::Ready(IoStatusUpdate::WriteBackpressure)
660 } else {
661 st.dispatch_task.register(cx.waker());
662 Poll::Pending
663 }
664 }
665
666 #[inline]
667 pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
669 self.st().dispatch_task.register(cx.waker());
670 }
671}
672
673impl<F> AsRef<IoRef> for Io<F> {
674 #[inline]
675 fn as_ref(&self) -> &IoRef {
676 self.io_ref()
677 }
678}
679
680impl<F> Eq for Io<F> {}
681
682impl<F> PartialEq for Io<F> {
683 #[inline]
684 fn eq(&self, other: &Self) -> bool {
685 self.io_ref().eq(other.io_ref())
686 }
687}
688
689impl<F> hash::Hash for Io<F> {
690 #[inline]
691 fn hash<H: hash::Hasher>(&self, state: &mut H) {
692 self.io_ref().hash(state);
693 }
694}
695
696impl<F> fmt::Debug for Io<F> {
697 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698 f.debug_struct("Io").field("state", self.st()).finish()
699 }
700}
701
702impl<F> ops::Deref for Io<F> {
703 type Target = IoRef;
704
705 #[inline]
706 fn deref(&self) -> &Self::Target {
707 self.io_ref()
708 }
709}
710
711impl<F> Drop for Io<F> {
712 fn drop(&mut self) {
713 let st = self.st();
714 self.stop_timer();
715
716 if st.filter.is_set() {
717 if !st.flags.get().is_stopped() {
720 log::trace!(
721 "{}: Io is dropped, force stopping io streams {:?}",
722 st.cfg.get().tag(),
723 st.flags.get()
724 );
725 }
726
727 self.force_close();
728 st.filter.drop_filter::<F>();
729 }
730 }
731}
732
733const KIND_SEALED: u8 = 0b01;
734const KIND_PTR: u8 = 0b10;
735const KIND_MASK: u8 = 0b11;
736const KIND_UNMASK: u8 = !KIND_MASK;
737const KIND_MASK_USIZE: usize = 0b11;
738const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
739const SEALED_SIZE: usize = mem::size_of::<Sealed>();
740const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
741
742#[cfg(target_endian = "little")]
743const KIND_IDX: usize = 0;
744
745#[cfg(target_endian = "big")]
746const KIND_IDX: usize = SEALED_SIZE - 1;
747
748struct FilterPtr {
749 data: Cell<[u8; SEALED_SIZE]>,
750 filter: Cell<&'static dyn Filter>,
751}
752
753impl FilterPtr {
754 const fn null() -> Self {
755 Self {
756 data: Cell::new(NULL),
757 filter: Cell::new(NullFilter::get()),
758 }
759 }
760
761 fn update<F: Filter>(&self, filter: F) {
762 if self.is_set() {
763 panic!("Filter is set, must be dropped first");
764 }
765
766 let filter = Box::new(filter);
767 let mut data = NULL;
768 unsafe {
769 let filter_ref: &'static dyn Filter = {
770 let f: &dyn Filter = filter.as_ref();
771 mem::transmute(f)
772 };
773 self.filter.set(filter_ref);
774
775 let ptr = &mut data as *mut _ as *mut *mut F;
776 ptr.write(Box::into_raw(filter));
777 data[KIND_IDX] |= KIND_PTR;
778 self.data.set(data);
779 }
780 }
781
782 fn filter<F: Filter>(&self) -> &F {
784 let data = self.data.get();
785 if data[KIND_IDX] & KIND_PTR != 0 {
786 let ptr = &data as *const _ as *const *mut F;
787 unsafe {
788 let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
789 (p as *const F as *mut F).as_ref().unwrap()
790 }
791 } else {
792 panic!("Wrong filter item");
793 }
794 }
795
796 fn take_filter<F>(&self) -> Box<F> {
798 let mut data = self.data.get();
799 if data[KIND_IDX] & KIND_PTR != 0 {
800 data[KIND_IDX] &= KIND_UNMASK;
801 let ptr = &mut data as *mut _ as *mut *mut F;
802 unsafe { Box::from_raw(*ptr) }
803 } else {
804 panic!(
805 "Wrong filter item {:?} expected: {:?}",
806 data[KIND_IDX], KIND_PTR
807 );
808 }
809 }
810
811 fn take_sealed(&self) -> Sealed {
813 let mut data = self.data.get();
814
815 if data[KIND_IDX] & KIND_SEALED != 0 {
816 data[KIND_IDX] &= KIND_UNMASK;
817 let ptr = &mut data as *mut _ as *mut Sealed;
818 unsafe { ptr.read() }
819 } else {
820 panic!(
821 "Wrong filter item {:?} expected: {:?}",
822 data[KIND_IDX], KIND_SEALED
823 );
824 }
825 }
826
827 fn is_set(&self) -> bool {
828 self.data.get()[KIND_IDX] & KIND_MASK != 0
829 }
830
831 fn drop_filter<F>(&self) {
832 let data = self.data.get();
833
834 if data[KIND_IDX] & KIND_MASK != 0 {
835 if data[KIND_IDX] & KIND_PTR != 0 {
836 self.take_filter::<F>();
837 } else if data[KIND_IDX] & KIND_SEALED != 0 {
838 self.take_sealed();
839 }
840 self.data.set(NULL);
841 self.filter.set(NullFilter::get());
842 }
843 }
844}
845
846impl FilterPtr {
847 fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
848 let data = self.data.get();
849 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
850 Box::new(Layer::new(new, *self.take_filter::<F>()))
851 } else if data[KIND_IDX] & KIND_SEALED != 0 {
852 let f = Box::new(Layer::new(new, self.take_sealed()));
853 unsafe { mem::transmute::<Box<Layer<T, Sealed>>, Box<Layer<T, F>>>(f) }
855 } else {
856 panic!(
857 "Wrong filter item {:?} expected: {:?}",
858 data[KIND_IDX], KIND_PTR
859 );
860 };
861
862 let mut data = NULL;
863 unsafe {
864 let filter_ref: &'static dyn Filter = {
865 let f: &dyn Filter = filter.as_ref();
866 mem::transmute(f)
867 };
868 self.filter.set(filter_ref);
869
870 let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
871 ptr.write(Box::into_raw(filter));
872 data[KIND_IDX] |= KIND_PTR;
873 self.data.set(data);
874 }
875 }
876
877 fn map_filter<F: Filter, U, R>(&self, f: U)
878 where
879 U: FnOnce(F) -> R,
880 R: Filter,
881 {
882 let mut data = NULL;
883 let filter = Box::new(f(*self.take_filter::<F>()));
884 unsafe {
885 let filter_ref: &'static dyn Filter = {
886 let f: &dyn Filter = filter.as_ref();
887 mem::transmute(f)
888 };
889 self.filter.set(filter_ref);
890
891 let ptr = &mut data as *mut _ as *mut *mut R;
892 ptr.write(Box::into_raw(filter));
893 data[KIND_IDX] |= KIND_PTR;
894 self.data.set(data);
895 }
896 }
897
898 fn seal<F: Filter>(&self) {
899 let mut data = self.data.get();
900
901 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
902 Sealed(Box::new(*self.take_filter::<F>()))
903 } else if data[KIND_IDX] & KIND_SEALED != 0 {
904 self.take_sealed()
905 } else {
906 panic!(
907 "Wrong filter item {:?} expected: {:?}",
908 data[KIND_IDX], KIND_PTR
909 );
910 };
911
912 unsafe {
913 let filter_ref: &'static dyn Filter = {
914 let f: &dyn Filter = filter.0.as_ref();
915 mem::transmute(f)
916 };
917 self.filter.set(filter_ref);
918
919 let ptr = &mut data as *mut _ as *mut Sealed;
920 ptr.write(filter);
921 data[KIND_IDX] |= KIND_SEALED;
922 self.data.set(data);
923 }
924 }
925}
926
927#[derive(Debug)]
928#[must_use = "OnDisconnect do nothing unless polled"]
930pub struct OnDisconnect {
931 token: usize,
932 inner: Rc<IoState>,
933}
934
935impl OnDisconnect {
936 pub(super) fn new(inner: Rc<IoState>) -> Self {
937 Self::new_inner(inner.flags.get().is_stopped(), inner)
938 }
939
940 fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
941 let token = if disconnected {
942 usize::MAX
943 } else {
944 let mut on_disconnect = inner.on_disconnect.take();
945 let token = if let Some(ref mut on_disconnect) = on_disconnect {
946 let token = on_disconnect.len();
947 on_disconnect.push(LocalWaker::default());
948 token
949 } else {
950 on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
951 0
952 };
953 inner.on_disconnect.set(on_disconnect);
954 token
955 };
956 Self { token, inner }
957 }
958
959 #[inline]
960 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
962 if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
963 Poll::Ready(())
964 } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
965 on_disconnect[self.token].register(cx.waker());
966 self.inner.on_disconnect.set(Some(on_disconnect));
967 Poll::Pending
968 } else {
969 Poll::Ready(())
970 }
971 }
972}
973
974impl Clone for OnDisconnect {
975 fn clone(&self) -> Self {
976 if self.token == usize::MAX {
977 OnDisconnect::new_inner(true, self.inner.clone())
978 } else {
979 OnDisconnect::new_inner(false, self.inner.clone())
980 }
981 }
982}
983
984impl Future for OnDisconnect {
985 type Output = ();
986
987 #[inline]
988 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
989 self.poll_ready(cx)
990 }
991}
992
993#[cfg(test)]
994mod tests {
995 use ntex_bytes::Bytes;
996 use ntex_codec::BytesCodec;
997
998 use super::*;
999 use crate::{ReadBuf, WriteBuf, testing::IoTest};
1000
1001 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
1002 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
1003
1004 #[ntex::test]
1005 async fn test_basics() {
1006 let (client, server) = IoTest::create();
1007 client.remote_buffer_cap(1024);
1008
1009 let server = Io::from(server);
1010 assert!(server.eq(&server));
1011 assert!(server.io_ref().eq(server.io_ref()));
1012
1013 assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
1014 assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
1015 assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
1016 }
1017
1018 #[ntex::test]
1019 async fn test_recv() {
1020 let (client, server) = IoTest::create();
1021 client.remote_buffer_cap(1024);
1022
1023 let server = Io::from(server);
1024
1025 server.st().notify_timeout();
1026 let err = server.recv(&BytesCodec).await.err().unwrap();
1027 assert!(format!("{err:?}").contains("Timeout"));
1028
1029 server.st().insert_flags(Flags::DSP_STOP);
1030 let err = server.recv(&BytesCodec).await.err().unwrap();
1031 assert!(format!("{err:?}").contains("Dispatcher stopped"));
1032
1033 client.write(TEXT);
1034 server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
1035 let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
1036 assert_eq!(item, TEXT);
1037 }
1038
1039 #[ntex::test]
1040 async fn test_send() {
1041 let (client, server) = IoTest::create();
1042 client.remote_buffer_cap(1024);
1043
1044 let server = Io::from(server);
1045 assert!(server.eq(&server));
1046
1047 server
1048 .send(Bytes::from_static(BIN), &BytesCodec)
1049 .await
1050 .ok()
1051 .unwrap();
1052 let item = client.read_any();
1053 assert_eq!(item, TEXT);
1054 }
1055
1056 #[derive(Debug)]
1057 struct DropFilter {
1058 p: Rc<Cell<usize>>,
1059 }
1060
1061 impl Drop for DropFilter {
1062 fn drop(&mut self) {
1063 self.p.set(self.p.get() + 1);
1064 }
1065 }
1066
1067 impl FilterLayer for DropFilter {
1068 fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1069 if let Some(src) = buf.take_src() {
1070 let len = src.len();
1071 buf.set_dst(Some(src));
1072 Ok(len)
1073 } else {
1074 Ok(0)
1075 }
1076 }
1077 fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
1078 if let Some(src) = buf.take_src() {
1079 buf.set_dst(Some(src));
1080 }
1081 Ok(())
1082 }
1083 }
1084
1085 #[ntex::test]
1086 async fn drop_filter() {
1087 let p = Rc::new(Cell::new(0));
1088
1089 let (client, server) = IoTest::create();
1090 let f = DropFilter { p: p.clone() };
1091 let _ = format!("{f:?}");
1092 let io = Io::from(server).add_filter(f);
1093
1094 client.remote_buffer_cap(1024);
1095 client.write(TEXT);
1096 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1097 assert_eq!(msg, Bytes::from_static(BIN));
1098
1099 io.send(Bytes::from_static(b"test"), &BytesCodec)
1100 .await
1101 .unwrap();
1102 let buf = client.read().await.unwrap();
1103 assert_eq!(buf, Bytes::from_static(b"test"));
1104
1105 let io2 = io.take();
1106 let mut io3: crate::IoBoxed = io2.into();
1107 let io4 = io3.take();
1108
1109 drop(io);
1110 drop(io3);
1111 drop(io4);
1112
1113 assert_eq!(p.get(), 1);
1114 }
1115
1116 #[ntex::test]
1117 async fn test_take_sealed_filter() {
1118 let p = Rc::new(Cell::new(0));
1119 let f = DropFilter { p: p.clone() };
1120
1121 let io = Io::from(IoTest::create().0).seal();
1122 let _io: Io<Layer<DropFilter, Sealed>> = io.add_filter(f);
1123 }
1124}