ntex_io/
io.rs

1use std::cell::{Cell, UnsafeCell};
2use std::future::{poll_fn, Future};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_bytes::{PoolId, PoolRef};
7use ntex_codec::{Decoder, Encoder};
8use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
9
10use crate::buf::Stack;
11use crate::filter::{Base, Filter, Layer, NullFilter};
12use crate::flags::Flags;
13use crate::seal::{IoBoxed, Sealed};
14use crate::tasks::{ReadContext, WriteContext};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
17
18/// Interface object to underlying io stream
19pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25    filter: FilterPtr,
26    pub(super) flags: Cell<Flags>,
27    pub(super) pool: Cell<PoolRef>,
28    pub(super) disconnect_timeout: Cell<Seconds>,
29    pub(super) error: Cell<Option<io::Error>>,
30    pub(super) read_task: LocalWaker,
31    pub(super) write_task: LocalWaker,
32    pub(super) dispatch_task: LocalWaker,
33    pub(super) buffer: Stack,
34    pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35    pub(super) timeout: Cell<TimerHandle>,
36    pub(super) tag: Cell<&'static str>,
37    #[allow(clippy::box_collection)]
38    pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
39}
40
41const DEFAULT_TAG: &str = "IO";
42
43impl IoState {
44    pub(super) fn filter(&self) -> &dyn Filter {
45        self.filter.filter.get()
46    }
47
48    pub(super) fn insert_flags(&self, f: Flags) {
49        let mut flags = self.flags.get();
50        flags.insert(f);
51        self.flags.set(flags);
52    }
53
54    pub(super) fn remove_flags(&self, f: Flags) -> bool {
55        let mut flags = self.flags.get();
56        if flags.intersects(f) {
57            flags.remove(f);
58            self.flags.set(flags);
59            true
60        } else {
61            false
62        }
63    }
64
65    pub(super) fn notify_timeout(&self) {
66        let mut flags = self.flags.get();
67        if !flags.contains(Flags::DSP_TIMEOUT) {
68            flags.insert(Flags::DSP_TIMEOUT);
69            self.flags.set(flags);
70            self.dispatch_task.wake();
71            log::trace!("{}: Timer, notify dispatcher", self.tag.get());
72        }
73    }
74
75    pub(super) fn notify_disconnect(&self) {
76        if let Some(on_disconnect) = self.on_disconnect.take() {
77            for item in on_disconnect.into_iter() {
78                item.wake();
79            }
80        }
81    }
82
83    /// Get current io error
84    pub(super) fn error(&self) -> Option<io::Error> {
85        if let Some(err) = self.error.take() {
86            self.error
87                .set(Some(io::Error::new(err.kind(), format!("{}", err))));
88            Some(err)
89        } else {
90            None
91        }
92    }
93
94    /// Get current io result
95    pub(super) fn error_or_disconnected(&self) -> io::Error {
96        self.error()
97            .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
98    }
99
100    pub(super) fn io_stopped(&self, err: Option<io::Error>) {
101        if err.is_some() {
102            self.error.set(err);
103        }
104        self.read_task.wake();
105        self.write_task.wake();
106        self.dispatch_task.wake();
107        self.notify_disconnect();
108        self.handle.take();
109        self.insert_flags(
110            Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
111        );
112    }
113
114    /// Gracefully shutdown read and write io tasks
115    pub(super) fn init_shutdown(&self) {
116        if !self
117            .flags
118            .get()
119            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
120        {
121            log::trace!(
122                "{}: Initiate io shutdown {:?}",
123                self.tag.get(),
124                self.flags.get()
125            );
126            self.insert_flags(Flags::IO_STOPPING_FILTERS);
127            self.read_task.wake();
128        }
129    }
130}
131
132impl Eq for IoState {}
133
134impl PartialEq for IoState {
135    #[inline]
136    fn eq(&self, other: &Self) -> bool {
137        ptr::eq(self, other)
138    }
139}
140
141impl hash::Hash for IoState {
142    #[inline]
143    fn hash<H: hash::Hasher>(&self, state: &mut H) {
144        (self as *const _ as usize).hash(state);
145    }
146}
147
148impl Drop for IoState {
149    #[inline]
150    fn drop(&mut self) {
151        self.buffer.release(self.pool.get());
152    }
153}
154
155impl fmt::Debug for IoState {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        let err = self.error.take();
158        let res = f
159            .debug_struct("IoState")
160            .field("flags", &self.flags)
161            .field("filter", &self.filter.is_set())
162            .field("disconnect_timeout", &self.disconnect_timeout)
163            .field("timeout", &self.timeout)
164            .field("error", &err)
165            .field("buffer", &self.buffer)
166            .finish();
167        self.error.set(err);
168        res
169    }
170}
171
172impl Io {
173    #[inline]
174    /// Create `Io` instance
175    pub fn new<I: IoStream>(io: I) -> Self {
176        Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
177    }
178
179    #[inline]
180    /// Create `Io` instance in specific memory pool.
181    pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
182        let inner = Rc::new(IoState {
183            filter: FilterPtr::null(),
184            pool: Cell::new(pool),
185            flags: Cell::new(Flags::WR_PAUSED),
186            error: Cell::new(None),
187            dispatch_task: LocalWaker::new(),
188            read_task: LocalWaker::new(),
189            write_task: LocalWaker::new(),
190            buffer: Stack::new(),
191            handle: Cell::new(None),
192            timeout: Cell::new(TimerHandle::default()),
193            disconnect_timeout: Cell::new(Seconds(1)),
194            on_disconnect: Cell::new(None),
195            tag: Cell::new(DEFAULT_TAG),
196        });
197        inner.filter.update(Base::new(IoRef(inner.clone())));
198
199        let io_ref = IoRef(inner);
200
201        // start io tasks
202        let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
203        io_ref.0.handle.set(hnd);
204
205        Io(UnsafeCell::new(io_ref), marker::PhantomData)
206    }
207}
208
209impl<F> Io<F> {
210    #[inline]
211    /// Set memory pool
212    pub fn set_memory_pool(&self, pool: PoolRef) {
213        self.st().buffer.set_memory_pool(pool);
214        self.st().pool.set(pool);
215    }
216
217    #[inline]
218    /// Set io disconnect timeout in millis
219    pub fn set_disconnect_timeout(&self, timeout: Seconds) {
220        self.st().disconnect_timeout.set(timeout);
221    }
222
223    #[inline]
224    /// Clone current io object.
225    ///
226    /// Current io object becomes closed.
227    pub fn take(&self) -> Self {
228        Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
229    }
230
231    fn take_io_ref(&self) -> IoRef {
232        let inner = Rc::new(IoState {
233            filter: FilterPtr::null(),
234            pool: self.st().pool.clone(),
235            flags: Cell::new(
236                Flags::DSP_STOP
237                    | Flags::IO_STOPPED
238                    | Flags::IO_STOPPING
239                    | Flags::IO_STOPPING_FILTERS,
240            ),
241            error: Cell::new(None),
242            disconnect_timeout: Cell::new(Seconds(1)),
243            dispatch_task: LocalWaker::new(),
244            read_task: LocalWaker::new(),
245            write_task: LocalWaker::new(),
246            buffer: Stack::new(),
247            handle: Cell::new(None),
248            timeout: Cell::new(TimerHandle::default()),
249            on_disconnect: Cell::new(None),
250            tag: Cell::new(DEFAULT_TAG),
251        });
252        unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
253    }
254}
255
256impl<F> Io<F> {
257    #[inline]
258    #[doc(hidden)]
259    /// Get current state flags
260    pub fn flags(&self) -> Flags {
261        self.st().flags.get()
262    }
263
264    #[inline]
265    /// Get instance of `IoRef`
266    pub fn get_ref(&self) -> IoRef {
267        self.io_ref().clone()
268    }
269
270    fn st(&self) -> &IoState {
271        unsafe { &(*self.0.get()).0 }
272    }
273
274    fn io_ref(&self) -> &IoRef {
275        unsafe { &*self.0.get() }
276    }
277}
278
279impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
280    #[inline]
281    /// Get referece to a filter
282    pub fn filter(&self) -> &F {
283        &self.st().filter.filter::<Layer<F, T>>().0
284    }
285}
286
287impl<F: Filter> Io<F> {
288    #[inline]
289    /// Convert current io stream into sealed version
290    pub fn seal(self) -> Io<Sealed> {
291        let state = self.take_io_ref();
292        state.0.filter.seal::<F>();
293
294        Io(UnsafeCell::new(state), marker::PhantomData)
295    }
296
297    #[inline]
298    /// Convert current io stream into boxed version
299    pub fn boxed(self) -> IoBoxed {
300        self.seal().into()
301    }
302
303    #[inline]
304    /// Map current filter with new one
305    pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
306    where
307        U: FilterLayer,
308    {
309        let state = self.take_io_ref();
310
311        // add layer to buffers
312        if U::BUFFERS {
313            // Safety: .add_layer() only increases internal buffers
314            // there is no api that holds references into buffers storage
315            // all apis first removes buffer from storage and then work with it
316            unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
317                .buffer
318                .add_layer();
319        }
320
321        // replace current filter
322        state.0.filter.add_filter::<F, U>(nf);
323
324        Io(UnsafeCell::new(state), marker::PhantomData)
325    }
326}
327
328impl<F> Io<F> {
329    #[inline]
330    /// Read incoming io stream and decode codec item.
331    pub async fn recv<U>(
332        &self,
333        codec: &U,
334    ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
335    where
336        U: Decoder,
337    {
338        loop {
339            return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
340                Ok(item) => Ok(Some(item)),
341                Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
342                    io::ErrorKind::TimedOut,
343                    "Timeout",
344                ))),
345                Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
346                    io::ErrorKind::UnexpectedEof,
347                    "Dispatcher stopped",
348                ))),
349                Err(RecvError::WriteBackpressure) => {
350                    poll_fn(|cx| self.poll_flush(cx, false))
351                        .await
352                        .map_err(Either::Right)?;
353                    continue;
354                }
355                Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
356                Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
357                Err(RecvError::PeerGone(None)) => Ok(None),
358            };
359        }
360    }
361
362    #[inline]
363    /// Wait until read becomes ready.
364    pub async fn read_ready(&self) -> io::Result<Option<()>> {
365        poll_fn(|cx| self.poll_read_ready(cx)).await
366    }
367
368    #[inline]
369    /// Wait until io reads any data.
370    pub async fn read_notify(&self) -> io::Result<Option<()>> {
371        poll_fn(|cx| self.poll_read_notify(cx)).await
372    }
373
374    #[inline]
375    /// Pause read task
376    pub fn pause(&self) {
377        let st = self.st();
378        if !st.flags.get().contains(Flags::RD_PAUSED) {
379            st.read_task.wake();
380            st.insert_flags(Flags::RD_PAUSED);
381        }
382    }
383
384    #[inline]
385    /// Encode item, send to the peer. Fully flush write buffer.
386    pub async fn send<U>(
387        &self,
388        item: U::Item,
389        codec: &U,
390    ) -> Result<(), Either<U::Error, io::Error>>
391    where
392        U: Encoder,
393    {
394        self.encode(item, codec).map_err(Either::Left)?;
395
396        poll_fn(|cx| self.poll_flush(cx, true))
397            .await
398            .map_err(Either::Right)?;
399
400        Ok(())
401    }
402
403    #[inline]
404    /// Wake write task and instruct to flush data.
405    ///
406    /// This is async version of .poll_flush() method.
407    pub async fn flush(&self, full: bool) -> io::Result<()> {
408        poll_fn(|cx| self.poll_flush(cx, full)).await
409    }
410
411    #[inline]
412    /// Gracefully shutdown io stream
413    pub async fn shutdown(&self) -> io::Result<()> {
414        poll_fn(|cx| self.poll_shutdown(cx)).await
415    }
416
417    #[inline]
418    /// Polls for read readiness.
419    ///
420    /// If the io stream is not currently ready for reading,
421    /// this method will store a clone of the Waker from the provided Context.
422    /// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
423    ///
424    /// Return value
425    /// The function returns:
426    ///
427    /// `Poll::Pending` if the io stream is not ready for reading.
428    /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
429    /// `Poll::Ready(Ok(None))` if io stream is disconnected
430    /// `Some(Poll::Ready(Err(e)))` if an error is encountered.
431    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
432        let st = self.st();
433        let mut flags = st.flags.get();
434
435        if flags.is_stopped() {
436            Poll::Ready(Err(st.error_or_disconnected()))
437        } else {
438            st.dispatch_task.register(cx.waker());
439
440            let ready = flags.is_read_buf_ready();
441            if flags.cannot_read() {
442                flags.cleanup_read_flags();
443                st.read_task.wake();
444                st.flags.set(flags);
445                if ready {
446                    Poll::Ready(Ok(Some(())))
447                } else {
448                    Poll::Pending
449                }
450            } else if ready {
451                flags.remove(Flags::BUF_R_READY);
452                st.flags.set(flags);
453                Poll::Ready(Ok(Some(())))
454            } else {
455                Poll::Pending
456            }
457        }
458    }
459
460    #[inline]
461    /// Polls for any incoming data.
462    pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
463        let ready = self.poll_read_ready(cx);
464
465        if ready.is_pending() {
466            let st = self.st();
467            if st.remove_flags(Flags::RD_NOTIFY) {
468                Poll::Ready(Ok(Some(())))
469            } else {
470                st.insert_flags(Flags::RD_NOTIFY);
471                Poll::Pending
472            }
473        } else {
474            ready
475        }
476    }
477
478    #[inline]
479    /// Decode codec item from incoming bytes stream.
480    ///
481    /// Wake read task and request to read more data if data is not enough for decoding.
482    /// If error get returned this method does not register waker for later wake up action.
483    pub fn poll_recv<U>(
484        &self,
485        codec: &U,
486        cx: &mut Context<'_>,
487    ) -> Poll<Result<U::Item, RecvError<U>>>
488    where
489        U: Decoder,
490    {
491        let decoded = self.poll_recv_decode(codec, cx)?;
492
493        if let Some(item) = decoded.item {
494            Poll::Ready(Ok(item))
495        } else {
496            Poll::Pending
497        }
498    }
499
500    #[doc(hidden)]
501    #[inline]
502    /// Decode codec item from incoming bytes stream.
503    ///
504    /// Wake read task and request to read more data if data is not enough for decoding.
505    /// If error get returned this method does not register waker for later wake up action.
506    pub fn poll_recv_decode<U>(
507        &self,
508        codec: &U,
509        cx: &mut Context<'_>,
510    ) -> Result<Decoded<U::Item>, RecvError<U>>
511    where
512        U: Decoder,
513    {
514        let decoded = self
515            .decode_item(codec)
516            .map_err(|err| RecvError::Decoder(err))?;
517
518        if decoded.item.is_some() {
519            Ok(decoded)
520        } else {
521            let st = self.st();
522            let flags = st.flags.get();
523            if flags.is_stopped() {
524                Err(RecvError::PeerGone(st.error()))
525            } else if flags.contains(Flags::DSP_STOP) {
526                st.remove_flags(Flags::DSP_STOP);
527                Err(RecvError::Stop)
528            } else if flags.contains(Flags::DSP_TIMEOUT) {
529                st.remove_flags(Flags::DSP_TIMEOUT);
530                Err(RecvError::KeepAlive)
531            } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
532                Err(RecvError::WriteBackpressure)
533            } else {
534                match self.poll_read_ready(cx) {
535                    Poll::Pending | Poll::Ready(Ok(Some(()))) => {
536                        if log::log_enabled!(log::Level::Debug) && decoded.remains != 0 {
537                            log::debug!(
538                                "{}: Not enough data to decode next frame",
539                                self.tag()
540                            );
541                        }
542                        Ok(decoded)
543                    }
544                    Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
545                    Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
546                }
547            }
548        }
549    }
550
551    #[inline]
552    /// Wake write task and instruct to flush data.
553    ///
554    /// If `full` is true then wake up dispatcher when all data is flushed
555    /// otherwise wake up when size of write buffer is lower than
556    /// buffer max size.
557    pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
558        let st = self.st();
559        let flags = self.flags();
560
561        let len = st.buffer.write_destination_size();
562        if len > 0 {
563            if full {
564                st.insert_flags(Flags::BUF_W_MUST_FLUSH);
565                st.dispatch_task.register(cx.waker());
566                return if flags.is_stopped() {
567                    Poll::Ready(Err(st.error_or_disconnected()))
568                } else {
569                    Poll::Pending
570                };
571            } else if len >= st.pool.get().write_params_high() << 1 {
572                st.insert_flags(Flags::BUF_W_BACKPRESSURE);
573                st.dispatch_task.register(cx.waker());
574                return if flags.is_stopped() {
575                    Poll::Ready(Err(st.error_or_disconnected()))
576                } else {
577                    Poll::Pending
578                };
579            }
580        }
581        st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
582        Poll::Ready(Ok(()))
583    }
584
585    #[inline]
586    /// Gracefully shutdown io stream
587    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
588        let st = self.st();
589        let flags = st.flags.get();
590
591        if flags.is_stopped() {
592            if let Some(err) = st.error() {
593                Poll::Ready(Err(err))
594            } else {
595                Poll::Ready(Ok(()))
596            }
597        } else {
598            if !flags.contains(Flags::IO_STOPPING_FILTERS) {
599                st.init_shutdown();
600            }
601
602            st.read_task.wake();
603            st.write_task.wake();
604            st.dispatch_task.register(cx.waker());
605            Poll::Pending
606        }
607    }
608
609    #[inline]
610    /// Pause read task
611    ///
612    /// Returns status updates
613    pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
614        self.pause();
615        let result = self.poll_status_update(cx);
616        if !result.is_pending() {
617            self.st().dispatch_task.register(cx.waker());
618        }
619        result
620    }
621
622    #[inline]
623    /// Wait for status updates
624    pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
625        let st = self.st();
626        let flags = st.flags.get();
627        if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
628            Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
629        } else if flags.contains(Flags::DSP_STOP) {
630            st.remove_flags(Flags::DSP_STOP);
631            Poll::Ready(IoStatusUpdate::Stop)
632        } else if flags.contains(Flags::DSP_TIMEOUT) {
633            st.remove_flags(Flags::DSP_TIMEOUT);
634            Poll::Ready(IoStatusUpdate::KeepAlive)
635        } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
636            Poll::Ready(IoStatusUpdate::WriteBackpressure)
637        } else {
638            st.dispatch_task.register(cx.waker());
639            Poll::Pending
640        }
641    }
642
643    #[inline]
644    /// Register dispatch task
645    pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
646        self.st().dispatch_task.register(cx.waker());
647    }
648}
649
650impl<F> AsRef<IoRef> for Io<F> {
651    #[inline]
652    fn as_ref(&self) -> &IoRef {
653        self.io_ref()
654    }
655}
656
657impl<F> Eq for Io<F> {}
658
659impl<F> PartialEq for Io<F> {
660    #[inline]
661    fn eq(&self, other: &Self) -> bool {
662        self.io_ref().eq(other.io_ref())
663    }
664}
665
666impl<F> hash::Hash for Io<F> {
667    #[inline]
668    fn hash<H: hash::Hasher>(&self, state: &mut H) {
669        self.io_ref().hash(state);
670    }
671}
672
673impl<F> fmt::Debug for Io<F> {
674    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
675        f.debug_struct("Io").field("state", self.st()).finish()
676    }
677}
678
679impl<F> ops::Deref for Io<F> {
680    type Target = IoRef;
681
682    #[inline]
683    fn deref(&self) -> &Self::Target {
684        self.io_ref()
685    }
686}
687
688impl<F> Drop for Io<F> {
689    fn drop(&mut self) {
690        let st = self.st();
691        self.stop_timer();
692
693        if st.filter.is_set() {
694            // filter is unsafe and must be dropped explicitly,
695            // and wont be dropped without special attention
696            if !st.flags.get().is_stopped() {
697                log::trace!(
698                    "{}: Io is dropped, force stopping io streams {:?}",
699                    st.tag.get(),
700                    st.flags.get()
701                );
702            }
703
704            self.force_close();
705            st.filter.drop_filter::<F>();
706        }
707    }
708}
709
710const KIND_SEALED: u8 = 0b01;
711const KIND_PTR: u8 = 0b10;
712const KIND_MASK: u8 = 0b11;
713const KIND_UNMASK: u8 = !KIND_MASK;
714const KIND_MASK_USIZE: usize = 0b11;
715const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
716const SEALED_SIZE: usize = mem::size_of::<Sealed>();
717const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
718
719#[cfg(target_endian = "little")]
720const KIND_IDX: usize = 0;
721
722#[cfg(target_endian = "big")]
723const KIND_IDX: usize = SEALED_SIZE - 1;
724
725struct FilterPtr {
726    data: Cell<[u8; SEALED_SIZE]>,
727    filter: Cell<&'static dyn Filter>,
728}
729
730impl FilterPtr {
731    const fn null() -> Self {
732        Self {
733            data: Cell::new(NULL),
734            filter: Cell::new(NullFilter::get()),
735        }
736    }
737
738    fn update<F: Filter>(&self, filter: F) {
739        if self.is_set() {
740            panic!("Filter is set, must be dropped first");
741        }
742
743        let filter = Box::new(filter);
744        let mut data = NULL;
745        unsafe {
746            let filter_ref: &'static dyn Filter = {
747                let f: &dyn Filter = filter.as_ref();
748                std::mem::transmute(f)
749            };
750            self.filter.set(filter_ref);
751
752            let ptr = &mut data as *mut _ as *mut *mut F;
753            ptr.write(Box::into_raw(filter));
754            data[KIND_IDX] |= KIND_PTR;
755            self.data.set(data);
756        }
757    }
758
759    /// Get filter, panic if it is not filter
760    fn filter<F: Filter>(&self) -> &F {
761        let data = self.data.get();
762        if data[KIND_IDX] & KIND_PTR != 0 {
763            let ptr = &data as *const _ as *const *mut F;
764            unsafe {
765                let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
766                (p as *const F as *mut F).as_ref().unwrap()
767            }
768        } else {
769            panic!("Wrong filter item");
770        }
771    }
772
773    /// Get filter, panic if it is not set
774    fn take_filter<F>(&self) -> Box<F> {
775        let mut data = self.data.get();
776        if data[KIND_IDX] & KIND_PTR != 0 {
777            data[KIND_IDX] &= KIND_UNMASK;
778            let ptr = &mut data as *mut _ as *mut *mut F;
779            unsafe { Box::from_raw(*ptr) }
780        } else {
781            panic!(
782                "Wrong filter item {:?} expected: {:?}",
783                data[KIND_IDX], KIND_PTR
784            );
785        }
786    }
787
788    /// Get sealed, panic if it is already sealed
789    fn take_sealed(&self) -> Sealed {
790        let mut data = self.data.get();
791
792        if data[KIND_IDX] & KIND_SEALED != 0 {
793            data[KIND_IDX] &= KIND_UNMASK;
794            let ptr = &mut data as *mut _ as *mut Sealed;
795            unsafe { ptr.read() }
796        } else {
797            panic!(
798                "Wrong filter item {:?} expected: {:?}",
799                data[KIND_IDX], KIND_SEALED
800            );
801        }
802    }
803
804    fn is_set(&self) -> bool {
805        self.data.get()[KIND_IDX] & KIND_MASK != 0
806    }
807
808    fn drop_filter<F>(&self) {
809        let data = self.data.get();
810
811        if data[KIND_IDX] & KIND_MASK != 0 {
812            if data[KIND_IDX] & KIND_PTR != 0 {
813                self.take_filter::<F>();
814            } else if data[KIND_IDX] & KIND_SEALED != 0 {
815                self.take_sealed();
816            }
817            self.data.set(NULL);
818            self.filter.set(NullFilter::get());
819        }
820    }
821}
822
823impl FilterPtr {
824    fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
825        let mut data = NULL;
826        let filter = Box::new(Layer::new(new, *self.take_filter::<F>()));
827        unsafe {
828            let filter_ref: &'static dyn Filter = {
829                let f: &dyn Filter = filter.as_ref();
830                std::mem::transmute(f)
831            };
832            self.filter.set(filter_ref);
833
834            let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
835            ptr.write(Box::into_raw(filter));
836            data[KIND_IDX] |= KIND_PTR;
837            self.data.set(data);
838        }
839    }
840
841    fn seal<F: Filter>(&self) {
842        let mut data = self.data.get();
843
844        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
845            Sealed(Box::new(*self.take_filter::<F>()))
846        } else if data[KIND_IDX] & KIND_SEALED != 0 {
847            self.take_sealed()
848        } else {
849            panic!(
850                "Wrong filter item {:?} expected: {:?}",
851                data[KIND_IDX], KIND_PTR
852            );
853        };
854
855        unsafe {
856            let filter_ref: &'static dyn Filter = {
857                let f: &dyn Filter = filter.0.as_ref();
858                std::mem::transmute(f)
859            };
860            self.filter.set(filter_ref);
861
862            let ptr = &mut data as *mut _ as *mut Sealed;
863            ptr.write(filter);
864            data[KIND_IDX] |= KIND_SEALED;
865            self.data.set(data);
866        }
867    }
868}
869
870#[derive(Debug)]
871/// OnDisconnect future resolves when socket get disconnected
872#[must_use = "OnDisconnect do nothing unless polled"]
873pub struct OnDisconnect {
874    token: usize,
875    inner: Rc<IoState>,
876}
877
878impl OnDisconnect {
879    pub(super) fn new(inner: Rc<IoState>) -> Self {
880        Self::new_inner(inner.flags.get().is_stopped(), inner)
881    }
882
883    fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
884        let token = if disconnected {
885            usize::MAX
886        } else {
887            let mut on_disconnect = inner.on_disconnect.take();
888            let token = if let Some(ref mut on_disconnect) = on_disconnect {
889                let token = on_disconnect.len();
890                on_disconnect.push(LocalWaker::default());
891                token
892            } else {
893                on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
894                0
895            };
896            inner.on_disconnect.set(on_disconnect);
897            token
898        };
899        Self { token, inner }
900    }
901
902    #[inline]
903    /// Check if connection is disconnected
904    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
905        if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
906            Poll::Ready(())
907        } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
908            on_disconnect[self.token].register(cx.waker());
909            self.inner.on_disconnect.set(Some(on_disconnect));
910            Poll::Pending
911        } else {
912            Poll::Ready(())
913        }
914    }
915}
916
917impl Clone for OnDisconnect {
918    fn clone(&self) -> Self {
919        if self.token == usize::MAX {
920            OnDisconnect::new_inner(true, self.inner.clone())
921        } else {
922            OnDisconnect::new_inner(false, self.inner.clone())
923        }
924    }
925}
926
927impl Future for OnDisconnect {
928    type Output = ();
929
930    #[inline]
931    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
932        self.poll_ready(cx)
933    }
934}
935
936#[cfg(test)]
937mod tests {
938    use ntex_bytes::Bytes;
939    use ntex_codec::BytesCodec;
940
941    use super::*;
942    use crate::{testing::IoTest, ReadBuf, WriteBuf};
943
944    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
945    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
946
947    #[ntex::test]
948    async fn test_basics() {
949        let (client, server) = IoTest::create();
950        client.remote_buffer_cap(1024);
951
952        let server = Io::new(server);
953        assert!(server.eq(&server));
954        assert!(server.io_ref().eq(server.io_ref()));
955
956        assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
957        assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
958        assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
959    }
960
961    #[ntex::test]
962    async fn test_recv() {
963        let (client, server) = IoTest::create();
964        client.remote_buffer_cap(1024);
965
966        let server = Io::new(server);
967
968        server.st().notify_timeout();
969        let err = server.recv(&BytesCodec).await.err().unwrap();
970        assert!(format!("{:?}", err).contains("Timeout"));
971
972        server.st().insert_flags(Flags::DSP_STOP);
973        let err = server.recv(&BytesCodec).await.err().unwrap();
974        assert!(format!("{:?}", err).contains("Dispatcher stopped"));
975
976        client.write(TEXT);
977        server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
978        let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
979        assert_eq!(item, TEXT);
980    }
981
982    #[ntex::test]
983    async fn test_send() {
984        let (client, server) = IoTest::create();
985        client.remote_buffer_cap(1024);
986
987        let server = Io::new(server);
988        assert!(server.eq(&server));
989
990        server
991            .send(Bytes::from_static(BIN), &BytesCodec)
992            .await
993            .ok()
994            .unwrap();
995        let item = client.read_any();
996        assert_eq!(item, TEXT);
997    }
998
999    #[derive(Debug)]
1000    struct DropFilter {
1001        p: Rc<Cell<usize>>,
1002    }
1003
1004    impl Drop for DropFilter {
1005        fn drop(&mut self) {
1006            self.p.set(self.p.get() + 1);
1007        }
1008    }
1009
1010    impl FilterLayer for DropFilter {
1011        const BUFFERS: bool = false;
1012        fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1013            Ok(buf.nbytes())
1014        }
1015        fn process_write_buf(&self, _: &WriteBuf<'_>) -> io::Result<()> {
1016            Ok(())
1017        }
1018    }
1019
1020    #[ntex::test]
1021    async fn drop_filter() {
1022        let p = Rc::new(Cell::new(0));
1023
1024        let (client, server) = IoTest::create();
1025        let f = DropFilter { p: p.clone() };
1026        let _ = format!("{:?}", f);
1027        let io = Io::new(server).add_filter(f);
1028
1029        client.remote_buffer_cap(1024);
1030        client.write(TEXT);
1031        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1032        assert_eq!(msg, Bytes::from_static(BIN));
1033
1034        io.send(Bytes::from_static(b"test"), &BytesCodec)
1035            .await
1036            .unwrap();
1037        let buf = client.read().await.unwrap();
1038        assert_eq!(buf, Bytes::from_static(b"test"));
1039
1040        let io2 = io.take();
1041        let mut io3: crate::IoBoxed = io2.into();
1042        let io4 = io3.take();
1043
1044        drop(io);
1045        drop(io3);
1046        drop(io4);
1047
1048        assert_eq!(p.get(), 1);
1049    }
1050}