1use std::cell::{Cell, UnsafeCell};
2use std::future::{poll_fn, Future};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_bytes::{PoolId, PoolRef};
7use ntex_codec::{Decoder, Encoder};
8use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
9
10use crate::buf::Stack;
11use crate::filter::{Base, Filter, Layer, NullFilter};
12use crate::flags::Flags;
13use crate::seal::{IoBoxed, Sealed};
14use crate::tasks::{ReadContext, WriteContext};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
17
18pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25    filter: FilterPtr,
26    pub(super) flags: Cell<Flags>,
27    pub(super) pool: Cell<PoolRef>,
28    pub(super) disconnect_timeout: Cell<Seconds>,
29    pub(super) error: Cell<Option<io::Error>>,
30    pub(super) read_task: LocalWaker,
31    pub(super) write_task: LocalWaker,
32    pub(super) dispatch_task: LocalWaker,
33    pub(super) buffer: Stack,
34    pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35    pub(super) timeout: Cell<TimerHandle>,
36    pub(super) tag: Cell<&'static str>,
37    #[allow(clippy::box_collection)]
38    pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
39}
40
41const DEFAULT_TAG: &str = "IO";
42
43impl IoState {
44    pub(super) fn filter(&self) -> &dyn Filter {
45        self.filter.filter.get()
46    }
47
48    pub(super) fn insert_flags(&self, f: Flags) {
49        let mut flags = self.flags.get();
50        flags.insert(f);
51        self.flags.set(flags);
52    }
53
54    pub(super) fn remove_flags(&self, f: Flags) -> bool {
55        let mut flags = self.flags.get();
56        if flags.intersects(f) {
57            flags.remove(f);
58            self.flags.set(flags);
59            true
60        } else {
61            false
62        }
63    }
64
65    pub(super) fn notify_timeout(&self) {
66        let mut flags = self.flags.get();
67        if !flags.contains(Flags::DSP_TIMEOUT) {
68            flags.insert(Flags::DSP_TIMEOUT);
69            self.flags.set(flags);
70            self.dispatch_task.wake();
71            log::trace!("{}: Timer, notify dispatcher", self.tag.get());
72        }
73    }
74
75    pub(super) fn notify_disconnect(&self) {
76        if let Some(on_disconnect) = self.on_disconnect.take() {
77            for item in on_disconnect.into_iter() {
78                item.wake();
79            }
80        }
81    }
82
83    pub(super) fn error(&self) -> Option<io::Error> {
85        if let Some(err) = self.error.take() {
86            self.error
87                .set(Some(io::Error::new(err.kind(), format!("{}", err))));
88            Some(err)
89        } else {
90            None
91        }
92    }
93
94    pub(super) fn error_or_disconnected(&self) -> io::Error {
96        self.error()
97            .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
98    }
99
100    pub(super) fn io_stopped(&self, err: Option<io::Error>) {
101        if !self.flags.get().is_stopped() {
102            log::trace!(
103                "{}: {} Io error {:?} flags: {:?}",
104                self.tag.get(),
105                self as *const _ as usize,
106                err,
107                self.flags.get()
108            );
109
110            if err.is_some() {
111                self.error.set(err);
112            }
113            self.read_task.wake();
114            self.write_task.wake();
115            self.notify_disconnect();
116            self.handle.take();
117            self.insert_flags(
118                Flags::IO_STOPPED
119                    | Flags::IO_STOPPING
120                    | Flags::IO_STOPPING_FILTERS
121                    | Flags::BUF_R_READY,
122            );
123            if !self.dispatch_task.wake_checked() {
124                log::trace!(
125                    "{}: {} Dispatcher is not registered, flags: {:?}",
126                    self.tag.get(),
127                    self as *const _ as usize,
128                    self.flags.get()
129                );
130            }
131        }
132    }
133
134    pub(super) fn init_shutdown(&self) {
136        if !self
137            .flags
138            .get()
139            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
140        {
141            log::trace!(
142                "{}: Initiate io shutdown {:?}",
143                self.tag.get(),
144                self.flags.get()
145            );
146            self.insert_flags(Flags::IO_STOPPING_FILTERS);
147            self.read_task.wake();
148        }
149    }
150}
151
152impl Eq for IoState {}
153
154impl PartialEq for IoState {
155    #[inline]
156    fn eq(&self, other: &Self) -> bool {
157        ptr::eq(self, other)
158    }
159}
160
161impl hash::Hash for IoState {
162    #[inline]
163    fn hash<H: hash::Hasher>(&self, state: &mut H) {
164        (self as *const _ as usize).hash(state);
165    }
166}
167
168impl Drop for IoState {
169    #[inline]
170    fn drop(&mut self) {
171        self.buffer.release(self.pool.get());
172    }
173}
174
175impl fmt::Debug for IoState {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        let err = self.error.take();
178        let res = f
179            .debug_struct("IoState")
180            .field("flags", &self.flags)
181            .field("filter", &self.filter.is_set())
182            .field("disconnect_timeout", &self.disconnect_timeout)
183            .field("timeout", &self.timeout)
184            .field("error", &err)
185            .field("buffer", &self.buffer)
186            .finish();
187        self.error.set(err);
188        res
189    }
190}
191
192impl Io {
193    #[inline]
194    pub fn new<I: IoStream>(io: I) -> Self {
196        Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
197    }
198
199    #[inline]
200    pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
202        let inner = Rc::new(IoState {
203            filter: FilterPtr::null(),
204            pool: Cell::new(pool),
205            flags: Cell::new(Flags::WR_PAUSED),
206            error: Cell::new(None),
207            dispatch_task: LocalWaker::new(),
208            read_task: LocalWaker::new(),
209            write_task: LocalWaker::new(),
210            buffer: Stack::new(),
211            handle: Cell::new(None),
212            timeout: Cell::new(TimerHandle::default()),
213            disconnect_timeout: Cell::new(Seconds(1)),
214            on_disconnect: Cell::new(None),
215            tag: Cell::new(DEFAULT_TAG),
216        });
217        inner.filter.update(Base::new(IoRef(inner.clone())));
218
219        let io_ref = IoRef(inner);
220
221        let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
223        io_ref.0.handle.set(hnd);
224
225        Io(UnsafeCell::new(io_ref), marker::PhantomData)
226    }
227}
228
229impl<F> Io<F> {
230    #[inline]
231    pub fn set_memory_pool(&self, pool: PoolRef) {
233        self.st().buffer.set_memory_pool(pool);
234        self.st().pool.set(pool);
235    }
236
237    #[inline]
238    pub fn set_disconnect_timeout(&self, timeout: Seconds) {
240        self.st().disconnect_timeout.set(timeout);
241    }
242
243    #[inline]
244    pub fn take(&self) -> Self {
248        Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
249    }
250
251    fn take_io_ref(&self) -> IoRef {
252        let inner = Rc::new(IoState {
253            filter: FilterPtr::null(),
254            pool: self.st().pool.clone(),
255            flags: Cell::new(
256                Flags::DSP_STOP
257                    | Flags::IO_STOPPED
258                    | Flags::IO_STOPPING
259                    | Flags::IO_STOPPING_FILTERS,
260            ),
261            error: Cell::new(None),
262            disconnect_timeout: Cell::new(Seconds(1)),
263            dispatch_task: LocalWaker::new(),
264            read_task: LocalWaker::new(),
265            write_task: LocalWaker::new(),
266            buffer: Stack::new(),
267            handle: Cell::new(None),
268            timeout: Cell::new(TimerHandle::default()),
269            on_disconnect: Cell::new(None),
270            tag: Cell::new(DEFAULT_TAG),
271        });
272        unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
273    }
274}
275
276impl<F> Io<F> {
277    #[inline]
278    #[doc(hidden)]
279    pub fn flags(&self) -> Flags {
281        self.st().flags.get()
282    }
283
284    #[inline]
285    pub fn get_ref(&self) -> IoRef {
287        self.io_ref().clone()
288    }
289
290    fn st(&self) -> &IoState {
291        unsafe { &(*self.0.get()).0 }
292    }
293
294    fn io_ref(&self) -> &IoRef {
295        unsafe { &*self.0.get() }
296    }
297}
298
299impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
300    #[inline]
301    pub fn filter(&self) -> &F {
303        &self.st().filter.filter::<Layer<F, T>>().0
304    }
305}
306
307impl<F: Filter> Io<F> {
308    #[inline]
309    pub fn seal(self) -> Io<Sealed> {
311        let state = self.take_io_ref();
312        state.0.filter.seal::<F>();
313
314        Io(UnsafeCell::new(state), marker::PhantomData)
315    }
316
317    #[inline]
318    pub fn boxed(self) -> IoBoxed {
320        self.seal().into()
321    }
322
323    #[inline]
324    pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
326    where
327        U: FilterLayer,
328    {
329        let state = self.take_io_ref();
330
331        unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
336            .buffer
337            .add_layer();
338
339        state.0.filter.add_filter::<F, U>(nf);
341
342        Io(UnsafeCell::new(state), marker::PhantomData)
343    }
344
345    pub fn map_filter<U, R>(self, f: U) -> Io<R>
347    where
348        U: FnOnce(F) -> R,
349        R: Filter,
350    {
351        let state = self.take_io_ref();
352        state.0.filter.map_filter::<F, U, R>(f);
353
354        Io(UnsafeCell::new(state), marker::PhantomData)
355    }
356}
357
358impl<F> Io<F> {
359    #[inline]
360    pub async fn recv<U>(
362        &self,
363        codec: &U,
364    ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
365    where
366        U: Decoder,
367    {
368        loop {
369            return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
370                Ok(item) => Ok(Some(item)),
371                Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
372                    io::ErrorKind::TimedOut,
373                    "Timeout",
374                ))),
375                Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
376                    io::ErrorKind::UnexpectedEof,
377                    "Dispatcher stopped",
378                ))),
379                Err(RecvError::WriteBackpressure) => {
380                    poll_fn(|cx| self.poll_flush(cx, false))
381                        .await
382                        .map_err(Either::Right)?;
383                    continue;
384                }
385                Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
386                Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
387                Err(RecvError::PeerGone(None)) => Ok(None),
388            };
389        }
390    }
391
392    #[inline]
393    pub async fn read_ready(&self) -> io::Result<Option<()>> {
395        poll_fn(|cx| self.poll_read_ready(cx)).await
396    }
397
398    #[inline]
399    pub async fn read_notify(&self) -> io::Result<Option<()>> {
401        poll_fn(|cx| self.poll_read_notify(cx)).await
402    }
403
404    #[inline]
405    pub fn pause(&self) {
407        let st = self.st();
408        if !st.flags.get().contains(Flags::RD_PAUSED) {
409            st.read_task.wake();
410            st.insert_flags(Flags::RD_PAUSED);
411        }
412    }
413
414    #[inline]
415    pub async fn send<U>(
417        &self,
418        item: U::Item,
419        codec: &U,
420    ) -> Result<(), Either<U::Error, io::Error>>
421    where
422        U: Encoder,
423    {
424        self.encode(item, codec).map_err(Either::Left)?;
425
426        poll_fn(|cx| self.poll_flush(cx, true))
427            .await
428            .map_err(Either::Right)?;
429
430        Ok(())
431    }
432
433    #[inline]
434    pub async fn flush(&self, full: bool) -> io::Result<()> {
438        poll_fn(|cx| self.poll_flush(cx, full)).await
439    }
440
441    #[inline]
442    pub async fn shutdown(&self) -> io::Result<()> {
444        poll_fn(|cx| self.poll_shutdown(cx)).await
445    }
446
447    #[inline]
448    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
462        let st = self.st();
463        let mut flags = st.flags.get();
464
465        if flags.is_stopped() {
466            Poll::Ready(Err(st.error_or_disconnected()))
467        } else {
468            st.dispatch_task.register(cx.waker());
469
470            let ready = flags.is_read_buf_ready();
471            if flags.cannot_read() {
472                flags.cleanup_read_flags();
473                st.read_task.wake();
474                st.flags.set(flags);
475                if ready {
476                    Poll::Ready(Ok(Some(())))
477                } else {
478                    Poll::Pending
479                }
480            } else if ready {
481                flags.remove(Flags::BUF_R_READY);
482                st.flags.set(flags);
483                Poll::Ready(Ok(Some(())))
484            } else {
485                Poll::Pending
486            }
487        }
488    }
489
490    #[inline]
491    pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
493        let ready = self.poll_read_ready(cx);
494
495        if ready.is_pending() {
496            let st = self.st();
497            if st.remove_flags(Flags::RD_NOTIFY) {
498                Poll::Ready(Ok(Some(())))
499            } else {
500                st.insert_flags(Flags::RD_NOTIFY);
501                Poll::Pending
502            }
503        } else {
504            ready
505        }
506    }
507
508    #[inline]
509    pub fn poll_recv<U>(
514        &self,
515        codec: &U,
516        cx: &mut Context<'_>,
517    ) -> Poll<Result<U::Item, RecvError<U>>>
518    where
519        U: Decoder,
520    {
521        let decoded = self.poll_recv_decode(codec, cx)?;
522
523        if let Some(item) = decoded.item {
524            Poll::Ready(Ok(item))
525        } else {
526            Poll::Pending
527        }
528    }
529
530    #[doc(hidden)]
531    #[inline]
532    pub fn poll_recv_decode<U>(
537        &self,
538        codec: &U,
539        cx: &mut Context<'_>,
540    ) -> Result<Decoded<U::Item>, RecvError<U>>
541    where
542        U: Decoder,
543    {
544        let decoded = self
545            .decode_item(codec)
546            .map_err(|err| RecvError::Decoder(err))?;
547
548        if decoded.item.is_some() {
549            Ok(decoded)
550        } else {
551            let st = self.st();
552            let flags = st.flags.get();
553            if flags.is_stopped() {
554                Err(RecvError::PeerGone(st.error()))
555            } else if flags.contains(Flags::DSP_STOP) {
556                st.remove_flags(Flags::DSP_STOP);
557                Err(RecvError::Stop)
558            } else if flags.contains(Flags::DSP_TIMEOUT) {
559                st.remove_flags(Flags::DSP_TIMEOUT);
560                Err(RecvError::KeepAlive)
561            } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
562                Err(RecvError::WriteBackpressure)
563            } else {
564                match self.poll_read_ready(cx) {
565                    Poll::Pending | Poll::Ready(Ok(Some(()))) => {
566                        if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
567                            log::trace!(
568                                "{}: Not enough data to decode next frame",
569                                self.tag()
570                            );
571                        }
572                        Ok(decoded)
573                    }
574                    Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
575                    Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
576                }
577            }
578        }
579    }
580
581    #[inline]
582    pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
588        let st = self.st();
589        let flags = self.flags();
590
591        let len = st.buffer.write_destination_size();
592        if len > 0 {
593            if full {
594                st.insert_flags(Flags::BUF_W_MUST_FLUSH);
595                st.dispatch_task.register(cx.waker());
596                return if flags.is_stopped() {
597                    Poll::Ready(Err(st.error_or_disconnected()))
598                } else {
599                    Poll::Pending
600                };
601            } else if len >= st.pool.get().write_params_high() << 1 {
602                st.insert_flags(Flags::BUF_W_BACKPRESSURE);
603                st.dispatch_task.register(cx.waker());
604                return if flags.is_stopped() {
605                    Poll::Ready(Err(st.error_or_disconnected()))
606                } else {
607                    Poll::Pending
608                };
609            }
610        }
611        st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
612        Poll::Ready(Ok(()))
613    }
614
615    #[inline]
616    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
618        let st = self.st();
619        let flags = st.flags.get();
620
621        if flags.is_stopped() {
622            if let Some(err) = st.error() {
623                Poll::Ready(Err(err))
624            } else {
625                Poll::Ready(Ok(()))
626            }
627        } else {
628            if !flags.contains(Flags::IO_STOPPING_FILTERS) {
629                st.init_shutdown();
630            }
631
632            st.read_task.wake();
633            st.write_task.wake();
634            st.dispatch_task.register(cx.waker());
635            Poll::Pending
636        }
637    }
638
639    #[inline]
640    pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
644        self.pause();
645        let result = self.poll_status_update(cx);
646        if !result.is_pending() {
647            self.st().dispatch_task.register(cx.waker());
648        }
649        result
650    }
651
652    #[inline]
653    pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
655        let st = self.st();
656        let flags = st.flags.get();
657        if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
658            Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
659        } else if flags.contains(Flags::DSP_STOP) {
660            st.remove_flags(Flags::DSP_STOP);
661            Poll::Ready(IoStatusUpdate::Stop)
662        } else if flags.contains(Flags::DSP_TIMEOUT) {
663            st.remove_flags(Flags::DSP_TIMEOUT);
664            Poll::Ready(IoStatusUpdate::KeepAlive)
665        } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
666            Poll::Ready(IoStatusUpdate::WriteBackpressure)
667        } else {
668            st.dispatch_task.register(cx.waker());
669            Poll::Pending
670        }
671    }
672
673    #[inline]
674    pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
676        self.st().dispatch_task.register(cx.waker());
677    }
678}
679
680impl<F> AsRef<IoRef> for Io<F> {
681    #[inline]
682    fn as_ref(&self) -> &IoRef {
683        self.io_ref()
684    }
685}
686
687impl<F> Eq for Io<F> {}
688
689impl<F> PartialEq for Io<F> {
690    #[inline]
691    fn eq(&self, other: &Self) -> bool {
692        self.io_ref().eq(other.io_ref())
693    }
694}
695
696impl<F> hash::Hash for Io<F> {
697    #[inline]
698    fn hash<H: hash::Hasher>(&self, state: &mut H) {
699        self.io_ref().hash(state);
700    }
701}
702
703impl<F> fmt::Debug for Io<F> {
704    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
705        f.debug_struct("Io").field("state", self.st()).finish()
706    }
707}
708
709impl<F> ops::Deref for Io<F> {
710    type Target = IoRef;
711
712    #[inline]
713    fn deref(&self) -> &Self::Target {
714        self.io_ref()
715    }
716}
717
718impl<F> Drop for Io<F> {
719    fn drop(&mut self) {
720        let st = self.st();
721        self.stop_timer();
722
723        if st.filter.is_set() {
724            if !st.flags.get().is_stopped() {
727                log::trace!(
728                    "{}: Io is dropped, force stopping io streams {:?}",
729                    st.tag.get(),
730                    st.flags.get()
731                );
732            }
733
734            self.force_close();
735            st.filter.drop_filter::<F>();
736        }
737    }
738}
739
740const KIND_SEALED: u8 = 0b01;
741const KIND_PTR: u8 = 0b10;
742const KIND_MASK: u8 = 0b11;
743const KIND_UNMASK: u8 = !KIND_MASK;
744const KIND_MASK_USIZE: usize = 0b11;
745const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
746const SEALED_SIZE: usize = mem::size_of::<Sealed>();
747const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
748
749#[cfg(target_endian = "little")]
750const KIND_IDX: usize = 0;
751
752#[cfg(target_endian = "big")]
753const KIND_IDX: usize = SEALED_SIZE - 1;
754
755struct FilterPtr {
756    data: Cell<[u8; SEALED_SIZE]>,
757    filter: Cell<&'static dyn Filter>,
758}
759
760impl FilterPtr {
761    const fn null() -> Self {
762        Self {
763            data: Cell::new(NULL),
764            filter: Cell::new(NullFilter::get()),
765        }
766    }
767
768    fn update<F: Filter>(&self, filter: F) {
769        if self.is_set() {
770            panic!("Filter is set, must be dropped first");
771        }
772
773        let filter = Box::new(filter);
774        let mut data = NULL;
775        unsafe {
776            let filter_ref: &'static dyn Filter = {
777                let f: &dyn Filter = filter.as_ref();
778                std::mem::transmute(f)
779            };
780            self.filter.set(filter_ref);
781
782            let ptr = &mut data as *mut _ as *mut *mut F;
783            ptr.write(Box::into_raw(filter));
784            data[KIND_IDX] |= KIND_PTR;
785            self.data.set(data);
786        }
787    }
788
789    fn filter<F: Filter>(&self) -> &F {
791        let data = self.data.get();
792        if data[KIND_IDX] & KIND_PTR != 0 {
793            let ptr = &data as *const _ as *const *mut F;
794            unsafe {
795                let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
796                (p as *const F as *mut F).as_ref().unwrap()
797            }
798        } else {
799            panic!("Wrong filter item");
800        }
801    }
802
803    fn take_filter<F>(&self) -> Box<F> {
805        let mut data = self.data.get();
806        if data[KIND_IDX] & KIND_PTR != 0 {
807            data[KIND_IDX] &= KIND_UNMASK;
808            let ptr = &mut data as *mut _ as *mut *mut F;
809            unsafe { Box::from_raw(*ptr) }
810        } else {
811            panic!(
812                "Wrong filter item {:?} expected: {:?}",
813                data[KIND_IDX], KIND_PTR
814            );
815        }
816    }
817
818    fn take_sealed(&self) -> Sealed {
820        let mut data = self.data.get();
821
822        if data[KIND_IDX] & KIND_SEALED != 0 {
823            data[KIND_IDX] &= KIND_UNMASK;
824            let ptr = &mut data as *mut _ as *mut Sealed;
825            unsafe { ptr.read() }
826        } else {
827            panic!(
828                "Wrong filter item {:?} expected: {:?}",
829                data[KIND_IDX], KIND_SEALED
830            );
831        }
832    }
833
834    fn is_set(&self) -> bool {
835        self.data.get()[KIND_IDX] & KIND_MASK != 0
836    }
837
838    fn drop_filter<F>(&self) {
839        let data = self.data.get();
840
841        if data[KIND_IDX] & KIND_MASK != 0 {
842            if data[KIND_IDX] & KIND_PTR != 0 {
843                self.take_filter::<F>();
844            } else if data[KIND_IDX] & KIND_SEALED != 0 {
845                self.take_sealed();
846            }
847            self.data.set(NULL);
848            self.filter.set(NullFilter::get());
849        }
850    }
851}
852
853impl FilterPtr {
854    fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
855        let mut data = NULL;
856        let filter = Box::new(Layer::new(new, *self.take_filter::<F>()));
857        unsafe {
858            let filter_ref: &'static dyn Filter = {
859                let f: &dyn Filter = filter.as_ref();
860                std::mem::transmute(f)
861            };
862            self.filter.set(filter_ref);
863
864            let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
865            ptr.write(Box::into_raw(filter));
866            data[KIND_IDX] |= KIND_PTR;
867            self.data.set(data);
868        }
869    }
870
871    fn map_filter<F: Filter, U, R>(&self, f: U)
872    where
873        U: FnOnce(F) -> R,
874        R: Filter,
875    {
876        let mut data = NULL;
877        let filter = Box::new(f(*self.take_filter::<F>()));
878        unsafe {
879            let filter_ref: &'static dyn Filter = {
880                let f: &dyn Filter = filter.as_ref();
881                std::mem::transmute(f)
882            };
883            self.filter.set(filter_ref);
884
885            let ptr = &mut data as *mut _ as *mut *mut R;
886            ptr.write(Box::into_raw(filter));
887            data[KIND_IDX] |= KIND_PTR;
888            self.data.set(data);
889        }
890    }
891
892    fn seal<F: Filter>(&self) {
893        let mut data = self.data.get();
894
895        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
896            Sealed(Box::new(*self.take_filter::<F>()))
897        } else if data[KIND_IDX] & KIND_SEALED != 0 {
898            self.take_sealed()
899        } else {
900            panic!(
901                "Wrong filter item {:?} expected: {:?}",
902                data[KIND_IDX], KIND_PTR
903            );
904        };
905
906        unsafe {
907            let filter_ref: &'static dyn Filter = {
908                let f: &dyn Filter = filter.0.as_ref();
909                std::mem::transmute(f)
910            };
911            self.filter.set(filter_ref);
912
913            let ptr = &mut data as *mut _ as *mut Sealed;
914            ptr.write(filter);
915            data[KIND_IDX] |= KIND_SEALED;
916            self.data.set(data);
917        }
918    }
919}
920
921#[derive(Debug)]
922#[must_use = "OnDisconnect do nothing unless polled"]
924pub struct OnDisconnect {
925    token: usize,
926    inner: Rc<IoState>,
927}
928
929impl OnDisconnect {
930    pub(super) fn new(inner: Rc<IoState>) -> Self {
931        Self::new_inner(inner.flags.get().is_stopped(), inner)
932    }
933
934    fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
935        let token = if disconnected {
936            usize::MAX
937        } else {
938            let mut on_disconnect = inner.on_disconnect.take();
939            let token = if let Some(ref mut on_disconnect) = on_disconnect {
940                let token = on_disconnect.len();
941                on_disconnect.push(LocalWaker::default());
942                token
943            } else {
944                on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
945                0
946            };
947            inner.on_disconnect.set(on_disconnect);
948            token
949        };
950        Self { token, inner }
951    }
952
953    #[inline]
954    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
956        if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
957            Poll::Ready(())
958        } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
959            on_disconnect[self.token].register(cx.waker());
960            self.inner.on_disconnect.set(Some(on_disconnect));
961            Poll::Pending
962        } else {
963            Poll::Ready(())
964        }
965    }
966}
967
968impl Clone for OnDisconnect {
969    fn clone(&self) -> Self {
970        if self.token == usize::MAX {
971            OnDisconnect::new_inner(true, self.inner.clone())
972        } else {
973            OnDisconnect::new_inner(false, self.inner.clone())
974        }
975    }
976}
977
978impl Future for OnDisconnect {
979    type Output = ();
980
981    #[inline]
982    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
983        self.poll_ready(cx)
984    }
985}
986
987#[cfg(test)]
988mod tests {
989    use ntex_bytes::Bytes;
990    use ntex_codec::BytesCodec;
991
992    use super::*;
993    use crate::{testing::IoTest, ReadBuf, WriteBuf};
994
995    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
996    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
997
998    #[ntex::test]
999    async fn test_basics() {
1000        let (client, server) = IoTest::create();
1001        client.remote_buffer_cap(1024);
1002
1003        let server = Io::new(server);
1004        assert!(server.eq(&server));
1005        assert!(server.io_ref().eq(server.io_ref()));
1006
1007        assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
1008        assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
1009        assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
1010    }
1011
1012    #[ntex::test]
1013    async fn test_recv() {
1014        let (client, server) = IoTest::create();
1015        client.remote_buffer_cap(1024);
1016
1017        let server = Io::new(server);
1018
1019        server.st().notify_timeout();
1020        let err = server.recv(&BytesCodec).await.err().unwrap();
1021        assert!(format!("{:?}", err).contains("Timeout"));
1022
1023        server.st().insert_flags(Flags::DSP_STOP);
1024        let err = server.recv(&BytesCodec).await.err().unwrap();
1025        assert!(format!("{:?}", err).contains("Dispatcher stopped"));
1026
1027        client.write(TEXT);
1028        server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
1029        let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
1030        assert_eq!(item, TEXT);
1031    }
1032
1033    #[ntex::test]
1034    async fn test_send() {
1035        let (client, server) = IoTest::create();
1036        client.remote_buffer_cap(1024);
1037
1038        let server = Io::new(server);
1039        assert!(server.eq(&server));
1040
1041        server
1042            .send(Bytes::from_static(BIN), &BytesCodec)
1043            .await
1044            .ok()
1045            .unwrap();
1046        let item = client.read_any();
1047        assert_eq!(item, TEXT);
1048    }
1049
1050    #[derive(Debug)]
1051    struct DropFilter {
1052        p: Rc<Cell<usize>>,
1053    }
1054
1055    impl Drop for DropFilter {
1056        fn drop(&mut self) {
1057            self.p.set(self.p.get() + 1);
1058        }
1059    }
1060
1061    impl FilterLayer for DropFilter {
1062        fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1063            if let Some(src) = buf.take_src() {
1064                let len = src.len();
1065                buf.set_dst(Some(src));
1066                Ok(len)
1067            } else {
1068                Ok(0)
1069            }
1070        }
1071        fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
1072            if let Some(src) = buf.take_src() {
1073                buf.set_dst(Some(src));
1074            }
1075            Ok(())
1076        }
1077    }
1078
1079    #[ntex::test]
1080    async fn drop_filter() {
1081        let p = Rc::new(Cell::new(0));
1082
1083        let (client, server) = IoTest::create();
1084        let f = DropFilter { p: p.clone() };
1085        let _ = format!("{:?}", f);
1086        let io = Io::new(server).add_filter(f);
1087
1088        client.remote_buffer_cap(1024);
1089        client.write(TEXT);
1090        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1091        assert_eq!(msg, Bytes::from_static(BIN));
1092
1093        io.send(Bytes::from_static(b"test"), &BytesCodec)
1094            .await
1095            .unwrap();
1096        let buf = client.read().await.unwrap();
1097        assert_eq!(buf, Bytes::from_static(b"test"));
1098
1099        let io2 = io.take();
1100        let mut io3: crate::IoBoxed = io2.into();
1101        let io4 = io3.take();
1102
1103        drop(io);
1104        drop(io3);
1105        drop(io4);
1106
1107        assert_eq!(p.get(), 1);
1108    }
1109}