ntex_io/
io.rs

1use std::cell::{Cell, UnsafeCell};
2use std::future::{Future, poll_fn};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::cfg::SharedCfg;
8use ntex_util::{future::Either, task::LocalWaker};
9
10use crate::buf::Stack;
11use crate::cfg::{BufConfig, IoConfig};
12use crate::filter::{Base, Filter, Layer, NullFilter};
13use crate::flags::Flags;
14use crate::seal::{IoBoxed, Sealed};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoContext, IoStatusUpdate, IoStream, RecvError};
17
18/// Interface object to underlying io stream
19pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25    filter: FilterPtr,
26    pub(super) cfg: Cell<&'static IoConfig>,
27    pub(super) flags: Cell<Flags>,
28    pub(super) error: Cell<Option<io::Error>>,
29    pub(super) read_task: LocalWaker,
30    pub(super) write_task: LocalWaker,
31    pub(super) dispatch_task: LocalWaker,
32    pub(super) buffer: Stack,
33    pub(super) handle: Cell<Option<Box<dyn Handle>>>,
34    pub(super) timeout: Cell<TimerHandle>,
35    #[allow(clippy::box_collection)]
36    pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
37}
38
39impl IoState {
40    pub(super) fn filter(&self) -> &dyn Filter {
41        self.filter.filter.get()
42    }
43
44    pub(super) fn insert_flags(&self, f: Flags) {
45        let mut flags = self.flags.get();
46        flags.insert(f);
47        self.flags.set(flags);
48    }
49
50    pub(super) fn remove_flags(&self, f: Flags) -> bool {
51        let mut flags = self.flags.get();
52        if flags.intersects(f) {
53            flags.remove(f);
54            self.flags.set(flags);
55            true
56        } else {
57            false
58        }
59    }
60
61    pub(super) fn notify_timeout(&self) {
62        let mut flags = self.flags.get();
63        if !flags.contains(Flags::DSP_TIMEOUT) {
64            flags.insert(Flags::DSP_TIMEOUT);
65            self.flags.set(flags);
66            self.dispatch_task.wake();
67            log::trace!("{}: Timer, notify dispatcher", self.cfg.get().tag());
68        }
69    }
70
71    pub(super) fn notify_disconnect(&self) {
72        if let Some(on_disconnect) = self.on_disconnect.take() {
73            for item in on_disconnect.into_iter() {
74                item.wake();
75            }
76        }
77    }
78
79    /// Get current io error
80    pub(super) fn error(&self) -> Option<io::Error> {
81        if let Some(err) = self.error.take() {
82            self.error
83                .set(Some(io::Error::new(err.kind(), format!("{err}"))));
84            Some(err)
85        } else {
86            None
87        }
88    }
89
90    /// Get current io result
91    pub(super) fn error_or_disconnected(&self) -> io::Error {
92        self.error()
93            .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
94    }
95
96    pub(super) fn io_stopped(&self, err: Option<io::Error>) {
97        if !self.flags.get().is_stopped() {
98            log::trace!(
99                "{}: {} Io error {:?} flags: {:?}",
100                self.cfg.get().tag(),
101                self as *const _ as usize,
102                err,
103                self.flags.get()
104            );
105
106            if err.is_some() {
107                self.error.set(err);
108            }
109            self.read_task.wake();
110            self.write_task.wake();
111            self.notify_disconnect();
112            self.handle.take();
113            self.insert_flags(
114                Flags::IO_STOPPED
115                    | Flags::IO_STOPPING
116                    | Flags::IO_STOPPING_FILTERS
117                    | Flags::BUF_R_READY,
118            );
119            if !self.dispatch_task.wake_checked() {
120                log::trace!(
121                    "{}: {} Dispatcher is not registered, flags: {:?}",
122                    self.cfg.get().tag(),
123                    self as *const _ as usize,
124                    self.flags.get()
125                );
126            }
127        }
128    }
129
130    /// Gracefully shutdown read and write io tasks
131    pub(super) fn init_shutdown(&self) {
132        if !self
133            .flags
134            .get()
135            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
136        {
137            log::trace!(
138                "{}: Initiate io shutdown {:?}",
139                self.cfg.get().tag(),
140                self.flags.get()
141            );
142            self.insert_flags(Flags::IO_STOPPING_FILTERS);
143            self.read_task.wake();
144        }
145    }
146
147    #[inline]
148    pub(super) fn read_buf(&self) -> &BufConfig {
149        self.cfg.get().read_buf()
150    }
151
152    #[inline]
153    pub(super) fn write_buf(&self) -> &BufConfig {
154        self.cfg.get().write_buf()
155    }
156}
157
158impl Eq for IoState {}
159
160impl PartialEq for IoState {
161    #[inline]
162    fn eq(&self, other: &Self) -> bool {
163        ptr::eq(self, other)
164    }
165}
166
167impl hash::Hash for IoState {
168    #[inline]
169    fn hash<H: hash::Hasher>(&self, state: &mut H) {
170        (self as *const _ as usize).hash(state);
171    }
172}
173
174impl fmt::Debug for IoState {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        let err = self.error.take();
177        let res = f
178            .debug_struct("IoState")
179            .field("flags", &self.flags)
180            .field("filter", &self.filter.is_set())
181            .field("timeout", &self.timeout)
182            .field("error", &err)
183            .field("buffer", &self.buffer)
184            .field("cfg", &self.cfg)
185            .finish();
186        self.error.set(err);
187        res
188    }
189}
190
191impl Io {
192    #[inline]
193    /// Create `Io` instance
194    pub fn new<I: IoStream, T: Into<SharedCfg>>(io: I, cfg: T) -> Self {
195        let inner = Rc::new(IoState {
196            cfg: Cell::new(cfg.into().get::<IoConfig>().into_static()),
197            filter: FilterPtr::null(),
198            flags: Cell::new(Flags::WR_PAUSED),
199            error: Cell::new(None),
200            dispatch_task: LocalWaker::new(),
201            read_task: LocalWaker::new(),
202            write_task: LocalWaker::new(),
203            buffer: Stack::new(),
204            handle: Cell::new(None),
205            timeout: Cell::new(TimerHandle::default()),
206            on_disconnect: Cell::new(None),
207        });
208        inner.filter.update(Base::new(IoRef(inner.clone())));
209
210        let io_ref = IoRef(inner);
211
212        // start io tasks
213        let hnd = io.start(IoContext::new(&io_ref));
214        io_ref.0.handle.set(hnd);
215
216        Io(UnsafeCell::new(io_ref), marker::PhantomData)
217    }
218}
219
220impl<I: IoStream> From<I> for Io {
221    #[inline]
222    fn from(io: I) -> Io {
223        Io::new(io, SharedCfg::default())
224    }
225}
226
227impl<F> Io<F> {
228    #[inline]
229    /// Clone current io object.
230    ///
231    /// Current io object becomes closed.
232    pub fn take(&self) -> Self {
233        Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
234    }
235
236    fn take_io_ref(&self) -> IoRef {
237        let inner = Rc::new(IoState {
238            cfg: Cell::new(SharedCfg::default().get::<IoConfig>().into_static()),
239            filter: FilterPtr::null(),
240            flags: Cell::new(
241                Flags::DSP_STOP
242                    | Flags::IO_STOPPED
243                    | Flags::IO_STOPPING
244                    | Flags::IO_STOPPING_FILTERS,
245            ),
246            error: Cell::new(None),
247            dispatch_task: LocalWaker::new(),
248            read_task: LocalWaker::new(),
249            write_task: LocalWaker::new(),
250            buffer: Stack::new(),
251            handle: Cell::new(None),
252            timeout: Cell::new(TimerHandle::default()),
253            on_disconnect: Cell::new(None),
254        });
255        unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
256    }
257
258    #[inline]
259    #[doc(hidden)]
260    /// Get current state flags
261    pub fn flags(&self) -> Flags {
262        self.st().flags.get()
263    }
264
265    #[inline]
266    /// Get instance of `IoRef`
267    pub fn get_ref(&self) -> IoRef {
268        self.io_ref().clone()
269    }
270
271    fn st(&self) -> &IoState {
272        unsafe { &(*self.0.get()).0 }
273    }
274
275    fn io_ref(&self) -> &IoRef {
276        unsafe { &*self.0.get() }
277    }
278
279    #[inline]
280    /// Set shared io config
281    pub fn set_config<T: Into<SharedCfg>>(&self, cfg: T) {
282        self.st()
283            .cfg
284            .set(cfg.into().get::<IoConfig>().into_static());
285    }
286}
287
288impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
289    #[inline]
290    /// Get referece to a filter
291    pub fn filter(&self) -> &F {
292        &self.st().filter.filter::<Layer<F, T>>().0
293    }
294}
295
296impl<F: Filter> Io<F> {
297    #[inline]
298    /// Convert current io stream into sealed version
299    pub fn seal(self) -> Io<Sealed> {
300        let state = self.take_io_ref();
301        state.0.filter.seal::<F>();
302
303        Io(UnsafeCell::new(state), marker::PhantomData)
304    }
305
306    #[inline]
307    /// Convert current io stream into boxed version
308    pub fn boxed(self) -> IoBoxed {
309        self.seal().into()
310    }
311
312    #[inline]
313    /// Add new layer current current filter
314    pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
315    where
316        U: FilterLayer,
317    {
318        let state = self.take_io_ref();
319
320        // add buffers layer
321        // Safety: .add_layer() only increases internal buffers
322        // there is no api that holds references into buffers storage
323        // all apis first removes buffer from storage and then work with it
324        unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
325            .buffer
326            .add_layer();
327
328        // replace current filter
329        state.0.filter.add_filter::<F, U>(nf);
330
331        Io(UnsafeCell::new(state), marker::PhantomData)
332    }
333
334    /// Map layer
335    pub fn map_filter<U, R>(self, f: U) -> Io<R>
336    where
337        U: FnOnce(F) -> R,
338        R: Filter,
339    {
340        let state = self.take_io_ref();
341        state.0.filter.map_filter::<F, U, R>(f);
342
343        Io(UnsafeCell::new(state), marker::PhantomData)
344    }
345}
346
347impl<F> Io<F> {
348    #[inline]
349    /// Read incoming io stream and decode codec item.
350    pub async fn recv<U>(
351        &self,
352        codec: &U,
353    ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
354    where
355        U: Decoder,
356    {
357        loop {
358            return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
359                Ok(item) => Ok(Some(item)),
360                Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
361                    io::ErrorKind::TimedOut,
362                    "Timeout",
363                ))),
364                Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
365                    io::ErrorKind::UnexpectedEof,
366                    "Dispatcher stopped",
367                ))),
368                Err(RecvError::WriteBackpressure) => {
369                    poll_fn(|cx| self.poll_flush(cx, false))
370                        .await
371                        .map_err(Either::Right)?;
372                    continue;
373                }
374                Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
375                Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
376                Err(RecvError::PeerGone(None)) => Ok(None),
377            };
378        }
379    }
380
381    #[inline]
382    /// Wait until read becomes ready.
383    pub async fn read_ready(&self) -> io::Result<Option<()>> {
384        poll_fn(|cx| self.poll_read_ready(cx)).await
385    }
386
387    #[inline]
388    /// Wait until io reads any data.
389    pub async fn read_notify(&self) -> io::Result<Option<()>> {
390        poll_fn(|cx| self.poll_read_notify(cx)).await
391    }
392
393    #[inline]
394    /// Pause read task
395    pub fn pause(&self) {
396        let st = self.st();
397        if !st.flags.get().contains(Flags::RD_PAUSED) {
398            st.read_task.wake();
399            st.insert_flags(Flags::RD_PAUSED);
400        }
401    }
402
403    #[inline]
404    /// Encode item, send to the peer. Fully flush write buffer.
405    pub async fn send<U>(
406        &self,
407        item: U::Item,
408        codec: &U,
409    ) -> Result<(), Either<U::Error, io::Error>>
410    where
411        U: Encoder,
412    {
413        self.encode(item, codec).map_err(Either::Left)?;
414
415        poll_fn(|cx| self.poll_flush(cx, true))
416            .await
417            .map_err(Either::Right)?;
418
419        Ok(())
420    }
421
422    #[inline]
423    /// Wake write task and instruct to flush data.
424    ///
425    /// This is async version of .poll_flush() method.
426    pub async fn flush(&self, full: bool) -> io::Result<()> {
427        poll_fn(|cx| self.poll_flush(cx, full)).await
428    }
429
430    #[inline]
431    /// Gracefully shutdown io stream
432    pub async fn shutdown(&self) -> io::Result<()> {
433        poll_fn(|cx| self.poll_shutdown(cx)).await
434    }
435
436    #[inline]
437    /// Polls for read readiness.
438    ///
439    /// If the io stream is not currently ready for reading,
440    /// this method will store a clone of the Waker from the provided Context.
441    /// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
442    ///
443    /// Return value
444    /// The function returns:
445    ///
446    /// `Poll::Pending` if the io stream is not ready for reading.
447    /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
448    /// `Poll::Ready(Ok(None))` if io stream is disconnected
449    /// `Some(Poll::Ready(Err(e)))` if an error is encountered.
450    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
451        let st = self.st();
452        let mut flags = st.flags.get();
453
454        if flags.is_stopped() {
455            Poll::Ready(Err(st.error_or_disconnected()))
456        } else {
457            st.dispatch_task.register(cx.waker());
458
459            let ready = flags.is_read_buf_ready();
460            if flags.cannot_read() {
461                flags.cleanup_read_flags();
462                st.read_task.wake();
463                st.flags.set(flags);
464                if ready {
465                    Poll::Ready(Ok(Some(())))
466                } else {
467                    Poll::Pending
468                }
469            } else if ready {
470                flags.remove(Flags::BUF_R_READY);
471                st.flags.set(flags);
472                Poll::Ready(Ok(Some(())))
473            } else {
474                Poll::Pending
475            }
476        }
477    }
478
479    #[inline]
480    /// Polls for any incoming data.
481    pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
482        let ready = self.poll_read_ready(cx);
483
484        if ready.is_pending() {
485            let st = self.st();
486            if st.remove_flags(Flags::RD_NOTIFY) {
487                Poll::Ready(Ok(Some(())))
488            } else {
489                st.insert_flags(Flags::RD_NOTIFY);
490                Poll::Pending
491            }
492        } else {
493            ready
494        }
495    }
496
497    #[inline]
498    /// Decode codec item from incoming bytes stream.
499    ///
500    /// Wake read task and request to read more data if data is not enough for decoding.
501    /// If error get returned this method does not register waker for later wake up action.
502    pub fn poll_recv<U>(
503        &self,
504        codec: &U,
505        cx: &mut Context<'_>,
506    ) -> Poll<Result<U::Item, RecvError<U>>>
507    where
508        U: Decoder,
509    {
510        let decoded = self.poll_recv_decode(codec, cx)?;
511
512        if let Some(item) = decoded.item {
513            Poll::Ready(Ok(item))
514        } else {
515            Poll::Pending
516        }
517    }
518
519    #[doc(hidden)]
520    #[inline]
521    /// Decode codec item from incoming bytes stream.
522    ///
523    /// Wake read task and request to read more data if data is not enough for decoding.
524    /// If error get returned this method does not register waker for later wake up action.
525    pub fn poll_recv_decode<U>(
526        &self,
527        codec: &U,
528        cx: &mut Context<'_>,
529    ) -> Result<Decoded<U::Item>, RecvError<U>>
530    where
531        U: Decoder,
532    {
533        let decoded = self
534            .decode_item(codec)
535            .map_err(|err| RecvError::Decoder(err))?;
536
537        if decoded.item.is_some() {
538            Ok(decoded)
539        } else {
540            let st = self.st();
541            let flags = st.flags.get();
542            if flags.is_stopped() {
543                Err(RecvError::PeerGone(st.error()))
544            } else if flags.contains(Flags::DSP_STOP) {
545                st.remove_flags(Flags::DSP_STOP);
546                Err(RecvError::Stop)
547            } else if flags.contains(Flags::DSP_TIMEOUT) {
548                st.remove_flags(Flags::DSP_TIMEOUT);
549                Err(RecvError::KeepAlive)
550            } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
551                Err(RecvError::WriteBackpressure)
552            } else {
553                match self.poll_read_ready(cx) {
554                    Poll::Pending | Poll::Ready(Ok(Some(()))) => {
555                        if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
556                            log::trace!(
557                                "{}: Not enough data to decode next frame",
558                                self.tag()
559                            );
560                        }
561                        Ok(decoded)
562                    }
563                    Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
564                    Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
565                }
566            }
567        }
568    }
569
570    #[inline]
571    /// Wake write task and instruct to flush data.
572    ///
573    /// If `full` is true then wake up dispatcher when all data is flushed
574    /// otherwise wake up when size of write buffer is lower than
575    /// buffer max size.
576    pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
577        let st = self.st();
578        let flags = self.flags();
579
580        let len = st.buffer.write_destination_size();
581        if len > 0 {
582            if full {
583                st.insert_flags(Flags::BUF_W_MUST_FLUSH);
584                st.dispatch_task.register(cx.waker());
585                return if flags.is_stopped() {
586                    Poll::Ready(Err(st.error_or_disconnected()))
587                } else {
588                    Poll::Pending
589                };
590            } else if len >= st.write_buf().half {
591                st.insert_flags(Flags::BUF_W_BACKPRESSURE);
592                st.dispatch_task.register(cx.waker());
593                return if flags.is_stopped() {
594                    Poll::Ready(Err(st.error_or_disconnected()))
595                } else {
596                    Poll::Pending
597                };
598            }
599        }
600        if flags.is_stopped() {
601            Poll::Ready(Err(st.error_or_disconnected()))
602        } else {
603            st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
604            Poll::Ready(Ok(()))
605        }
606    }
607
608    #[inline]
609    /// Gracefully shutdown io stream
610    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
611        let st = self.st();
612        let flags = st.flags.get();
613
614        if flags.is_stopped() {
615            if let Some(err) = st.error() {
616                Poll::Ready(Err(err))
617            } else {
618                Poll::Ready(Ok(()))
619            }
620        } else {
621            if !flags.contains(Flags::IO_STOPPING_FILTERS) {
622                st.init_shutdown();
623            }
624
625            st.read_task.wake();
626            st.write_task.wake();
627            st.dispatch_task.register(cx.waker());
628            Poll::Pending
629        }
630    }
631
632    #[inline]
633    /// Pause read task
634    ///
635    /// Returns status updates
636    pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
637        self.pause();
638        let result = self.poll_status_update(cx);
639        if !result.is_pending() {
640            self.st().dispatch_task.register(cx.waker());
641        }
642        result
643    }
644
645    #[inline]
646    /// Wait for status updates
647    pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
648        let st = self.st();
649        let flags = st.flags.get();
650        if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
651            Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
652        } else if flags.contains(Flags::DSP_STOP) {
653            st.remove_flags(Flags::DSP_STOP);
654            Poll::Ready(IoStatusUpdate::Stop)
655        } else if flags.contains(Flags::DSP_TIMEOUT) {
656            st.remove_flags(Flags::DSP_TIMEOUT);
657            Poll::Ready(IoStatusUpdate::KeepAlive)
658        } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
659            Poll::Ready(IoStatusUpdate::WriteBackpressure)
660        } else {
661            st.dispatch_task.register(cx.waker());
662            Poll::Pending
663        }
664    }
665
666    #[inline]
667    /// Register dispatch task
668    pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
669        self.st().dispatch_task.register(cx.waker());
670    }
671}
672
673impl<F> AsRef<IoRef> for Io<F> {
674    #[inline]
675    fn as_ref(&self) -> &IoRef {
676        self.io_ref()
677    }
678}
679
680impl<F> Eq for Io<F> {}
681
682impl<F> PartialEq for Io<F> {
683    #[inline]
684    fn eq(&self, other: &Self) -> bool {
685        self.io_ref().eq(other.io_ref())
686    }
687}
688
689impl<F> hash::Hash for Io<F> {
690    #[inline]
691    fn hash<H: hash::Hasher>(&self, state: &mut H) {
692        self.io_ref().hash(state);
693    }
694}
695
696impl<F> fmt::Debug for Io<F> {
697    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698        f.debug_struct("Io").field("state", self.st()).finish()
699    }
700}
701
702impl<F> ops::Deref for Io<F> {
703    type Target = IoRef;
704
705    #[inline]
706    fn deref(&self) -> &Self::Target {
707        self.io_ref()
708    }
709}
710
711impl<F> Drop for Io<F> {
712    fn drop(&mut self) {
713        let st = self.st();
714        self.stop_timer();
715
716        if st.filter.is_set() {
717            // filter is unsafe and must be dropped explicitly,
718            // and wont be dropped without special attention
719            if !st.flags.get().is_stopped() {
720                log::trace!(
721                    "{}: Io is dropped, force stopping io streams {:?}",
722                    st.cfg.get().tag(),
723                    st.flags.get()
724                );
725            }
726
727            self.force_close();
728            st.filter.drop_filter::<F>();
729        }
730    }
731}
732
733const KIND_SEALED: u8 = 0b01;
734const KIND_PTR: u8 = 0b10;
735const KIND_MASK: u8 = 0b11;
736const KIND_UNMASK: u8 = !KIND_MASK;
737const KIND_MASK_USIZE: usize = 0b11;
738const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
739const SEALED_SIZE: usize = mem::size_of::<Sealed>();
740const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
741
742#[cfg(target_endian = "little")]
743const KIND_IDX: usize = 0;
744
745#[cfg(target_endian = "big")]
746const KIND_IDX: usize = SEALED_SIZE - 1;
747
748struct FilterPtr {
749    data: Cell<[u8; SEALED_SIZE]>,
750    filter: Cell<&'static dyn Filter>,
751}
752
753impl FilterPtr {
754    const fn null() -> Self {
755        Self {
756            data: Cell::new(NULL),
757            filter: Cell::new(NullFilter::get()),
758        }
759    }
760
761    fn update<F: Filter>(&self, filter: F) {
762        if self.is_set() {
763            panic!("Filter is set, must be dropped first");
764        }
765
766        let filter = Box::new(filter);
767        let mut data = NULL;
768        unsafe {
769            let filter_ref: &'static dyn Filter = {
770                let f: &dyn Filter = filter.as_ref();
771                mem::transmute(f)
772            };
773            self.filter.set(filter_ref);
774
775            let ptr = &mut data as *mut _ as *mut *mut F;
776            ptr.write(Box::into_raw(filter));
777            data[KIND_IDX] |= KIND_PTR;
778            self.data.set(data);
779        }
780    }
781
782    /// Get filter, panic if it is not filter
783    fn filter<F: Filter>(&self) -> &F {
784        let data = self.data.get();
785        if data[KIND_IDX] & KIND_PTR != 0 {
786            let ptr = &data as *const _ as *const *mut F;
787            unsafe {
788                let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
789                (p as *const F as *mut F).as_ref().unwrap()
790            }
791        } else {
792            panic!("Wrong filter item");
793        }
794    }
795
796    /// Get filter, panic if it is not set
797    fn take_filter<F>(&self) -> Box<F> {
798        let mut data = self.data.get();
799        if data[KIND_IDX] & KIND_PTR != 0 {
800            data[KIND_IDX] &= KIND_UNMASK;
801            let ptr = &mut data as *mut _ as *mut *mut F;
802            unsafe { Box::from_raw(*ptr) }
803        } else {
804            panic!(
805                "Wrong filter item {:?} expected: {:?}",
806                data[KIND_IDX], KIND_PTR
807            );
808        }
809    }
810
811    /// Get sealed, panic if it is already sealed
812    fn take_sealed(&self) -> Sealed {
813        let mut data = self.data.get();
814
815        if data[KIND_IDX] & KIND_SEALED != 0 {
816            data[KIND_IDX] &= KIND_UNMASK;
817            let ptr = &mut data as *mut _ as *mut Sealed;
818            unsafe { ptr.read() }
819        } else {
820            panic!(
821                "Wrong filter item {:?} expected: {:?}",
822                data[KIND_IDX], KIND_SEALED
823            );
824        }
825    }
826
827    fn is_set(&self) -> bool {
828        self.data.get()[KIND_IDX] & KIND_MASK != 0
829    }
830
831    fn drop_filter<F>(&self) {
832        let data = self.data.get();
833
834        if data[KIND_IDX] & KIND_MASK != 0 {
835            if data[KIND_IDX] & KIND_PTR != 0 {
836                self.take_filter::<F>();
837            } else if data[KIND_IDX] & KIND_SEALED != 0 {
838                self.take_sealed();
839            }
840            self.data.set(NULL);
841            self.filter.set(NullFilter::get());
842        }
843    }
844}
845
846impl FilterPtr {
847    fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
848        let data = self.data.get();
849        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
850            Box::new(Layer::new(new, *self.take_filter::<F>()))
851        } else if data[KIND_IDX] & KIND_SEALED != 0 {
852            let f = Box::new(Layer::new(new, self.take_sealed()));
853            // SAFETY: If filter is marked as Sealed, then F = Sealed
854            unsafe { mem::transmute::<Box<Layer<T, Sealed>>, Box<Layer<T, F>>>(f) }
855        } else {
856            panic!(
857                "Wrong filter item {:?} expected: {:?}",
858                data[KIND_IDX], KIND_PTR
859            );
860        };
861
862        let mut data = NULL;
863        unsafe {
864            let filter_ref: &'static dyn Filter = {
865                let f: &dyn Filter = filter.as_ref();
866                mem::transmute(f)
867            };
868            self.filter.set(filter_ref);
869
870            let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
871            ptr.write(Box::into_raw(filter));
872            data[KIND_IDX] |= KIND_PTR;
873            self.data.set(data);
874        }
875    }
876
877    fn map_filter<F: Filter, U, R>(&self, f: U)
878    where
879        U: FnOnce(F) -> R,
880        R: Filter,
881    {
882        let mut data = NULL;
883        let filter = Box::new(f(*self.take_filter::<F>()));
884        unsafe {
885            let filter_ref: &'static dyn Filter = {
886                let f: &dyn Filter = filter.as_ref();
887                mem::transmute(f)
888            };
889            self.filter.set(filter_ref);
890
891            let ptr = &mut data as *mut _ as *mut *mut R;
892            ptr.write(Box::into_raw(filter));
893            data[KIND_IDX] |= KIND_PTR;
894            self.data.set(data);
895        }
896    }
897
898    fn seal<F: Filter>(&self) {
899        let mut data = self.data.get();
900
901        let filter = if data[KIND_IDX] & KIND_PTR != 0 {
902            Sealed(Box::new(*self.take_filter::<F>()))
903        } else if data[KIND_IDX] & KIND_SEALED != 0 {
904            self.take_sealed()
905        } else {
906            panic!(
907                "Wrong filter item {:?} expected: {:?}",
908                data[KIND_IDX], KIND_PTR
909            );
910        };
911
912        unsafe {
913            let filter_ref: &'static dyn Filter = {
914                let f: &dyn Filter = filter.0.as_ref();
915                mem::transmute(f)
916            };
917            self.filter.set(filter_ref);
918
919            let ptr = &mut data as *mut _ as *mut Sealed;
920            ptr.write(filter);
921            data[KIND_IDX] |= KIND_SEALED;
922            self.data.set(data);
923        }
924    }
925}
926
927#[derive(Debug)]
928/// OnDisconnect future resolves when socket get disconnected
929#[must_use = "OnDisconnect do nothing unless polled"]
930pub struct OnDisconnect {
931    token: usize,
932    inner: Rc<IoState>,
933}
934
935impl OnDisconnect {
936    pub(super) fn new(inner: Rc<IoState>) -> Self {
937        Self::new_inner(inner.flags.get().is_stopped(), inner)
938    }
939
940    fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
941        let token = if disconnected {
942            usize::MAX
943        } else {
944            let mut on_disconnect = inner.on_disconnect.take();
945            let token = if let Some(ref mut on_disconnect) = on_disconnect {
946                let token = on_disconnect.len();
947                on_disconnect.push(LocalWaker::default());
948                token
949            } else {
950                on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
951                0
952            };
953            inner.on_disconnect.set(on_disconnect);
954            token
955        };
956        Self { token, inner }
957    }
958
959    #[inline]
960    /// Check if connection is disconnected
961    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
962        if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
963            Poll::Ready(())
964        } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
965            on_disconnect[self.token].register(cx.waker());
966            self.inner.on_disconnect.set(Some(on_disconnect));
967            Poll::Pending
968        } else {
969            Poll::Ready(())
970        }
971    }
972}
973
974impl Clone for OnDisconnect {
975    fn clone(&self) -> Self {
976        if self.token == usize::MAX {
977            OnDisconnect::new_inner(true, self.inner.clone())
978        } else {
979            OnDisconnect::new_inner(false, self.inner.clone())
980        }
981    }
982}
983
984impl Future for OnDisconnect {
985    type Output = ();
986
987    #[inline]
988    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
989        self.poll_ready(cx)
990    }
991}
992
993#[cfg(test)]
994mod tests {
995    use ntex_bytes::Bytes;
996    use ntex_codec::BytesCodec;
997
998    use super::*;
999    use crate::{ReadBuf, WriteBuf, testing::IoTest};
1000
1001    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
1002    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
1003
1004    #[ntex::test]
1005    async fn test_basics() {
1006        let (client, server) = IoTest::create();
1007        client.remote_buffer_cap(1024);
1008
1009        let server = Io::from(server);
1010        assert!(server.eq(&server));
1011        assert!(server.io_ref().eq(server.io_ref()));
1012
1013        assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
1014        assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
1015        assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
1016    }
1017
1018    #[ntex::test]
1019    async fn test_recv() {
1020        let (client, server) = IoTest::create();
1021        client.remote_buffer_cap(1024);
1022
1023        let server = Io::from(server);
1024
1025        server.st().notify_timeout();
1026        let err = server.recv(&BytesCodec).await.err().unwrap();
1027        assert!(format!("{err:?}").contains("Timeout"));
1028
1029        server.st().insert_flags(Flags::DSP_STOP);
1030        let err = server.recv(&BytesCodec).await.err().unwrap();
1031        assert!(format!("{err:?}").contains("Dispatcher stopped"));
1032
1033        client.write(TEXT);
1034        server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
1035        let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
1036        assert_eq!(item, TEXT);
1037    }
1038
1039    #[ntex::test]
1040    async fn test_send() {
1041        let (client, server) = IoTest::create();
1042        client.remote_buffer_cap(1024);
1043
1044        let server = Io::from(server);
1045        assert!(server.eq(&server));
1046
1047        server
1048            .send(Bytes::from_static(BIN), &BytesCodec)
1049            .await
1050            .ok()
1051            .unwrap();
1052        let item = client.read_any();
1053        assert_eq!(item, TEXT);
1054    }
1055
1056    #[derive(Debug)]
1057    struct DropFilter {
1058        p: Rc<Cell<usize>>,
1059    }
1060
1061    impl Drop for DropFilter {
1062        fn drop(&mut self) {
1063            self.p.set(self.p.get() + 1);
1064        }
1065    }
1066
1067    impl FilterLayer for DropFilter {
1068        fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1069            if let Some(src) = buf.take_src() {
1070                let len = src.len();
1071                buf.set_dst(Some(src));
1072                Ok(len)
1073            } else {
1074                Ok(0)
1075            }
1076        }
1077        fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
1078            if let Some(src) = buf.take_src() {
1079                buf.set_dst(Some(src));
1080            }
1081            Ok(())
1082        }
1083    }
1084
1085    #[ntex::test]
1086    async fn drop_filter() {
1087        let p = Rc::new(Cell::new(0));
1088
1089        let (client, server) = IoTest::create();
1090        let f = DropFilter { p: p.clone() };
1091        let _ = format!("{f:?}");
1092        let io = Io::from(server).add_filter(f);
1093
1094        client.remote_buffer_cap(1024);
1095        client.write(TEXT);
1096        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1097        assert_eq!(msg, Bytes::from_static(BIN));
1098
1099        io.send(Bytes::from_static(b"test"), &BytesCodec)
1100            .await
1101            .unwrap();
1102        let buf = client.read().await.unwrap();
1103        assert_eq!(buf, Bytes::from_static(b"test"));
1104
1105        let io2 = io.take();
1106        let mut io3: crate::IoBoxed = io2.into();
1107        let io4 = io3.take();
1108
1109        drop(io);
1110        drop(io3);
1111        drop(io4);
1112
1113        assert_eq!(p.get(), 1);
1114    }
1115
1116    #[ntex::test]
1117    async fn test_take_sealed_filter() {
1118        let p = Rc::new(Cell::new(0));
1119        let f = DropFilter { p: p.clone() };
1120
1121        let io = Io::from(IoTest::create().0).seal();
1122        let _io: Io<Layer<DropFilter, Sealed>> = io.add_filter(f);
1123    }
1124}