1use std::cell::{Cell, UnsafeCell};
2use std::future::{poll_fn, Future};
3use std::task::{Context, Poll};
4use std::{fmt, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
5
6use ntex_bytes::{PoolId, PoolRef};
7use ntex_codec::{Decoder, Encoder};
8use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
9
10use crate::buf::Stack;
11use crate::filter::{Base, Filter, Layer, NullFilter};
12use crate::flags::Flags;
13use crate::seal::{IoBoxed, Sealed};
14use crate::tasks::{ReadContext, WriteContext};
15use crate::timer::TimerHandle;
16use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
17
18pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
20
21#[derive(Clone)]
22pub struct IoRef(pub(super) Rc<IoState>);
23
24pub(crate) struct IoState {
25 filter: FilterPtr,
26 pub(super) flags: Cell<Flags>,
27 pub(super) pool: Cell<PoolRef>,
28 pub(super) disconnect_timeout: Cell<Seconds>,
29 pub(super) error: Cell<Option<io::Error>>,
30 pub(super) read_task: LocalWaker,
31 pub(super) write_task: LocalWaker,
32 pub(super) dispatch_task: LocalWaker,
33 pub(super) buffer: Stack,
34 pub(super) handle: Cell<Option<Box<dyn Handle>>>,
35 pub(super) timeout: Cell<TimerHandle>,
36 pub(super) tag: Cell<&'static str>,
37 #[allow(clippy::box_collection)]
38 pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
39}
40
41const DEFAULT_TAG: &str = "IO";
42
43impl IoState {
44 pub(super) fn filter(&self) -> &dyn Filter {
45 self.filter.filter.get()
46 }
47
48 pub(super) fn insert_flags(&self, f: Flags) {
49 let mut flags = self.flags.get();
50 flags.insert(f);
51 self.flags.set(flags);
52 }
53
54 pub(super) fn remove_flags(&self, f: Flags) -> bool {
55 let mut flags = self.flags.get();
56 if flags.intersects(f) {
57 flags.remove(f);
58 self.flags.set(flags);
59 true
60 } else {
61 false
62 }
63 }
64
65 pub(super) fn notify_timeout(&self) {
66 let mut flags = self.flags.get();
67 if !flags.contains(Flags::DSP_TIMEOUT) {
68 flags.insert(Flags::DSP_TIMEOUT);
69 self.flags.set(flags);
70 self.dispatch_task.wake();
71 log::trace!("{}: Timer, notify dispatcher", self.tag.get());
72 }
73 }
74
75 pub(super) fn notify_disconnect(&self) {
76 if let Some(on_disconnect) = self.on_disconnect.take() {
77 for item in on_disconnect.into_iter() {
78 item.wake();
79 }
80 }
81 }
82
83 pub(super) fn error(&self) -> Option<io::Error> {
85 if let Some(err) = self.error.take() {
86 self.error
87 .set(Some(io::Error::new(err.kind(), format!("{}", err))));
88 Some(err)
89 } else {
90 None
91 }
92 }
93
94 pub(super) fn error_or_disconnected(&self) -> io::Error {
96 self.error()
97 .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected"))
98 }
99
100 pub(super) fn io_stopped(&self, err: Option<io::Error>) {
101 if err.is_some() {
102 self.error.set(err);
103 }
104 self.read_task.wake();
105 self.write_task.wake();
106 self.dispatch_task.wake();
107 self.notify_disconnect();
108 self.handle.take();
109 self.insert_flags(
110 Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
111 );
112 }
113
114 pub(super) fn init_shutdown(&self) {
116 if !self
117 .flags
118 .get()
119 .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
120 {
121 log::trace!(
122 "{}: Initiate io shutdown {:?}",
123 self.tag.get(),
124 self.flags.get()
125 );
126 self.insert_flags(Flags::IO_STOPPING_FILTERS);
127 self.read_task.wake();
128 }
129 }
130}
131
132impl Eq for IoState {}
133
134impl PartialEq for IoState {
135 #[inline]
136 fn eq(&self, other: &Self) -> bool {
137 ptr::eq(self, other)
138 }
139}
140
141impl hash::Hash for IoState {
142 #[inline]
143 fn hash<H: hash::Hasher>(&self, state: &mut H) {
144 (self as *const _ as usize).hash(state);
145 }
146}
147
148impl Drop for IoState {
149 #[inline]
150 fn drop(&mut self) {
151 self.buffer.release(self.pool.get());
152 }
153}
154
155impl fmt::Debug for IoState {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 let err = self.error.take();
158 let res = f
159 .debug_struct("IoState")
160 .field("flags", &self.flags)
161 .field("filter", &self.filter.is_set())
162 .field("disconnect_timeout", &self.disconnect_timeout)
163 .field("timeout", &self.timeout)
164 .field("error", &err)
165 .field("buffer", &self.buffer)
166 .finish();
167 self.error.set(err);
168 res
169 }
170}
171
172impl Io {
173 #[inline]
174 pub fn new<I: IoStream>(io: I) -> Self {
176 Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
177 }
178
179 #[inline]
180 pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
182 let inner = Rc::new(IoState {
183 filter: FilterPtr::null(),
184 pool: Cell::new(pool),
185 flags: Cell::new(Flags::WR_PAUSED),
186 error: Cell::new(None),
187 dispatch_task: LocalWaker::new(),
188 read_task: LocalWaker::new(),
189 write_task: LocalWaker::new(),
190 buffer: Stack::new(),
191 handle: Cell::new(None),
192 timeout: Cell::new(TimerHandle::default()),
193 disconnect_timeout: Cell::new(Seconds(1)),
194 on_disconnect: Cell::new(None),
195 tag: Cell::new(DEFAULT_TAG),
196 });
197 inner.filter.update(Base::new(IoRef(inner.clone())));
198
199 let io_ref = IoRef(inner);
200
201 let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
203 io_ref.0.handle.set(hnd);
204
205 Io(UnsafeCell::new(io_ref), marker::PhantomData)
206 }
207}
208
209impl<F> Io<F> {
210 #[inline]
211 pub fn set_memory_pool(&self, pool: PoolRef) {
213 self.st().buffer.set_memory_pool(pool);
214 self.st().pool.set(pool);
215 }
216
217 #[inline]
218 pub fn set_disconnect_timeout(&self, timeout: Seconds) {
220 self.st().disconnect_timeout.set(timeout);
221 }
222
223 #[inline]
224 pub fn take(&self) -> Self {
228 Self(UnsafeCell::new(self.take_io_ref()), marker::PhantomData)
229 }
230
231 fn take_io_ref(&self) -> IoRef {
232 let inner = Rc::new(IoState {
233 filter: FilterPtr::null(),
234 pool: self.st().pool.clone(),
235 flags: Cell::new(
236 Flags::DSP_STOP
237 | Flags::IO_STOPPED
238 | Flags::IO_STOPPING
239 | Flags::IO_STOPPING_FILTERS,
240 ),
241 error: Cell::new(None),
242 disconnect_timeout: Cell::new(Seconds(1)),
243 dispatch_task: LocalWaker::new(),
244 read_task: LocalWaker::new(),
245 write_task: LocalWaker::new(),
246 buffer: Stack::new(),
247 handle: Cell::new(None),
248 timeout: Cell::new(TimerHandle::default()),
249 on_disconnect: Cell::new(None),
250 tag: Cell::new(DEFAULT_TAG),
251 });
252 unsafe { mem::replace(&mut *self.0.get(), IoRef(inner)) }
253 }
254}
255
256impl<F> Io<F> {
257 #[inline]
258 #[doc(hidden)]
259 pub fn flags(&self) -> Flags {
261 self.st().flags.get()
262 }
263
264 #[inline]
265 pub fn get_ref(&self) -> IoRef {
267 self.io_ref().clone()
268 }
269
270 fn st(&self) -> &IoState {
271 unsafe { &(*self.0.get()).0 }
272 }
273
274 fn io_ref(&self) -> &IoRef {
275 unsafe { &*self.0.get() }
276 }
277}
278
279impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
280 #[inline]
281 pub fn filter(&self) -> &F {
283 &self.st().filter.filter::<Layer<F, T>>().0
284 }
285}
286
287impl<F: Filter> Io<F> {
288 #[inline]
289 pub fn seal(self) -> Io<Sealed> {
291 let state = self.take_io_ref();
292 state.0.filter.seal::<F>();
293
294 Io(UnsafeCell::new(state), marker::PhantomData)
295 }
296
297 #[inline]
298 pub fn boxed(self) -> IoBoxed {
300 self.seal().into()
301 }
302
303 #[inline]
304 pub fn add_filter<U>(self, nf: U) -> Io<Layer<U, F>>
306 where
307 U: FilterLayer,
308 {
309 let state = self.take_io_ref();
310
311 if U::BUFFERS {
313 unsafe { &mut *(Rc::as_ptr(&state.0) as *mut IoState) }
317 .buffer
318 .add_layer();
319 }
320
321 state.0.filter.add_filter::<F, U>(nf);
323
324 Io(UnsafeCell::new(state), marker::PhantomData)
325 }
326}
327
328impl<F> Io<F> {
329 #[inline]
330 pub async fn recv<U>(
332 &self,
333 codec: &U,
334 ) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
335 where
336 U: Decoder,
337 {
338 loop {
339 return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
340 Ok(item) => Ok(Some(item)),
341 Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
342 io::ErrorKind::TimedOut,
343 "Timeout",
344 ))),
345 Err(RecvError::Stop) => Err(Either::Right(io::Error::new(
346 io::ErrorKind::UnexpectedEof,
347 "Dispatcher stopped",
348 ))),
349 Err(RecvError::WriteBackpressure) => {
350 poll_fn(|cx| self.poll_flush(cx, false))
351 .await
352 .map_err(Either::Right)?;
353 continue;
354 }
355 Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
356 Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
357 Err(RecvError::PeerGone(None)) => Ok(None),
358 };
359 }
360 }
361
362 #[inline]
363 pub async fn read_ready(&self) -> io::Result<Option<()>> {
365 poll_fn(|cx| self.poll_read_ready(cx)).await
366 }
367
368 #[inline]
369 pub async fn read_notify(&self) -> io::Result<Option<()>> {
371 poll_fn(|cx| self.poll_read_notify(cx)).await
372 }
373
374 #[inline]
375 pub fn pause(&self) {
377 let st = self.st();
378 if !st.flags.get().contains(Flags::RD_PAUSED) {
379 st.read_task.wake();
380 st.insert_flags(Flags::RD_PAUSED);
381 }
382 }
383
384 #[inline]
385 pub async fn send<U>(
387 &self,
388 item: U::Item,
389 codec: &U,
390 ) -> Result<(), Either<U::Error, io::Error>>
391 where
392 U: Encoder,
393 {
394 self.encode(item, codec).map_err(Either::Left)?;
395
396 poll_fn(|cx| self.poll_flush(cx, true))
397 .await
398 .map_err(Either::Right)?;
399
400 Ok(())
401 }
402
403 #[inline]
404 pub async fn flush(&self, full: bool) -> io::Result<()> {
408 poll_fn(|cx| self.poll_flush(cx, full)).await
409 }
410
411 #[inline]
412 pub async fn shutdown(&self) -> io::Result<()> {
414 poll_fn(|cx| self.poll_shutdown(cx)).await
415 }
416
417 #[inline]
418 pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
432 let st = self.st();
433 let mut flags = st.flags.get();
434
435 if flags.is_stopped() {
436 Poll::Ready(Err(st.error_or_disconnected()))
437 } else {
438 st.dispatch_task.register(cx.waker());
439
440 let ready = flags.is_read_buf_ready();
441 if flags.cannot_read() {
442 flags.cleanup_read_flags();
443 st.read_task.wake();
444 st.flags.set(flags);
445 if ready {
446 Poll::Ready(Ok(Some(())))
447 } else {
448 Poll::Pending
449 }
450 } else if ready {
451 flags.remove(Flags::BUF_R_READY);
452 st.flags.set(flags);
453 Poll::Ready(Ok(Some(())))
454 } else {
455 Poll::Pending
456 }
457 }
458 }
459
460 #[inline]
461 pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
463 let ready = self.poll_read_ready(cx);
464
465 if ready.is_pending() {
466 let st = self.st();
467 if st.remove_flags(Flags::RD_NOTIFY) {
468 Poll::Ready(Ok(Some(())))
469 } else {
470 st.insert_flags(Flags::RD_NOTIFY);
471 Poll::Pending
472 }
473 } else {
474 ready
475 }
476 }
477
478 #[inline]
479 pub fn poll_recv<U>(
484 &self,
485 codec: &U,
486 cx: &mut Context<'_>,
487 ) -> Poll<Result<U::Item, RecvError<U>>>
488 where
489 U: Decoder,
490 {
491 let decoded = self.poll_recv_decode(codec, cx)?;
492
493 if let Some(item) = decoded.item {
494 Poll::Ready(Ok(item))
495 } else {
496 Poll::Pending
497 }
498 }
499
500 #[doc(hidden)]
501 #[inline]
502 pub fn poll_recv_decode<U>(
507 &self,
508 codec: &U,
509 cx: &mut Context<'_>,
510 ) -> Result<Decoded<U::Item>, RecvError<U>>
511 where
512 U: Decoder,
513 {
514 let decoded = self
515 .decode_item(codec)
516 .map_err(|err| RecvError::Decoder(err))?;
517
518 if decoded.item.is_some() {
519 Ok(decoded)
520 } else {
521 let st = self.st();
522 let flags = st.flags.get();
523 if flags.is_stopped() {
524 Err(RecvError::PeerGone(st.error()))
525 } else if flags.contains(Flags::DSP_STOP) {
526 st.remove_flags(Flags::DSP_STOP);
527 Err(RecvError::Stop)
528 } else if flags.contains(Flags::DSP_TIMEOUT) {
529 st.remove_flags(Flags::DSP_TIMEOUT);
530 Err(RecvError::KeepAlive)
531 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
532 Err(RecvError::WriteBackpressure)
533 } else {
534 match self.poll_read_ready(cx) {
535 Poll::Pending | Poll::Ready(Ok(Some(()))) => {
536 if log::log_enabled!(log::Level::Debug) && decoded.remains != 0 {
537 log::debug!(
538 "{}: Not enough data to decode next frame",
539 self.tag()
540 );
541 }
542 Ok(decoded)
543 }
544 Poll::Ready(Err(e)) => Err(RecvError::PeerGone(Some(e))),
545 Poll::Ready(Ok(None)) => Err(RecvError::PeerGone(None)),
546 }
547 }
548 }
549 }
550
551 #[inline]
552 pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
558 let st = self.st();
559 let flags = self.flags();
560
561 let len = st.buffer.write_destination_size();
562 if len > 0 {
563 if full {
564 st.insert_flags(Flags::BUF_W_MUST_FLUSH);
565 st.dispatch_task.register(cx.waker());
566 return if flags.is_stopped() {
567 Poll::Ready(Err(st.error_or_disconnected()))
568 } else {
569 Poll::Pending
570 };
571 } else if len >= st.pool.get().write_params_high() << 1 {
572 st.insert_flags(Flags::BUF_W_BACKPRESSURE);
573 st.dispatch_task.register(cx.waker());
574 return if flags.is_stopped() {
575 Poll::Ready(Err(st.error_or_disconnected()))
576 } else {
577 Poll::Pending
578 };
579 }
580 }
581 st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE);
582 Poll::Ready(Ok(()))
583 }
584
585 #[inline]
586 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
588 let st = self.st();
589 let flags = st.flags.get();
590
591 if flags.is_stopped() {
592 if let Some(err) = st.error() {
593 Poll::Ready(Err(err))
594 } else {
595 Poll::Ready(Ok(()))
596 }
597 } else {
598 if !flags.contains(Flags::IO_STOPPING_FILTERS) {
599 st.init_shutdown();
600 }
601
602 st.read_task.wake();
603 st.write_task.wake();
604 st.dispatch_task.register(cx.waker());
605 Poll::Pending
606 }
607 }
608
609 #[inline]
610 pub fn poll_read_pause(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
614 self.pause();
615 let result = self.poll_status_update(cx);
616 if !result.is_pending() {
617 self.st().dispatch_task.register(cx.waker());
618 }
619 result
620 }
621
622 #[inline]
623 pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
625 let st = self.st();
626 let flags = st.flags.get();
627 if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
628 Poll::Ready(IoStatusUpdate::PeerGone(st.error()))
629 } else if flags.contains(Flags::DSP_STOP) {
630 st.remove_flags(Flags::DSP_STOP);
631 Poll::Ready(IoStatusUpdate::Stop)
632 } else if flags.contains(Flags::DSP_TIMEOUT) {
633 st.remove_flags(Flags::DSP_TIMEOUT);
634 Poll::Ready(IoStatusUpdate::KeepAlive)
635 } else if flags.contains(Flags::BUF_W_BACKPRESSURE) {
636 Poll::Ready(IoStatusUpdate::WriteBackpressure)
637 } else {
638 st.dispatch_task.register(cx.waker());
639 Poll::Pending
640 }
641 }
642
643 #[inline]
644 pub fn poll_dispatch(&self, cx: &mut Context<'_>) {
646 self.st().dispatch_task.register(cx.waker());
647 }
648}
649
650impl<F> AsRef<IoRef> for Io<F> {
651 #[inline]
652 fn as_ref(&self) -> &IoRef {
653 self.io_ref()
654 }
655}
656
657impl<F> Eq for Io<F> {}
658
659impl<F> PartialEq for Io<F> {
660 #[inline]
661 fn eq(&self, other: &Self) -> bool {
662 self.io_ref().eq(other.io_ref())
663 }
664}
665
666impl<F> hash::Hash for Io<F> {
667 #[inline]
668 fn hash<H: hash::Hasher>(&self, state: &mut H) {
669 self.io_ref().hash(state);
670 }
671}
672
673impl<F> fmt::Debug for Io<F> {
674 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
675 f.debug_struct("Io").field("state", self.st()).finish()
676 }
677}
678
679impl<F> ops::Deref for Io<F> {
680 type Target = IoRef;
681
682 #[inline]
683 fn deref(&self) -> &Self::Target {
684 self.io_ref()
685 }
686}
687
688impl<F> Drop for Io<F> {
689 fn drop(&mut self) {
690 let st = self.st();
691 self.stop_timer();
692
693 if st.filter.is_set() {
694 if !st.flags.get().is_stopped() {
697 log::trace!(
698 "{}: Io is dropped, force stopping io streams {:?}",
699 st.tag.get(),
700 st.flags.get()
701 );
702 }
703
704 self.force_close();
705 st.filter.drop_filter::<F>();
706 }
707 }
708}
709
710const KIND_SEALED: u8 = 0b01;
711const KIND_PTR: u8 = 0b10;
712const KIND_MASK: u8 = 0b11;
713const KIND_UNMASK: u8 = !KIND_MASK;
714const KIND_MASK_USIZE: usize = 0b11;
715const KIND_UNMASK_USIZE: usize = !KIND_MASK_USIZE;
716const SEALED_SIZE: usize = mem::size_of::<Sealed>();
717const NULL: [u8; SEALED_SIZE] = [0u8; SEALED_SIZE];
718
719#[cfg(target_endian = "little")]
720const KIND_IDX: usize = 0;
721
722#[cfg(target_endian = "big")]
723const KIND_IDX: usize = SEALED_SIZE - 1;
724
725struct FilterPtr {
726 data: Cell<[u8; SEALED_SIZE]>,
727 filter: Cell<&'static dyn Filter>,
728}
729
730impl FilterPtr {
731 const fn null() -> Self {
732 Self {
733 data: Cell::new(NULL),
734 filter: Cell::new(NullFilter::get()),
735 }
736 }
737
738 fn update<F: Filter>(&self, filter: F) {
739 if self.is_set() {
740 panic!("Filter is set, must be dropped first");
741 }
742
743 let filter = Box::new(filter);
744 let mut data = NULL;
745 unsafe {
746 let filter_ref: &'static dyn Filter = {
747 let f: &dyn Filter = filter.as_ref();
748 std::mem::transmute(f)
749 };
750 self.filter.set(filter_ref);
751
752 let ptr = &mut data as *mut _ as *mut *mut F;
753 ptr.write(Box::into_raw(filter));
754 data[KIND_IDX] |= KIND_PTR;
755 self.data.set(data);
756 }
757 }
758
759 fn filter<F: Filter>(&self) -> &F {
761 let data = self.data.get();
762 if data[KIND_IDX] & KIND_PTR != 0 {
763 let ptr = &data as *const _ as *const *mut F;
764 unsafe {
765 let p = (ptr.read() as *const _ as usize) & KIND_UNMASK_USIZE;
766 (p as *const F as *mut F).as_ref().unwrap()
767 }
768 } else {
769 panic!("Wrong filter item");
770 }
771 }
772
773 fn take_filter<F>(&self) -> Box<F> {
775 let mut data = self.data.get();
776 if data[KIND_IDX] & KIND_PTR != 0 {
777 data[KIND_IDX] &= KIND_UNMASK;
778 let ptr = &mut data as *mut _ as *mut *mut F;
779 unsafe { Box::from_raw(*ptr) }
780 } else {
781 panic!(
782 "Wrong filter item {:?} expected: {:?}",
783 data[KIND_IDX], KIND_PTR
784 );
785 }
786 }
787
788 fn take_sealed(&self) -> Sealed {
790 let mut data = self.data.get();
791
792 if data[KIND_IDX] & KIND_SEALED != 0 {
793 data[KIND_IDX] &= KIND_UNMASK;
794 let ptr = &mut data as *mut _ as *mut Sealed;
795 unsafe { ptr.read() }
796 } else {
797 panic!(
798 "Wrong filter item {:?} expected: {:?}",
799 data[KIND_IDX], KIND_SEALED
800 );
801 }
802 }
803
804 fn is_set(&self) -> bool {
805 self.data.get()[KIND_IDX] & KIND_MASK != 0
806 }
807
808 fn drop_filter<F>(&self) {
809 let data = self.data.get();
810
811 if data[KIND_IDX] & KIND_MASK != 0 {
812 if data[KIND_IDX] & KIND_PTR != 0 {
813 self.take_filter::<F>();
814 } else if data[KIND_IDX] & KIND_SEALED != 0 {
815 self.take_sealed();
816 }
817 self.data.set(NULL);
818 self.filter.set(NullFilter::get());
819 }
820 }
821}
822
823impl FilterPtr {
824 fn add_filter<F: Filter, T: FilterLayer>(&self, new: T) {
825 let mut data = NULL;
826 let filter = Box::new(Layer::new(new, *self.take_filter::<F>()));
827 unsafe {
828 let filter_ref: &'static dyn Filter = {
829 let f: &dyn Filter = filter.as_ref();
830 std::mem::transmute(f)
831 };
832 self.filter.set(filter_ref);
833
834 let ptr = &mut data as *mut _ as *mut *mut Layer<T, F>;
835 ptr.write(Box::into_raw(filter));
836 data[KIND_IDX] |= KIND_PTR;
837 self.data.set(data);
838 }
839 }
840
841 fn seal<F: Filter>(&self) {
842 let mut data = self.data.get();
843
844 let filter = if data[KIND_IDX] & KIND_PTR != 0 {
845 Sealed(Box::new(*self.take_filter::<F>()))
846 } else if data[KIND_IDX] & KIND_SEALED != 0 {
847 self.take_sealed()
848 } else {
849 panic!(
850 "Wrong filter item {:?} expected: {:?}",
851 data[KIND_IDX], KIND_PTR
852 );
853 };
854
855 unsafe {
856 let filter_ref: &'static dyn Filter = {
857 let f: &dyn Filter = filter.0.as_ref();
858 std::mem::transmute(f)
859 };
860 self.filter.set(filter_ref);
861
862 let ptr = &mut data as *mut _ as *mut Sealed;
863 ptr.write(filter);
864 data[KIND_IDX] |= KIND_SEALED;
865 self.data.set(data);
866 }
867 }
868}
869
870#[derive(Debug)]
871#[must_use = "OnDisconnect do nothing unless polled"]
873pub struct OnDisconnect {
874 token: usize,
875 inner: Rc<IoState>,
876}
877
878impl OnDisconnect {
879 pub(super) fn new(inner: Rc<IoState>) -> Self {
880 Self::new_inner(inner.flags.get().is_stopped(), inner)
881 }
882
883 fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
884 let token = if disconnected {
885 usize::MAX
886 } else {
887 let mut on_disconnect = inner.on_disconnect.take();
888 let token = if let Some(ref mut on_disconnect) = on_disconnect {
889 let token = on_disconnect.len();
890 on_disconnect.push(LocalWaker::default());
891 token
892 } else {
893 on_disconnect = Some(Box::new(vec![LocalWaker::default()]));
894 0
895 };
896 inner.on_disconnect.set(on_disconnect);
897 token
898 };
899 Self { token, inner }
900 }
901
902 #[inline]
903 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
905 if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
906 Poll::Ready(())
907 } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
908 on_disconnect[self.token].register(cx.waker());
909 self.inner.on_disconnect.set(Some(on_disconnect));
910 Poll::Pending
911 } else {
912 Poll::Ready(())
913 }
914 }
915}
916
917impl Clone for OnDisconnect {
918 fn clone(&self) -> Self {
919 if self.token == usize::MAX {
920 OnDisconnect::new_inner(true, self.inner.clone())
921 } else {
922 OnDisconnect::new_inner(false, self.inner.clone())
923 }
924 }
925}
926
927impl Future for OnDisconnect {
928 type Output = ();
929
930 #[inline]
931 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
932 self.poll_ready(cx)
933 }
934}
935
936#[cfg(test)]
937mod tests {
938 use ntex_bytes::Bytes;
939 use ntex_codec::BytesCodec;
940
941 use super::*;
942 use crate::{testing::IoTest, ReadBuf, WriteBuf};
943
944 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
945 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
946
947 #[ntex::test]
948 async fn test_basics() {
949 let (client, server) = IoTest::create();
950 client.remote_buffer_cap(1024);
951
952 let server = Io::new(server);
953 assert!(server.eq(&server));
954 assert!(server.io_ref().eq(server.io_ref()));
955
956 assert!(format!("{:?}", Flags::IO_STOPPED).contains("IO_STOPPED"));
957 assert!(Flags::IO_STOPPED == Flags::IO_STOPPED);
958 assert!(Flags::IO_STOPPED != Flags::IO_STOPPING);
959 }
960
961 #[ntex::test]
962 async fn test_recv() {
963 let (client, server) = IoTest::create();
964 client.remote_buffer_cap(1024);
965
966 let server = Io::new(server);
967
968 server.st().notify_timeout();
969 let err = server.recv(&BytesCodec).await.err().unwrap();
970 assert!(format!("{:?}", err).contains("Timeout"));
971
972 server.st().insert_flags(Flags::DSP_STOP);
973 let err = server.recv(&BytesCodec).await.err().unwrap();
974 assert!(format!("{:?}", err).contains("Dispatcher stopped"));
975
976 client.write(TEXT);
977 server.st().insert_flags(Flags::BUF_W_BACKPRESSURE);
978 let item = server.recv(&BytesCodec).await.ok().unwrap().unwrap();
979 assert_eq!(item, TEXT);
980 }
981
982 #[ntex::test]
983 async fn test_send() {
984 let (client, server) = IoTest::create();
985 client.remote_buffer_cap(1024);
986
987 let server = Io::new(server);
988 assert!(server.eq(&server));
989
990 server
991 .send(Bytes::from_static(BIN), &BytesCodec)
992 .await
993 .ok()
994 .unwrap();
995 let item = client.read_any();
996 assert_eq!(item, TEXT);
997 }
998
999 #[derive(Debug)]
1000 struct DropFilter {
1001 p: Rc<Cell<usize>>,
1002 }
1003
1004 impl Drop for DropFilter {
1005 fn drop(&mut self) {
1006 self.p.set(self.p.get() + 1);
1007 }
1008 }
1009
1010 impl FilterLayer for DropFilter {
1011 const BUFFERS: bool = false;
1012 fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
1013 Ok(buf.nbytes())
1014 }
1015 fn process_write_buf(&self, _: &WriteBuf<'_>) -> io::Result<()> {
1016 Ok(())
1017 }
1018 }
1019
1020 #[ntex::test]
1021 async fn drop_filter() {
1022 let p = Rc::new(Cell::new(0));
1023
1024 let (client, server) = IoTest::create();
1025 let f = DropFilter { p: p.clone() };
1026 let _ = format!("{:?}", f);
1027 let io = Io::new(server).add_filter(f);
1028
1029 client.remote_buffer_cap(1024);
1030 client.write(TEXT);
1031 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
1032 assert_eq!(msg, Bytes::from_static(BIN));
1033
1034 io.send(Bytes::from_static(b"test"), &BytesCodec)
1035 .await
1036 .unwrap();
1037 let buf = client.read().await.unwrap();
1038 assert_eq!(buf, Bytes::from_static(b"test"));
1039
1040 let io2 = io.take();
1041 let mut io3: crate::IoBoxed = io2.into();
1042 let io4 = io3.take();
1043
1044 drop(io);
1045 drop(io3);
1046 drop(io4);
1047
1048 assert_eq!(p.get(), 1);
1049 }
1050}