1use std::{any, fmt, hash, io, ptr};
2
3use ntex_bytes::{BytePage, BytePages, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::ops::{Id, Iops, TimerHandle};
9use crate::{Decoded, Filter, FilterBuf, Flags, IoConfig, IoContext, IoRef, types};
10
11impl IoRef {
12 #[inline]
13 pub fn id(&self) -> Id {
15 self.0.id()
16 }
17
18 #[inline]
19 pub fn tag(&self) -> &'static str {
21 self.0.tag()
22 }
23
24 #[doc(hidden)]
25 pub fn flags(&self) -> Flags {
27 self.0.flags.clone()
28 }
29
30 #[inline]
31 pub(crate) fn filter(&self) -> &dyn Filter {
33 self.0.filter()
34 }
35
36 #[inline]
37 pub fn cfg(&self) -> &IoConfig {
39 &self.0.cfg
40 }
41
42 #[inline]
43 pub fn shared(&self) -> SharedCfg {
45 self.0.cfg.shared()
46 }
47
48 #[inline]
49 pub fn is_closed(&self) -> bool {
51 self.0.flags.is_closed()
52 }
53
54 #[inline]
55 pub fn is_wr_backpressure(&self) -> bool {
57 self.0.flags.is_wr_backpressure()
58 }
59
60 pub fn close(&self) {
64 self.0.start_shutdown();
65 }
66
67 pub fn terminate(&self) {
72 log::trace!("{}: Terminate io stream object", self.tag());
73 self.0.terminate_connection(None);
74 }
75
76 #[doc(hidden)]
77 #[deprecated(since = "3.10.0", note = "use IoRef::terminate() instead")]
78 pub fn force_close(&self) {
83 self.terminate();
84 }
85
86 #[doc(hidden)]
87 #[deprecated(since = "3.11.0", note = "use IoRef::close() instead")]
88 pub fn wants_shutdown(&self) {
90 self.0.start_shutdown();
91 }
92
93 pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
95 types::QueryItem::new(self.filter().query(any::TypeId::of::<T>()))
96 }
97
98 #[inline]
99 pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
101 where
102 U: Encoder,
103 {
104 self.with_write_buf(|buf| codec.encodev(item, buf))
105 .unwrap_or_else(|_| Ok(()))
106 }
107
108 #[inline]
109 pub fn encode_slice(&self, src: &[u8]) -> io::Result<()> {
111 self.with_write_buf(|buf| buf.extend_from_slice(src))
112 }
113
114 #[inline]
115 pub fn encode_bytes<B>(&self, src: B) -> io::Result<()>
117 where
118 BytePage: From<B>,
119 {
120 self.with_write_buf(|buf| buf.append(src))
121 }
122
123 pub fn decode<U>(
125 &self,
126 codec: &U,
127 ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
128 where
129 U: Decoder,
130 {
131 self.0.buffer.with_read_dst(self, |buf| {
132 let res = codec.decode(buf);
133 self.0.flags.unset_read_ready();
134 self.update_read_destination(buf);
135 res
136 })
137 }
138
139 pub fn decode_item<U>(
141 &self,
142 codec: &U,
143 ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
144 where
145 U: Decoder,
146 {
147 self.0.buffer.with_read_dst(self, |buf| {
148 let len = buf.len();
149 let res = codec.decode(buf).map(|item| Decoded {
150 item,
151 remains: buf.len(),
152 consumed: len - buf.len(),
153 });
154 self.0.flags.unset_read_ready();
155 self.update_read_destination(buf);
156 res
157 })
158 }
159
160 pub fn send_buf(&self) -> io::Result<()> {
165 self.consolidate_write_state(true);
167
168 if self.0.flags.is_stopping_any()
169 && let Some(err) = self.0.error.take()
170 {
171 Err(err)
172 } else {
173 Ok(())
174 }
175 }
176
177 pub(crate) fn ops_send_buf(&self) {
178 let st = &self.0;
179 if st.flags.is_wr_send_scheduled() {
180 st.flags.unset_wr_send_scheduled();
181
182 if st.flags.is_write_paused() {
183 if self.call_write() == WakeWriteTask::Yes {
187 st.wake_write_task();
188 st.flags.unset_write_paused();
189 }
190 } else {
191 st.wake_write_task();
192 }
193 }
194 }
195
196 pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
198 where
199 F: FnOnce(&mut FilterBuf<'_>) -> R,
200 {
201 let result = self.0.buffer.with_filter(self, |ctx| ctx.with_buffer(f));
202 self.consolidate_write_state(false);
203 Ok(result)
204 }
205
206 pub fn with_read_buf<F, R>(&self, f: F) -> R
208 where
209 F: FnOnce(&mut BytesMut) -> R,
210 {
211 self.0.buffer.with_read_dst(self, |buf| {
212 let res = f(buf);
213 self.update_read_destination(buf);
214 res
215 })
216 }
217
218 pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
220 where
221 F: FnOnce(&mut BytePages) -> R,
222 {
223 let st = &self.0;
224
225 if st.flags.is_stopping_any() {
226 if st.flags.is_closed() {
227 Err(st.error_or_disconnected())
228 } else {
229 Err(io::Error::other("I/O stream is closing"))
230 }
231 } else {
232 let result = st.buffer.with_write_src(f);
233 self.consolidate_write_state(false);
234 Ok(result)
235 }
236 }
237
238 pub(crate) fn consolidate_write_state(&self, force: bool) {
239 let st = &self.0;
240
241 let size = st.buffer.write_buf_size();
243
244 #[cfg(feature = "trace")]
245 log::trace!("{}: write-upd == buf:{size} flags:{:?}", st.tag(), st.flags);
246
247 if size > 0 && st.flags.is_write_paused() {
248 if st.flags.is_direct_wr_enabled()
259 && (force || size >= st.cfg.write_buf_threshold())
260 {
261 if self.call_write() == WakeWriteTask::Yes {
263 #[cfg(feature = "trace")]
264 log::trace!(
265 "{}: write-upd == schedule(more):{} flags:{:?}",
266 st.tag(),
267 st.buffer.write_buf_size(),
268 st.flags
269 );
270 if !st.flags.is_wr_send_scheduled() {
271 st.flags.set_wr_send_scheduled();
273 Iops::schedule_write(st.id());
274 }
275 } else {
276 st.flags.unset_wr_send_scheduled();
277 }
278 } else if !st.flags.is_wr_send_scheduled() {
279 #[cfg(feature = "trace")]
280 log::trace!("{}: write-upd == schedule(too small)", st.tag());
281 st.flags.set_wr_send_scheduled();
282 Iops::schedule_write(st.id());
283 }
284 }
285 if !st.flags.is_wr_backpressure() && st.is_wr_backpressure_needed(size) {
287 st.flags.set_wr_backpressure();
288 st.wake_dispatch_task();
289 }
290 }
291
292 fn update_read_destination(&self, buf: &mut BytesMut) {
293 let st = &self.0;
294
295 #[cfg(feature = "trace")]
296 log::trace!(
297 "{}: read-upd == buf:{} flags:{:?}",
298 st.tag(),
299 buf.len(),
300 st.flags
301 );
302
303 if st.flags.is_rd_backpressure() {
304 if st.is_rd_backpressure_needed(buf.len()) {
306 return;
307 }
308 st.flags.unset_all_read_flags();
309 } else {
310 st.flags.unset_read_ready();
311 }
312
313 if st.flags.is_read_paused() {
314 st.wake_read_task();
315 st.flags.unset_read_paused();
316 }
317 }
318
319 pub fn resize_read_buf(&self, buf: &mut BytesMut) {
321 self.0.cfg.read_buf().resize(buf);
322 }
323
324 #[doc(hidden)]
325 #[deprecated(since = "3.10.3", note = "Use .notify_disapatcher()")]
326 pub fn wake(&self) {
328 self.notify_dispatcher();
329 }
330
331 pub fn notify_dispatcher(&self) {
333 log::trace!("{}: Timer, notify dispatcher", self.tag());
334 self.0.wake_dispatch_task();
335 }
336
337 pub fn notify_timeout(&self) {
339 self.0.notify_timeout();
340 }
341
342 pub fn timer_handle(&self) -> TimerHandle {
344 self.0.timeout.get()
345 }
346
347 pub fn start_timer(&self, timeout: Seconds) -> TimerHandle {
349 let cur_hnd = self.0.timeout.get();
350
351 if timeout.is_zero() {
352 if cur_hnd.is_set() {
353 self.0.timeout.set(TimerHandle::ZERO);
354 cur_hnd.unregister(self);
355 }
356 TimerHandle::ZERO
357 } else if cur_hnd.is_set() {
358 let hnd = cur_hnd.update(timeout, self);
359 if hnd != cur_hnd {
360 log::trace!("{}: Update timer {:?}", self.tag(), timeout);
361 self.0.timeout.set(hnd);
362 }
363 hnd
364 } else {
365 log::trace!("{}: Start timer {:?}", self.tag(), timeout);
366 let hnd = TimerHandle::register(timeout, self);
367 self.0.timeout.set(hnd);
368 hnd
369 }
370 }
371
372 pub fn stop_timer(&self) {
374 let hnd = self.0.timeout.get();
375 if hnd.is_set() {
376 log::trace!("{}: Stop timer", self.tag());
377 self.0.timeout.set(TimerHandle::ZERO);
378 hnd.unregister(self);
379 }
380 }
381
382 pub fn on_disconnect(&self) -> crate::OnDisconnect {
384 crate::OnDisconnect::new(self.0.clone())
385 }
386
387 fn call_write(&self) -> WakeWriteTask {
390 if let Some(hnd) = self.0.handle.take() {
391 self.0.flags.unset_write_paused();
392 #[cfg(feature = "trace")]
393 log::trace!(
394 "{}: call-write ({}), flags:{:?}",
395 self.tag(),
396 self.0.buffer.write_buf_size(),
397 self.0.flags
398 );
399 let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
400 hnd.write(ctx);
401 self.0.handle.set(Some(hnd));
402 }
403 if self.0.flags.is_write_paused() {
404 WakeWriteTask::No
405 } else {
406 WakeWriteTask::Yes
407 }
408 }
409
410 pub(crate) fn call_notify(&self) {
411 if let Some(hnd) = self.0.handle.take() {
412 let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
413 hnd.notify(ctx);
414 self.0.handle.set(Some(hnd));
415 }
416 }
417}
418
419#[derive(Copy, Clone, PartialEq, Eq, Debug)]
420enum WakeWriteTask {
421 Yes,
422 No,
423}
424
425impl Eq for IoRef {}
426
427impl PartialEq for IoRef {
428 #[inline]
429 fn eq(&self, other: &Self) -> bool {
430 self.0.eq(&other.0)
431 }
432}
433
434impl hash::Hash for IoRef {
435 #[inline]
436 fn hash<H: hash::Hasher>(&self, state: &mut H) {
437 self.0.hash(state);
438 }
439}
440
441impl fmt::Debug for IoRef {
442 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443 f.debug_struct("IoRef")
444 .field("state", self.0.as_ref())
445 .finish()
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use std::cell::{Cell, RefCell};
452 use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
453
454 use ntex_bytes::Bytes;
455 use ntex_codec::BytesCodec;
456 use ntex_util::{future::lazy, time::Millis, time::sleep};
457
458 use super::*;
459 use crate::{FilterCtx, Io, testing::IoTest};
460
461 const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
462 const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
463
464 #[ntex::test]
465 async fn utils() {
466 let (client, server) = IoTest::create();
467 client.remote_buffer_cap(1024);
468 client.write(TEXT);
469
470 let state = Io::from(server);
471 assert_eq!(state.get_ref(), state.get_ref());
472
473 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
474 assert_eq!(msg, Bytes::from_static(BIN));
475 assert_eq!(state.get_ref(), state.as_ref().clone());
476 assert!(format!("{state:?}").find("Io {").is_some());
477 assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
478
479 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
480 assert!(res.is_pending());
481 client.write(TEXT);
482 sleep(Millis(50)).await;
483 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
484 if let Poll::Ready(msg) = res {
485 assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
486 }
487
488 client.read_error(io::Error::other("err"));
489 let msg = state.recv(&BytesCodec).await;
490 assert!(msg.is_err());
491 assert!(state.flags().is_terminated());
492
493 let (client, server) = IoTest::create();
494 client.remote_buffer_cap(1024);
495 let state = Io::from(server);
496
497 client.read_error(io::Error::other("err"));
498 let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
499 if let Poll::Ready(msg) = res {
500 assert!(msg.is_err());
501 assert!(state.flags().is_terminated());
502 }
503
504 let (client, server) = IoTest::create();
505 client.remote_buffer_cap(1024);
506 let state = Io::from(server);
507 state.encode_slice(b"test").unwrap();
508 let buf = client.read().await.unwrap();
509 assert_eq!(buf, Bytes::from_static(b"test"));
510
511 client.write(b"test");
512 state.read_ready().await.unwrap();
513 let buf = state.decode(&BytesCodec).unwrap().unwrap();
514 assert_eq!(buf, Bytes::from_static(b"test"));
515
516 client.write_error(io::Error::other("err"));
517 state
518 .send(Bytes::from_static(b"test"), &BytesCodec)
519 .await
520 .unwrap();
521 assert!(state.flags().is_terminated());
522
523 let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
524 assert!(res.is_err());
525
526 let (client, server) = IoTest::create();
527 client.remote_buffer_cap(1024);
528 let state = Io::from(server);
529 state.terminate();
530 assert!(state.flags().is_stopping());
531 assert!(state.flags().is_terminated());
532 }
533
534 #[ntex::test]
535 #[allow(clippy::unit_cmp)]
536 async fn on_disconnect() {
537 let (client, server) = IoTest::create();
538 let state = Io::from(server);
539 let mut waiter = state.on_disconnect();
540 assert_eq!(
541 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
542 Poll::Pending
543 );
544 let mut waiter2 = waiter.clone();
545 assert_eq!(
546 lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
547 Poll::Pending
548 );
549 client.close().await;
550 assert_eq!(waiter.await, ());
551 assert_eq!(waiter2.await, ());
552
553 let mut waiter = state.on_disconnect();
554 assert_eq!(
555 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
556 Poll::Ready(())
557 );
558
559 let (client, server) = IoTest::create();
560 let state = Io::from(server);
561 let mut waiter = state.on_disconnect();
562 assert_eq!(
563 lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
564 Poll::Pending
565 );
566 client.read_error(io::Error::other("err"));
567 assert_eq!(waiter.await, ());
568 }
569
570 #[ntex::test]
571 async fn write_to_closed_io() {
572 let (client, server) = IoTest::create();
573 let state = Io::from(server);
574 client.close().await;
575
576 assert!(state.is_closed());
577 assert!(state.encode_slice(TEXT.as_bytes()).is_err());
578 assert!(state.encode_bytes(Bytes::from_static(BIN)).is_err());
579 assert!(
580 state
581 .with_write_buf(|buf| buf.extend_from_slice(BIN))
582 .is_err()
583 );
584 }
585
586 #[derive(Debug)]
587 struct Counter<F> {
588 layer: F,
589 idx: usize,
590 in_bytes: Rc<Cell<usize>>,
591 out_bytes: Rc<Cell<usize>>,
592 read_order: Rc<RefCell<Vec<usize>>>,
593 write_order: Rc<RefCell<Vec<usize>>>,
594 }
595
596 impl<F: Filter> Filter for Counter<F> {
597 fn process_read_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
598 self.read_order.borrow_mut().push(self.idx);
599 let result = self.layer.process_read_buf(ctx);
600 self.in_bytes
601 .set(self.in_bytes.get() + ctx.new_read_bytes());
602 result
603 }
604
605 fn process_write_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
606 self.write_order.borrow_mut().push(self.idx);
607 ctx.with_buffer(|buf| {
608 buf.with_write_buffers(|src, _| {
609 self.out_bytes.set(self.out_bytes.get() + src.len());
610 });
611 });
612 self.layer.process_write_buf(ctx)
613 }
614
615 crate::forward_ready!(layer);
616 crate::forward_query!(layer);
617 crate::forward_shutdown!(layer);
618 }
619
620 #[ntex::test]
621 async fn filter() {
622 let in_bytes = Rc::new(Cell::new(0));
623 let out_bytes = Rc::new(Cell::new(0));
624 let read_order = Rc::new(RefCell::new(Vec::new()));
625 let write_order = Rc::new(RefCell::new(Vec::new()));
626
627 let (client, server) = IoTest::create();
628 let io = Io::from(server)
629 .map_filter(|layer| Counter {
630 layer,
631 idx: 1,
632 in_bytes: in_bytes.clone(),
633 out_bytes: out_bytes.clone(),
634 read_order: read_order.clone(),
635 write_order: write_order.clone(),
636 })
637 .seal();
638
639 client.remote_buffer_cap(1024);
640 client.write(TEXT);
641 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
642 assert_eq!(msg, Bytes::from_static(BIN));
643
644 io.send(Bytes::from_static(b"test"), &BytesCodec)
645 .await
646 .unwrap();
647 let buf = client.read().await.unwrap();
648 assert_eq!(buf, Bytes::from_static(b"test"));
649
650 client.write(TEXT);
651 let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
652 assert_eq!(msg, Bytes::from_static(BIN));
653
654 assert_eq!(in_bytes.get(), BIN.len() * 2);
655 assert_eq!(out_bytes.get(), 8);
656 }
657
658 #[ntex::test]
659 async fn boxed_filter() {
660 let in_bytes = Rc::new(Cell::new(0));
661 let out_bytes = Rc::new(Cell::new(0));
662 let read_order = Rc::new(RefCell::new(Vec::new()));
663 let write_order = Rc::new(RefCell::new(Vec::new()));
664
665 let (client, server) = IoTest::create();
666 let state = Io::from(server)
667 .map_filter(|layer| Counter {
668 layer,
669 idx: 2,
670 in_bytes: in_bytes.clone(),
671 out_bytes: out_bytes.clone(),
672 read_order: read_order.clone(),
673 write_order: write_order.clone(),
674 })
675 .map_filter(|layer| Counter {
676 layer,
677 idx: 1,
678 in_bytes: in_bytes.clone(),
679 out_bytes: out_bytes.clone(),
680 read_order: read_order.clone(),
681 write_order: write_order.clone(),
682 });
683 let state = state.seal();
684
685 client.remote_buffer_cap(1024);
686 client.write(TEXT);
687 let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
688 assert_eq!(msg, Bytes::from_static(BIN));
689
690 state
691 .send(Bytes::from_static(b"test"), &BytesCodec)
692 .await
693 .unwrap();
694 let buf = client.read().await.unwrap();
695 assert_eq!(buf, Bytes::from_static(b"test"));
696
697 assert_eq!(in_bytes.get(), BIN.len() * 2);
698 assert_eq!(out_bytes.get(), 16);
699 assert_eq!(state.0.buffer.with_write_dst(|b| b.len()), 0);
700
701 assert_eq!(Rc::strong_count(&in_bytes), 3);
703 drop(state);
704 assert_eq!(Rc::strong_count(&in_bytes), 1);
705 assert_eq!(*read_order.borrow(), &[1, 2][..]);
706 assert_eq!(*write_order.borrow(), &[1, 2, 1, 2, 1, 2][..]);
707 }
708}