1use std::cell::{Cell, UnsafeCell};
2use std::future::{Future, poll_fn};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_codec::{Decoder, Encoder};
7use ntex_service::cfg::SharedCfg;
8use ntex_util::{future::Either, task::LocalWaker};
9
10use crate::buf::Stack;
11use crate::cfg::{BufConfig, IoConfig};
12use crate::filter::{Base, Filter, Layer, NullFilter};
13use crate::flags::Flags;
14use crate::seal::{IoBoxed, Sealed};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoContext, IoStatusUpdate, IoStream, RecvError};
17
18pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25 filter: FilterPtr,
26 pub(super) cfg: Cell<&'static IoConfig>,
27 pub(super) flags: Cell<Flags>,
28 pub(super) error: Cell<Option<io::Error>>,
29 pub(super) read_task: LocalWaker,
30 pub(super) write_task: LocalWaker,
31 pub(super) dispatch_task: LocalWaker,
32 pub(super) buffer: Stack,
33 pub(super) handle: Cell<Option<Box<dyn Handle>>>,
34 pub(super) timeout: Cell<TimerHandle>,
35 #[allow(clippy::box_collection)]
36 pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
37}
38
39impl IoState {
40 pub(super) fn filter(&self) -> &dyn Filter {
41 self.filter.filter.get()
42 }
43
44 pub(super) fn insert_flags(&self, f: Flags) {
45 let mut flags = self.flags.get();
46 flags.insert(f);
47 self.flags.set(flags);
48 }
49
50 pub(super) fn remove_flags(&self, f: Flags) -> bool {
51 let mut flags = self.flags.get();
52 if flags.intersects(f) {
53 flags.remove(f);
54 self.flags.set(flags);
55 true
56 } else {
57 false
58 }
59 }
60
61 pub(super) fn notify_timeout(&self) {
62 let mut flags = self.flags.get();
63 if !flags.contains(Flags::DSP_TIMEOUT) {
64 flags.insert(Flags::DSP_TIMEOUT);
65 self.flags.set(flags);
66 self.dispatch_task.wake();
67 log::trace!("{}: Timer, notify dispatcher", self.cfg.get().tag());
68 }
69 }
70
71 pub(super) fn notify_disconnect(&self) {
72 if let Some(on_disconnect) = self.on_disconnect.take() {
73 for item in on_disconnect.into_iter() {
74 item.wake();
75 }
76 }
77 }
78
79 pub(super) fn error(&self) -> Option<io::Error> {
81 if let Some(err) = self.error.take() {
82 self.error
83 .set(Some(io::Error::new(err.kind(), format!("{err}"))));
84 Some(err)
85 } else {
86 None
87 }
88 }
89
90 pub(super) fn error_or_disconnected(&self) -> io::Error {
92 self.error()
93 .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
94 }
95
96 pub(super) fn io_stopped(&self, err: Option<io::Error>) {
97 if !self.flags.get().is_stopped() {
98 log::trace!(
99 "{}: {} Io error {:?} flags: {:?}",
100 self.cfg.get().tag(),
101 self as *const _ as usize,
102 err,
103 self.flags.get()
104 );
105
106 if err.is_some() {
107 self.error.set(err);
108 }
109 self.read_task.wake();
110 self.write_task.wake();
111 self.notify_disconnect();
112 self.handle.take();
113 self.insert_flags(
114 Flags::IO_STOPPED
115 | Flags::IO_STOPPING
116 | Flags::IO_STOPPING_FILTERS
117 | Flags::BUF_R_READY,
118 );
119 if !self.dispatch_task.wake_checked() {
120 log::trace!(
121 "{}: {} Dispatcher is not registered, flags: {:?}",
122 self.cfg.get().tag(),
123 self as *const _ as usize,
124 self.flags.get()
125 );
126 }
127 }
128 }
129
130 pub(super) fn init_shutdown(&self) {
132 if !self
133 .flags
134 .get()
135 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
136 {
137 log::trace!(
138 "{}: Initiate io shutdown {:?}",
139 self.cfg.get().tag(),
140 self.flags.get()
141 );
142 self.insert_flags(Flags::IO_STOPPING_FILTERS);
143 self.read_task.wake();
144 }
145 }
146
147 #[inline]
148 pub(super) fn read_buf(&self) -> &BufConfig {
149 self.cfg.get().read_buf()
150 }
151
152 #[inline]
153 pub(super) fn write_buf(&self) -> &BufConfig {
154 self.cfg.get().write_buf()
155 }
156}
157
158impl Eq for IoState {}
159
160impl PartialEq for IoState {
161 #[inline]
162 fn eq(&self, other: &Self) -> bool {
163 ptr::eq(self, other)
164 }
165}
166
167impl hash::Hash for IoState {
168 #[inline]
169 fn hash<H: hash::Hasher>(&self, state: &mut H) {
170 (self as *const _ as usize).hash(state);
171 }
172}
173
174impl fmt::Debug for IoState {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 let err = self.error.take();
177 let res = f
178 .debug_struct("IoState")
179 .field("flags", &self.flags)
180 .field("filter", &self.filter.is_set())
181 .field("timeout", &self.timeout)
182 .field("error", &err)
183 .field("buffer", &self.buffer)
184 .field("cfg", &self.cfg)
185 .finish();
186 self.error.set(err);
187 res
188 }
189}
190
191impl Io {
192 #[inline]
193 pub fn new<I: IoStream, T: Into<SharedCfg>>(io: I, cfg: T) -> Self {
195 let inner = Rc::new(IoState {
196 cfg: Cell::new(cfg.into().get::<IoConfig>().into_static()),
197 filter: FilterPtr::null(),
198 flags: Cell::new(Flags::WR_PAUSED),
199 error: Cell::new(None),
200 dispatch_task: LocalWaker::new(),
201 read_task: LocalWaker::new(),
202 write_task: LocalWaker::new(),
203 buffer: Stack::new(),
204 handle: Cell::new(None),
205 timeout: Cell::new(TimerHandle::default()),
206 on_disconnect: Cell::new(None),
207 });
208 inner.filter.update(Base::new(IoRef(inner.clone())));
209
210 let io_ref = IoRef(inner);
211
212 let hnd = io.start(IoContext::new(&io_ref));
214 io_ref.0.handle.set(hnd);
215
216 Io(UnsafeCell::new(io_ref), marker::PhantomData)
217 }
218}
219
220impl<I: IoStream> From<I> for Io {
221 #[inline]
222 fn from(io: I) -> Io {
223 Io::new(io, SharedCfg::default())
224 }
225}
226
227impl<F> Io<F> {
228 #[inline]
229 pub fn take(&self) -> Self {
233 Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
234 }
235
236 fn take_io_ref(&self) -> IoRef {
237 let inner = Rc::new(IoState {
238 cfg: Cell::new(SharedCfg::default().get::<IoConfig>().into_static()),
239 filter: FilterPtr::null(),
240 flags: Cell::new(
241 Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
242 ),
243 error: Cell::new(None),
244 dispatch_task: LocalWaker::new(),
245 read_task: LocalWaker::new(),
246 write_task: LocalWaker::new(),
247 buffer: Stack::new(),
248 handle: Cell::new(None),
249 timeout: Cell::new(TimerHandle::default()),
250 on_disconnect: Cell::new(None),
251 });
252 unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
253 }
254
255 #[inline]
256 #[doc(hidden)]
257 pub fn flags(&self) -> Flags {
259 self.st().flags.get()
260 }
261
262 #[inline]
263 pub fn get_ref(&self) -> IoRef {
265 self.io_ref().clone()
266 }
267
268 fn st(&self) -> &IoState {
269 unsafe { &(*self.0.get()).0 }
270 }
271
272 fn io_ref(&self) -> &IoRef {
273 unsafe { &*self.0.get() }
274 }
275
276 #[inline]
277 pub fn set_config<T: Into<SharedCfg>>(&self, cfg: T) {
279 self.st()
280 .cfg
281 .set(cfg.into().get::<IoConfig>().into_static());
282 }
283}
284
285impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
286 #[inline]
287 pub fn filter(&self) -> &F {
289 &self.st().filter.filter::<Layer<F, T>>().0
290 }
291}
292
293impl<F: Filter> Io<F> {
294 #[inline]
295 pub fn seal(self) -> Io<Sealed> {
297 let state = self.take_io_ref();
298 state.0.filter.seal::<F>();
299
300 Io(UnsafeCell::new(state), marker::PhantomData)
301 }
302
303 #[inline]
304 pub fn boxed(self) -> IoBoxed {
306 self.seal().into()
307 }
308
309 #[inline]
310 pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
312 where
313 U: FilterLayer,
314 {
315 let state = self.take_io_ref();
316
317 unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
322 .buffer
323 .add_layer();
324
325 state.0.filter.add_filter::<F, U>(nf);
327
328 Io(UnsafeCell::new(state), marker::PhantomData)
329 }
330
331 pub fn map_filter<U, R>(self, f: U) -> Io<R>
333 where
334 U: FnOnce(F) -> R,
335 R: Filter,
336 {
337 let state = self.take_io_ref();
338 state.0.filter.map_filter::<F, U, R>(f);
339
340 Io(UnsafeCell::new(state), marker::PhantomData)
341 }
342}
343
344impl<F> Io<F> {
345 #[inline]
346 pub async fn recv<U>(
348 &self,
349 codec: &U,
350 ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
351 where
352 U: Decoder,
353 {
354 loop {
355 return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
356 Ok(item) => Ok(Some(item)),
357 Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
358 io::ErrorKind::TimedOut,
359 "Timeout",
360 ))),
361 Err(RecvError::WriteBackpressure) => {
362 poll_fn(|cx| self.poll_flush(cx, false))
363 .await
364 .map_err(Either::Right)?;
365 continue;
366 }
367 Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
368 Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
369 Err(RecvError::PeerGone(None)) => Ok(None),
370 };
371 }
372 }
373
374 #[inline]
375 pub async fn read_ready(&self) -> io::Result<Option<()>> {
377 poll_fn(|cx| self.poll_read_ready(cx)).await
378 }
379
380 #[inline]
381 pub async fn read_notify(&self) -> io::Result<Option<()>> {
383 poll_fn(|cx| self.poll_read_notify(cx)).await
384 }
385
386 #[inline]
387 pub fn pause(&self) {
389 let st = self.st();
390 if !st.flags.get().contains(Flags::RD_PAUSED) {
391 st.read_task.wake();
392 st.insert_flags(Flags::RD_PAUSED);
393 }
394 }
395
396 #[inline]
397 pub async fn send<U>(
399 &self,
400 item: U::Item,
401 codec: &U,
402 ) -> Result<(), Either<U::Error, io::Error>>
403 where
404 U: Encoder,
405 {
406 self.encode(item, codec).map_err(Either::Left)?;
407
408 poll_fn(|cx| self.poll_flush(cx, true))
409 .await
410 .map_err(Either::Right)?;
411
412 Ok(())
413 }
414
415 #[inline]
416 pub async fn flush(&self, full: bool) -> io::Result<()> {
420 poll_fn(|cx| self.poll_flush(cx, full)).await
421 }
422
423 #[inline]
424 pub async fn shutdown(&self) -> io::Result<()> {
426 poll_fn(|cx| self.poll_shutdown(cx)).await
427 }
428
429 #[inline]
430 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
444 let st = self.st();
445 let mut flags = st.flags.get();
446
447 if flags.is_stopped() {
448 Poll::Ready(Err(st.error_or_disconnected()))
449 } else {
450 st.dispatch_task.register(cx.waker());
451
452 let ready = flags.is_read_buf_ready();
453 if flags.cannot_read() {
454 flags.cleanup_read_flags();
455 st.read_task.wake();
456 st.flags.set(flags);
457 if ready {
458 Poll::Ready(Ok(Some(())))
459 } else {
460 Poll::Pending
461 }
462 } else if ready {
463 flags.remove(Flags::BUF_R_READY);
464 st.flags.set(flags);
465 Poll::Ready(Ok(Some(())))
466 } else {
467 Poll::Pending
468 }
469 }
470 }
471
472 #[inline]
473 pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
475 let ready = self.poll_read_ready(cx);
476
477 if ready.is_pending() {
478 let st = self.st();
479 if st.remove_flags(Flags::RD_NOTIFY) {
480 Poll::Ready(Ok(Some(())))
481 } else {
482 st.insert_flags(Flags::RD_NOTIFY);
483 Poll::Pending
484 }
485 } else {
486 ready
487 }
488 }
489
490 #[inline]
491 pub fn poll_recv<U>(
496 &self,
497 codec: &U,
498 cx: &mut Context<'_>,
499 ) -> Poll<Result<U::Item, RecvError<U>>>
500 where
501 U: Decoder,
502 {
503 let decoded = self.poll_recv_decode(codec, cx)?;
504
505 if let Some(item) = decoded.item {
506 Poll::Ready(Ok(item))
507 } else {
508 Poll::Pending
509 }
510 }
511
512 #[doc(hidden)]
513 #[inline]
514 pub fn poll_recv_decode<U>(
519 &self,
520 codec: &U,
521 cx: &mut Context<'_>,
522 ) -> Result<Decoded<U::Item>, RecvError<U>>
523 where
524 U: Decoder,
525 {
526 let decoded = self
527 .decode_item(codec)
528 .map_err(|err| RecvError::Decoder(err))?;
529
530 if decoded.item.is_some() {
531 Ok(decoded)
532 } else {
533 let st = self.st();
534 let flags = st.flags.get();
535 if flags.is_stopped() {
536 Err(RecvError::PeerGone(st.error()))
537 } else if flags.contains(Flags::DSP_TIMEOUT) {
538 st.remove_flags(Flags::DSP_TIMEOUT);
539 Err(RecvError::KeepAlive)
540 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
541 Err(RecvError::WriteBackpressure)
542 } else {
543 match self.poll_read_ready(cx) {
544 Poll::Pending | Poll::Ready(Ok(Some(()))) => {
545 if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
546 log::trace!(
547 "{}: Not enough data to decode next frame",
548 self.tag()
549 );
550 }
551 Ok(decoded)
552 }
553 Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
554 Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
555 }
556 }
557 }
558 }
559
560 #[inline]
561 pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
567 let st = self.st();
568 let flags = self.flags();
569
570 let len = st.buffer.write_destination_size();
571 if len > 0 {
572 if full {
573 st.insert_flags(Flags::BUF_W_MUST_FLUSH);
574 st.dispatch_task.register(cx.waker());
575 return if flags.is_stopped() {
576 Poll::Ready(Err(st.error_or_disconnected()))
577 } else {
578 Poll::Pending
579 };
580 } else if len >= st.write_buf().half {
581 st.insert_flags(Flags::BUF_W_BACKPRESSURE);
582 st.dispatch_task.register(cx.waker());
583 return if flags.is_stopped() {
584 Poll::Ready(Err(st.error_or_disconnected()))
585 } else {
586 Poll::Pending
587 };
588 }
589 }
590 if flags.is_stopped() {
591 Poll::Ready(Err(st.error_or_disconnected()))
592 } else {
593 st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
594 Poll::Ready(Ok(()))
595 }
596 }
597
598 #[inline]
599 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
601 let st = self.st();
602 let flags = st.flags.get();
603
604 if flags.is_stopped() {
605 if let Some(err) = st.error() {
606 Poll::Ready(Err(err))
607 } else {
608 Poll::Ready(Ok(()))
609 }
610 } else {
611 if !flags.contains(Flags::IO_STOPPING_FILTERS) {
612 st.init_shutdown();
613 }
614
615 st.read_task.wake();
616 st.write_task.wake();
617 st.dispatch_task.register(cx.waker());
618 Poll::Pending
619 }
620 }
621
622 #[inline]
623 pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
627 self.pause();
628 let result = self.poll_status_update(cx);
629 if !result.is_pending() {
630 self.st().dispatch_task.register(cx.waker());
631 }
632 result
633 }
634
635 #[inline]
636 pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
638 let st = self.st();
639 let flags = st.flags.get();
640 if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
641 Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
642 } else if flags.contains(Flags::DSP_TIMEOUT) {
643 st.remove_flags(Flags::DSP_TIMEOUT);
644 Poll::Ready(IoStatusUpdate::KeepAlive)
645 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
646 Poll::Ready(IoStatusUpdate::WriteBackpressure)
647 } else {
648 st.dispatch_task.register(cx.waker());
649 Poll::Pending
650 }
651 }
652
653 #[inline]
654 pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
656 self.st().dispatch_task.register(cx.waker());
657 }
658}
659
660impl<F> AsRef<IoRef> for Io<F> {
661 #[inline]
662 fn as_ref(&self) -> &IoRef {
663 self.io_ref()
664 }
665}
666
667impl<F> Eq for Io<F> {}
668
669impl<F> PartialEq for Io<F> {
670 #[inline]
671 fn eq(&self, other: &Self) -> bool {
672 self.io_ref().eq(other.io_ref())
673 }
674}
675
676impl<F> hash::Hash for Io<F> {
677 #[inline]
678 fn hash<H: hash::Hasher>(&self, state: &mut H) {
679 self.io_ref().hash(state);
680 }
681}
682
683impl<F> fmt::Debug for Io<F> {
684 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
685 f.debug_struct("Io").field("state", self.st()).finish()
686 }
687}
688
689impl<F> ops::Deref for Io<F> {
690 type Target = IoRef;
691
692 #[inline]
693 fn deref(&self) -> &Self::Target {
694 self.io_ref()
695 }
696}
697
698impl<F> Drop for Io<F> {
699 fn drop(&mut self) {
700 let st = self.st();
701 self.stop_timer();
702
703 if st.filter.is_set() {
704 if !st.flags.get().is_stopped() {
707 log::trace!(
708 "{}: Io is dropped, force stopping io streams {:?}",
709 st.cfg.get().tag(),
710 st.flags.get()
711 );
712 }
713
714 self.force_close();
715 st.filter.drop_filter::<F>();
716 }
717 }
718}
719
720const KIND_SEALED: u8 = 0b01;
721const KIND_PTR: u8 = 0b10;
722const KIND_MASK: u8 = 0b11;
723const KIND_UNMASK: u8 = !KIND_MASK;
724const KIND_MASK_USIZE: usize = 0b11;
725const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
726const SEALED_SIZE: usize = mem::size_of::<Sealed>();
727const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
728
729#[cfg(target_endian = "little")]
730const KIND_IDX: usize = 0;
731
732#[cfg(target_endian = "big")]
733const KIND_IDX: usize = SEALED_SIZE - 1;
734
735struct FilterPtr {
736 data: Cell<[u8; SEALED_SIZE]>,
737 filter: Cell<&'static dyn Filter>,
738}
739
740impl FilterPtr {
741 const fn null() -> Self {
742 Self {
743 data: Cell::new(NULL),
744 filter: Cell::new(NullFilter::get()),
745 }
746 }
747
748 fn update<F: Filter>(&self, filter: F) {
749 if self.is_set() {
750 panic!("Filter is set, must be dropped first");
751 }
752
753 let filter = Box::new(filter);
754 let mut data = NULL;
755 unsafe {
756 let filter_ref: &'static dyn Filter = {
757 let f: &dyn Filter = filter.as_ref();
758 mem::transmute(f)
759 };
760 self.filter.set(filter_ref);
761
762 let ptr = &mut data as *mut _ as *mut *mut F;
763 ptr.write(Box::into_raw(filter));
764 data[KIND_IDX] |= KIND_PTR;
765 self.data.set(data);
766 }
767 }
768
769 fn filter<F: Filter>(&self) -> &F {
771 let data = self.data.get();
772 if data[KIND_IDX] & KIND_PTR != 0 {
773 let ptr = &data as *const _ as *const *mut F;
774 unsafe {
775 let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
776 (p as *const F as *mut F).as_ref().unwrap()
777 }
778 } else {
779 panic!("Wrong filter item");
780 }
781 }
782
783 fn take_filter<F>(&self) -> Box<F> {
785 let mut data = self.data.get();
786 if data[KIND_IDX] & KIND_PTR != 0 {
787 data[KIND_IDX] &= KIND_UNMASK;
788 let ptr = &mut data as *mut _ as *mut *mut F;
789 unsafe { Box::from_raw(*ptr) }
790 } else {
791 panic!(
792 "Wrong filter item {:?} expected: {:?}",
793 data[KIND_IDX], KIND_PTR
794 );
795 }
796 }
797
798 fn take_sealed(&self) -> Sealed {
800 let mut data = self.data.get();
801
802 if data[KIND_IDX] & KIND_SEALED != 0 {
803 data[KIND_IDX] &= KIND_UNMASK;
804 let ptr = &mut data as *mut _ as *mut Sealed;
805 unsafe { ptr.read() }
806 } else {
807 panic!(
808 "Wrong filter item {:?} expected: {:?}",
809 data[KIND_IDX], KIND_SEALED
810 );
811 }
812 }
813
814 fn is_set(&self) -> bool {
815 self.data.get()[KIND_IDX] & KIND_MASK != 0
816 }
817
818 fn drop_filter<F>(&self) {
819 let data = self.data.get();
820
821 if data[KIND_IDX] & KIND_MASK != 0 {
822 if data[KIND_IDX] & KIND_PTR != 0 {
823 self.take_filter::<F>();
824 } else if data[KIND_IDX] & KIND_SEALED != 0 {
825 self.take_sealed();
826 }
827 self.data.set(NULL);
828 self.filter.set(NullFilter::get());
829 }
830 }
831}
832
833impl FilterPtr {
834 fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
835 let data = self.data.get();
836 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
837 Box::new(Layer::new(new, *self.take_filter::<F>()))
838 } else if data[KIND_IDX] & KIND_SEALED != 0 {
839 let f = Box::new(Layer::new(new, self.take_sealed()));
840 unsafe { mem::transmute::<Box<Layer<T, Sealed>>, Box<Layer<T, F>>>(f) }
842 } else {
843 panic!(
844 "Wrong filter item {:?} expected: {:?}",
845 data[KIND_IDX], KIND_PTR
846 );
847 };
848
849 let mut data = NULL;
850 unsafe {
851 let filter_ref: &'static dyn Filter = {
852 let f: &dyn Filter = filter.as_ref();
853 mem::transmute(f)
854 };
855 self.filter.set(filter_ref);
856
857 let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
858 ptr.write(Box::into_raw(filter));
859 data[KIND_IDX] |= KIND_PTR;
860 self.data.set(data);
861 }
862 }
863
864 fn map_filter<F: Filter, U, R>(&self, f: U)
865 where
866 U: FnOnce(F) -> R,
867 R: Filter,
868 {
869 let mut data = NULL;
870 let filter = Box::new(f(*self.take_filter::<F>()));
871 unsafe {
872 let filter_ref: &'static dyn Filter = {
873 let f: &dyn Filter = filter.as_ref();
874 mem::transmute(f)
875 };
876 self.filter.set(filter_ref);
877
878 let ptr = &mut data as *mut _ as *mut *mut R;
879 ptr.write(Box::into_raw(filter));
880 data[KIND_IDX] |= KIND_PTR;
881 self.data.set(data);
882 }
883 }
884
885 fn seal<F: Filter>(&self) {
886 let mut data = self.data.get();
887
888 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
889 Sealed(Box::new(*self.take_filter::<F>()))
890 } else if data[KIND_IDX] & KIND_SEALED != 0 {
891 self.take_sealed()
892 } else {
893 panic!(
894 "Wrong filter item {:?} expected: {:?}",
895 data[KIND_IDX], KIND_PTR
896 );
897 };
898
899 unsafe {
900 let filter_ref: &'static dyn Filter = {
901 let f: &dyn Filter = filter.0.as_ref();
902 mem::transmute(f)
903 };
904 self.filter.set(filter_ref);
905
906 let ptr = &mut data as *mut _ as *mut Sealed;
907 ptr.write(filter);
908 data[KIND_IDX] |= KIND_SEALED;
909 self.data.set(data);
910 }
911 }
912}
913
914#[derive(Debug)]
915#[must_use = "OnDisconnect do nothing unless polled"]
917pub struct OnDisconnect {
918 token: usize,
919 inner: Rc<IoState>,
920}
921
922impl OnDisconnect {
923 pub(super) fn new(inner: Rc<IoState>) -> Self {
924 Self::new_inner(inner.flags.get().is_stopped(), inner)
925 }
926
927 fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
928 let token = if disconnected {
929 usize::MAX
930 } else {
931 let mut on_disconnect = inner.on_disconnect.take();
932 let token = if let Some(ref mut on_disconnect) = on_disconnect {
933 let token = on_disconnect.len();
934 on_disconnect.push(LocalWaker::default());
935 token
936 } else {
937 on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
938 0
939 };
940 inner.on_disconnect.set(on_disconnect);
941 token
942 };
943 Self { token, inner }
944 }
945
946 #[inline]
947 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
949 if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
950 Poll::Ready(())
951 } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
952 on_disconnect[self.token].register(cx.waker());
953 self.inner.on_disconnect.set(Some(on_disconnect));
954 Poll::Pending
955 } else {
956 Poll::Ready(())
957 }
958 }
959}
960
961impl Clone for OnDisconnect {
962 fn clone(&self) -> Self {
963 if self.token == usize::MAX {
964 OnDisconnect::new_inner(true, self.inner.clone())
965 } else {
966 OnDisconnect::new_inner(false, self.inner.clone())
967 }
968 }
969}
970
971impl Future for OnDisconnect {
972 type Output = ();
973
974 #[inline]
975 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
976 self.poll_ready(cx)
977 }
978}
979
980#[cfg(test)]
981mod tests {
982 use ntex_bytes::Bytes;
983 use ntex_codec::BytesCodec;
984
985 use super::*;
986 use crate::{ReadBuf, WriteBuf, testing::IoTest};
987
988 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
989 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
990
991 #[ntex::test]
992 async fn test_basics() {
993 let (client, server) = IoTest::create();
994 client.remote_buffer_cap(1024);
995
996 let server = Io::from(server);
997 assert!(server.eq(&server));
998 assert!(server.io_ref().eq(server.io_ref()));
999
1000 assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
1001 assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
1002 assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
1003 }
1004
1005 #[ntex::test]
1006 async fn test_recv() {
1007 let (client, server) = IoTest::create();
1008 client.remote_buffer_cap(1024);
1009
1010 let server = Io::from(server);
1011
1012 server.st().notify_timeout();
1013 let err = server.recv(&BytesCodec).await.err().unwrap();
1014 assert!(format!("{err:?}").contains("Timeout"));
1015
1016 client.write(TEXT);
1017 server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
1018 let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
1019 assert_eq!(item, TEXT);
1020 }
1021
1022 #[ntex::test]
1023 async fn test_send() {
1024 let (client, server) = IoTest::create();
1025 client.remote_buffer_cap(1024);
1026
1027 let server = Io::from(server);
1028 assert!(server.eq(&server));
1029
1030 server
1031 .send(Bytes::from_static(BIN), &BytesCodec)
1032 .await
1033 .ok()
1034 .unwrap();
1035 let item = client.read_any();
1036 assert_eq!(item, TEXT);
1037 }
1038
1039 #[derive(Debug)]
1040 struct DropFilter {
1041 p: Rc<Cell<usize>>,
1042 }
1043
1044 impl Drop for DropFilter {
1045 fn drop(&mut self) {
1046 self.p.set(self.p.get() + 1);
1047 }
1048 }
1049
1050 impl FilterLayer for DropFilter {
1051 fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1052 if let Some(src) = buf.take_src() {
1053 let len = src.len();
1054 buf.set_dst(Some(src));
1055 Ok(len)
1056 } else {
1057 Ok(0)
1058 }
1059 }
1060 fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
1061 if let Some(src) = buf.take_src() {
1062 buf.set_dst(Some(src));
1063 }
1064 Ok(())
1065 }
1066 }
1067
1068 #[ntex::test]
1069 async fn drop_filter() {
1070 let p = Rc::new(Cell::new(0));
1071
1072 let (client, server) = IoTest::create();
1073 let f = DropFilter { p: p.clone() };
1074 let _ = format!("{f:?}");
1075 let io = Io::from(server).add_filter(f);
1076
1077 client.remote_buffer_cap(1024);
1078 client.write(TEXT);
1079 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1080 assert_eq!(msg, Bytes::from_static(BIN));
1081
1082 io.send(Bytes::from_static(b"test"), &BytesCodec)
1083 .await
1084 .unwrap();
1085 let buf = client.read().await.unwrap();
1086 assert_eq!(buf, Bytes::from_static(b"test"));
1087
1088 let io2 = io.take();
1089 let mut io3: crate::IoBoxed = io2.into();
1090 let io4 = io3.take();
1091
1092 drop(io);
1093 drop(io3);
1094 drop(io4);
1095
1096 assert_eq!(p.get(), 1);
1097 }
1098
1099 #[ntex::test]
1100 async fn test_take_sealed_filter() {
1101 let p = Rc::new(Cell::new(0));
1102 let f = DropFilter { p: p.clone() };
1103
1104 let io = Io::from(IoTest::create().0).seal();
1105 let _io: Io<Layer<DropFilter, Sealed>> = io.add_filter(f);
1106 }
1107}