Skip to main content

ntex_io/
io.rs

1use std::cell::{Cell, UnsafeCell};
2use std::future::{Future, poll_fn};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::cfg::SharedCfg;
8use ntex_util::{future::Either, task::LocalWaker};
9
10use crate::buf::Stack;
11use crate::cfg::{BufConfig, IoConfig};
12use crate::filter::{Base, Filter, Layer};
13use crate::filterptr::FilterPtr;
14use crate::flags::Flags;
15use crate::seal::{IoBoxed, Sealed};
16use crate::timer::TimerHandle;
17use crate::{Decoded, FilterLayer, Handle, IoContext, IoStatusUpdate, IoStream, RecvError};
18
19/// Interface object to underlying io stream
20pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
21
22#[derive(Clone)]
23pub struct IoRef(pub(super) Rc<IoState>);
24
25pub(crate) struct IoState {
26    filter: FilterPtr,
27    pub(super) cfg: Cell<&'static IoConfig>,
28    pub(super) flags: Cell<Flags>,
29    pub(super) error: Cell<Option<io::Error>>,
30    pub(super) read_task: LocalWaker,
31    pub(super) write_task: LocalWaker,
32    pub(super) dispatch_task: LocalWaker,
33    pub(super) buffer: Stack,
34    pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35    pub(super) timeout: Cell<TimerHandle>,
36    #[allow(clippy::box_collection)]
37    pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
38}
39
40impl IoState {
41    pub(super) fn filter(&self) -> &dyn Filter {
42        self.filter.get()
43    }
44
45    pub(super) fn insert_flags(&self, f: Flags) {
46        let mut flags = self.flags.get();
47        flags.insert(f);
48        self.flags.set(flags);
49    }
50
51    pub(super) fn remove_flags(&self, f: Flags) -> bool {
52        let mut flags = self.flags.get();
53        if flags.intersects(f) {
54            flags.remove(f);
55            self.flags.set(flags);
56            true
57        } else {
58            false
59        }
60    }
61
62    pub(super) fn notify_timeout(&self) {
63        let mut flags = self.flags.get();
64        if !flags.contains(Flags::DSP_TIMEOUT) {
65            flags.insert(Flags::DSP_TIMEOUT);
66            self.flags.set(flags);
67            self.dispatch_task.wake();
68            log::trace!("{}: Timer, notify dispatcher", self.cfg.get().tag());
69        }
70    }
71
72    pub(super) fn notify_disconnect(&self) {
73        if let Some(on_disconnect) = self.on_disconnect.take() {
74            for item in on_disconnect.into_iter() {
75                item.wake();
76            }
77        }
78    }
79
80    /// Get current io error
81    pub(super) fn error(&self) -> Option<io::Error> {
82        if let Some(err) = self.error.take() {
83            self.error
84                .set(Some(io::Error::new(err.kind(), format!("{err}"))));
85            Some(err)
86        } else {
87            None
88        }
89    }
90
91    /// Get current io result
92    pub(super) fn error_or_disconnected(&self) -> io::Error {
93        self.error()
94            .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
95    }
96
97    pub(super) fn io_stopped(&self, err: Option<io::Error>) {
98        if !self.flags.get().is_stopped() {
99            log::trace!(
100                "{}: {:?} Io error {:?} flags: {:?}",
101                self.cfg.get().tag(),
102                ptr::from_ref(self),
103                err,
104                self.flags.get()
105            );
106
107            if err.is_some() {
108                self.error.set(err);
109            }
110            self.read_task.wake();
111            self.write_task.wake();
112            self.notify_disconnect();
113            self.handle.take();
114            self.insert_flags(
115                Flags::IO_STOPPED
116                    | Flags::IO_STOPPING
117                    | Flags::IO_STOPPING_FILTERS
118                    | Flags::BUF_R_READY,
119            );
120            if !self.dispatch_task.wake_checked() {
121                log::trace!(
122                    "{}: {:?} Dispatcher is not registered, flags: {:?}",
123                    self.cfg.get().tag(),
124                    ptr::from_ref(self),
125                    self.flags.get()
126                );
127            }
128        }
129    }
130
131    /// Gracefully shutdown read and write io tasks
132    pub(super) fn init_shutdown(&self) {
133        if !self
134            .flags
135            .get()
136            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
137        {
138            log::trace!(
139                "{}: Initiate io shutdown {:?}",
140                self.cfg.get().tag(),
141                self.flags.get()
142            );
143            self.insert_flags(Flags::IO_STOPPING_FILTERS);
144            self.read_task.wake();
145        }
146    }
147
148    #[inline]
149    pub(super) fn read_buf(&self) -> &BufConfig {
150        self.cfg.get().read_buf()
151    }
152
153    #[inline]
154    pub(super) fn write_buf(&self) -> &BufConfig {
155        self.cfg.get().write_buf()
156    }
157}
158
159impl Eq for IoState {}
160
161impl PartialEq for IoState {
162    #[inline]
163    fn eq(&self, other: &Self) -> bool {
164        ptr::eq(self, other)
165    }
166}
167
168impl hash::Hash for IoState {
169    #[inline]
170    fn hash<H: hash::Hasher>(&self, state: &mut H) {
171        (ptr::from_ref(self) as usize).hash(state);
172    }
173}
174
175impl fmt::Debug for IoState {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        let err = self.error.take();
178        let res = f
179            .debug_struct("IoState")
180            .field("flags", &self.flags)
181            .field("filter", &self.filter.is_set())
182            .field("timeout", &self.timeout)
183            .field("error", &err)
184            .field("buffer", &self.buffer)
185            .field("cfg", &self.cfg)
186            .finish();
187        self.error.set(err);
188        res
189    }
190}
191
192impl Io {
193    #[inline]
194    /// Create `Io` instance
195    pub fn new<I: IoStream, T: Into<SharedCfg>>(io: I, cfg: T) -> Self {
196        let inner = Rc::new(IoState {
197            cfg: Cell::new(cfg.into().get::<IoConfig>().into_static()),
198            filter: FilterPtr::null(),
199            flags: Cell::new(Flags::WR_PAUSED),
200            error: Cell::new(None),
201            dispatch_task: LocalWaker::new(),
202            read_task: LocalWaker::new(),
203            write_task: LocalWaker::new(),
204            buffer: Stack::new(),
205            handle: Cell::new(None),
206            timeout: Cell::new(TimerHandle::default()),
207            on_disconnect: Cell::new(None),
208        });
209        inner.filter.set(Base::new(IoRef(inner.clone())));
210
211        let io_ref = IoRef(inner);
212
213        // start io tasks
214        let hnd = io.start(IoContext::new(&io_ref));
215        io_ref.0.handle.set(hnd);
216
217        Io(UnsafeCell::new(io_ref), marker::PhantomData)
218    }
219}
220
221impl<I: IoStream> From<I> for Io {
222    #[inline]
223    fn from(io: I) -> Io {
224        Io::new(io, SharedCfg::default())
225    }
226}
227
228impl<F> Io<F> {
229    #[inline]
230    #[must_use]
231    /// Clone current io object.
232    ///
233    /// Current io object becomes closed.
234    pub fn take(&self) -> Self {
235        Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
236    }
237
238    fn take_io_ref(&self) -> IoRef {
239        let inner = Rc::new(IoState {
240            cfg: Cell::new(SharedCfg::default().get::<IoConfig>().into_static()),
241            filter: FilterPtr::null(),
242            flags: Cell::new(
243                Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
244            ),
245            error: Cell::new(None),
246            dispatch_task: LocalWaker::new(),
247            read_task: LocalWaker::new(),
248            write_task: LocalWaker::new(),
249            buffer: Stack::new(),
250            handle: Cell::new(None),
251            timeout: Cell::new(TimerHandle::default()),
252            on_disconnect: Cell::new(None),
253        });
254        unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
255    }
256
257    #[inline]
258    #[doc(hidden)]
259    /// Get current state flags
260    pub fn flags(&self) -> Flags {
261        self.st().flags.get()
262    }
263
264    #[inline]
265    /// Get instance of `IoRef`
266    pub fn get_ref(&self) -> IoRef {
267        self.io_ref().clone()
268    }
269
270    fn st(&self) -> &IoState {
271        unsafe { &(*self.0.get()).0 }
272    }
273
274    fn io_ref(&self) -> &IoRef {
275        unsafe { &*self.0.get() }
276    }
277
278    #[inline]
279    /// Set shared io config
280    pub fn set_config<T: Into<SharedCfg>>(&self, cfg: T) {
281        self.st()
282            .cfg
283            .set(cfg.into().get::<IoConfig>().into_static());
284    }
285}
286
287impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
288    #[inline]
289    /// Get referece to a filter
290    pub fn filter(&self) -> &F {
291        &self.st().filter.filter::<Layer<F, T>>().0
292    }
293}
294
295impl<F: Filter> Io<F> {
296    #[inline]
297    /// Convert current io stream into sealed version
298    pub fn seal(self) -> Io<Sealed> {
299        let state = self.take_io_ref();
300        state.0.filter.seal::<F>();
301
302        Io(UnsafeCell::new(state), marker::PhantomData)
303    }
304
305    #[inline]
306    /// Convert current io stream into boxed version
307    pub fn boxed(self) -> IoBoxed {
308        self.seal().into()
309    }
310
311    #[inline]
312    /// Add new layer current current filter
313    pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
314    where
315        U: FilterLayer,
316    {
317        let state = self.take_io_ref();
318
319        // add buffers layer
320        // Safety: .add_layer() only increases internal buffers
321        // there is no api that holds references into buffers storage
322        // all apis first removes buffer from storage and then work with it
323        unsafe { &mut *(Rc::as_ptr(&state.0).cast_mut()) }
324            .buffer
325            .add_layer();
326
327        // replace current filter
328        state.0.filter.add_filter::<F, U>(nf);
329
330        Io(UnsafeCell::new(state), marker::PhantomData)
331    }
332
333    /// Map layer
334    pub fn map_filter<U, R>(self, f: U) -> Io<R>
335    where
336        U: FnOnce(F) -> R,
337        R: Filter,
338    {
339        let state = self.take_io_ref();
340        state.0.filter.map_filter::<F, U, R>(f);
341
342        Io(UnsafeCell::new(state), marker::PhantomData)
343    }
344}
345
346impl<F> Io<F> {
347    #[inline]
348    /// Read incoming io stream and decode codec item.
349    pub async fn recv<U>(
350        &self,
351        codec: &U,
352    ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
353    where
354        U: Decoder,
355    {
356        loop {
357            return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
358                Ok(item) => Ok(Some(item)),
359                Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
360                    io::ErrorKind::TimedOut,
361                    "Timeout",
362                ))),
363                Err(RecvError::WriteBackpressure) => {
364                    poll_fn(|cx| self.poll_flush(cx, false))
365                        .await
366                        .map_err(Either::Right)?;
367                    continue;
368                }
369                Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
370                Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
371                Err(RecvError::PeerGone(None)) => Ok(None),
372            };
373        }
374    }
375
376    #[inline]
377    /// Wait until read becomes ready.
378    pub async fn read_ready(&self) -> io::Result<Option<()>> {
379        poll_fn(|cx| self.poll_read_ready(cx)).await
380    }
381
382    #[inline]
383    /// Wait until io reads any data.
384    pub async fn read_notify(&self) -> io::Result<Option<()>> {
385        poll_fn(|cx| self.poll_read_notify(cx)).await
386    }
387
388    #[inline]
389    /// Pause read task
390    pub fn pause(&self) {
391        let st = self.st();
392        if !st.flags.get().contains(Flags::RD_PAUSED) {
393            st.read_task.wake();
394            st.insert_flags(Flags::RD_PAUSED);
395        }
396    }
397
398    #[inline]
399    /// Encode item, send to the peer. Fully flush write buffer.
400    pub async fn send<U>(
401        &self,
402        item: U::Item,
403        codec: &U,
404    ) -> Result<(), Either<U::Error, io::Error>>
405    where
406        U: Encoder,
407    {
408        self.encode(item, codec).map_err(Either::Left)?;
409
410        poll_fn(|cx| self.poll_flush(cx, true))
411            .await
412            .map_err(Either::Right)?;
413
414        Ok(())
415    }
416
417    #[inline]
418    /// Wake write task and instruct to flush data.
419    ///
420    /// This is async version of `poll_flush()` method.
421    pub async fn flush(&self, full: bool) -> io::Result<()> {
422        poll_fn(|cx| self.poll_flush(cx, full)).await
423    }
424
425    #[inline]
426    /// Gracefully shutdown io stream
427    pub async fn shutdown(&self) -> io::Result<()> {
428        poll_fn(|cx| self.poll_shutdown(cx)).await
429    }
430
431    #[inline]
432    /// Polls for read readiness.
433    ///
434    /// If the io stream is not currently ready for reading,
435    /// this method will store a clone of the Waker from the provided Context.
436    /// When the io stream becomes ready for reading, `Waker::wake()` will be called on the waker.
437    ///
438    /// Return value
439    /// The function returns:
440    ///
441    /// `Poll::Pending` if the io stream is not ready for reading.
442    /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
443    /// `Poll::Ready(Ok(None))` if io stream is disconnected
444    /// `Some(Poll::Ready(Err(e)))` if an error is encountered.
445    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
446        let st = self.st();
447        let mut flags = st.flags.get();
448
449        if flags.is_stopped() {
450            Poll::Ready(Err(st.error_or_disconnected()))
451        } else {
452            st.dispatch_task.register(cx.waker());
453
454            let ready = flags.is_read_buf_ready();
455            if flags.cannot_read() {
456                flags.cleanup_read_flags();
457                st.read_task.wake();
458                st.flags.set(flags);
459                if ready {
460                    Poll::Ready(Ok(Some(())))
461                } else {
462                    Poll::Pending
463                }
464            } else if ready {
465                flags.remove(Flags::BUF_R_READY);
466                st.flags.set(flags);
467                Poll::Ready(Ok(Some(())))
468            } else {
469                Poll::Pending
470            }
471        }
472    }
473
474    #[inline]
475    /// Polls for any incoming data.
476    pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
477        let ready = self.poll_read_ready(cx);
478
479        if ready.is_pending() {
480            let st = self.st();
481            if st.remove_flags(Flags::RD_NOTIFY) {
482                Poll::Ready(Ok(Some(())))
483            } else {
484                st.insert_flags(Flags::RD_NOTIFY);
485                Poll::Pending
486            }
487        } else {
488            ready
489        }
490    }
491
492    #[inline]
493    /// Decode codec item from incoming bytes stream.
494    ///
495    /// Wake read task and request to read more data if data is not enough for decoding.
496    /// If error get returned this method does not register waker for later wake up action.
497    pub fn poll_recv<U>(
498        &self,
499        codec: &U,
500        cx: &mut Context<'_>,
501    ) -> Poll<Result<U::Item, RecvError<U>>>
502    where
503        U: Decoder,
504    {
505        let decoded = self.poll_recv_decode(codec, cx)?;
506
507        if let Some(item) = decoded.item {
508            Poll::Ready(Ok(item))
509        } else {
510            Poll::Pending
511        }
512    }
513
514    #[doc(hidden)]
515    #[inline]
516    /// Decode codec item from incoming bytes stream.
517    ///
518    /// Wake read task and request to read more data if data is not enough for decoding.
519    /// If error get returned this method does not register waker for later wake up action.
520    pub fn poll_recv_decode<U>(
521        &self,
522        codec: &U,
523        cx: &mut Context<'_>,
524    ) -> Result<Decoded<U::Item>, RecvError<U>>
525    where
526        U: Decoder,
527    {
528        let decoded = self
529            .decode_item(codec)
530            .map_err(|err| RecvError::Decoder(err))?;
531
532        if decoded.item.is_some() {
533            Ok(decoded)
534        } else {
535            let st = self.st();
536            let flags = st.flags.get();
537            if flags.is_stopped() {
538                Err(RecvError::PeerGone(st.error()))
539            } else if flags.contains(Flags::DSP_TIMEOUT) {
540                st.remove_flags(Flags::DSP_TIMEOUT);
541                Err(RecvError::KeepAlive)
542            } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
543                Err(RecvError::WriteBackpressure)
544            } else {
545                match self.poll_read_ready(cx) {
546                    Poll::Pending | Poll::Ready(Ok(Some(()))) => {
547                        if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
548                            log::trace!(
549                                "{}: Not enough data to decode next frame",
550                                self.tag()
551                            );
552                        }
553                        Ok(decoded)
554                    }
555                    Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
556                    Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
557                }
558            }
559        }
560    }
561
562    #[inline]
563    /// Wake write task and instruct to flush data.
564    ///
565    /// If `full` is true then wake up dispatcher when all data is flushed
566    /// otherwise wake up when size of write buffer is lower than
567    /// buffer max size.
568    pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
569        let st = self.st();
570        let flags = self.flags();
571
572        let len = st.buffer.write_destination_size();
573        if len > 0 {
574            if full {
575                st.insert_flags(Flags::BUF_W_MUST_FLUSH);
576                st.dispatch_task.register(cx.waker());
577                return if flags.is_stopped() {
578                    Poll::Ready(Err(st.error_or_disconnected()))
579                } else {
580                    Poll::Pending
581                };
582            } else if len >= st.write_buf().half {
583                st.insert_flags(Flags::BUF_W_BACKPRESSURE);
584                st.dispatch_task.register(cx.waker());
585                return if flags.is_stopped() {
586                    Poll::Ready(Err(st.error_or_disconnected()))
587                } else {
588                    Poll::Pending
589                };
590            }
591        }
592        if flags.is_stopped() {
593            Poll::Ready(Err(st.error_or_disconnected()))
594        } else {
595            st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
596            Poll::Ready(Ok(()))
597        }
598    }
599
600    #[inline]
601    /// Gracefully shutdown io stream
602    pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
603        let st = self.st();
604        let flags = st.flags.get();
605
606        if flags.is_stopped() {
607            if let Some(err) = st.error() {
608                Poll::Ready(Err(err))
609            } else {
610                Poll::Ready(Ok(()))
611            }
612        } else {
613            if !flags.contains(Flags::IO_STOPPING_FILTERS) {
614                st.init_shutdown();
615            }
616
617            st.read_task.wake();
618            st.write_task.wake();
619            st.dispatch_task.register(cx.waker());
620            Poll::Pending
621        }
622    }
623
624    #[inline]
625    /// Pause read task
626    ///
627    /// Returns status updates
628    pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
629        self.pause();
630        let result = self.poll_status_update(cx);
631        if !result.is_pending() {
632            self.st().dispatch_task.register(cx.waker());
633        }
634        result
635    }
636
637    #[inline]
638    /// Wait for status updates
639    pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
640        let st = self.st();
641        let flags = st.flags.get();
642        if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
643            Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
644        } else if flags.contains(Flags::DSP_TIMEOUT) {
645            st.remove_flags(Flags::DSP_TIMEOUT);
646            Poll::Ready(IoStatusUpdate::KeepAlive)
647        } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
648            Poll::Ready(IoStatusUpdate::WriteBackpressure)
649        } else {
650            st.dispatch_task.register(cx.waker());
651            Poll::Pending
652        }
653    }
654
655    #[inline]
656    /// Register dispatch task
657    pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
658        self.st().dispatch_task.register(cx.waker());
659    }
660}
661
662impl<F> AsRef<IoRef> for Io<F> {
663    #[inline]
664    fn as_ref(&self) -> &IoRef {
665        self.io_ref()
666    }
667}
668
669impl<F> Eq for Io<F> {}
670
671impl<F> PartialEq for Io<F> {
672    #[inline]
673    fn eq(&self, other: &Self) -> bool {
674        self.io_ref().eq(other.io_ref())
675    }
676}
677
678impl<F> hash::Hash for Io<F> {
679    #[inline]
680    fn hash<H: hash::Hasher>(&self, state: &mut H) {
681        self.io_ref().hash(state);
682    }
683}
684
685impl<F> fmt::Debug for Io<F> {
686    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
687        f.debug_struct("Io").field("state", self.st()).finish()
688    }
689}
690
691impl<F> ops::Deref for Io<F> {
692    type Target = IoRef;
693
694    #[inline]
695    fn deref(&self) -> &Self::Target {
696        self.io_ref()
697    }
698}
699
700impl<F> Drop for Io<F> {
701    fn drop(&mut self) {
702        let st = self.st();
703        self.stop_timer();
704
705        if st.filter.is_set() {
706            // filter is unsafe and must be dropped explicitly,
707            // and wont be dropped without special attention
708            if !st.flags.get().is_stopped() {
709                log::trace!(
710                    "{}: Io is dropped, force stopping io streams {:?}",
711                    st.cfg.get().tag(),
712                    st.flags.get()
713                );
714            }
715
716            self.force_close();
717            st.filter.drop_filter::<F>();
718        }
719    }
720}
721
722#[derive(Debug)]
723/// `OnDisconnect` future resolves when socket get disconnected
724#[must_use = "OnDisconnect do nothing unless polled"]
725pub struct OnDisconnect {
726    token: usize,
727    inner: Rc<IoState>,
728}
729
730impl OnDisconnect {
731    pub(super) fn new(inner: Rc<IoState>) -> Self {
732        Self::new_inner(inner.flags.get().is_stopped(), inner)
733    }
734
735    fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
736        let token = if disconnected {
737            usize::MAX
738        } else {
739            let mut on_disconnect = inner.on_disconnect.take();
740            let token = if let Some(ref mut on_disconnect) = on_disconnect {
741                let token = on_disconnect.len();
742                on_disconnect.push(LocalWaker::default());
743                token
744            } else {
745                on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
746                0
747            };
748            inner.on_disconnect.set(on_disconnect);
749            token
750        };
751        Self { token, inner }
752    }
753
754    #[inline]
755    /// Check if connection is disconnected
756    pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
757        if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
758            Poll::Ready(())
759        } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
760            on_disconnect[self.token].register(cx.waker());
761            self.inner.on_disconnect.set(Some(on_disconnect));
762            Poll::Pending
763        } else {
764            Poll::Ready(())
765        }
766    }
767}
768
769impl Clone for OnDisconnect {
770    fn clone(&self) -> Self {
771        if self.token == usize::MAX {
772            OnDisconnect::new_inner(true, self.inner.clone())
773        } else {
774            OnDisconnect::new_inner(false, self.inner.clone())
775        }
776    }
777}
778
779impl Future for OnDisconnect {
780    type Output = ();
781
782    #[inline]
783    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
784        self.poll_ready(cx)
785    }
786}
787
788#[cfg(test)]
789mod tests {
790    use ntex_bytes::Bytes;
791    use ntex_codec::BytesCodec;
792
793    use super::*;
794    use crate::testing::IoTest;
795
796    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
797    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
798
799    #[ntex::test]
800    async fn test_basics() {
801        let (client, server) = IoTest::create();
802        client.remote_buffer_cap(1024);
803
804        let server = Io::from(server);
805        assert!(server.eq(&server));
806        assert!(server.io_ref().eq(server.io_ref()));
807
808        assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
809        assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
810        assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
811    }
812
813    #[ntex::test]
814    async fn test_recv() {
815        let (client, server) = IoTest::create();
816        client.remote_buffer_cap(1024);
817
818        let server = Io::from(server);
819
820        server.st().notify_timeout();
821        let err = server.recv(&BytesCodec).await.err().unwrap();
822        assert!(format!("{err:?}").contains("Timeout"));
823
824        client.write(TEXT);
825        server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
826        let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
827        assert_eq!(item, TEXT);
828    }
829
830    #[ntex::test]
831    async fn test_send() {
832        let (client, server) = IoTest::create();
833        client.remote_buffer_cap(1024);
834
835        let server = Io::from(server);
836        assert!(server.eq(&server));
837
838        server
839            .send(Bytes::from_static(BIN), &BytesCodec)
840            .await
841            .ok()
842            .unwrap();
843        let item = client.read_any();
844        assert_eq!(item, TEXT);
845    }
846}