1use std::cell::{Cell, UnsafeCell};
2use std::future::{poll_fn, Future};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_bytes::{PoolId, PoolRef};
7use ntex_codec::{Decoder, Encoder};
8use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
9
10use crate::buf::Stack;
11use crate::filter::{Base, Filter, Layer, NullFilter};
12use crate::flags::Flags;
13use crate::seal::{IoBoxed, Sealed};
14use crate::tasks::{ReadContext, WriteContext};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
17
18pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25 filter: FilterPtr,
26 pub(super) flags: Cell<Flags>,
27 pub(super) pool: Cell<PoolRef>,
28 pub(super) disconnect_timeout: Cell<Seconds>,
29 pub(super) error: Cell<Option<io::Error>>,
30 pub(super) read_task: LocalWaker,
31 pub(super) write_task: LocalWaker,
32 pub(super) dispatch_task: LocalWaker,
33 pub(super) buffer: Stack,
34 pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35 pub(super) timeout: Cell<TimerHandle>,
36 pub(super) tag: Cell<&'static str>,
37 #[allow(clippy::box_collection)]
38 pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
39}
40
41const DEFAULT_TAG: &str = "IO";
42
43impl IoState {
44 pub(super) fn filter(&self) -> &dyn Filter {
45 self.filter.filter.get()
46 }
47
48 pub(super) fn insert_flags(&self, f: Flags) {
49 let mut flags = self.flags.get();
50 flags.insert(f);
51 self.flags.set(flags);
52 }
53
54 pub(super) fn remove_flags(&self, f: Flags) -> bool {
55 let mut flags = self.flags.get();
56 if flags.intersects(f) {
57 flags.remove(f);
58 self.flags.set(flags);
59 true
60 } else {
61 false
62 }
63 }
64
65 pub(super) fn notify_timeout(&self) {
66 let mut flags = self.flags.get();
67 if !flags.contains(Flags::DSP_TIMEOUT) {
68 flags.insert(Flags::DSP_TIMEOUT);
69 self.flags.set(flags);
70 self.dispatch_task.wake();
71 log::trace!("{}: Timer, notify dispatcher", self.tag.get());
72 }
73 }
74
75 pub(super) fn notify_disconnect(&self) {
76 if let Some(on_disconnect) = self.on_disconnect.take() {
77 for item in on_disconnect.into_iter() {
78 item.wake();
79 }
80 }
81 }
82
83 pub(super) fn error(&self) -> Option<io::Error> {
85 if let Some(err) = self.error.take() {
86 self.error
87 .set(Some(io::Error::new(err.kind(), format!("{}", err))));
88 Some(err)
89 } else {
90 None
91 }
92 }
93
94 pub(super) fn error_or_disconnected(&self) -> io::Error {
96 self.error()
97 .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
98 }
99
100 pub(super) fn io_stopped(&self, err: Option<io::Error>) {
101 if !self.flags.get().contains(Flags::IO_STOPPED) {
102 if err.is_some() {
103 self.error.set(err);
104 }
105 self.read_task.wake();
106 self.write_task.wake();
107 self.dispatch_task.wake();
108 self.notify_disconnect();
109 self.handle.take();
110 self.insert_flags(
111 Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
112 );
113 }
114 }
115
116 pub(super) fn init_shutdown(&self) {
118 if !self
119 .flags
120 .get()
121 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
122 {
123 log::trace!(
124 "{}: Initiate io shutdown {:?}",
125 self.tag.get(),
126 self.flags.get()
127 );
128 self.insert_flags(Flags::IO_STOPPING_FILTERS);
129 self.read_task.wake();
130 }
131 }
132}
133
134impl Eq for IoState {}
135
136impl PartialEq for IoState {
137 #[inline]
138 fn eq(&self, other: &Self) -> bool {
139 ptr::eq(self, other)
140 }
141}
142
143impl hash::Hash for IoState {
144 #[inline]
145 fn hash<H: hash::Hasher>(&self, state: &mut H) {
146 (self as *const _ as usize).hash(state);
147 }
148}
149
150impl Drop for IoState {
151 #[inline]
152 fn drop(&mut self) {
153 self.buffer.release(self.pool.get());
154 }
155}
156
157impl fmt::Debug for IoState {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 let err = self.error.take();
160 let res = f
161 .debug_struct("IoState")
162 .field("flags", &self.flags)
163 .field("filter", &self.filter.is_set())
164 .field("disconnect_timeout", &self.disconnect_timeout)
165 .field("timeout", &self.timeout)
166 .field("error", &err)
167 .field("buffer", &self.buffer)
168 .finish();
169 self.error.set(err);
170 res
171 }
172}
173
174impl Io {
175 #[inline]
176 pub fn new<I: IoStream>(io: I) -> Self {
178 Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
179 }
180
181 #[inline]
182 pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
184 let inner = Rc::new(IoState {
185 filter: FilterPtr::null(),
186 pool: Cell::new(pool),
187 flags: Cell::new(Flags::WR_PAUSED),
188 error: Cell::new(None),
189 dispatch_task: LocalWaker::new(),
190 read_task: LocalWaker::new(),
191 write_task: LocalWaker::new(),
192 buffer: Stack::new(),
193 handle: Cell::new(None),
194 timeout: Cell::new(TimerHandle::default()),
195 disconnect_timeout: Cell::new(Seconds(1)),
196 on_disconnect: Cell::new(None),
197 tag: Cell::new(DEFAULT_TAG),
198 });
199 inner.filter.update(Base::new(IoRef(inner.clone())));
200
201 let io_ref = IoRef(inner);
202
203 let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
205 io_ref.0.handle.set(hnd);
206
207 Io(UnsafeCell::new(io_ref), marker::PhantomData)
208 }
209}
210
211impl<F> Io<F> {
212 #[inline]
213 pub fn set_memory_pool(&self, pool: PoolRef) {
215 self.st().buffer.set_memory_pool(pool);
216 self.st().pool.set(pool);
217 }
218
219 #[inline]
220 pub fn set_disconnect_timeout(&self, timeout: Seconds) {
222 self.st().disconnect_timeout.set(timeout);
223 }
224
225 #[inline]
226 pub fn take(&self) -> Self {
230 Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
231 }
232
233 fn take_io_ref(&self) -> IoRef {
234 let inner = Rc::new(IoState {
235 filter: FilterPtr::null(),
236 pool: self.st().pool.clone(),
237 flags: Cell::new(
238 Flags::DSP_STOP
239 | Flags::IO_STOPPED
240 | Flags::IO_STOPPING
241 | Flags::IO_STOPPING_FILTERS,
242 ),
243 error: Cell::new(None),
244 disconnect_timeout: Cell::new(Seconds(1)),
245 dispatch_task: LocalWaker::new(),
246 read_task: LocalWaker::new(),
247 write_task: LocalWaker::new(),
248 buffer: Stack::new(),
249 handle: Cell::new(None),
250 timeout: Cell::new(TimerHandle::default()),
251 on_disconnect: Cell::new(None),
252 tag: Cell::new(DEFAULT_TAG),
253 });
254 unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
255 }
256}
257
258impl<F> Io<F> {
259 #[inline]
260 #[doc(hidden)]
261 pub fn flags(&self) -> Flags {
263 self.st().flags.get()
264 }
265
266 #[inline]
267 pub fn get_ref(&self) -> IoRef {
269 self.io_ref().clone()
270 }
271
272 fn st(&self) -> &IoState {
273 unsafe { &(*self.0.get()).0 }
274 }
275
276 fn io_ref(&self) -> &IoRef {
277 unsafe { &*self.0.get() }
278 }
279}
280
281impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
282 #[inline]
283 pub fn filter(&self) -> &F {
285 &self.st().filter.filter::<Layer<F, T>>().0
286 }
287}
288
289impl<F: Filter> Io<F> {
290 #[inline]
291 pub fn seal(self) -> Io<Sealed> {
293 let state = self.take_io_ref();
294 state.0.filter.seal::<F>();
295
296 Io(UnsafeCell::new(state), marker::PhantomData)
297 }
298
299 #[inline]
300 pub fn boxed(self) -> IoBoxed {
302 self.seal().into()
303 }
304
305 #[inline]
306 pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
308 where
309 U: FilterLayer,
310 {
311 let state = self.take_io_ref();
312
313 if U::BUFFERS {
315 unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
319 .buffer
320 .add_layer();
321 }
322
323 state.0.filter.add_filter::<F, U>(nf);
325
326 Io(UnsafeCell::new(state), marker::PhantomData)
327 }
328}
329
330impl<F> Io<F> {
331 #[inline]
332 pub async fn recv<U>(
334 &self,
335 codec: &U,
336 ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
337 where
338 U: Decoder,
339 {
340 loop {
341 return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
342 Ok(item) => Ok(Some(item)),
343 Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
344 io::ErrorKind::TimedOut,
345 "Timeout",
346 ))),
347 Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
348 io::ErrorKind::UnexpectedEof,
349 "Dispatcher stopped",
350 ))),
351 Err(RecvError::WriteBackpressure) => {
352 poll_fn(|cx| self.poll_flush(cx, false))
353 .await
354 .map_err(Either::Right)?;
355 continue;
356 }
357 Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
358 Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
359 Err(RecvError::PeerGone(None)) => Ok(None),
360 };
361 }
362 }
363
364 #[inline]
365 pub async fn read_ready(&self) -> io::Result<Option<()>> {
367 poll_fn(|cx| self.poll_read_ready(cx)).await
368 }
369
370 #[inline]
371 pub async fn read_notify(&self) -> io::Result<Option<()>> {
373 poll_fn(|cx| self.poll_read_notify(cx)).await
374 }
375
376 #[inline]
377 pub fn pause(&self) {
379 let st = self.st();
380 if !st.flags.get().contains(Flags::RD_PAUSED) {
381 st.read_task.wake();
382 st.insert_flags(Flags::RD_PAUSED);
383 }
384 }
385
386 #[inline]
387 pub async fn send<U>(
389 &self,
390 item: U::Item,
391 codec: &U,
392 ) -> Result<(), Either<U::Error, io::Error>>
393 where
394 U: Encoder,
395 {
396 self.encode(item, codec).map_err(Either::Left)?;
397
398 poll_fn(|cx| self.poll_flush(cx, true))
399 .await
400 .map_err(Either::Right)?;
401
402 Ok(())
403 }
404
405 #[inline]
406 pub async fn flush(&self, full: bool) -> io::Result<()> {
410 poll_fn(|cx| self.poll_flush(cx, full)).await
411 }
412
413 #[inline]
414 pub async fn shutdown(&self) -> io::Result<()> {
416 poll_fn(|cx| self.poll_shutdown(cx)).await
417 }
418
419 #[inline]
420 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
434 let st = self.st();
435 let mut flags = st.flags.get();
436
437 if flags.is_stopped() {
438 Poll::Ready(Err(st.error_or_disconnected()))
439 } else {
440 st.dispatch_task.register(cx.waker());
441
442 let ready = flags.is_read_buf_ready();
443 if flags.cannot_read() {
444 flags.cleanup_read_flags();
445 st.read_task.wake();
446 st.flags.set(flags);
447 if ready {
448 Poll::Ready(Ok(Some(())))
449 } else {
450 Poll::Pending
451 }
452 } else if ready {
453 flags.remove(Flags::BUF_R_READY);
454 st.flags.set(flags);
455 Poll::Ready(Ok(Some(())))
456 } else {
457 Poll::Pending
458 }
459 }
460 }
461
462 #[inline]
463 pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
465 let ready = self.poll_read_ready(cx);
466
467 if ready.is_pending() {
468 let st = self.st();
469 if st.remove_flags(Flags::RD_NOTIFY) {
470 Poll::Ready(Ok(Some(())))
471 } else {
472 st.insert_flags(Flags::RD_NOTIFY);
473 Poll::Pending
474 }
475 } else {
476 ready
477 }
478 }
479
480 #[inline]
481 pub fn poll_recv<U>(
486 &self,
487 codec: &U,
488 cx: &mut Context<'_>,
489 ) -> Poll<Result<U::Item, RecvError<U>>>
490 where
491 U: Decoder,
492 {
493 let decoded = self.poll_recv_decode(codec, cx)?;
494
495 if let Some(item) = decoded.item {
496 Poll::Ready(Ok(item))
497 } else {
498 Poll::Pending
499 }
500 }
501
502 #[doc(hidden)]
503 #[inline]
504 pub fn poll_recv_decode<U>(
509 &self,
510 codec: &U,
511 cx: &mut Context<'_>,
512 ) -> Result<Decoded<U::Item>, RecvError<U>>
513 where
514 U: Decoder,
515 {
516 let decoded = self
517 .decode_item(codec)
518 .map_err(|err| RecvError::Decoder(err))?;
519
520 if decoded.item.is_some() {
521 Ok(decoded)
522 } else {
523 let st = self.st();
524 let flags = st.flags.get();
525 if flags.is_stopped() {
526 Err(RecvError::PeerGone(st.error()))
527 } else if flags.contains(Flags::DSP_STOP) {
528 st.remove_flags(Flags::DSP_STOP);
529 Err(RecvError::Stop)
530 } else if flags.contains(Flags::DSP_TIMEOUT) {
531 st.remove_flags(Flags::DSP_TIMEOUT);
532 Err(RecvError::KeepAlive)
533 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
534 Err(RecvError::WriteBackpressure)
535 } else {
536 match self.poll_read_ready(cx) {
537 Poll::Pending | Poll::Ready(Ok(Some(()))) => {
538 if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
539 log::trace!(
540 "{}: Not enough data to decode next frame",
541 self.tag()
542 );
543 }
544 Ok(decoded)
545 }
546 Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
547 Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
548 }
549 }
550 }
551 }
552
553 #[inline]
554 pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
560 let st = self.st();
561 let flags = self.flags();
562
563 let len = st.buffer.write_destination_size();
564 if len > 0 {
565 if full {
566 st.insert_flags(Flags::BUF_W_MUST_FLUSH);
567 st.dispatch_task.register(cx.waker());
568 return if flags.is_stopped() {
569 Poll::Ready(Err(st.error_or_disconnected()))
570 } else {
571 Poll::Pending
572 };
573 } else if len >= st.pool.get().write_params_high() << 1 {
574 st.insert_flags(Flags::BUF_W_BACKPRESSURE);
575 st.dispatch_task.register(cx.waker());
576 return if flags.is_stopped() {
577 Poll::Ready(Err(st.error_or_disconnected()))
578 } else {
579 Poll::Pending
580 };
581 }
582 }
583 st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
584 Poll::Ready(Ok(()))
585 }
586
587 #[inline]
588 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
590 let st = self.st();
591 let flags = st.flags.get();
592
593 if flags.is_stopped() {
594 if let Some(err) = st.error() {
595 Poll::Ready(Err(err))
596 } else {
597 Poll::Ready(Ok(()))
598 }
599 } else {
600 if !flags.contains(Flags::IO_STOPPING_FILTERS) {
601 st.init_shutdown();
602 }
603
604 st.read_task.wake();
605 st.write_task.wake();
606 st.dispatch_task.register(cx.waker());
607 Poll::Pending
608 }
609 }
610
611 #[inline]
612 pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
616 self.pause();
617 let result = self.poll_status_update(cx);
618 if !result.is_pending() {
619 self.st().dispatch_task.register(cx.waker());
620 }
621 result
622 }
623
624 #[inline]
625 pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
627 let st = self.st();
628 let flags = st.flags.get();
629 if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
630 Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
631 } else if flags.contains(Flags::DSP_STOP) {
632 st.remove_flags(Flags::DSP_STOP);
633 Poll::Ready(IoStatusUpdate::Stop)
634 } else if flags.contains(Flags::DSP_TIMEOUT) {
635 st.remove_flags(Flags::DSP_TIMEOUT);
636 Poll::Ready(IoStatusUpdate::KeepAlive)
637 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
638 Poll::Ready(IoStatusUpdate::WriteBackpressure)
639 } else {
640 st.dispatch_task.register(cx.waker());
641 Poll::Pending
642 }
643 }
644
645 #[inline]
646 pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
648 self.st().dispatch_task.register(cx.waker());
649 }
650}
651
652impl<F> AsRef<IoRef> for Io<F> {
653 #[inline]
654 fn as_ref(&self) -> &IoRef {
655 self.io_ref()
656 }
657}
658
659impl<F> Eq for Io<F> {}
660
661impl<F> PartialEq for Io<F> {
662 #[inline]
663 fn eq(&self, other: &Self) -> bool {
664 self.io_ref().eq(other.io_ref())
665 }
666}
667
668impl<F> hash::Hash for Io<F> {
669 #[inline]
670 fn hash<H: hash::Hasher>(&self, state: &mut H) {
671 self.io_ref().hash(state);
672 }
673}
674
675impl<F> fmt::Debug for Io<F> {
676 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
677 f.debug_struct("Io").field("state", self.st()).finish()
678 }
679}
680
681impl<F> ops::Deref for Io<F> {
682 type Target = IoRef;
683
684 #[inline]
685 fn deref(&self) -> &Self::Target {
686 self.io_ref()
687 }
688}
689
690impl<F> Drop for Io<F> {
691 fn drop(&mut self) {
692 let st = self.st();
693 self.stop_timer();
694
695 if st.filter.is_set() {
696 if !st.flags.get().is_stopped() {
699 log::trace!(
700 "{}: Io is dropped, force stopping io streams {:?}",
701 st.tag.get(),
702 st.flags.get()
703 );
704 }
705
706 self.force_close();
707 st.filter.drop_filter::<F>();
708 }
709 }
710}
711
712const KIND_SEALED: u8 = 0b01;
713const KIND_PTR: u8 = 0b10;
714const KIND_MASK: u8 = 0b11;
715const KIND_UNMASK: u8 = !KIND_MASK;
716const KIND_MASK_USIZE: usize = 0b11;
717const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
718const SEALED_SIZE: usize = mem::size_of::<Sealed>();
719const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
720
721#[cfg(target_endian = "little")]
722const KIND_IDX: usize = 0;
723
724#[cfg(target_endian = "big")]
725const KIND_IDX: usize = SEALED_SIZE - 1;
726
727struct FilterPtr {
728 data: Cell<[u8; SEALED_SIZE]>,
729 filter: Cell<&'static dyn Filter>,
730}
731
732impl FilterPtr {
733 const fn null() -> Self {
734 Self {
735 data: Cell::new(NULL),
736 filter: Cell::new(NullFilter::get()),
737 }
738 }
739
740 fn update<F: Filter>(&self, filter: F) {
741 if self.is_set() {
742 panic!("Filter is set, must be dropped first");
743 }
744
745 let filter = Box::new(filter);
746 let mut data = NULL;
747 unsafe {
748 let filter_ref: &'static dyn Filter = {
749 let f: &dyn Filter = filter.as_ref();
750 std::mem::transmute(f)
751 };
752 self.filter.set(filter_ref);
753
754 let ptr = &mut data as *mut _ as *mut *mut F;
755 ptr.write(Box::into_raw(filter));
756 data[KIND_IDX] |= KIND_PTR;
757 self.data.set(data);
758 }
759 }
760
761 fn filter<F: Filter>(&self) -> &F {
763 let data = self.data.get();
764 if data[KIND_IDX] & KIND_PTR != 0 {
765 let ptr = &data as *const _ as *const *mut F;
766 unsafe {
767 let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
768 (p as *const F as *mut F).as_ref().unwrap()
769 }
770 } else {
771 panic!("Wrong filter item");
772 }
773 }
774
775 fn take_filter<F>(&self) -> Box<F> {
777 let mut data = self.data.get();
778 if data[KIND_IDX] & KIND_PTR != 0 {
779 data[KIND_IDX] &= KIND_UNMASK;
780 let ptr = &mut data as *mut _ as *mut *mut F;
781 unsafe { Box::from_raw(*ptr) }
782 } else {
783 panic!(
784 "Wrong filter item {:?} expected: {:?}",
785 data[KIND_IDX], KIND_PTR
786 );
787 }
788 }
789
790 fn take_sealed(&self) -> Sealed {
792 let mut data = self.data.get();
793
794 if data[KIND_IDX] & KIND_SEALED != 0 {
795 data[KIND_IDX] &= KIND_UNMASK;
796 let ptr = &mut data as *mut _ as *mut Sealed;
797 unsafe { ptr.read() }
798 } else {
799 panic!(
800 "Wrong filter item {:?} expected: {:?}",
801 data[KIND_IDX], KIND_SEALED
802 );
803 }
804 }
805
806 fn is_set(&self) -> bool {
807 self.data.get()[KIND_IDX] & KIND_MASK != 0
808 }
809
810 fn drop_filter<F>(&self) {
811 let data = self.data.get();
812
813 if data[KIND_IDX] & KIND_MASK != 0 {
814 if data[KIND_IDX] & KIND_PTR != 0 {
815 self.take_filter::<F>();
816 } else if data[KIND_IDX] & KIND_SEALED != 0 {
817 self.take_sealed();
818 }
819 self.data.set(NULL);
820 self.filter.set(NullFilter::get());
821 }
822 }
823}
824
825impl FilterPtr {
826 fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
827 let mut data = NULL;
828 let filter = Box::new(Layer::new(new, *self.take_filter::<F>()));
829 unsafe {
830 let filter_ref: &'static dyn Filter = {
831 let f: &dyn Filter = filter.as_ref();
832 std::mem::transmute(f)
833 };
834 self.filter.set(filter_ref);
835
836 let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
837 ptr.write(Box::into_raw(filter));
838 data[KIND_IDX] |= KIND_PTR;
839 self.data.set(data);
840 }
841 }
842
843 fn seal<F: Filter>(&self) {
844 let mut data = self.data.get();
845
846 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
847 Sealed(Box::new(*self.take_filter::<F>()))
848 } else if data[KIND_IDX] & KIND_SEALED != 0 {
849 self.take_sealed()
850 } else {
851 panic!(
852 "Wrong filter item {:?} expected: {:?}",
853 data[KIND_IDX], KIND_PTR
854 );
855 };
856
857 unsafe {
858 let filter_ref: &'static dyn Filter = {
859 let f: &dyn Filter = filter.0.as_ref();
860 std::mem::transmute(f)
861 };
862 self.filter.set(filter_ref);
863
864 let ptr = &mut data as *mut _ as *mut Sealed;
865 ptr.write(filter);
866 data[KIND_IDX] |= KIND_SEALED;
867 self.data.set(data);
868 }
869 }
870}
871
872#[derive(Debug)]
873#[must_use = "OnDisconnect do nothing unless polled"]
875pub struct OnDisconnect {
876 token: usize,
877 inner: Rc<IoState>,
878}
879
880impl OnDisconnect {
881 pub(super) fn new(inner: Rc<IoState>) -> Self {
882 Self::new_inner(inner.flags.get().is_stopped(), inner)
883 }
884
885 fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
886 let token = if disconnected {
887 usize::MAX
888 } else {
889 let mut on_disconnect = inner.on_disconnect.take();
890 let token = if let Some(ref mut on_disconnect) = on_disconnect {
891 let token = on_disconnect.len();
892 on_disconnect.push(LocalWaker::default());
893 token
894 } else {
895 on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
896 0
897 };
898 inner.on_disconnect.set(on_disconnect);
899 token
900 };
901 Self { token, inner }
902 }
903
904 #[inline]
905 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
907 if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
908 Poll::Ready(())
909 } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
910 on_disconnect[self.token].register(cx.waker());
911 self.inner.on_disconnect.set(Some(on_disconnect));
912 Poll::Pending
913 } else {
914 Poll::Ready(())
915 }
916 }
917}
918
919impl Clone for OnDisconnect {
920 fn clone(&self) -> Self {
921 if self.token == usize::MAX {
922 OnDisconnect::new_inner(true, self.inner.clone())
923 } else {
924 OnDisconnect::new_inner(false, self.inner.clone())
925 }
926 }
927}
928
929impl Future for OnDisconnect {
930 type Output = ();
931
932 #[inline]
933 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
934 self.poll_ready(cx)
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use ntex_bytes::Bytes;
941 use ntex_codec::BytesCodec;
942
943 use super::*;
944 use crate::{testing::IoTest, ReadBuf, WriteBuf};
945
946 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
947 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
948
949 #[ntex::test]
950 async fn test_basics() {
951 let (client, server) = IoTest::create();
952 client.remote_buffer_cap(1024);
953
954 let server = Io::new(server);
955 assert!(server.eq(&server));
956 assert!(server.io_ref().eq(server.io_ref()));
957
958 assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
959 assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
960 assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
961 }
962
963 #[ntex::test]
964 async fn test_recv() {
965 let (client, server) = IoTest::create();
966 client.remote_buffer_cap(1024);
967
968 let server = Io::new(server);
969
970 server.st().notify_timeout();
971 let err = server.recv(&BytesCodec).await.err().unwrap();
972 assert!(format!("{:?}", err).contains("Timeout"));
973
974 server.st().insert_flags(Flags::DSP_STOP);
975 let err = server.recv(&BytesCodec).await.err().unwrap();
976 assert!(format!("{:?}", err).contains("Dispatcher stopped"));
977
978 client.write(TEXT);
979 server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
980 let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
981 assert_eq!(item, TEXT);
982 }
983
984 #[ntex::test]
985 async fn test_send() {
986 let (client, server) = IoTest::create();
987 client.remote_buffer_cap(1024);
988
989 let server = Io::new(server);
990 assert!(server.eq(&server));
991
992 server
993 .send(Bytes::from_static(BIN), &BytesCodec)
994 .await
995 .ok()
996 .unwrap();
997 let item = client.read_any();
998 assert_eq!(item, TEXT);
999 }
1000
1001 #[derive(Debug)]
1002 struct DropFilter {
1003 p: Rc<Cell<usize>>,
1004 }
1005
1006 impl Drop for DropFilter {
1007 fn drop(&mut self) {
1008 self.p.set(self.p.get() + 1);
1009 }
1010 }
1011
1012 impl FilterLayer for DropFilter {
1013 const BUFFERS: bool = false;
1014 fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1015 Ok(buf.nbytes())
1016 }
1017 fn process_write_buf(&self, _: &WriteBuf<'_>) -> io::Result<()> {
1018 Ok(())
1019 }
1020 }
1021
1022 #[ntex::test]
1023 async fn drop_filter() {
1024 let p = Rc::new(Cell::new(0));
1025
1026 let (client, server) = IoTest::create();
1027 let f = DropFilter { p: p.clone() };
1028 let _ = format!("{:?}", f);
1029 let io = Io::new(server).add_filter(f);
1030
1031 client.remote_buffer_cap(1024);
1032 client.write(TEXT);
1033 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1034 assert_eq!(msg, Bytes::from_static(BIN));
1035
1036 io.send(Bytes::from_static(b"test"), &BytesCodec)
1037 .await
1038 .unwrap();
1039 let buf = client.read().await.unwrap();
1040 assert_eq!(buf, Bytes::from_static(b"test"));
1041
1042 let io2 = io.take();
1043 let mut io3: crate::IoBoxed = io2.into();
1044 let io4 = io3.take();
1045
1046 drop(io);
1047 drop(io3);
1048 drop(io4);
1049
1050 assert_eq!(p.get(), 1);
1051 }
1052}