ntex_io/
io.rs

1use std::cell::{Cell, UnsafeCell};
2use std::future::{poll_fn, Future};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_bytes::{PoolId, PoolRef};
7use ntex_codec::{Decoder, Encoder};
8use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
9
10use crate::buf::Stack;
11use crate::filter::{Base, Filter, Layer, NullFilter};
12use crate::flags::Flags;
13use crate::seal::{IoBoxed, Sealed};
14use crate::tasks::{ReadContext, WriteContext};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
17
18/// Interface object to underlying io stream
19pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25    filter: FilterPtr,
26    pub(super) flags: Cell<Flags>,
27    pub(super) pool: Cell<PoolRef>,
28    pub(super) disconnect_timeout: Cell<Seconds>,
29    pub(super) error: Cell<Option<io::Error>>,
30    pub(super) read_task: LocalWaker,
31    pub(super) write_task: LocalWaker,
32    pub(super) dispatch_task: LocalWaker,
33    pub(super) buffer: Stack,
34    pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35    pub(super) timeout: Cell<TimerHandle>,
36    pub(super) tag: Cell<&'static str>,
37    #[allow(clippy::box_collection)]
38    pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
39}
40
41const DEFAULT_TAG: &str = "IO";
42
43impl IoState {
44    pub(super) fn filter(&self) -> &dyn Filter {
45        self.filter.filter.get()
46    }
47
48    pub(super) fn insert_flags(&self, f: Flags) {
49        let mut flags = self.flags.get();
50        flags.insert(f);
51        self.flags.set(flags);
52    }
53
54    pub(super) fn remove_flags(&self, f: Flags) -> bool {
55        let mut flags = self.flags.get();
56        if flags.intersects(f) {
57            flags.remove(f);
58            self.flags.set(flags);
59            true
60        } else {
61            false
62        }
63    }
64
65    pub(super) fn notify_timeout(&self) {
66        let mut flags = self.flags.get();
67        if !flags.contains(Flags::DSP_TIMEOUT) {
68            flags.insert(Flags::DSP_TIMEOUT);
69            self.flags.set(flags);
70            self.dispatch_task.wake();
71            log::trace!("{}: Timer, notify dispatcher", self.tag.get());
72        }
73    }
74
75    pub(super) fn notify_disconnect(&self) {
76        if let Some(on_disconnect) = self.on_disconnect.take() {
77            for item in on_disconnect.into_iter() {
78                item.wake();
79            }
80        }
81    }
82
83    /// Get current io error
84    pub(super) fn error(&self) -> Option<io::Error> {
85        if let Some(err) = self.error.take() {
86            self.error
87                .set(Some(io::Error::new(err.kind(), format!("{}", err))));
88            Some(err)
89        } else {
90            None
91        }
92    }
93
94    /// Get current io result
95    pub(super) fn error_or_disconnected(&self) -> io::Error {
96        self.error()
97            .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
98    }
99
100    pub(super) fn io_stopped(&self, err: Option<io::Error>) {
101        if !self.flags.get().contains(Flags::IO_STOPPED) {
102            if err.is_some() {
103                self.error.set(err);
104            }
105            self.read_task.wake();
106            self.write_task.wake();
107            self.dispatch_task.wake();
108            self.notify_disconnect();
109            self.handle.take();
110            self.insert_flags(
111                Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
112            );
113        }
114    }
115
116    /// Gracefully shutdown read and write io tasks
117    pub(super) fn init_shutdown(&self) {
118        if !self
119            .flags
120            .get()
121            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
122        {
123            log::trace!(
124                "{}: Initiate io shutdown {:?}",
125                self.tag.get(),
126                self.flags.get()
127            );
128            self.insert_flags(Flags::IO_STOPPING_FILTERS);
129            self.read_task.wake();
130        }
131    }
132}
133
134impl Eq for IoState {}
135
136impl PartialEq for IoState {
137    #[inline]
138    fn eq(&self, other: &Self) -> bool {
139        ptr::eq(self, other)
140    }
141}
142
143impl hash::Hash for IoState {
144    #[inline]
145    fn hash<H: hash::Hasher>(&self, state: &mut H) {
146        (self as *const _ as usize).hash(state);
147    }
148}
149
150impl Drop for IoState {
151    #[inline]
152    fn drop(&mut self) {
153        self.buffer.release(self.pool.get());
154    }
155}
156
157impl fmt::Debug for IoState {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        let err = self.error.take();
160        let res = f
161            .debug_struct("IoState")
162            .field("flags", &self.flags)
163            .field("filter", &self.filter.is_set())
164            .field("disconnect_timeout", &self.disconnect_timeout)
165            .field("timeout", &self.timeout)
166            .field("error", &err)
167            .field("buffer", &self.buffer)
168            .finish();
169        self.error.set(err);
170        res
171    }
172}
173
174impl Io {
175    #[inline]
176    /// Create `Io` instance
177    pub fn new<I: IoStream>(io: I) -> Self {
178        Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
179    }
180
181    #[inline]
182    /// Create `Io` instance in specific memory pool.
183    pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
184        let inner = Rc::new(IoState {
185            filter: FilterPtr::null(),
186            pool: Cell::new(pool),
187            flags: Cell::new(Flags::WR_PAUSED),
188            error: Cell::new(None),
189            dispatch_task: LocalWaker::new(),
190            read_task: LocalWaker::new(),
191            write_task: LocalWaker::new(),
192            buffer: Stack::new(),
193            handle: Cell::new(None),
194            timeout: Cell::new(TimerHandle::default()),
195            disconnect_timeout: Cell::new(Seconds(1)),
196            on_disconnect: Cell::new(None),
197            tag: Cell::new(DEFAULT_TAG),
198        });
199        inner.filter.update(Base::new(IoRef(inner.clone())));
200
201        let io_ref = IoRef(inner);
202
203        // start io tasks
204        let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
205        io_ref.0.handle.set(hnd);
206
207        Io(UnsafeCell::new(io_ref), marker::PhantomData)
208    }
209}
210
211impl<F> Io<F> {
212    #[inline]
213    /// Set memory pool
214    pub fn set_memory_pool(&self, pool: PoolRef) {
215        self.st().buffer.set_memory_pool(pool);
216        self.st().pool.set(pool);
217    }
218
219    #[inline]
220    /// Set io disconnect timeout in millis
221    pub fn set_disconnect_timeout(&self, timeout: Seconds) {
222        self.st().disconnect_timeout.set(timeout);
223    }
224
225    #[inline]
226    /// Clone current io object.
227    ///
228    /// Current io object becomes closed.
229    pub fn take(&self) -> Self {
230        Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
231    }
232
233    fn take_io_ref(&self) -> IoRef {
234        let inner = Rc::new(IoState {
235            filter: FilterPtr::null(),
236            pool: self.st().pool.clone(),
237            flags: Cell::new(
238                Flags::DSP_STOP
239                    | Flags::IO_STOPPED
240                    | Flags::IO_STOPPING
241                    | Flags::IO_STOPPING_FILTERS,
242            ),
243            error: Cell::new(None),
244            disconnect_timeout: Cell::new(Seconds(1)),
245            dispatch_task: LocalWaker::new(),
246            read_task: LocalWaker::new(),
247            write_task: LocalWaker::new(),
248            buffer: Stack::new(),
249            handle: Cell::new(None),
250            timeout: Cell::new(TimerHandle::default()),
251            on_disconnect: Cell::new(None),
252            tag: Cell::new(DEFAULT_TAG),
253        });
254        unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
255    }
256}
257
258impl<F> Io<F> {
259    #[inline]
260    #[doc(hidden)]
261    /// Get current state flags
262    pub fn flags(&self) -> Flags {
263        self.st().flags.get()
264    }
265
266    #[inline]
267    /// Get instance of `IoRef`
268    pub fn get_ref(&self) -> IoRef {
269        self.io_ref().clone()
270    }
271
272    fn st(&self) -> &IoState {
273        unsafe { &(*self.0.get()).0 }
274    }
275
276    fn io_ref(&self) -> &IoRef {
277        unsafe { &*self.0.get() }
278    }
279}
280
281impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
282    #[inline]
283    /// Get referece to a filter
284    pub fn filter(&self) -> &F {
285        &self.st().filter.filter::<Layer<F, T>>().0
286    }
287}
288
289impl<F: Filter> Io<F> {
290    #[inline]
291    /// Convert current io stream into sealed version
292    pub fn seal(self) -> Io<Sealed> {
293        let state = self.take_io_ref();
294        state.0.filter.seal::<F>();
295
296        Io(UnsafeCell::new(state), marker::PhantomData)
297    }
298
299    #[inline]
300    /// Convert current io stream into boxed version
301    pub fn boxed(self) -> IoBoxed {
302        self.seal().into()
303    }
304
305    #[inline]
306    /// Map current filter with new one
307    pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
308    where
309        U: FilterLayer,
310    {
311        let state = self.take_io_ref();
312
313        // add layer to buffers
314        if U::BUFFERS {
315            // Safety: .add_layer() only increases internal buffers
316            // there is no api that holds references into buffers storage
317            // all apis first removes buffer from storage and then work with it
318            unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
319                .buffer
320                .add_layer();
321        }
322
323        // replace current filter
324        state.0.filter.add_filter::<F, U>(nf);
325
326        Io(UnsafeCell::new(state), marker::PhantomData)
327    }
328}
329
330impl<F> Io<F> {
331    #[inline]
332    /// Read incoming io stream and decode codec item.
333    pub async fn recv<U>(
334        &self,
335        codec: &U,
336    ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
337    where
338        U: Decoder,
339    {
340        loop {
341            return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
342                Ok(item) => Ok(Some(item)),
343                Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
344                    io::ErrorKind::TimedOut,
345                    "Timeout",
346                ))),
347                Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
348                    io::ErrorKind::UnexpectedEof,
349                    "Dispatcher stopped",
350                ))),
351                Err(RecvError::WriteBackpressure) => {
352                    poll_fn(|cx| self.poll_flush(cx, false))
353                        .await
354                        .map_err(Either::Right)?;
355                    continue;
356                }
357                Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
358                Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
359                Err(RecvError::PeerGone(None)) => Ok(None),
360            };
361        }
362    }
363
364    #[inline]
365    /// Wait until read becomes ready.
366    pub async fn read_ready(&self) -> io::Result<Option<()>> {
367        poll_fn(|cx| self.poll_read_ready(cx)).await
368    }
369
370    #[inline]
371    /// Wait until io reads any data.
372    pub async fn read_notify(&self) -> io::Result<Option<()>> {
373        poll_fn(|cx| self.poll_read_notify(cx)).await
374    }
375
376    #[inline]
377    /// Pause read task
378    pub fn pause(&self) {
379        let st = self.st();
380        if !st.flags.get().contains(Flags::RD_PAUSED) {
381            st.read_task.wake();
382            st.insert_flags(Flags::RD_PAUSED);
383        }
384    }
385
386    #[inline]
387    /// Encode item, send to the peer. Fully flush write buffer.
388    pub async fn send<U>(
389        &self,
390        item: U::Item,
391        codec: &U,
392    ) -> Result<(), Either<U::Error, io::Error>>
393    where
394        U: Encoder,
395    {
396        self.encode(item, codec).map_err(Either::Left)?;
397
398        poll_fn(|cx| self.poll_flush(cx, true))
399            .await
400            .map_err(Either::Right)?;
401
402        Ok(())
403    }
404
405    #[inline]
406    /// Wake write task and instruct to flush data.
407    ///
408    /// This is async version of .poll_flush() method.
409    pub async fn flush(&self, full: bool) -> io::Result<()> {
410        poll_fn(|cx| self.poll_flush(cx, full)).await
411    }
412
413    #[inline]
414    /// Gracefully shutdown io stream
415    pub async fn shutdown(&self) -> io::Result<()> {
416        poll_fn(|cx| self.poll_shutdown(cx)).await
417    }
418
419    #[inline]
420    /// Polls for read readiness.
421    ///
422    /// If the io stream is not currently ready for reading,
423    /// this method will store a clone of the Waker from the provided Context.
424    /// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
425    ///
426    /// Return value
427    /// The function returns:
428    ///
429    /// `Poll::Pending` if the io stream is not ready for reading.
430    /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
431    /// `Poll::Ready(Ok(None))` if io stream is disconnected
432    /// `Some(Poll::Ready(Err(e)))` if an error is encountered.
433    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
434        let st = self.st();
435        let mut flags = st.flags.get();
436
437        if flags.is_stopped() {
438            Poll::Ready(Err(st.error_or_disconnected()))
439        } else {
440            st.dispatch_task.register(cx.waker());
441
442            let ready = flags.is_read_buf_ready();
443            if flags.cannot_read() {
444                flags.cleanup_read_flags();
445                st.read_task.wake();
446                st.flags.set(flags);
447                if ready {
448                    Poll::Ready(Ok(Some(())))
449                } else {
450                    Poll::Pending
451                }
452            } else if ready {
453                flags.remove(Flags::BUF_R_READY);
454                st.flags.set(flags);
455                Poll::Ready(Ok(Some(())))
456            } else {
457                Poll::Pending
458            }
459        }
460    }
461
462    #[inline]
463    /// Polls for any incoming data.
464    pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
465        let ready = self.poll_read_ready(cx);
466
467        if ready.is_pending() {
468            let st = self.st();
469            if st.remove_flags(Flags::RD_NOTIFY) {
470                Poll::Ready(Ok(Some(())))
471            } else {
472                st.insert_flags(Flags::RD_NOTIFY);
473                Poll::Pending
474            }
475        } else {
476            ready
477        }
478    }
479
480    #[inline]
481    /// Decode codec item from incoming bytes stream.
482    ///
483    /// Wake read task and request to read more data if data is not enough for decoding.
484    /// If error get returned this method does not register waker for later wake up action.
485    pub fn poll_recv<U>(
486        &self,
487        codec: &U,
488        cx: &mut Context<'_>,
489    ) -> Poll<Result<U::Item, RecvError<U>>>
490    where
491        U: Decoder,
492    {
493        let decoded = self.poll_recv_decode(codec, cx)?;
494
495        if let Some(item) = decoded.item {
496            Poll::Ready(Ok(item))
497        } else {
498            Poll::Pending
499        }
500    }
501
502    #[doc(hidden)]
503    #[inline]
504    /// Decode codec item from incoming bytes stream.
505    ///
506    /// Wake read task and request to read more data if data is not enough for decoding.
507    /// If error get returned this method does not register waker for later wake up action.
508    pub fn poll_recv_decode<U>(
509        &self,
510        codec: &U,
511        cx: &mut Context<'_>,
512    ) -> Result<Decoded<U::Item>, RecvError<U>>
513    where
514        U: Decoder,
515    {
516        let decoded = self
517            .decode_item(codec)
518            .map_err(|err| RecvError::Decoder(err))?;
519
520        if decoded.item.is_some() {
521            Ok(decoded)
522        } else {
523            let st = self.st();
524            let flags = st.flags.get();
525            if flags.is_stopped() {
526                Err(RecvError::PeerGone(st.error()))
527            } else if flags.contains(Flags::DSP_STOP) {
528                st.remove_flags(Flags::DSP_STOP);
529                Err(RecvError::Stop)
530            } else if flags.contains(Flags::DSP_TIMEOUT) {
531                st.remove_flags(Flags::DSP_TIMEOUT);
532                Err(RecvError::KeepAlive)
533            } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
534                Err(RecvError::WriteBackpressure)
535            } else {
536                match self.poll_read_ready(cx) {
537                    Poll::Pending | Poll::Ready(Ok(Some(()))) => {
538                        if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
539                            log::trace!(
540                                "{}: Not enough data to decode next frame",
541                                self.tag()
542                            );
543                        }
544                        Ok(decoded)
545                    }
546                    Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
547                    Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
548                }
549            }
550        }
551    }
552
553    #[inline]
554    /// Wake write task and instruct to flush data.
555    ///
556    /// If `full` is true then wake up dispatcher when all data is flushed
557    /// otherwise wake up when size of write buffer is lower than
558    /// buffer max size.
559    pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
560        let st = self.st();
561        let flags = self.flags();
562
563        let len = st.buffer.write_destination_size();
564        if len > 0 {
565            if full {
566                st.insert_flags(Flags::BUF_W_MUST_FLUSH);
567                st.dispatch_task.register(cx.waker());
568                return if flags.is_stopped() {
569                    Poll::Ready(Err(st.error_or_disconnected()))
570                } else {
571                    Poll::Pending
572                };
573            } else if len >= st.pool.get().write_params_high() << 1 {
574                st.insert_flags(Flags::BUF_W_BACKPRESSURE);
575                st.dispatch_task.register(cx.waker());
576                return if flags.is_stopped() {
577                    Poll::Ready(Err(st.error_or_disconnected()))
578                } else {
579                    Poll::Pending
580                };
581            }
582        }
583        st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
584        Poll::Ready(Ok(()))
585    }
586
587    #[inline]
588    /// Gracefully shutdown io stream
589    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
590        let st = self.st();
591        let flags = st.flags.get();
592
593        if flags.is_stopped() {
594            if let Some(err) = st.error() {
595                Poll::Ready(Err(err))
596            } else {
597                Poll::Ready(Ok(()))
598            }
599        } else {
600            if !flags.contains(Flags::IO_STOPPING_FILTERS) {
601                st.init_shutdown();
602            }
603
604            st.read_task.wake();
605            st.write_task.wake();
606            st.dispatch_task.register(cx.waker());
607            Poll::Pending
608        }
609    }
610
611    #[inline]
612    /// Pause read task
613    ///
614    /// Returns status updates
615    pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
616        self.pause();
617        let result = self.poll_status_update(cx);
618        if !result.is_pending() {
619            self.st().dispatch_task.register(cx.waker());
620        }
621        result
622    }
623
624    #[inline]
625    /// Wait for status updates
626    pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
627        let st = self.st();
628        let flags = st.flags.get();
629        if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
630            Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
631        } else if flags.contains(Flags::DSP_STOP) {
632            st.remove_flags(Flags::DSP_STOP);
633            Poll::Ready(IoStatusUpdate::Stop)
634        } else if flags.contains(Flags::DSP_TIMEOUT) {
635            st.remove_flags(Flags::DSP_TIMEOUT);
636            Poll::Ready(IoStatusUpdate::KeepAlive)
637        } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
638            Poll::Ready(IoStatusUpdate::WriteBackpressure)
639        } else {
640            st.dispatch_task.register(cx.waker());
641            Poll::Pending
642        }
643    }
644
645    #[inline]
646    /// Register dispatch task
647    pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
648        self.st().dispatch_task.register(cx.waker());
649    }
650}
651
652impl<F> AsRef<IoRef> for Io<F> {
653    #[inline]
654    fn as_ref(&self) -> &IoRef {
655        self.io_ref()
656    }
657}
658
659impl<F> Eq for Io<F> {}
660
661impl<F> PartialEq for Io<F> {
662    #[inline]
663    fn eq(&self, other: &Self) -> bool {
664        self.io_ref().eq(other.io_ref())
665    }
666}
667
668impl<F> hash::Hash for Io<F> {
669    #[inline]
670    fn hash<H: hash::Hasher>(&self, state: &mut H) {
671        self.io_ref().hash(state);
672    }
673}
674
675impl<F> fmt::Debug for Io<F> {
676    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
677        f.debug_struct("Io").field("state", self.st()).finish()
678    }
679}
680
681impl<F> ops::Deref for Io<F> {
682    type Target = IoRef;
683
684    #[inline]
685    fn deref(&self) -> &Self::Target {
686        self.io_ref()
687    }
688}
689
690impl<F> Drop for Io<F> {
691    fn drop(&mut self) {
692        let st = self.st();
693        self.stop_timer();
694
695        if st.filter.is_set() {
696            // filter is unsafe and must be dropped explicitly,
697            // and wont be dropped without special attention
698            if !st.flags.get().is_stopped() {
699                log::trace!(
700                    "{}: Io is dropped, force stopping io streams {:?}",
701                    st.tag.get(),
702                    st.flags.get()
703                );
704            }
705
706            self.force_close();
707            st.filter.drop_filter::<F>();
708        }
709    }
710}
711
712const KIND_SEALED: u8 = 0b01;
713const KIND_PTR: u8 = 0b10;
714const KIND_MASK: u8 = 0b11;
715const KIND_UNMASK: u8 = !KIND_MASK;
716const KIND_MASK_USIZE: usize = 0b11;
717const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
718const SEALED_SIZE: usize = mem::size_of::<Sealed>();
719const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
720
721#[cfg(target_endian = "little")]
722const KIND_IDX: usize = 0;
723
724#[cfg(target_endian = "big")]
725const KIND_IDX: usize = SEALED_SIZE - 1;
726
727struct FilterPtr {
728    data: Cell<[u8; SEALED_SIZE]>,
729    filter: Cell<&'static dyn Filter>,
730}
731
732impl FilterPtr {
733    const fn null() -> Self {
734        Self {
735            data: Cell::new(NULL),
736            filter: Cell::new(NullFilter::get()),
737        }
738    }
739
740    fn update<F: Filter>(&self, filter: F) {
741        if self.is_set() {
742            panic!("Filter is set, must be dropped first");
743        }
744
745        let filter = Box::new(filter);
746        let mut data = NULL;
747        unsafe {
748            let filter_ref: &'static dyn Filter = {
749                let f: &dyn Filter = filter.as_ref();
750                std::mem::transmute(f)
751            };
752            self.filter.set(filter_ref);
753
754            let ptr = &mut data as *mut _ as *mut *mut F;
755            ptr.write(Box::into_raw(filter));
756            data[KIND_IDX] |= KIND_PTR;
757            self.data.set(data);
758        }
759    }
760
761    /// Get filter, panic if it is not filter
762    fn filter<F: Filter>(&self) -> &F {
763        let data = self.data.get();
764        if data[KIND_IDX] & KIND_PTR != 0 {
765            let ptr = &data as *const _ as *const *mut F;
766            unsafe {
767                let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
768                (p as *const F as *mut F).as_ref().unwrap()
769            }
770        } else {
771            panic!("Wrong filter item");
772        }
773    }
774
775    /// Get filter, panic if it is not set
776    fn take_filter<F>(&self) -> Box<F> {
777        let mut data = self.data.get();
778        if data[KIND_IDX] & KIND_PTR != 0 {
779            data[KIND_IDX] &= KIND_UNMASK;
780            let ptr = &mut data as *mut _ as *mut *mut F;
781            unsafe { Box::from_raw(*ptr) }
782        } else {
783            panic!(
784                "Wrong filter item {:?} expected: {:?}",
785                data[KIND_IDX], KIND_PTR
786            );
787        }
788    }
789
790    /// Get sealed, panic if it is already sealed
791    fn take_sealed(&self) -> Sealed {
792        let mut data = self.data.get();
793
794        if data[KIND_IDX] & KIND_SEALED != 0 {
795            data[KIND_IDX] &= KIND_UNMASK;
796            let ptr = &mut data as *mut _ as *mut Sealed;
797            unsafe { ptr.read() }
798        } else {
799            panic!(
800                "Wrong filter item {:?} expected: {:?}",
801                data[KIND_IDX], KIND_SEALED
802            );
803        }
804    }
805
806    fn is_set(&self) -> bool {
807        self.data.get()[KIND_IDX] & KIND_MASK != 0
808    }
809
810    fn drop_filter<F>(&self) {
811        let data = self.data.get();
812
813        if data[KIND_IDX] & KIND_MASK != 0 {
814            if data[KIND_IDX] & KIND_PTR != 0 {
815                self.take_filter::<F>();
816            } else if data[KIND_IDX] & KIND_SEALED != 0 {
817                self.take_sealed();
818            }
819            self.data.set(NULL);
820            self.filter.set(NullFilter::get());
821        }
822    }
823}
824
825impl FilterPtr {
826    fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
827        let mut data = NULL;
828        let filter = Box::new(Layer::new(new, *self.take_filter::<F>()));
829        unsafe {
830            let filter_ref: &'static dyn Filter = {
831                let f: &dyn Filter = filter.as_ref();
832                std::mem::transmute(f)
833            };
834            self.filter.set(filter_ref);
835
836            let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
837            ptr.write(Box::into_raw(filter));
838            data[KIND_IDX] |= KIND_PTR;
839            self.data.set(data);
840        }
841    }
842
843    fn seal<F: Filter>(&self) {
844        let mut data = self.data.get();
845
846        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
847            Sealed(Box::new(*self.take_filter::<F>()))
848        } else if data[KIND_IDX] & KIND_SEALED != 0 {
849            self.take_sealed()
850        } else {
851            panic!(
852                "Wrong filter item {:?} expected: {:?}",
853                data[KIND_IDX], KIND_PTR
854            );
855        };
856
857        unsafe {
858            let filter_ref: &'static dyn Filter = {
859                let f: &dyn Filter = filter.0.as_ref();
860                std::mem::transmute(f)
861            };
862            self.filter.set(filter_ref);
863
864            let ptr = &mut data as *mut _ as *mut Sealed;
865            ptr.write(filter);
866            data[KIND_IDX] |= KIND_SEALED;
867            self.data.set(data);
868        }
869    }
870}
871
872#[derive(Debug)]
873/// OnDisconnect future resolves when socket get disconnected
874#[must_use = "OnDisconnect do nothing unless polled"]
875pub struct OnDisconnect {
876    token: usize,
877    inner: Rc<IoState>,
878}
879
880impl OnDisconnect {
881    pub(super) fn new(inner: Rc<IoState>) -> Self {
882        Self::new_inner(inner.flags.get().is_stopped(), inner)
883    }
884
885    fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
886        let token = if disconnected {
887            usize::MAX
888        } else {
889            let mut on_disconnect = inner.on_disconnect.take();
890            let token = if let Some(ref mut on_disconnect) = on_disconnect {
891                let token = on_disconnect.len();
892                on_disconnect.push(LocalWaker::default());
893                token
894            } else {
895                on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
896                0
897            };
898            inner.on_disconnect.set(on_disconnect);
899            token
900        };
901        Self { token, inner }
902    }
903
904    #[inline]
905    /// Check if connection is disconnected
906    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
907        if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
908            Poll::Ready(())
909        } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
910            on_disconnect[self.token].register(cx.waker());
911            self.inner.on_disconnect.set(Some(on_disconnect));
912            Poll::Pending
913        } else {
914            Poll::Ready(())
915        }
916    }
917}
918
919impl Clone for OnDisconnect {
920    fn clone(&self) -> Self {
921        if self.token == usize::MAX {
922            OnDisconnect::new_inner(true, self.inner.clone())
923        } else {
924            OnDisconnect::new_inner(false, self.inner.clone())
925        }
926    }
927}
928
929impl Future for OnDisconnect {
930    type Output = ();
931
932    #[inline]
933    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
934        self.poll_ready(cx)
935    }
936}
937
938#[cfg(test)]
939mod tests {
940    use ntex_bytes::Bytes;
941    use ntex_codec::BytesCodec;
942
943    use super::*;
944    use crate::{testing::IoTest, ReadBuf, WriteBuf};
945
946    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
947    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
948
949    #[ntex::test]
950    async fn test_basics() {
951        let (client, server) = IoTest::create();
952        client.remote_buffer_cap(1024);
953
954        let server = Io::new(server);
955        assert!(server.eq(&server));
956        assert!(server.io_ref().eq(server.io_ref()));
957
958        assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
959        assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
960        assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
961    }
962
963    #[ntex::test]
964    async fn test_recv() {
965        let (client, server) = IoTest::create();
966        client.remote_buffer_cap(1024);
967
968        let server = Io::new(server);
969
970        server.st().notify_timeout();
971        let err = server.recv(&BytesCodec).await.err().unwrap();
972        assert!(format!("{:?}", err).contains("Timeout"));
973
974        server.st().insert_flags(Flags::DSP_STOP);
975        let err = server.recv(&BytesCodec).await.err().unwrap();
976        assert!(format!("{:?}", err).contains("Dispatcher stopped"));
977
978        client.write(TEXT);
979        server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
980        let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
981        assert_eq!(item, TEXT);
982    }
983
984    #[ntex::test]
985    async fn test_send() {
986        let (client, server) = IoTest::create();
987        client.remote_buffer_cap(1024);
988
989        let server = Io::new(server);
990        assert!(server.eq(&server));
991
992        server
993            .send(Bytes::from_static(BIN), &BytesCodec)
994            .await
995            .ok()
996            .unwrap();
997        let item = client.read_any();
998        assert_eq!(item, TEXT);
999    }
1000
1001    #[derive(Debug)]
1002    struct DropFilter {
1003        p: Rc<Cell<usize>>,
1004    }
1005
1006    impl Drop for DropFilter {
1007        fn drop(&mut self) {
1008            self.p.set(self.p.get() + 1);
1009        }
1010    }
1011
1012    impl FilterLayer for DropFilter {
1013        const BUFFERS: bool = false;
1014        fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1015            Ok(buf.nbytes())
1016        }
1017        fn process_write_buf(&self, _: &WriteBuf<'_>) -> io::Result<()> {
1018            Ok(())
1019        }
1020    }
1021
1022    #[ntex::test]
1023    async fn drop_filter() {
1024        let p = Rc::new(Cell::new(0));
1025
1026        let (client, server) = IoTest::create();
1027        let f = DropFilter { p: p.clone() };
1028        let _ = format!("{:?}", f);
1029        let io = Io::new(server).add_filter(f);
1030
1031        client.remote_buffer_cap(1024);
1032        client.write(TEXT);
1033        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1034        assert_eq!(msg, Bytes::from_static(BIN));
1035
1036        io.send(Bytes::from_static(b"test"), &BytesCodec)
1037            .await
1038            .unwrap();
1039        let buf = client.read().await.unwrap();
1040        assert_eq!(buf, Bytes::from_static(b"test"));
1041
1042        let io2 = io.take();
1043        let mut io3: crate::IoBoxed = io2.into();
1044        let io4 = io3.take();
1045
1046        drop(io);
1047        drop(io3);
1048        drop(io4);
1049
1050        assert_eq!(p.get(), 1);
1051    }
1052}