ntex_io/
io.rs

1use std::cell::{Cell, UnsafeCell};
2use std::future::{Future, poll_fn};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::cfg::SharedCfg;
8use ntex_util::{future::Either, task::LocalWaker};
9
10use crate::buf::Stack;
11use crate::cfg::{BufConfig, IoConfig};
12use crate::filter::{Base, Filter, Layer, NullFilter};
13use crate::flags::Flags;
14use crate::seal::{IoBoxed, Sealed};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoContext, IoStatusUpdate, IoStream, RecvError};
17
18/// Interface object to underlying io stream
19pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25    filter: FilterPtr,
26    pub(super) cfg: Cell<&'static IoConfig>,
27    pub(super) flags: Cell<Flags>,
28    pub(super) error: Cell<Option<io::Error>>,
29    pub(super) read_task: LocalWaker,
30    pub(super) write_task: LocalWaker,
31    pub(super) dispatch_task: LocalWaker,
32    pub(super) buffer: Stack,
33    pub(super) handle: Cell<Option<Box<dyn Handle>>>,
34    pub(super) timeout: Cell<TimerHandle>,
35    #[allow(clippy::box_collection)]
36    pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
37}
38
39impl IoState {
40    pub(super) fn filter(&self) -> &dyn Filter {
41        self.filter.filter.get()
42    }
43
44    pub(super) fn insert_flags(&self, f: Flags) {
45        let mut flags = self.flags.get();
46        flags.insert(f);
47        self.flags.set(flags);
48    }
49
50    pub(super) fn remove_flags(&self, f: Flags) -> bool {
51        let mut flags = self.flags.get();
52        if flags.intersects(f) {
53            flags.remove(f);
54            self.flags.set(flags);
55            true
56        } else {
57            false
58        }
59    }
60
61    pub(super) fn notify_timeout(&self) {
62        let mut flags = self.flags.get();
63        if !flags.contains(Flags::DSP_TIMEOUT) {
64            flags.insert(Flags::DSP_TIMEOUT);
65            self.flags.set(flags);
66            self.dispatch_task.wake();
67            log::trace!("{}: Timer, notify dispatcher", self.cfg.get().tag());
68        }
69    }
70
71    pub(super) fn notify_disconnect(&self) {
72        if let Some(on_disconnect) = self.on_disconnect.take() {
73            for item in on_disconnect.into_iter() {
74                item.wake();
75            }
76        }
77    }
78
79    /// Get current io error
80    pub(super) fn error(&self) -> Option<io::Error> {
81        if let Some(err) = self.error.take() {
82            self.error
83                .set(Some(io::Error::new(err.kind(), format!("{err}"))));
84            Some(err)
85        } else {
86            None
87        }
88    }
89
90    /// Get current io result
91    pub(super) fn error_or_disconnected(&self) -> io::Error {
92        self.error()
93            .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
94    }
95
96    pub(super) fn io_stopped(&self, err: Option<io::Error>) {
97        if !self.flags.get().is_stopped() {
98            log::trace!(
99                "{}: {} Io error {:?} flags: {:?}",
100                self.cfg.get().tag(),
101                self as *const _ as usize,
102                err,
103                self.flags.get()
104            );
105
106            if err.is_some() {
107                self.error.set(err);
108            }
109            self.read_task.wake();
110            self.write_task.wake();
111            self.notify_disconnect();
112            self.handle.take();
113            self.insert_flags(
114                Flags::IO_STOPPED
115                    | Flags::IO_STOPPING
116                    | Flags::IO_STOPPING_FILTERS
117                    | Flags::BUF_R_READY,
118            );
119            if !self.dispatch_task.wake_checked() {
120                log::trace!(
121                    "{}: {} Dispatcher is not registered, flags: {:?}",
122                    self.cfg.get().tag(),
123                    self as *const _ as usize,
124                    self.flags.get()
125                );
126            }
127        }
128    }
129
130    /// Gracefully shutdown read and write io tasks
131    pub(super) fn init_shutdown(&self) {
132        if !self
133            .flags
134            .get()
135            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
136        {
137            log::trace!(
138                "{}: Initiate io shutdown {:?}",
139                self.cfg.get().tag(),
140                self.flags.get()
141            );
142            self.insert_flags(Flags::IO_STOPPING_FILTERS);
143            self.read_task.wake();
144        }
145    }
146
147    #[inline]
148    pub(super) fn read_buf(&self) -> &BufConfig {
149        self.cfg.get().read_buf()
150    }
151
152    #[inline]
153    pub(super) fn write_buf(&self) -> &BufConfig {
154        self.cfg.get().write_buf()
155    }
156}
157
158impl Eq for IoState {}
159
160impl PartialEq for IoState {
161    #[inline]
162    fn eq(&self, other: &Self) -> bool {
163        ptr::eq(self, other)
164    }
165}
166
167impl hash::Hash for IoState {
168    #[inline]
169    fn hash<H: hash::Hasher>(&self, state: &mut H) {
170        (self as *const _ as usize).hash(state);
171    }
172}
173
174impl fmt::Debug for IoState {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        let err = self.error.take();
177        let res = f
178            .debug_struct("IoState")
179            .field("flags", &self.flags)
180            .field("filter", &self.filter.is_set())
181            .field("timeout", &self.timeout)
182            .field("error", &err)
183            .field("buffer", &self.buffer)
184            .field("cfg", &self.cfg)
185            .finish();
186        self.error.set(err);
187        res
188    }
189}
190
191impl Io {
192    #[inline]
193    /// Create `Io` instance
194    pub fn new<I: IoStream, T: Into<SharedCfg>>(io: I, cfg: T) -> Self {
195        let inner = Rc::new(IoState {
196            cfg: Cell::new(cfg.into().get::<IoConfig>().into_static()),
197            filter: FilterPtr::null(),
198            flags: Cell::new(Flags::WR_PAUSED),
199            error: Cell::new(None),
200            dispatch_task: LocalWaker::new(),
201            read_task: LocalWaker::new(),
202            write_task: LocalWaker::new(),
203            buffer: Stack::new(),
204            handle: Cell::new(None),
205            timeout: Cell::new(TimerHandle::default()),
206            on_disconnect: Cell::new(None),
207        });
208        inner.filter.update(Base::new(IoRef(inner.clone())));
209
210        let io_ref = IoRef(inner);
211
212        // start io tasks
213        let hnd = io.start(IoContext::new(&io_ref));
214        io_ref.0.handle.set(hnd);
215
216        Io(UnsafeCell::new(io_ref), marker::PhantomData)
217    }
218}
219
220impl<I: IoStream> From<I> for Io {
221    #[inline]
222    fn from(io: I) -> Io {
223        Io::new(io, SharedCfg::default())
224    }
225}
226
227impl<F> Io<F> {
228    #[inline]
229    /// Clone current io object.
230    ///
231    /// Current io object becomes closed.
232    pub fn take(&self) -> Self {
233        Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
234    }
235
236    fn take_io_ref(&self) -> IoRef {
237        let inner = Rc::new(IoState {
238            cfg: Cell::new(SharedCfg::default().get::<IoConfig>().into_static()),
239            filter: FilterPtr::null(),
240            flags: Cell::new(
241                Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
242            ),
243            error: Cell::new(None),
244            dispatch_task: LocalWaker::new(),
245            read_task: LocalWaker::new(),
246            write_task: LocalWaker::new(),
247            buffer: Stack::new(),
248            handle: Cell::new(None),
249            timeout: Cell::new(TimerHandle::default()),
250            on_disconnect: Cell::new(None),
251        });
252        unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
253    }
254
255    #[inline]
256    #[doc(hidden)]
257    /// Get current state flags
258    pub fn flags(&self) -> Flags {
259        self.st().flags.get()
260    }
261
262    #[inline]
263    /// Get instance of `IoRef`
264    pub fn get_ref(&self) -> IoRef {
265        self.io_ref().clone()
266    }
267
268    fn st(&self) -> &IoState {
269        unsafe { &(*self.0.get()).0 }
270    }
271
272    fn io_ref(&self) -> &IoRef {
273        unsafe { &*self.0.get() }
274    }
275
276    #[inline]
277    /// Set shared io config
278    pub fn set_config<T: Into<SharedCfg>>(&self, cfg: T) {
279        self.st()
280            .cfg
281            .set(cfg.into().get::<IoConfig>().into_static());
282    }
283}
284
285impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
286    #[inline]
287    /// Get referece to a filter
288    pub fn filter(&self) -> &F {
289        &self.st().filter.filter::<Layer<F, T>>().0
290    }
291}
292
293impl<F: Filter> Io<F> {
294    #[inline]
295    /// Convert current io stream into sealed version
296    pub fn seal(self) -> Io<Sealed> {
297        let state = self.take_io_ref();
298        state.0.filter.seal::<F>();
299
300        Io(UnsafeCell::new(state), marker::PhantomData)
301    }
302
303    #[inline]
304    /// Convert current io stream into boxed version
305    pub fn boxed(self) -> IoBoxed {
306        self.seal().into()
307    }
308
309    #[inline]
310    /// Add new layer current current filter
311    pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
312    where
313        U: FilterLayer,
314    {
315        let state = self.take_io_ref();
316
317        // add buffers layer
318        // Safety: .add_layer() only increases internal buffers
319        // there is no api that holds references into buffers storage
320        // all apis first removes buffer from storage and then work with it
321        unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
322            .buffer
323            .add_layer();
324
325        // replace current filter
326        state.0.filter.add_filter::<F, U>(nf);
327
328        Io(UnsafeCell::new(state), marker::PhantomData)
329    }
330
331    /// Map layer
332    pub fn map_filter<U, R>(self, f: U) -> Io<R>
333    where
334        U: FnOnce(F) -> R,
335        R: Filter,
336    {
337        let state = self.take_io_ref();
338        state.0.filter.map_filter::<F, U, R>(f);
339
340        Io(UnsafeCell::new(state), marker::PhantomData)
341    }
342}
343
344impl<F> Io<F> {
345    #[inline]
346    /// Read incoming io stream and decode codec item.
347    pub async fn recv<U>(
348        &self,
349        codec: &U,
350    ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
351    where
352        U: Decoder,
353    {
354        loop {
355            return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
356                Ok(item) => Ok(Some(item)),
357                Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
358                    io::ErrorKind::TimedOut,
359                    "Timeout",
360                ))),
361                Err(RecvError::WriteBackpressure) => {
362                    poll_fn(|cx| self.poll_flush(cx, false))
363                        .await
364                        .map_err(Either::Right)?;
365                    continue;
366                }
367                Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
368                Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
369                Err(RecvError::PeerGone(None)) => Ok(None),
370            };
371        }
372    }
373
374    #[inline]
375    /// Wait until read becomes ready.
376    pub async fn read_ready(&self) -> io::Result<Option<()>> {
377        poll_fn(|cx| self.poll_read_ready(cx)).await
378    }
379
380    #[inline]
381    /// Wait until io reads any data.
382    pub async fn read_notify(&self) -> io::Result<Option<()>> {
383        poll_fn(|cx| self.poll_read_notify(cx)).await
384    }
385
386    #[inline]
387    /// Pause read task
388    pub fn pause(&self) {
389        let st = self.st();
390        if !st.flags.get().contains(Flags::RD_PAUSED) {
391            st.read_task.wake();
392            st.insert_flags(Flags::RD_PAUSED);
393        }
394    }
395
396    #[inline]
397    /// Encode item, send to the peer. Fully flush write buffer.
398    pub async fn send<U>(
399        &self,
400        item: U::Item,
401        codec: &U,
402    ) -> Result<(), Either<U::Error, io::Error>>
403    where
404        U: Encoder,
405    {
406        self.encode(item, codec).map_err(Either::Left)?;
407
408        poll_fn(|cx| self.poll_flush(cx, true))
409            .await
410            .map_err(Either::Right)?;
411
412        Ok(())
413    }
414
415    #[inline]
416    /// Wake write task and instruct to flush data.
417    ///
418    /// This is async version of .poll_flush() method.
419    pub async fn flush(&self, full: bool) -> io::Result<()> {
420        poll_fn(|cx| self.poll_flush(cx, full)).await
421    }
422
423    #[inline]
424    /// Gracefully shutdown io stream
425    pub async fn shutdown(&self) -> io::Result<()> {
426        poll_fn(|cx| self.poll_shutdown(cx)).await
427    }
428
429    #[inline]
430    /// Polls for read readiness.
431    ///
432    /// If the io stream is not currently ready for reading,
433    /// this method will store a clone of the Waker from the provided Context.
434    /// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
435    ///
436    /// Return value
437    /// The function returns:
438    ///
439    /// `Poll::Pending` if the io stream is not ready for reading.
440    /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
441    /// `Poll::Ready(Ok(None))` if io stream is disconnected
442    /// `Some(Poll::Ready(Err(e)))` if an error is encountered.
443    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
444        let st = self.st();
445        let mut flags = st.flags.get();
446
447        if flags.is_stopped() {
448            Poll::Ready(Err(st.error_or_disconnected()))
449        } else {
450            st.dispatch_task.register(cx.waker());
451
452            let ready = flags.is_read_buf_ready();
453            if flags.cannot_read() {
454                flags.cleanup_read_flags();
455                st.read_task.wake();
456                st.flags.set(flags);
457                if ready {
458                    Poll::Ready(Ok(Some(())))
459                } else {
460                    Poll::Pending
461                }
462            } else if ready {
463                flags.remove(Flags::BUF_R_READY);
464                st.flags.set(flags);
465                Poll::Ready(Ok(Some(())))
466            } else {
467                Poll::Pending
468            }
469        }
470    }
471
472    #[inline]
473    /// Polls for any incoming data.
474    pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
475        let ready = self.poll_read_ready(cx);
476
477        if ready.is_pending() {
478            let st = self.st();
479            if st.remove_flags(Flags::RD_NOTIFY) {
480                Poll::Ready(Ok(Some(())))
481            } else {
482                st.insert_flags(Flags::RD_NOTIFY);
483                Poll::Pending
484            }
485        } else {
486            ready
487        }
488    }
489
490    #[inline]
491    /// Decode codec item from incoming bytes stream.
492    ///
493    /// Wake read task and request to read more data if data is not enough for decoding.
494    /// If error get returned this method does not register waker for later wake up action.
495    pub fn poll_recv<U>(
496        &self,
497        codec: &U,
498        cx: &mut Context<'_>,
499    ) -> Poll<Result<U::Item, RecvError<U>>>
500    where
501        U: Decoder,
502    {
503        let decoded = self.poll_recv_decode(codec, cx)?;
504
505        if let Some(item) = decoded.item {
506            Poll::Ready(Ok(item))
507        } else {
508            Poll::Pending
509        }
510    }
511
512    #[doc(hidden)]
513    #[inline]
514    /// Decode codec item from incoming bytes stream.
515    ///
516    /// Wake read task and request to read more data if data is not enough for decoding.
517    /// If error get returned this method does not register waker for later wake up action.
518    pub fn poll_recv_decode<U>(
519        &self,
520        codec: &U,
521        cx: &mut Context<'_>,
522    ) -> Result<Decoded<U::Item>, RecvError<U>>
523    where
524        U: Decoder,
525    {
526        let decoded = self
527            .decode_item(codec)
528            .map_err(|err| RecvError::Decoder(err))?;
529
530        if decoded.item.is_some() {
531            Ok(decoded)
532        } else {
533            let st = self.st();
534            let flags = st.flags.get();
535            if flags.is_stopped() {
536                Err(RecvError::PeerGone(st.error()))
537            } else if flags.contains(Flags::DSP_TIMEOUT) {
538                st.remove_flags(Flags::DSP_TIMEOUT);
539                Err(RecvError::KeepAlive)
540            } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
541                Err(RecvError::WriteBackpressure)
542            } else {
543                match self.poll_read_ready(cx) {
544                    Poll::Pending | Poll::Ready(Ok(Some(()))) => {
545                        if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
546                            log::trace!(
547                                "{}: Not enough data to decode next frame",
548                                self.tag()
549                            );
550                        }
551                        Ok(decoded)
552                    }
553                    Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
554                    Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
555                }
556            }
557        }
558    }
559
560    #[inline]
561    /// Wake write task and instruct to flush data.
562    ///
563    /// If `full` is true then wake up dispatcher when all data is flushed
564    /// otherwise wake up when size of write buffer is lower than
565    /// buffer max size.
566    pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
567        let st = self.st();
568        let flags = self.flags();
569
570        let len = st.buffer.write_destination_size();
571        if len > 0 {
572            if full {
573                st.insert_flags(Flags::BUF_W_MUST_FLUSH);
574                st.dispatch_task.register(cx.waker());
575                return if flags.is_stopped() {
576                    Poll::Ready(Err(st.error_or_disconnected()))
577                } else {
578                    Poll::Pending
579                };
580            } else if len >= st.write_buf().half {
581                st.insert_flags(Flags::BUF_W_BACKPRESSURE);
582                st.dispatch_task.register(cx.waker());
583                return if flags.is_stopped() {
584                    Poll::Ready(Err(st.error_or_disconnected()))
585                } else {
586                    Poll::Pending
587                };
588            }
589        }
590        if flags.is_stopped() {
591            Poll::Ready(Err(st.error_or_disconnected()))
592        } else {
593            st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
594            Poll::Ready(Ok(()))
595        }
596    }
597
598    #[inline]
599    /// Gracefully shutdown io stream
600    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
601        let st = self.st();
602        let flags = st.flags.get();
603
604        if flags.is_stopped() {
605            if let Some(err) = st.error() {
606                Poll::Ready(Err(err))
607            } else {
608                Poll::Ready(Ok(()))
609            }
610        } else {
611            if !flags.contains(Flags::IO_STOPPING_FILTERS) {
612                st.init_shutdown();
613            }
614
615            st.read_task.wake();
616            st.write_task.wake();
617            st.dispatch_task.register(cx.waker());
618            Poll::Pending
619        }
620    }
621
622    #[inline]
623    /// Pause read task
624    ///
625    /// Returns status updates
626    pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
627        self.pause();
628        let result = self.poll_status_update(cx);
629        if !result.is_pending() {
630            self.st().dispatch_task.register(cx.waker());
631        }
632        result
633    }
634
635    #[inline]
636    /// Wait for status updates
637    pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
638        let st = self.st();
639        let flags = st.flags.get();
640        if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
641            Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
642        } else if flags.contains(Flags::DSP_TIMEOUT) {
643            st.remove_flags(Flags::DSP_TIMEOUT);
644            Poll::Ready(IoStatusUpdate::KeepAlive)
645        } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
646            Poll::Ready(IoStatusUpdate::WriteBackpressure)
647        } else {
648            st.dispatch_task.register(cx.waker());
649            Poll::Pending
650        }
651    }
652
653    #[inline]
654    /// Register dispatch task
655    pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
656        self.st().dispatch_task.register(cx.waker());
657    }
658}
659
660impl<F> AsRef<IoRef> for Io<F> {
661    #[inline]
662    fn as_ref(&self) -> &IoRef {
663        self.io_ref()
664    }
665}
666
667impl<F> Eq for Io<F> {}
668
669impl<F> PartialEq for Io<F> {
670    #[inline]
671    fn eq(&self, other: &Self) -> bool {
672        self.io_ref().eq(other.io_ref())
673    }
674}
675
676impl<F> hash::Hash for Io<F> {
677    #[inline]
678    fn hash<H: hash::Hasher>(&self, state: &mut H) {
679        self.io_ref().hash(state);
680    }
681}
682
683impl<F> fmt::Debug for Io<F> {
684    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
685        f.debug_struct("Io").field("state", self.st()).finish()
686    }
687}
688
689impl<F> ops::Deref for Io<F> {
690    type Target = IoRef;
691
692    #[inline]
693    fn deref(&self) -> &Self::Target {
694        self.io_ref()
695    }
696}
697
698impl<F> Drop for Io<F> {
699    fn drop(&mut self) {
700        let st = self.st();
701        self.stop_timer();
702
703        if st.filter.is_set() {
704            // filter is unsafe and must be dropped explicitly,
705            // and wont be dropped without special attention
706            if !st.flags.get().is_stopped() {
707                log::trace!(
708                    "{}: Io is dropped, force stopping io streams {:?}",
709                    st.cfg.get().tag(),
710                    st.flags.get()
711                );
712            }
713
714            self.force_close();
715            st.filter.drop_filter::<F>();
716        }
717    }
718}
719
720const KIND_SEALED: u8 = 0b01;
721const KIND_PTR: u8 = 0b10;
722const KIND_MASK: u8 = 0b11;
723const KIND_UNMASK: u8 = !KIND_MASK;
724const KIND_MASK_USIZE: usize = 0b11;
725const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
726const SEALED_SIZE: usize = mem::size_of::<Sealed>();
727const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
728
729#[cfg(target_endian = "little")]
730const KIND_IDX: usize = 0;
731
732#[cfg(target_endian = "big")]
733const KIND_IDX: usize = SEALED_SIZE - 1;
734
735struct FilterPtr {
736    data: Cell<[u8; SEALED_SIZE]>,
737    filter: Cell<&'static dyn Filter>,
738}
739
740impl FilterPtr {
741    const fn null() -> Self {
742        Self {
743            data: Cell::new(NULL),
744            filter: Cell::new(NullFilter::get()),
745        }
746    }
747
748    fn update<F: Filter>(&self, filter: F) {
749        if self.is_set() {
750            panic!("Filter is set, must be dropped first");
751        }
752
753        let filter = Box::new(filter);
754        let mut data = NULL;
755        unsafe {
756            let filter_ref: &'static dyn Filter = {
757                let f: &dyn Filter = filter.as_ref();
758                mem::transmute(f)
759            };
760            self.filter.set(filter_ref);
761
762            let ptr = &mut data as *mut _ as *mut *mut F;
763            ptr.write(Box::into_raw(filter));
764            data[KIND_IDX] |= KIND_PTR;
765            self.data.set(data);
766        }
767    }
768
769    /// Get filter, panic if it is not filter
770    fn filter<F: Filter>(&self) -> &F {
771        let data = self.data.get();
772        if data[KIND_IDX] & KIND_PTR != 0 {
773            let ptr = &data as *const _ as *const *mut F;
774            unsafe {
775                let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
776                (p as *const F as *mut F).as_ref().unwrap()
777            }
778        } else {
779            panic!("Wrong filter item");
780        }
781    }
782
783    /// Get filter, panic if it is not set
784    fn take_filter<F>(&self) -> Box<F> {
785        let mut data = self.data.get();
786        if data[KIND_IDX] & KIND_PTR != 0 {
787            data[KIND_IDX] &= KIND_UNMASK;
788            let ptr = &mut data as *mut _ as *mut *mut F;
789            unsafe { Box::from_raw(*ptr) }
790        } else {
791            panic!(
792                "Wrong filter item {:?} expected: {:?}",
793                data[KIND_IDX], KIND_PTR
794            );
795        }
796    }
797
798    /// Get sealed, panic if it is already sealed
799    fn take_sealed(&self) -> Sealed {
800        let mut data = self.data.get();
801
802        if data[KIND_IDX] & KIND_SEALED != 0 {
803            data[KIND_IDX] &= KIND_UNMASK;
804            let ptr = &mut data as *mut _ as *mut Sealed;
805            unsafe { ptr.read() }
806        } else {
807            panic!(
808                "Wrong filter item {:?} expected: {:?}",
809                data[KIND_IDX], KIND_SEALED
810            );
811        }
812    }
813
814    fn is_set(&self) -> bool {
815        self.data.get()[KIND_IDX] & KIND_MASK != 0
816    }
817
818    fn drop_filter<F>(&self) {
819        let data = self.data.get();
820
821        if data[KIND_IDX] & KIND_MASK != 0 {
822            if data[KIND_IDX] & KIND_PTR != 0 {
823                self.take_filter::<F>();
824            } else if data[KIND_IDX] & KIND_SEALED != 0 {
825                self.take_sealed();
826            }
827            self.data.set(NULL);
828            self.filter.set(NullFilter::get());
829        }
830    }
831}
832
833impl FilterPtr {
834    fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
835        let data = self.data.get();
836        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
837            Box::new(Layer::new(new, *self.take_filter::<F>()))
838        } else if data[KIND_IDX] & KIND_SEALED != 0 {
839            let f = Box::new(Layer::new(new, self.take_sealed()));
840            // SAFETY: If filter is marked as Sealed, then F = Sealed
841            unsafe { mem::transmute::<Box<Layer<T, Sealed>>, Box<Layer<T, F>>>(f) }
842        } else {
843            panic!(
844                "Wrong filter item {:?} expected: {:?}",
845                data[KIND_IDX], KIND_PTR
846            );
847        };
848
849        let mut data = NULL;
850        unsafe {
851            let filter_ref: &'static dyn Filter = {
852                let f: &dyn Filter = filter.as_ref();
853                mem::transmute(f)
854            };
855            self.filter.set(filter_ref);
856
857            let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
858            ptr.write(Box::into_raw(filter));
859            data[KIND_IDX] |= KIND_PTR;
860            self.data.set(data);
861        }
862    }
863
864    fn map_filter<F: Filter, U, R>(&self, f: U)
865    where
866        U: FnOnce(F) -> R,
867        R: Filter,
868    {
869        let mut data = NULL;
870        let filter = Box::new(f(*self.take_filter::<F>()));
871        unsafe {
872            let filter_ref: &'static dyn Filter = {
873                let f: &dyn Filter = filter.as_ref();
874                mem::transmute(f)
875            };
876            self.filter.set(filter_ref);
877
878            let ptr = &mut data as *mut _ as *mut *mut R;
879            ptr.write(Box::into_raw(filter));
880            data[KIND_IDX] |= KIND_PTR;
881            self.data.set(data);
882        }
883    }
884
885    fn seal<F: Filter>(&self) {
886        let mut data = self.data.get();
887
888        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
889            Sealed(Box::new(*self.take_filter::<F>()))
890        } else if data[KIND_IDX] & KIND_SEALED != 0 {
891            self.take_sealed()
892        } else {
893            panic!(
894                "Wrong filter item {:?} expected: {:?}",
895                data[KIND_IDX], KIND_PTR
896            );
897        };
898
899        unsafe {
900            let filter_ref: &'static dyn Filter = {
901                let f: &dyn Filter = filter.0.as_ref();
902                mem::transmute(f)
903            };
904            self.filter.set(filter_ref);
905
906            let ptr = &mut data as *mut _ as *mut Sealed;
907            ptr.write(filter);
908            data[KIND_IDX] |= KIND_SEALED;
909            self.data.set(data);
910        }
911    }
912}
913
914#[derive(Debug)]
915/// OnDisconnect future resolves when socket get disconnected
916#[must_use = "OnDisconnect do nothing unless polled"]
917pub struct OnDisconnect {
918    token: usize,
919    inner: Rc<IoState>,
920}
921
922impl OnDisconnect {
923    pub(super) fn new(inner: Rc<IoState>) -> Self {
924        Self::new_inner(inner.flags.get().is_stopped(), inner)
925    }
926
927    fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
928        let token = if disconnected {
929            usize::MAX
930        } else {
931            let mut on_disconnect = inner.on_disconnect.take();
932            let token = if let Some(ref mut on_disconnect) = on_disconnect {
933                let token = on_disconnect.len();
934                on_disconnect.push(LocalWaker::default());
935                token
936            } else {
937                on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
938                0
939            };
940            inner.on_disconnect.set(on_disconnect);
941            token
942        };
943        Self { token, inner }
944    }
945
946    #[inline]
947    /// Check if connection is disconnected
948    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
949        if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
950            Poll::Ready(())
951        } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
952            on_disconnect[self.token].register(cx.waker());
953            self.inner.on_disconnect.set(Some(on_disconnect));
954            Poll::Pending
955        } else {
956            Poll::Ready(())
957        }
958    }
959}
960
961impl Clone for OnDisconnect {
962    fn clone(&self) -> Self {
963        if self.token == usize::MAX {
964            OnDisconnect::new_inner(true, self.inner.clone())
965        } else {
966            OnDisconnect::new_inner(false, self.inner.clone())
967        }
968    }
969}
970
971impl Future for OnDisconnect {
972    type Output = ();
973
974    #[inline]
975    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
976        self.poll_ready(cx)
977    }
978}
979
980#[cfg(test)]
981mod tests {
982    use ntex_bytes::Bytes;
983    use ntex_codec::BytesCodec;
984
985    use super::*;
986    use crate::{ReadBuf, WriteBuf, testing::IoTest};
987
988    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
989    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
990
991    #[ntex::test]
992    async fn test_basics() {
993        let (client, server) = IoTest::create();
994        client.remote_buffer_cap(1024);
995
996        let server = Io::from(server);
997        assert!(server.eq(&server));
998        assert!(server.io_ref().eq(server.io_ref()));
999
1000        assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
1001        assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
1002        assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
1003    }
1004
1005    #[ntex::test]
1006    async fn test_recv() {
1007        let (client, server) = IoTest::create();
1008        client.remote_buffer_cap(1024);
1009
1010        let server = Io::from(server);
1011
1012        server.st().notify_timeout();
1013        let err = server.recv(&BytesCodec).await.err().unwrap();
1014        assert!(format!("{err:?}").contains("Timeout"));
1015
1016        client.write(TEXT);
1017        server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
1018        let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
1019        assert_eq!(item, TEXT);
1020    }
1021
1022    #[ntex::test]
1023    async fn test_send() {
1024        let (client, server) = IoTest::create();
1025        client.remote_buffer_cap(1024);
1026
1027        let server = Io::from(server);
1028        assert!(server.eq(&server));
1029
1030        server
1031            .send(Bytes::from_static(BIN), &BytesCodec)
1032            .await
1033            .ok()
1034            .unwrap();
1035        let item = client.read_any();
1036        assert_eq!(item, TEXT);
1037    }
1038
1039    #[derive(Debug)]
1040    struct DropFilter {
1041        p: Rc<Cell<usize>>,
1042    }
1043
1044    impl Drop for DropFilter {
1045        fn drop(&mut self) {
1046            self.p.set(self.p.get() + 1);
1047        }
1048    }
1049
1050    impl FilterLayer for DropFilter {
1051        fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1052            if let Some(src) = buf.take_src() {
1053                let len = src.len();
1054                buf.set_dst(Some(src));
1055                Ok(len)
1056            } else {
1057                Ok(0)
1058            }
1059        }
1060        fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
1061            if let Some(src) = buf.take_src() {
1062                buf.set_dst(Some(src));
1063            }
1064            Ok(())
1065        }
1066    }
1067
1068    #[ntex::test]
1069    async fn drop_filter() {
1070        let p = Rc::new(Cell::new(0));
1071
1072        let (client, server) = IoTest::create();
1073        let f = DropFilter { p: p.clone() };
1074        let _ = format!("{f:?}");
1075        let io = Io::from(server).add_filter(f);
1076
1077        client.remote_buffer_cap(1024);
1078        client.write(TEXT);
1079        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1080        assert_eq!(msg, Bytes::from_static(BIN));
1081
1082        io.send(Bytes::from_static(b"test"), &BytesCodec)
1083            .await
1084            .unwrap();
1085        let buf = client.read().await.unwrap();
1086        assert_eq!(buf, Bytes::from_static(b"test"));
1087
1088        let io2 = io.take();
1089        let mut io3: crate::IoBoxed = io2.into();
1090        let io4 = io3.take();
1091
1092        drop(io);
1093        drop(io3);
1094        drop(io4);
1095
1096        assert_eq!(p.get(), 1);
1097    }
1098
1099    #[ntex::test]
1100    async fn test_take_sealed_filter() {
1101        let p = Rc::new(Cell::new(0));
1102        let f = DropFilter { p: p.clone() };
1103
1104        let io = Io::from(IoTest::create().0).seal();
1105        let _io: Io<Layer<DropFilter, Sealed>> = io.add_filter(f);
1106    }
1107}