1use std::cell::{Cell, UnsafeCell};
2use std::future::{poll_fn, Future};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_bytes::{PoolId, PoolRef};
7use ntex_codec::{Decoder, Encoder};
8use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
9
10use crate::buf::Stack;
11use crate::filter::{Base, Filter, Layer, NullFilter};
12use crate::flags::Flags;
13use crate::seal::{IoBoxed, Sealed};
14use crate::tasks::{ReadContext, WriteContext};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
17
18pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25 filter: FilterPtr,
26 pub(super) flags: Cell<Flags>,
27 pub(super) pool: Cell<PoolRef>,
28 pub(super) disconnect_timeout: Cell<Seconds>,
29 pub(super) error: Cell<Option<io::Error>>,
30 pub(super) read_task: LocalWaker,
31 pub(super) write_task: LocalWaker,
32 pub(super) dispatch_task: LocalWaker,
33 pub(super) buffer: Stack,
34 pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35 pub(super) timeout: Cell<TimerHandle>,
36 pub(super) tag: Cell<&'static str>,
37 #[allow(clippy::box_collection)]
38 pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
39}
40
41const DEFAULT_TAG: &str = "IO";
42
43impl IoState {
44 pub(super) fn filter(&self) -> &dyn Filter {
45 self.filter.filter.get()
46 }
47
48 pub(super) fn insert_flags(&self, f: Flags) {
49 let mut flags = self.flags.get();
50 flags.insert(f);
51 self.flags.set(flags);
52 }
53
54 pub(super) fn remove_flags(&self, f: Flags) -> bool {
55 let mut flags = self.flags.get();
56 if flags.intersects(f) {
57 flags.remove(f);
58 self.flags.set(flags);
59 true
60 } else {
61 false
62 }
63 }
64
65 pub(super) fn notify_timeout(&self) {
66 let mut flags = self.flags.get();
67 if !flags.contains(Flags::DSP_TIMEOUT) {
68 flags.insert(Flags::DSP_TIMEOUT);
69 self.flags.set(flags);
70 self.dispatch_task.wake();
71 log::trace!("{}: Timer, notify dispatcher", self.tag.get());
72 }
73 }
74
75 pub(super) fn notify_disconnect(&self) {
76 if let Some(on_disconnect) = self.on_disconnect.take() {
77 for item in on_disconnect.into_iter() {
78 item.wake();
79 }
80 }
81 }
82
83 pub(super) fn error(&self) -> Option<io::Error> {
85 if let Some(err) = self.error.take() {
86 self.error
87 .set(Some(io::Error::new(err.kind(), format!("{err}"))));
88 Some(err)
89 } else {
90 None
91 }
92 }
93
94 pub(super) fn error_or_disconnected(&self) -> io::Error {
96 self.error()
97 .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
98 }
99
100 pub(super) fn io_stopped(&self, err: Option<io::Error>) {
101 if !self.flags.get().is_stopped() {
102 log::trace!(
103 "{}: {} Io error {:?} flags: {:?}",
104 self.tag.get(),
105 self as *const _ as usize,
106 err,
107 self.flags.get()
108 );
109
110 if err.is_some() {
111 self.error.set(err);
112 }
113 self.read_task.wake();
114 self.write_task.wake();
115 self.notify_disconnect();
116 self.handle.take();
117 self.insert_flags(
118 Flags::IO_STOPPED
119 | Flags::IO_STOPPING
120 | Flags::IO_STOPPING_FILTERS
121 | Flags::BUF_R_READY,
122 );
123 if !self.dispatch_task.wake_checked() {
124 log::trace!(
125 "{}: {} Dispatcher is not registered, flags: {:?}",
126 self.tag.get(),
127 self as *const _ as usize,
128 self.flags.get()
129 );
130 }
131 }
132 }
133
134 pub(super) fn init_shutdown(&self) {
136 if !self
137 .flags
138 .get()
139 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
140 {
141 log::trace!(
142 "{}: Initiate io shutdown {:?}",
143 self.tag.get(),
144 self.flags.get()
145 );
146 self.insert_flags(Flags::IO_STOPPING_FILTERS);
147 self.read_task.wake();
148 }
149 }
150}
151
152impl Eq for IoState {}
153
154impl PartialEq for IoState {
155 #[inline]
156 fn eq(&self, other: &Self) -> bool {
157 ptr::eq(self, other)
158 }
159}
160
161impl hash::Hash for IoState {
162 #[inline]
163 fn hash<H: hash::Hasher>(&self, state: &mut H) {
164 (self as *const _ as usize).hash(state);
165 }
166}
167
168impl Drop for IoState {
169 #[inline]
170 fn drop(&mut self) {
171 self.buffer.release(self.pool.get());
172 }
173}
174
175impl fmt::Debug for IoState {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 let err = self.error.take();
178 let res = f
179 .debug_struct("IoState")
180 .field("flags", &self.flags)
181 .field("filter", &self.filter.is_set())
182 .field("disconnect_timeout", &self.disconnect_timeout)
183 .field("timeout", &self.timeout)
184 .field("error", &err)
185 .field("buffer", &self.buffer)
186 .finish();
187 self.error.set(err);
188 res
189 }
190}
191
192impl Io {
193 #[inline]
194 pub fn new<I: IoStream>(io: I) -> Self {
196 Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
197 }
198
199 #[inline]
200 pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
202 let inner = Rc::new(IoState {
203 filter: FilterPtr::null(),
204 pool: Cell::new(pool),
205 flags: Cell::new(Flags::WR_PAUSED),
206 error: Cell::new(None),
207 dispatch_task: LocalWaker::new(),
208 read_task: LocalWaker::new(),
209 write_task: LocalWaker::new(),
210 buffer: Stack::new(),
211 handle: Cell::new(None),
212 timeout: Cell::new(TimerHandle::default()),
213 disconnect_timeout: Cell::new(Seconds(1)),
214 on_disconnect: Cell::new(None),
215 tag: Cell::new(DEFAULT_TAG),
216 });
217 inner.filter.update(Base::new(IoRef(inner.clone())));
218
219 let io_ref = IoRef(inner);
220
221 let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
223 io_ref.0.handle.set(hnd);
224
225 Io(UnsafeCell::new(io_ref), marker::PhantomData)
226 }
227}
228
229impl<F> Io<F> {
230 #[inline]
231 pub fn set_memory_pool(&self, pool: PoolRef) {
233 self.st().buffer.set_memory_pool(pool);
234 self.st().pool.set(pool);
235 }
236
237 #[inline]
238 pub fn set_disconnect_timeout(&self, timeout: Seconds) {
240 self.st().disconnect_timeout.set(timeout);
241 }
242
243 #[inline]
244 pub fn take(&self) -> Self {
248 Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
249 }
250
251 fn take_io_ref(&self) -> IoRef {
252 let inner = Rc::new(IoState {
253 filter: FilterPtr::null(),
254 pool: self.st().pool.clone(),
255 flags: Cell::new(
256 Flags::DSP_STOP
257 | Flags::IO_STOPPED
258 | Flags::IO_STOPPING
259 | Flags::IO_STOPPING_FILTERS,
260 ),
261 error: Cell::new(None),
262 disconnect_timeout: Cell::new(Seconds(1)),
263 dispatch_task: LocalWaker::new(),
264 read_task: LocalWaker::new(),
265 write_task: LocalWaker::new(),
266 buffer: Stack::new(),
267 handle: Cell::new(None),
268 timeout: Cell::new(TimerHandle::default()),
269 on_disconnect: Cell::new(None),
270 tag: Cell::new(DEFAULT_TAG),
271 });
272 unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
273 }
274}
275
276impl<F> Io<F> {
277 #[inline]
278 #[doc(hidden)]
279 pub fn flags(&self) -> Flags {
281 self.st().flags.get()
282 }
283
284 #[inline]
285 pub fn get_ref(&self) -> IoRef {
287 self.io_ref().clone()
288 }
289
290 fn st(&self) -> &IoState {
291 unsafe { &(*self.0.get()).0 }
292 }
293
294 fn io_ref(&self) -> &IoRef {
295 unsafe { &*self.0.get() }
296 }
297}
298
299impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
300 #[inline]
301 pub fn filter(&self) -> &F {
303 &self.st().filter.filter::<Layer<F, T>>().0
304 }
305}
306
307impl<F: Filter> Io<F> {
308 #[inline]
309 pub fn seal(self) -> Io<Sealed> {
311 let state = self.take_io_ref();
312 state.0.filter.seal::<F>();
313
314 Io(UnsafeCell::new(state), marker::PhantomData)
315 }
316
317 #[inline]
318 pub fn boxed(self) -> IoBoxed {
320 self.seal().into()
321 }
322
323 #[inline]
324 pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
326 where
327 U: FilterLayer,
328 {
329 let state = self.take_io_ref();
330
331 unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
336 .buffer
337 .add_layer();
338
339 state.0.filter.add_filter::<F, U>(nf);
341
342 Io(UnsafeCell::new(state), marker::PhantomData)
343 }
344
345 pub fn map_filter<U, R>(self, f: U) -> Io<R>
347 where
348 U: FnOnce(F) -> R,
349 R: Filter,
350 {
351 let state = self.take_io_ref();
352 state.0.filter.map_filter::<F, U, R>(f);
353
354 Io(UnsafeCell::new(state), marker::PhantomData)
355 }
356}
357
358impl<F> Io<F> {
359 #[inline]
360 pub async fn recv<U>(
362 &self,
363 codec: &U,
364 ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
365 where
366 U: Decoder,
367 {
368 loop {
369 return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
370 Ok(item) => Ok(Some(item)),
371 Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
372 io::ErrorKind::TimedOut,
373 "Timeout",
374 ))),
375 Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
376 io::ErrorKind::UnexpectedEof,
377 "Dispatcher stopped",
378 ))),
379 Err(RecvError::WriteBackpressure) => {
380 poll_fn(|cx| self.poll_flush(cx, false))
381 .await
382 .map_err(Either::Right)?;
383 continue;
384 }
385 Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
386 Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
387 Err(RecvError::PeerGone(None)) => Ok(None),
388 };
389 }
390 }
391
392 #[inline]
393 pub async fn read_ready(&self) -> io::Result<Option<()>> {
395 poll_fn(|cx| self.poll_read_ready(cx)).await
396 }
397
398 #[inline]
399 pub async fn read_notify(&self) -> io::Result<Option<()>> {
401 poll_fn(|cx| self.poll_read_notify(cx)).await
402 }
403
404 #[inline]
405 pub fn pause(&self) {
407 let st = self.st();
408 if !st.flags.get().contains(Flags::RD_PAUSED) {
409 st.read_task.wake();
410 st.insert_flags(Flags::RD_PAUSED);
411 }
412 }
413
414 #[inline]
415 pub async fn send<U>(
417 &self,
418 item: U::Item,
419 codec: &U,
420 ) -> Result<(), Either<U::Error, io::Error>>
421 where
422 U: Encoder,
423 {
424 self.encode(item, codec).map_err(Either::Left)?;
425
426 poll_fn(|cx| self.poll_flush(cx, true))
427 .await
428 .map_err(Either::Right)?;
429
430 Ok(())
431 }
432
433 #[inline]
434 pub async fn flush(&self, full: bool) -> io::Result<()> {
438 poll_fn(|cx| self.poll_flush(cx, full)).await
439 }
440
441 #[inline]
442 pub async fn shutdown(&self) -> io::Result<()> {
444 poll_fn(|cx| self.poll_shutdown(cx)).await
445 }
446
447 #[inline]
448 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
462 let st = self.st();
463 let mut flags = st.flags.get();
464
465 if flags.is_stopped() {
466 Poll::Ready(Err(st.error_or_disconnected()))
467 } else {
468 st.dispatch_task.register(cx.waker());
469
470 let ready = flags.is_read_buf_ready();
471 if flags.cannot_read() {
472 flags.cleanup_read_flags();
473 st.read_task.wake();
474 st.flags.set(flags);
475 if ready {
476 Poll::Ready(Ok(Some(())))
477 } else {
478 Poll::Pending
479 }
480 } else if ready {
481 flags.remove(Flags::BUF_R_READY);
482 st.flags.set(flags);
483 Poll::Ready(Ok(Some(())))
484 } else {
485 Poll::Pending
486 }
487 }
488 }
489
490 #[inline]
491 pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
493 let ready = self.poll_read_ready(cx);
494
495 if ready.is_pending() {
496 let st = self.st();
497 if st.remove_flags(Flags::RD_NOTIFY) {
498 Poll::Ready(Ok(Some(())))
499 } else {
500 st.insert_flags(Flags::RD_NOTIFY);
501 Poll::Pending
502 }
503 } else {
504 ready
505 }
506 }
507
508 #[inline]
509 pub fn poll_recv<U>(
514 &self,
515 codec: &U,
516 cx: &mut Context<'_>,
517 ) -> Poll<Result<U::Item, RecvError<U>>>
518 where
519 U: Decoder,
520 {
521 let decoded = self.poll_recv_decode(codec, cx)?;
522
523 if let Some(item) = decoded.item {
524 Poll::Ready(Ok(item))
525 } else {
526 Poll::Pending
527 }
528 }
529
530 #[doc(hidden)]
531 #[inline]
532 pub fn poll_recv_decode<U>(
537 &self,
538 codec: &U,
539 cx: &mut Context<'_>,
540 ) -> Result<Decoded<U::Item>, RecvError<U>>
541 where
542 U: Decoder,
543 {
544 let decoded = self
545 .decode_item(codec)
546 .map_err(|err| RecvError::Decoder(err))?;
547
548 if decoded.item.is_some() {
549 Ok(decoded)
550 } else {
551 let st = self.st();
552 let flags = st.flags.get();
553 if flags.is_stopped() {
554 Err(RecvError::PeerGone(st.error()))
555 } else if flags.contains(Flags::DSP_STOP) {
556 st.remove_flags(Flags::DSP_STOP);
557 Err(RecvError::Stop)
558 } else if flags.contains(Flags::DSP_TIMEOUT) {
559 st.remove_flags(Flags::DSP_TIMEOUT);
560 Err(RecvError::KeepAlive)
561 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
562 Err(RecvError::WriteBackpressure)
563 } else {
564 match self.poll_read_ready(cx) {
565 Poll::Pending | Poll::Ready(Ok(Some(()))) => {
566 if log::log_enabled!(log::Level::Trace) && decoded.remains != 0 {
567 log::trace!(
568 "{}: Not enough data to decode next frame",
569 self.tag()
570 );
571 }
572 Ok(decoded)
573 }
574 Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
575 Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
576 }
577 }
578 }
579 }
580
581 #[inline]
582 pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
588 let st = self.st();
589 let flags = self.flags();
590
591 let len = st.buffer.write_destination_size();
592 if len > 0 {
593 if full {
594 st.insert_flags(Flags::BUF_W_MUST_FLUSH);
595 st.dispatch_task.register(cx.waker());
596 return if flags.is_stopped() {
597 Poll::Ready(Err(st.error_or_disconnected()))
598 } else {
599 Poll::Pending
600 };
601 } else if len >= st.pool.get().write_params_high() << 1 {
602 st.insert_flags(Flags::BUF_W_BACKPRESSURE);
603 st.dispatch_task.register(cx.waker());
604 return if flags.is_stopped() {
605 Poll::Ready(Err(st.error_or_disconnected()))
606 } else {
607 Poll::Pending
608 };
609 }
610 }
611 st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
612 Poll::Ready(Ok(()))
613 }
614
615 #[inline]
616 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
618 let st = self.st();
619 let flags = st.flags.get();
620
621 if flags.is_stopped() {
622 if let Some(err) = st.error() {
623 Poll::Ready(Err(err))
624 } else {
625 Poll::Ready(Ok(()))
626 }
627 } else {
628 if !flags.contains(Flags::IO_STOPPING_FILTERS) {
629 st.init_shutdown();
630 }
631
632 st.read_task.wake();
633 st.write_task.wake();
634 st.dispatch_task.register(cx.waker());
635 Poll::Pending
636 }
637 }
638
639 #[inline]
640 pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
644 self.pause();
645 let result = self.poll_status_update(cx);
646 if !result.is_pending() {
647 self.st().dispatch_task.register(cx.waker());
648 }
649 result
650 }
651
652 #[inline]
653 pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
655 let st = self.st();
656 let flags = st.flags.get();
657 if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
658 Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
659 } else if flags.contains(Flags::DSP_STOP) {
660 st.remove_flags(Flags::DSP_STOP);
661 Poll::Ready(IoStatusUpdate::Stop)
662 } else if flags.contains(Flags::DSP_TIMEOUT) {
663 st.remove_flags(Flags::DSP_TIMEOUT);
664 Poll::Ready(IoStatusUpdate::KeepAlive)
665 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
666 Poll::Ready(IoStatusUpdate::WriteBackpressure)
667 } else {
668 st.dispatch_task.register(cx.waker());
669 Poll::Pending
670 }
671 }
672
673 #[inline]
674 pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
676 self.st().dispatch_task.register(cx.waker());
677 }
678}
679
680impl<F> AsRef<IoRef> for Io<F> {
681 #[inline]
682 fn as_ref(&self) -> &IoRef {
683 self.io_ref()
684 }
685}
686
687impl<F> Eq for Io<F> {}
688
689impl<F> PartialEq for Io<F> {
690 #[inline]
691 fn eq(&self, other: &Self) -> bool {
692 self.io_ref().eq(other.io_ref())
693 }
694}
695
696impl<F> hash::Hash for Io<F> {
697 #[inline]
698 fn hash<H: hash::Hasher>(&self, state: &mut H) {
699 self.io_ref().hash(state);
700 }
701}
702
703impl<F> fmt::Debug for Io<F> {
704 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
705 f.debug_struct("Io").field("state", self.st()).finish()
706 }
707}
708
709impl<F> ops::Deref for Io<F> {
710 type Target = IoRef;
711
712 #[inline]
713 fn deref(&self) -> &Self::Target {
714 self.io_ref()
715 }
716}
717
718impl<F> Drop for Io<F> {
719 fn drop(&mut self) {
720 let st = self.st();
721 self.stop_timer();
722
723 if st.filter.is_set() {
724 if !st.flags.get().is_stopped() {
727 log::trace!(
728 "{}: Io is dropped, force stopping io streams {:?}",
729 st.tag.get(),
730 st.flags.get()
731 );
732 }
733
734 self.force_close();
735 st.filter.drop_filter::<F>();
736 }
737 }
738}
739
740const KIND_SEALED: u8 = 0b01;
741const KIND_PTR: u8 = 0b10;
742const KIND_MASK: u8 = 0b11;
743const KIND_UNMASK: u8 = !KIND_MASK;
744const KIND_MASK_USIZE: usize = 0b11;
745const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
746const SEALED_SIZE: usize = mem::size_of::<Sealed>();
747const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
748
749#[cfg(target_endian = "little")]
750const KIND_IDX: usize = 0;
751
752#[cfg(target_endian = "big")]
753const KIND_IDX: usize = SEALED_SIZE - 1;
754
755struct FilterPtr {
756 data: Cell<[u8; SEALED_SIZE]>,
757 filter: Cell<&'static dyn Filter>,
758}
759
760impl FilterPtr {
761 const fn null() -> Self {
762 Self {
763 data: Cell::new(NULL),
764 filter: Cell::new(NullFilter::get()),
765 }
766 }
767
768 fn update<F: Filter>(&self, filter: F) {
769 if self.is_set() {
770 panic!("Filter is set, must be dropped first");
771 }
772
773 let filter = Box::new(filter);
774 let mut data = NULL;
775 unsafe {
776 let filter_ref: &'static dyn Filter = {
777 let f: &dyn Filter = filter.as_ref();
778 mem::transmute(f)
779 };
780 self.filter.set(filter_ref);
781
782 let ptr = &mut data as *mut _ as *mut *mut F;
783 ptr.write(Box::into_raw(filter));
784 data[KIND_IDX] |= KIND_PTR;
785 self.data.set(data);
786 }
787 }
788
789 fn filter<F: Filter>(&self) -> &F {
791 let data = self.data.get();
792 if data[KIND_IDX] & KIND_PTR != 0 {
793 let ptr = &data as *const _ as *const *mut F;
794 unsafe {
795 let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
796 (p as *const F as *mut F).as_ref().unwrap()
797 }
798 } else {
799 panic!("Wrong filter item");
800 }
801 }
802
803 fn take_filter<F>(&self) -> Box<F> {
805 let mut data = self.data.get();
806 if data[KIND_IDX] & KIND_PTR != 0 {
807 data[KIND_IDX] &= KIND_UNMASK;
808 let ptr = &mut data as *mut _ as *mut *mut F;
809 unsafe { Box::from_raw(*ptr) }
810 } else {
811 panic!(
812 "Wrong filter item {:?} expected: {:?}",
813 data[KIND_IDX], KIND_PTR
814 );
815 }
816 }
817
818 fn take_sealed(&self) -> Sealed {
820 let mut data = self.data.get();
821
822 if data[KIND_IDX] & KIND_SEALED != 0 {
823 data[KIND_IDX] &= KIND_UNMASK;
824 let ptr = &mut data as *mut _ as *mut Sealed;
825 unsafe { ptr.read() }
826 } else {
827 panic!(
828 "Wrong filter item {:?} expected: {:?}",
829 data[KIND_IDX], KIND_SEALED
830 );
831 }
832 }
833
834 fn is_set(&self) -> bool {
835 self.data.get()[KIND_IDX] & KIND_MASK != 0
836 }
837
838 fn drop_filter<F>(&self) {
839 let data = self.data.get();
840
841 if data[KIND_IDX] & KIND_MASK != 0 {
842 if data[KIND_IDX] & KIND_PTR != 0 {
843 self.take_filter::<F>();
844 } else if data[KIND_IDX] & KIND_SEALED != 0 {
845 self.take_sealed();
846 }
847 self.data.set(NULL);
848 self.filter.set(NullFilter::get());
849 }
850 }
851}
852
853impl FilterPtr {
854 fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
855 let data = self.data.get();
856 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
857 Box::new(Layer::new(new, *self.take_filter::<F>()))
858 } else if data[KIND_IDX] & KIND_SEALED != 0 {
859 let f = Box::new(Layer::new(new, self.take_sealed()));
860 unsafe { mem::transmute::<Box<Layer<T, Sealed>>, Box<Layer<T, F>>>(f) }
862 } else {
863 panic!(
864 "Wrong filter item {:?} expected: {:?}",
865 data[KIND_IDX], KIND_PTR
866 );
867 };
868
869 let mut data = NULL;
870 unsafe {
871 let filter_ref: &'static dyn Filter = {
872 let f: &dyn Filter = filter.as_ref();
873 mem::transmute(f)
874 };
875 self.filter.set(filter_ref);
876
877 let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
878 ptr.write(Box::into_raw(filter));
879 data[KIND_IDX] |= KIND_PTR;
880 self.data.set(data);
881 }
882 }
883
884 fn map_filter<F: Filter, U, R>(&self, f: U)
885 where
886 U: FnOnce(F) -> R,
887 R: Filter,
888 {
889 let mut data = NULL;
890 let filter = Box::new(f(*self.take_filter::<F>()));
891 unsafe {
892 let filter_ref: &'static dyn Filter = {
893 let f: &dyn Filter = filter.as_ref();
894 mem::transmute(f)
895 };
896 self.filter.set(filter_ref);
897
898 let ptr = &mut data as *mut _ as *mut *mut R;
899 ptr.write(Box::into_raw(filter));
900 data[KIND_IDX] |= KIND_PTR;
901 self.data.set(data);
902 }
903 }
904
905 fn seal<F: Filter>(&self) {
906 let mut data = self.data.get();
907
908 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
909 Sealed(Box::new(*self.take_filter::<F>()))
910 } else if data[KIND_IDX] & KIND_SEALED != 0 {
911 self.take_sealed()
912 } else {
913 panic!(
914 "Wrong filter item {:?} expected: {:?}",
915 data[KIND_IDX], KIND_PTR
916 );
917 };
918
919 unsafe {
920 let filter_ref: &'static dyn Filter = {
921 let f: &dyn Filter = filter.0.as_ref();
922 mem::transmute(f)
923 };
924 self.filter.set(filter_ref);
925
926 let ptr = &mut data as *mut _ as *mut Sealed;
927 ptr.write(filter);
928 data[KIND_IDX] |= KIND_SEALED;
929 self.data.set(data);
930 }
931 }
932}
933
934#[derive(Debug)]
935#[must_use = "OnDisconnect do nothing unless polled"]
937pub struct OnDisconnect {
938 token: usize,
939 inner: Rc<IoState>,
940}
941
942impl OnDisconnect {
943 pub(super) fn new(inner: Rc<IoState>) -> Self {
944 Self::new_inner(inner.flags.get().is_stopped(), inner)
945 }
946
947 fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
948 let token = if disconnected {
949 usize::MAX
950 } else {
951 let mut on_disconnect = inner.on_disconnect.take();
952 let token = if let Some(ref mut on_disconnect) = on_disconnect {
953 let token = on_disconnect.len();
954 on_disconnect.push(LocalWaker::default());
955 token
956 } else {
957 on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
958 0
959 };
960 inner.on_disconnect.set(on_disconnect);
961 token
962 };
963 Self { token, inner }
964 }
965
966 #[inline]
967 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
969 if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
970 Poll::Ready(())
971 } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
972 on_disconnect[self.token].register(cx.waker());
973 self.inner.on_disconnect.set(Some(on_disconnect));
974 Poll::Pending
975 } else {
976 Poll::Ready(())
977 }
978 }
979}
980
981impl Clone for OnDisconnect {
982 fn clone(&self) -> Self {
983 if self.token == usize::MAX {
984 OnDisconnect::new_inner(true, self.inner.clone())
985 } else {
986 OnDisconnect::new_inner(false, self.inner.clone())
987 }
988 }
989}
990
991impl Future for OnDisconnect {
992 type Output = ();
993
994 #[inline]
995 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
996 self.poll_ready(cx)
997 }
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use ntex_bytes::Bytes;
1003 use ntex_codec::BytesCodec;
1004
1005 use super::*;
1006 use crate::{testing::IoTest, ReadBuf, WriteBuf};
1007
1008 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
1009 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
1010
1011 #[ntex::test]
1012 async fn test_basics() {
1013 let (client, server) = IoTest::create();
1014 client.remote_buffer_cap(1024);
1015
1016 let server = Io::new(server);
1017 assert!(server.eq(&server));
1018 assert!(server.io_ref().eq(server.io_ref()));
1019
1020 assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
1021 assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
1022 assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
1023 }
1024
1025 #[ntex::test]
1026 async fn test_recv() {
1027 let (client, server) = IoTest::create();
1028 client.remote_buffer_cap(1024);
1029
1030 let server = Io::new(server);
1031
1032 server.st().notify_timeout();
1033 let err = server.recv(&BytesCodec).await.err().unwrap();
1034 assert!(format!("{err:?}").contains("Timeout"));
1035
1036 server.st().insert_flags(Flags::DSP_STOP);
1037 let err = server.recv(&BytesCodec).await.err().unwrap();
1038 assert!(format!("{err:?}").contains("Dispatcher stopped"));
1039
1040 client.write(TEXT);
1041 server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
1042 let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
1043 assert_eq!(item, TEXT);
1044 }
1045
1046 #[ntex::test]
1047 async fn test_send() {
1048 let (client, server) = IoTest::create();
1049 client.remote_buffer_cap(1024);
1050
1051 let server = Io::new(server);
1052 assert!(server.eq(&server));
1053
1054 server
1055 .send(Bytes::from_static(BIN), &BytesCodec)
1056 .await
1057 .ok()
1058 .unwrap();
1059 let item = client.read_any();
1060 assert_eq!(item, TEXT);
1061 }
1062
1063 #[derive(Debug)]
1064 struct DropFilter {
1065 p: Rc<Cell<usize>>,
1066 }
1067
1068 impl Drop for DropFilter {
1069 fn drop(&mut self) {
1070 self.p.set(self.p.get() + 1);
1071 }
1072 }
1073
1074 impl FilterLayer for DropFilter {
1075 fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1076 if let Some(src) = buf.take_src() {
1077 let len = src.len();
1078 buf.set_dst(Some(src));
1079 Ok(len)
1080 } else {
1081 Ok(0)
1082 }
1083 }
1084 fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
1085 if let Some(src) = buf.take_src() {
1086 buf.set_dst(Some(src));
1087 }
1088 Ok(())
1089 }
1090 }
1091
1092 #[ntex::test]
1093 async fn drop_filter() {
1094 let p = Rc::new(Cell::new(0));
1095
1096 let (client, server) = IoTest::create();
1097 let f = DropFilter { p: p.clone() };
1098 let _ = format!("{f:?}");
1099 let io = Io::new(server).add_filter(f);
1100
1101 client.remote_buffer_cap(1024);
1102 client.write(TEXT);
1103 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1104 assert_eq!(msg, Bytes::from_static(BIN));
1105
1106 io.send(Bytes::from_static(b"test"), &BytesCodec)
1107 .await
1108 .unwrap();
1109 let buf = client.read().await.unwrap();
1110 assert_eq!(buf, Bytes::from_static(b"test"));
1111
1112 let io2 = io.take();
1113 let mut io3: crate::IoBoxed = io2.into();
1114 let io4 = io3.take();
1115
1116 drop(io);
1117 drop(io3);
1118 drop(io4);
1119
1120 assert_eq!(p.get(), 1);
1121 }
1122
1123 #[ntex::test]
1124 async fn test_take_sealed_filter() {
1125 let p = Rc::new(Cell::new(0));
1126 let f = DropFilter { p: p.clone() };
1127
1128 let io = Io::new(IoTest::create().0).seal();
1129 let _io: Io<Layer<DropFilter, Sealed>> = io.add_filter(f);
1130 }
1131}