1use std::{any, fmt, hash, io, ptr};
2
3use ntex_bytes::{BytePage, BytePages, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::ops::{Id, Iops, TimerHandle};
9use crate::{Decoded, Filter, FilterBuf, Flags, IoConfig, IoContext, IoRef, types};
10
11impl IoRef {
12 #[inline]
13 pub fn id(&self) -> Id {
15 self.0.id()
16 }
17
18 #[inline]
19 pub fn tag(&self) -> &'static str {
21 self.0.tag()
22 }
23
24 #[doc(hidden)]
25 pub fn flags(&self) -> Flags {
27 self.0.flags.clone()
28 }
29
30 #[inline]
31 pub(crate) fn filter(&self) -> &dyn Filter {
33 self.0.filter()
34 }
35
36 #[inline]
37 pub fn cfg(&self) -> &IoConfig {
39 &self.0.cfg
40 }
41
42 #[inline]
43 pub fn shared(&self) -> SharedCfg {
45 self.0.cfg.shared()
46 }
47
48 #[inline]
49 pub fn is_closed(&self) -> bool {
51 self.0.flags.is_closed()
52 }
53
54 #[inline]
55 pub fn is_wr_backpressure(&self) -> bool {
57 self.0.flags.is_wr_backpressure()
58 }
59
60 pub fn close(&self) {
64 self.0.start_shutdown();
65 }
66
67 pub fn terminate(&self) {
72 log::trace!("{}: Terminate io stream object", self.tag());
73 self.0.terminate_connection(None);
74 }
75
76 #[doc(hidden)]
77 #[deprecated(since = "3.10.0", note = "use IoRef::terminate() instead")]
78 pub fn force_close(&self) {
83 self.terminate();
84 }
85
86 #[doc(hidden)]
87 #[deprecated(since = "3.11.0", note = "use IoRef::close() instead")]
88 pub fn wants_shutdown(&self) {
90 self.0.start_shutdown();
91 }
92
93 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
95 types::QueryItem::new(self.filter().query(any::TypeId::of::<T>()))
96 }
97
98 #[inline]
99 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
101 where
102 U: Encoder,
103 {
104 self.with_write_buf(|buf| codec.encodev(item, buf))
105 .unwrap_or_else(|_| Ok(()))
106 }
107
108 #[inline]
109 pub fn encode_slice(&self, src: &[u8]) -> io::Result<()> {
111 self.with_write_buf(|buf| buf.extend_from_slice(src))
112 }
113
114 #[inline]
115 pub fn encode_bytes<B>(&self, src: B) -> io::Result<()>
117 where
118 BytePage: From<B>,
119 {
120 self.with_write_buf(|buf| buf.append(src))
121 }
122
123 pub fn decode<U>(
125 &self,
126 codec: &U,
127 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
128 where
129 U: Decoder,
130 {
131 self.0.buffer.with_read_dst(self, |buf| {
132 let res = codec.decode(buf);
133 self.0.flags.unset_read_ready();
134 self.update_read_destination(buf);
135 res
136 })
137 }
138
139 pub fn decode_item<U>(
141 &self,
142 codec: &U,
143 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
144 where
145 U: Decoder,
146 {
147 self.0.buffer.with_read_dst(self, |buf| {
148 let len = buf.len();
149 let res = codec.decode(buf).map(|item| Decoded {
150 item,
151 remains: buf.len(),
152 consumed: len - buf.len(),
153 });
154 self.0.flags.unset_read_ready();
155 self.update_read_destination(buf);
156 res
157 })
158 }
159
160 pub fn send_buf(&self) -> io::Result<()> {
165 self.consolidate_write_state(true);
167
168 if self.0.flags.is_stopping_any()
169 && let Some(err) = self.0.error.take()
170 {
171 Err(err)
172 } else {
173 Ok(())
174 }
175 }
176
177 pub(crate) fn ops_send_buf(&self) {
178 let st = &self.0;
179 #[cfg(feature = "trace")]
180 log::trace!(
181 "{}: ops-send == buf:{} flags:{:?}",
182 st.tag(),
183 st.buffer.write_buf_size(),
184 st.flags
185 );
186
187 if st.flags.is_wr_send_scheduled() {
188 st.flags.unset_wr_send_scheduled();
189
190 if st.flags.is_write_paused() {
191 if self.call_write() == WakeWriteTask::Yes {
195 st.wake_write_task();
196 st.flags.unset_write_paused();
197 }
198 } else {
199 st.wake_write_task();
200 }
201 }
202 }
203
204 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
206 where
207 F: FnOnce(&mut FilterBuf<'_>) -> R,
208 {
209 let result = self.0.buffer.with_filter(self, |ctx| ctx.with_buffer(f));
210 self.consolidate_write_state(false);
211 Ok(result)
212 }
213
214 pub fn with_read_buf<F, R>(&self, f: F) -> R
216 where
217 F: FnOnce(&mut BytesMut) -> R,
218 {
219 self.0.buffer.with_read_dst(self, |buf| {
220 let res = f(buf);
221 self.update_read_destination(buf);
222 res
223 })
224 }
225
226 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
228 where
229 F: FnOnce(&mut BytePages) -> R,
230 {
231 let st = &self.0;
232
233 if st.flags.is_stopping_any() {
234 if st.flags.is_closed() {
235 Err(st.error_or_disconnected())
236 } else {
237 Err(io::Error::other("I/O stream is closing"))
238 }
239 } else {
240 let result = st.buffer.with_write_src(f);
241 self.consolidate_write_state(false);
242 Ok(result)
243 }
244 }
245
246 pub(crate) fn consolidate_write_state(&self, force: bool) {
247 let st = &self.0;
248
249 let size = st.buffer.write_buf_size();
251
252 #[cfg(feature = "trace")]
253 log::trace!("{}: write-upd == buf:{size} flags:{:?}", st.tag(), st.flags);
254
255 if size > 0 && st.flags.is_write_paused() {
256 if st.flags.is_direct_wr_enabled()
267 && (force || size >= st.cfg.write_buf_threshold())
268 {
269 if self.call_write() == WakeWriteTask::Yes {
271 #[cfg(feature = "trace")]
272 log::trace!(
273 "{}: write-upd == schedule(more):{} flags:{:?}",
274 st.tag(),
275 st.buffer.write_buf_size(),
276 st.flags
277 );
278 if !st.flags.is_wr_send_scheduled() {
279 st.flags.set_wr_send_scheduled();
281 Iops::schedule_write(st.id());
282 }
283 } else {
284 st.flags.unset_wr_send_scheduled();
285 }
286 } else if !st.flags.is_wr_send_scheduled() {
287 #[cfg(feature = "trace")]
288 log::trace!("{}: write-upd == schedule(too small)", st.tag());
289 st.flags.set_wr_send_scheduled();
290 Iops::schedule_write(st.id());
291 }
292 }
293 if !st.flags.is_wr_backpressure() && st.is_wr_backpressure_needed(size) {
295 st.flags.set_wr_backpressure();
296 st.wake_dispatch_task();
297 }
298 }
299
300 fn update_read_destination(&self, buf: &mut BytesMut) {
301 let st = &self.0;
302
303 #[cfg(feature = "trace")]
304 log::trace!(
305 "{}: read-upd == buf:{} flags:{:?}",
306 st.tag(),
307 buf.len(),
308 st.flags
309 );
310
311 if st.flags.is_rd_backpressure() {
312 if st.is_rd_backpressure_needed(buf.len()) {
314 return;
315 }
316 st.flags.unset_all_read_flags();
317 } else {
318 st.flags.unset_read_ready();
319 }
320
321 if st.flags.is_read_paused() {
322 st.wake_read_task();
323 st.flags.unset_read_paused();
324 }
325 }
326
327 pub fn resize_read_buf(&self, buf: &mut BytesMut) {
329 self.0.cfg.read_buf().resize(buf);
330 }
331
332 #[doc(hidden)]
333 #[deprecated(since = "3.10.3", note = "Use .notify_disapatcher()")]
334 pub fn wake(&self) {
336 self.notify_dispatcher();
337 }
338
339 pub fn notify_dispatcher(&self) {
341 log::trace!("{}: Timer, notify dispatcher", self.tag());
342 self.0.wake_dispatch_task();
343 }
344
345 pub fn notify_timeout(&self) {
347 self.0.notify_timeout();
348 }
349
350 pub fn timer_handle(&self) -> TimerHandle {
352 self.0.timeout.get()
353 }
354
355 pub fn start_timer(&self, timeout: Seconds) -> TimerHandle {
357 let cur_hnd = self.0.timeout.get();
358
359 if timeout.is_zero() {
360 if cur_hnd.is_set() {
361 self.0.timeout.set(TimerHandle::ZERO);
362 cur_hnd.unregister(self);
363 }
364 TimerHandle::ZERO
365 } else if cur_hnd.is_set() {
366 let hnd = cur_hnd.update(timeout, self);
367 if hnd != cur_hnd {
368 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
369 self.0.timeout.set(hnd);
370 }
371 hnd
372 } else {
373 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
374 let hnd = TimerHandle::register(timeout, self);
375 self.0.timeout.set(hnd);
376 hnd
377 }
378 }
379
380 pub fn stop_timer(&self) {
382 let hnd = self.0.timeout.get();
383 if hnd.is_set() {
384 log::trace!("{}: Stop timer", self.tag());
385 self.0.timeout.set(TimerHandle::ZERO);
386 hnd.unregister(self);
387 }
388 }
389
390 pub fn on_disconnect(&self) -> crate::OnDisconnect {
392 crate::OnDisconnect::new(self.0.clone())
393 }
394
395 fn call_write(&self) -> WakeWriteTask {
398 if let Some(hnd) = self.0.handle.take() {
399 self.0.flags.unset_write_paused();
400 #[cfg(feature = "trace")]
401 log::trace!(
402 "{}: call-write ({}), flags:{:?}",
403 self.tag(),
404 self.0.buffer.write_buf_size(),
405 self.0.flags
406 );
407 let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
408 hnd.write(ctx);
409 self.0.handle.set(Some(hnd));
410 }
411 if self.0.flags.is_write_paused() {
412 WakeWriteTask::No
413 } else {
414 WakeWriteTask::Yes
415 }
416 }
417
418 pub(crate) fn call_notify(&self) {
419 if let Some(hnd) = self.0.handle.take() {
420 let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
421 hnd.notify(ctx);
422 self.0.handle.set(Some(hnd));
423 }
424 }
425}
426
427#[derive(Copy, Clone, PartialEq, Eq, Debug)]
428enum WakeWriteTask {
429 Yes,
430 No,
431}
432
433impl Eq for IoRef {}
434
435impl PartialEq for IoRef {
436 #[inline]
437 fn eq(&self, other: &Self) -> bool {
438 self.0.eq(&other.0)
439 }
440}
441
442impl hash::Hash for IoRef {
443 #[inline]
444 fn hash<H: hash::Hasher>(&self, state: &mut H) {
445 self.0.hash(state);
446 }
447}
448
449impl fmt::Debug for IoRef {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 f.debug_struct("IoRef")
452 .field("state", self.0.as_ref())
453 .finish()
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use std::cell::{Cell, RefCell};
460 use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
461
462 use ntex_bytes::Bytes;
463 use ntex_codec::BytesCodec;
464 use ntex_util::{future::lazy, time::Millis, time::sleep};
465
466 use super::*;
467 use crate::{FilterCtx, Io, testing::IoTest};
468
469 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
470 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
471
472 #[ntex::test]
473 async fn utils() {
474 let (client, server) = IoTest::create();
475 client.remote_buffer_cap(1024);
476 client.write(TEXT);
477
478 let state = Io::from(server);
479 assert_eq!(state.get_ref(), state.get_ref());
480
481 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
482 assert_eq!(msg, Bytes::from_static(BIN));
483 assert_eq!(state.get_ref(), state.as_ref().clone());
484 assert!(format!("{state:?}").find("Io {").is_some());
485 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
486
487 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
488 assert!(res.is_pending());
489 client.write(TEXT);
490 sleep(Millis(50)).await;
491 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
492 if let Poll::Ready(msg) = res {
493 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
494 }
495
496 client.read_error(io::Error::other("err"));
497 let msg = state.recv(&BytesCodec).await;
498 assert!(msg.is_err());
499 assert!(state.flags().is_terminated());
500
501 let (client, server) = IoTest::create();
502 client.remote_buffer_cap(1024);
503 let state = Io::from(server);
504
505 client.read_error(io::Error::other("err"));
506 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
507 if let Poll::Ready(msg) = res {
508 assert!(msg.is_err());
509 assert!(state.flags().is_terminated());
510 }
511
512 let (client, server) = IoTest::create();
513 client.remote_buffer_cap(1024);
514 let state = Io::from(server);
515 state.encode_slice(b"test").unwrap();
516 let buf = client.read().await.unwrap();
517 assert_eq!(buf, Bytes::from_static(b"test"));
518
519 client.write(b"test");
520 state.read_ready().await.unwrap();
521 let buf = state.decode(&BytesCodec).unwrap().unwrap();
522 assert_eq!(buf, Bytes::from_static(b"test"));
523
524 client.write_error(io::Error::other("err"));
525 state
526 .send(Bytes::from_static(b"test"), &BytesCodec)
527 .await
528 .unwrap();
529 assert!(state.flags().is_terminated());
530
531 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
532 assert!(res.is_err());
533
534 let (client, server) = IoTest::create();
535 client.remote_buffer_cap(1024);
536 let state = Io::from(server);
537 state.terminate();
538 assert!(state.flags().is_stopping());
539 assert!(state.flags().is_terminated());
540 }
541
542 #[ntex::test]
543 #[allow(clippy::unit_cmp)]
544 async fn on_disconnect() {
545 let (client, server) = IoTest::create();
546 let state = Io::from(server);
547 let mut waiter = state.on_disconnect();
548 assert_eq!(
549 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
550 Poll::Pending
551 );
552 let mut waiter2 = waiter.clone();
553 assert_eq!(
554 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
555 Poll::Pending
556 );
557 client.close().await;
558 assert_eq!(waiter.await, ());
559 assert_eq!(waiter2.await, ());
560
561 let mut waiter = state.on_disconnect();
562 assert_eq!(
563 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
564 Poll::Ready(())
565 );
566
567 let (client, server) = IoTest::create();
568 let state = Io::from(server);
569 let mut waiter = state.on_disconnect();
570 assert_eq!(
571 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
572 Poll::Pending
573 );
574 client.read_error(io::Error::other("err"));
575 assert_eq!(waiter.await, ());
576 }
577
578 #[ntex::test]
579 async fn write_to_closed_io() {
580 let (client, server) = IoTest::create();
581 let state = Io::from(server);
582 client.close().await;
583
584 assert!(state.is_closed());
585 assert!(state.encode_slice(TEXT.as_bytes()).is_err());
586 assert!(state.encode_bytes(Bytes::from_static(BIN)).is_err());
587 assert!(
588 state
589 .with_write_buf(|buf| buf.extend_from_slice(BIN))
590 .is_err()
591 );
592 }
593
594 #[derive(Debug)]
595 struct Counter<F> {
596 layer: F,
597 idx: usize,
598 in_bytes: Rc<Cell<usize>>,
599 out_bytes: Rc<Cell<usize>>,
600 read_order: Rc<RefCell<Vec<usize>>>,
601 write_order: Rc<RefCell<Vec<usize>>>,
602 }
603
604 impl<F: Filter> Filter for Counter<F> {
605 fn process_read_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
606 self.read_order.borrow_mut().push(self.idx);
607 let result = self.layer.process_read_buf(ctx);
608 self.in_bytes
609 .set(self.in_bytes.get() + ctx.new_read_bytes());
610 result
611 }
612
613 fn process_write_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
614 self.write_order.borrow_mut().push(self.idx);
615 ctx.with_buffer(|buf| {
616 buf.with_write_buffers(|src, _| {
617 self.out_bytes.set(self.out_bytes.get() + src.len());
618 });
619 });
620 self.layer.process_write_buf(ctx)
621 }
622
623 crate::forward_ready!(layer);
624 crate::forward_query!(layer);
625 crate::forward_shutdown!(layer);
626 }
627
628 #[ntex::test]
629 async fn filter() {
630 let in_bytes = Rc::new(Cell::new(0));
631 let out_bytes = Rc::new(Cell::new(0));
632 let read_order = Rc::new(RefCell::new(Vec::new()));
633 let write_order = Rc::new(RefCell::new(Vec::new()));
634
635 let (client, server) = IoTest::create();
636 let io = Io::from(server)
637 .map_filter(|layer| Counter {
638 layer,
639 idx: 1,
640 in_bytes: in_bytes.clone(),
641 out_bytes: out_bytes.clone(),
642 read_order: read_order.clone(),
643 write_order: write_order.clone(),
644 })
645 .seal();
646
647 client.remote_buffer_cap(1024);
648 client.write(TEXT);
649 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
650 assert_eq!(msg, Bytes::from_static(BIN));
651
652 io.send(Bytes::from_static(b"test"), &BytesCodec)
653 .await
654 .unwrap();
655 let buf = client.read().await.unwrap();
656 assert_eq!(buf, Bytes::from_static(b"test"));
657
658 client.write(TEXT);
659 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
660 assert_eq!(msg, Bytes::from_static(BIN));
661
662 assert_eq!(in_bytes.get(), BIN.len() * 2);
663 assert_eq!(out_bytes.get(), 8);
664 }
665
666 #[ntex::test]
667 async fn boxed_filter() {
668 let in_bytes = Rc::new(Cell::new(0));
669 let out_bytes = Rc::new(Cell::new(0));
670 let read_order = Rc::new(RefCell::new(Vec::new()));
671 let write_order = Rc::new(RefCell::new(Vec::new()));
672
673 let (client, server) = IoTest::create();
674 let state = Io::from(server)
675 .map_filter(|layer| Counter {
676 layer,
677 idx: 2,
678 in_bytes: in_bytes.clone(),
679 out_bytes: out_bytes.clone(),
680 read_order: read_order.clone(),
681 write_order: write_order.clone(),
682 })
683 .map_filter(|layer| Counter {
684 layer,
685 idx: 1,
686 in_bytes: in_bytes.clone(),
687 out_bytes: out_bytes.clone(),
688 read_order: read_order.clone(),
689 write_order: write_order.clone(),
690 });
691 let state = state.seal();
692
693 client.remote_buffer_cap(1024);
694 client.write(TEXT);
695 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
696 assert_eq!(msg, Bytes::from_static(BIN));
697
698 state
699 .send(Bytes::from_static(b"test"), &BytesCodec)
700 .await
701 .unwrap();
702 let buf = client.read().await.unwrap();
703 assert_eq!(buf, Bytes::from_static(b"test"));
704
705 assert_eq!(in_bytes.get(), BIN.len() * 2);
706 assert_eq!(out_bytes.get(), 16);
707 assert_eq!(state.0.buffer.with_write_dst(|b| b.len()), 0);
708
709 assert_eq!(Rc::strong_count(&in_bytes), 3);
711 drop(state);
712 assert_eq!(Rc::strong_count(&in_bytes), 1);
713 assert_eq!(*read_order.borrow(), &[1, 2][..]);
714 assert_eq!(*write_order.borrow(), &[1, 2, 1, 2, 1, 2][..]);
715 }
716}