Skip to main content

ntex_io/
ioref.rs

1use std::{any, fmt, hash, io, ptr};
2
3use ntex_bytes::{BytePage, BytePages, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::ops::{Id, Iops, TimerHandle};
9use crate::{Decoded, Filter, FilterBuf, Flags, IoConfig, IoContext, IoRef, types};
10
11impl IoRef {
12    #[inline]
13    /// Gets the ID.
14    pub fn id(&self) -> Id {
15        self.0.id()
16    }
17
18    #[inline]
19    /// Gets the I/O tag.
20    pub fn tag(&self) -> &'static str {
21        self.0.tag()
22    }
23
24    #[doc(hidden)]
25    /// Gets the current state flags.
26    pub fn flags(&self) -> Flags {
27        self.0.flags.clone()
28    }
29
30    #[inline]
31    /// Gets the current filter.
32    pub(crate) fn filter(&self) -> &dyn Filter {
33        self.0.filter()
34    }
35
36    #[inline]
37    /// Gets the configuration.
38    pub fn cfg(&self) -> &IoConfig {
39        &self.0.cfg
40    }
41
42    #[inline]
43    /// Gets the shared configuration.
44    pub fn shared(&self) -> SharedCfg {
45        self.0.cfg.shared()
46    }
47
48    #[inline]
49    /// Checks whether the I/O stream is closed.
50    pub fn is_closed(&self) -> bool {
51        self.0.flags.is_closed()
52    }
53
54    #[inline]
55    /// Checks whether write back-pressure is enabled.
56    pub fn is_wr_backpressure(&self) -> bool {
57        self.0.flags.is_wr_backpressure()
58    }
59
60    /// Gracefully closes the connection.
61    ///
62    /// Initiates the I/O stream shutdown process.
63    pub fn close(&self) {
64        self.0.start_shutdown();
65    }
66
67    /// Force-closes the connection.
68    ///
69    /// The dispatcher does not wait for incomplete responses. The I/O stream is
70    /// terminated without any graceful period.
71    pub fn terminate(&self) {
72        log::trace!("{}: Terminate io stream object", self.tag());
73        self.0.terminate_connection(None);
74    }
75
76    #[doc(hidden)]
77    #[deprecated(since = "3.10.0", note = "use IoRef::terminate() instead")]
78    /// Force close connection
79    ///
80    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
81    /// without any graceful period.
82    pub fn force_close(&self) {
83        self.terminate();
84    }
85
86    #[doc(hidden)]
87    #[deprecated(since = "3.11.0", note = "use IoRef::close() instead")]
88    /// Gracefully shuts down the I/O stream.
89    pub fn wants_shutdown(&self) {
90        self.0.start_shutdown();
91    }
92
93    /// Queries filter-specific data.
94    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
95        types::QueryItem::new(self.filter().query(any::TypeId::of::<T>()))
96    }
97
98    #[inline]
99    /// Encodes the item into the write buffer.
100    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
101    where
102        U: Encoder,
103    {
104        self.with_write_buf(|buf| codec.encodev(item, buf))
105            .unwrap_or_else(|_| Ok(()))
106    }
107
108    #[inline]
109    /// Encodes the slice into the write buffer.
110    pub fn encode_slice(&self, src: &[u8]) -> io::Result<()> {
111        self.with_write_buf(|buf| buf.extend_from_slice(src))
112    }
113
114    #[inline]
115    /// Writes bytes to the write buffer.
116    pub fn encode_bytes<B>(&self, src: B) -> io::Result<()>
117    where
118        BytePage: From<B>,
119    {
120        self.with_write_buf(|buf| buf.append(src))
121    }
122
123    /// Attempts to decode a frame from the read buffer.
124    pub fn decode<U>(
125        &self,
126        codec: &U,
127    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
128    where
129        U: Decoder,
130    {
131        self.0.buffer.with_read_dst(self, |buf| {
132            let res = codec.decode(buf);
133            self.0.flags.unset_read_ready();
134            self.update_read_destination(buf);
135            res
136        })
137    }
138
139    /// Attempts to decode a frame from the read buffer.
140    pub fn decode_item<U>(
141        &self,
142        codec: &U,
143    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
144    where
145        U: Decoder,
146    {
147        self.0.buffer.with_read_dst(self, |buf| {
148            let len = buf.len();
149            let res = codec.decode(buf).map(|item| Decoded {
150                item,
151                remains: buf.len(),
152                consumed: len - buf.len(),
153            });
154            self.0.flags.unset_read_ready();
155            self.update_read_destination(buf);
156            res
157        })
158    }
159
160    /// Sends the write buffer to the I/O layer.
161    ///
162    /// Requires the underlying runtime to implement `.write()`;
163    /// otherwise, no action is taken.
164    pub fn send_buf(&self) -> io::Result<()> {
165        // try send bytes
166        self.consolidate_write_state(true);
167
168        if self.0.flags.is_stopping_any()
169            && let Some(err) = self.0.error.take()
170        {
171            Err(err)
172        } else {
173            Ok(())
174        }
175    }
176
177    pub(crate) fn ops_send_buf(&self) {
178        let st = &self.0;
179        if st.flags.is_wr_send_scheduled() {
180            st.flags.unset_wr_send_scheduled();
181
182            if st.flags.is_write_paused() {
183                // call `Handle::write()`.
184                // if write task is not paused, io write is pending
185                // need to wake write task for io completeion
186                if self.call_write() == WakeWriteTask::Yes {
187                    st.wake_write_task();
188                    st.flags.unset_write_paused();
189                }
190            } else {
191                st.wake_write_task();
192            }
193        }
194    }
195
196    /// Get access to filter buffer
197    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
198    where
199        F: FnOnce(&mut FilterBuf<'_>) -> R,
200    {
201        let result = self.0.buffer.with_filter(self, |ctx| ctx.with_buffer(f));
202        self.consolidate_write_state(false);
203        Ok(result)
204    }
205
206    /// Get mut access to read buffer
207    pub fn with_read_buf<F, R>(&self, f: F) -> R
208    where
209        F: FnOnce(&mut BytesMut) -> R,
210    {
211        self.0.buffer.with_read_dst(self, |buf| {
212            let res = f(buf);
213            self.update_read_destination(buf);
214            res
215        })
216    }
217
218    /// Get mut access to source write buffer
219    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
220    where
221        F: FnOnce(&mut BytePages) -> R,
222    {
223        let st = &self.0;
224
225        if st.flags.is_stopping_any() {
226            if st.flags.is_closed() {
227                Err(st.error_or_disconnected())
228            } else {
229                Err(io::Error::other("I/O stream is closing"))
230            }
231        } else {
232            let result = st.buffer.with_write_src(f);
233            self.consolidate_write_state(false);
234            Ok(result)
235        }
236    }
237
238    pub(crate) fn consolidate_write_state(&self, force: bool) {
239        let st = &self.0;
240
241        // wake write task if needsed
242        let size = st.buffer.write_buf_size();
243
244        #[cfg(feature = "trace")]
245        log::trace!("{}: write-upd == buf:{size} flags:{:?}", st.tag(), st.flags);
246
247        if size > 0 && st.flags.is_write_paused() {
248            // The app encodes data in response to incoming data,
249            // continuing to fill the write buffer until all data
250            // has been processed. Only then can the runtime wake
251            // the write task to send the buffered data.
252            //
253            // By that time, the buffer may have accumulated a large
254            // amount of data, causing it to be sent in large bursts,
255            // which introduces latency. To prevent this behavior and
256            // flatten data delivery to the peer, IoRef can initiate
257            // out-of-order writes based on a configured threshold.
258            if st.flags.is_direct_wr_enabled()
259                && (force || size >= st.cfg.write_buf_threshold())
260            {
261                // Send data in-place
262                if self.call_write() == WakeWriteTask::Yes {
263                    #[cfg(feature = "trace")]
264                    log::trace!(
265                        "{}: write-upd == schedule(more):{} flags:{:?}",
266                        st.tag(),
267                        st.buffer.write_buf_size(),
268                        st.flags
269                    );
270                    if !st.flags.is_wr_send_scheduled() {
271                        // More data needs to be sent
272                        st.flags.set_wr_send_scheduled();
273                        Iops::schedule_write(st.id());
274                    }
275                } else {
276                    st.flags.unset_wr_send_scheduled();
277                }
278            } else if !st.flags.is_wr_send_scheduled() {
279                #[cfg(feature = "trace")]
280                log::trace!("{}: write-upd == schedule(too small)", st.tag());
281                st.flags.set_wr_send_scheduled();
282                Iops::schedule_write(st.id());
283            }
284        }
285        // Enable backpressure
286        if !st.flags.is_wr_backpressure() && st.is_wr_backpressure_needed(size) {
287            st.flags.set_wr_backpressure();
288            st.wake_dispatch_task();
289        }
290    }
291
292    fn update_read_destination(&self, buf: &mut BytesMut) {
293        let st = &self.0;
294
295        #[cfg(feature = "trace")]
296        log::trace!(
297            "{}: read-upd == buf:{} flags:{:?}",
298            st.tag(),
299            buf.len(),
300            st.flags
301        );
302
303        if st.flags.is_rd_backpressure() {
304            // back-pressure is still eanbled
305            if st.is_rd_backpressure_needed(buf.len()) {
306                return;
307            }
308            st.flags.unset_all_read_flags();
309        } else {
310            st.flags.unset_read_ready();
311        }
312
313        if st.flags.is_read_paused() {
314            st.wake_read_task();
315            st.flags.unset_read_paused();
316        }
317    }
318
319    /// Make sure buffer has enough free space
320    pub fn resize_read_buf(&self, buf: &mut BytesMut) {
321        self.0.cfg.read_buf().resize(buf);
322    }
323
324    #[doc(hidden)]
325    #[deprecated(since = "3.10.3", note = "Use .notify_disapatcher()")]
326    /// Wakeup dispatcher
327    pub fn wake(&self) {
328        self.notify_dispatcher();
329    }
330
331    /// Wakeup dispatcher
332    pub fn notify_dispatcher(&self) {
333        log::trace!("{}: Timer, notify dispatcher", self.tag());
334        self.0.wake_dispatch_task();
335    }
336
337    /// Wakeup dispatcher and send keep-alive error
338    pub fn notify_timeout(&self) {
339        self.0.notify_timeout();
340    }
341
342    /// Current timer handle
343    pub fn timer_handle(&self) -> TimerHandle {
344        self.0.timeout.get()
345    }
346
347    /// Start timer
348    pub fn start_timer(&self, timeout: Seconds) -> TimerHandle {
349        let cur_hnd = self.0.timeout.get();
350
351        if timeout.is_zero() {
352            if cur_hnd.is_set() {
353                self.0.timeout.set(TimerHandle::ZERO);
354                cur_hnd.unregister(self);
355            }
356            TimerHandle::ZERO
357        } else if cur_hnd.is_set() {
358            let hnd = cur_hnd.update(timeout, self);
359            if hnd != cur_hnd {
360                log::trace!("{}: Update timer {:?}", self.tag(), timeout);
361                self.0.timeout.set(hnd);
362            }
363            hnd
364        } else {
365            log::trace!("{}: Start timer {:?}", self.tag(), timeout);
366            let hnd = TimerHandle::register(timeout, self);
367            self.0.timeout.set(hnd);
368            hnd
369        }
370    }
371
372    /// Stop timer
373    pub fn stop_timer(&self) {
374        let hnd = self.0.timeout.get();
375        if hnd.is_set() {
376            log::trace!("{}: Stop timer", self.tag());
377            self.0.timeout.set(TimerHandle::ZERO);
378            hnd.unregister(self);
379        }
380    }
381
382    /// Notify when io stream get disconnected
383    pub fn on_disconnect(&self) -> crate::OnDisconnect {
384        crate::OnDisconnect::new(self.0.clone())
385    }
386
387    /// Call handle write method, returns true if
388    /// `write-paused` is still set
389    fn call_write(&self) -> WakeWriteTask {
390        if let Some(hnd) = self.0.handle.take() {
391            self.0.flags.unset_write_paused();
392            #[cfg(feature = "trace")]
393            log::trace!(
394                "{}: call-write ({}), flags:{:?}",
395                self.tag(),
396                self.0.buffer.write_buf_size(),
397                self.0.flags
398            );
399            let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
400            hnd.write(ctx);
401            self.0.handle.set(Some(hnd));
402        }
403        if self.0.flags.is_write_paused() {
404            WakeWriteTask::No
405        } else {
406            WakeWriteTask::Yes
407        }
408    }
409
410    pub(crate) fn call_notify(&self) {
411        if let Some(hnd) = self.0.handle.take() {
412            let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
413            hnd.notify(ctx);
414            self.0.handle.set(Some(hnd));
415        }
416    }
417}
418
419#[derive(Copy, Clone, PartialEq, Eq, Debug)]
420enum WakeWriteTask {
421    Yes,
422    No,
423}
424
425impl Eq for IoRef {}
426
427impl PartialEq for IoRef {
428    #[inline]
429    fn eq(&self, other: &Self) -> bool {
430        self.0.eq(&other.0)
431    }
432}
433
434impl hash::Hash for IoRef {
435    #[inline]
436    fn hash<H: hash::Hasher>(&self, state: &mut H) {
437        self.0.hash(state);
438    }
439}
440
441impl fmt::Debug for IoRef {
442    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443        f.debug_struct("IoRef")
444            .field("state", self.0.as_ref())
445            .finish()
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use std::cell::{Cell, RefCell};
452    use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
453
454    use ntex_bytes::Bytes;
455    use ntex_codec::BytesCodec;
456    use ntex_util::{future::lazy, time::Millis, time::sleep};
457
458    use super::*;
459    use crate::{FilterCtx, Io, testing::IoTest};
460
461    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
462    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
463
464    #[ntex::test]
465    async fn utils() {
466        let (client, server) = IoTest::create();
467        client.remote_buffer_cap(1024);
468        client.write(TEXT);
469
470        let state = Io::from(server);
471        assert_eq!(state.get_ref(), state.get_ref());
472
473        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
474        assert_eq!(msg, Bytes::from_static(BIN));
475        assert_eq!(state.get_ref(), state.as_ref().clone());
476        assert!(format!("{state:?}").find("Io {").is_some());
477        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
478
479        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
480        assert!(res.is_pending());
481        client.write(TEXT);
482        sleep(Millis(50)).await;
483        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
484        if let Poll::Ready(msg) = res {
485            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
486        }
487
488        client.read_error(io::Error::other("err"));
489        let msg = state.recv(&BytesCodec).await;
490        assert!(msg.is_err());
491        assert!(state.flags().is_terminated());
492
493        let (client, server) = IoTest::create();
494        client.remote_buffer_cap(1024);
495        let state = Io::from(server);
496
497        client.read_error(io::Error::other("err"));
498        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
499        if let Poll::Ready(msg) = res {
500            assert!(msg.is_err());
501            assert!(state.flags().is_terminated());
502        }
503
504        let (client, server) = IoTest::create();
505        client.remote_buffer_cap(1024);
506        let state = Io::from(server);
507        state.encode_slice(b"test").unwrap();
508        let buf = client.read().await.unwrap();
509        assert_eq!(buf, Bytes::from_static(b"test"));
510
511        client.write(b"test");
512        state.read_ready().await.unwrap();
513        let buf = state.decode(&BytesCodec).unwrap().unwrap();
514        assert_eq!(buf, Bytes::from_static(b"test"));
515
516        client.write_error(io::Error::other("err"));
517        state
518            .send(Bytes::from_static(b"test"), &BytesCodec)
519            .await
520            .unwrap();
521        assert!(state.flags().is_terminated());
522
523        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
524        assert!(res.is_err());
525
526        let (client, server) = IoTest::create();
527        client.remote_buffer_cap(1024);
528        let state = Io::from(server);
529        state.terminate();
530        assert!(state.flags().is_stopping());
531        assert!(state.flags().is_terminated());
532    }
533
534    #[ntex::test]
535    #[allow(clippy::unit_cmp)]
536    async fn on_disconnect() {
537        let (client, server) = IoTest::create();
538        let state = Io::from(server);
539        let mut waiter = state.on_disconnect();
540        assert_eq!(
541            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
542            Poll::Pending
543        );
544        let mut waiter2 = waiter.clone();
545        assert_eq!(
546            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
547            Poll::Pending
548        );
549        client.close().await;
550        assert_eq!(waiter.await, ());
551        assert_eq!(waiter2.await, ());
552
553        let mut waiter = state.on_disconnect();
554        assert_eq!(
555            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
556            Poll::Ready(())
557        );
558
559        let (client, server) = IoTest::create();
560        let state = Io::from(server);
561        let mut waiter = state.on_disconnect();
562        assert_eq!(
563            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
564            Poll::Pending
565        );
566        client.read_error(io::Error::other("err"));
567        assert_eq!(waiter.await, ());
568    }
569
570    #[ntex::test]
571    async fn write_to_closed_io() {
572        let (client, server) = IoTest::create();
573        let state = Io::from(server);
574        client.close().await;
575
576        assert!(state.is_closed());
577        assert!(state.encode_slice(TEXT.as_bytes()).is_err());
578        assert!(state.encode_bytes(Bytes::from_static(BIN)).is_err());
579        assert!(
580            state
581                .with_write_buf(|buf| buf.extend_from_slice(BIN))
582                .is_err()
583        );
584    }
585
586    #[derive(Debug)]
587    struct Counter<F> {
588        layer: F,
589        idx: usize,
590        in_bytes: Rc<Cell<usize>>,
591        out_bytes: Rc<Cell<usize>>,
592        read_order: Rc<RefCell<Vec<usize>>>,
593        write_order: Rc<RefCell<Vec<usize>>>,
594    }
595
596    impl<F: Filter> Filter for Counter<F> {
597        fn process_read_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
598            self.read_order.borrow_mut().push(self.idx);
599            let result = self.layer.process_read_buf(ctx);
600            self.in_bytes
601                .set(self.in_bytes.get() + ctx.new_read_bytes());
602            result
603        }
604
605        fn process_write_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
606            self.write_order.borrow_mut().push(self.idx);
607            ctx.with_buffer(|buf| {
608                buf.with_write_buffers(|src, _| {
609                    self.out_bytes.set(self.out_bytes.get() + src.len());
610                });
611            });
612            self.layer.process_write_buf(ctx)
613        }
614
615        crate::forward_ready!(layer);
616        crate::forward_query!(layer);
617        crate::forward_shutdown!(layer);
618    }
619
620    #[ntex::test]
621    async fn filter() {
622        let in_bytes = Rc::new(Cell::new(0));
623        let out_bytes = Rc::new(Cell::new(0));
624        let read_order = Rc::new(RefCell::new(Vec::new()));
625        let write_order = Rc::new(RefCell::new(Vec::new()));
626
627        let (client, server) = IoTest::create();
628        let io = Io::from(server)
629            .map_filter(|layer| Counter {
630                layer,
631                idx: 1,
632                in_bytes: in_bytes.clone(),
633                out_bytes: out_bytes.clone(),
634                read_order: read_order.clone(),
635                write_order: write_order.clone(),
636            })
637            .seal();
638
639        client.remote_buffer_cap(1024);
640        client.write(TEXT);
641        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
642        assert_eq!(msg, Bytes::from_static(BIN));
643
644        io.send(Bytes::from_static(b"test"), &BytesCodec)
645            .await
646            .unwrap();
647        let buf = client.read().await.unwrap();
648        assert_eq!(buf, Bytes::from_static(b"test"));
649
650        client.write(TEXT);
651        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
652        assert_eq!(msg, Bytes::from_static(BIN));
653
654        assert_eq!(in_bytes.get(), BIN.len() * 2);
655        assert_eq!(out_bytes.get(), 8);
656    }
657
658    #[ntex::test]
659    async fn boxed_filter() {
660        let in_bytes = Rc::new(Cell::new(0));
661        let out_bytes = Rc::new(Cell::new(0));
662        let read_order = Rc::new(RefCell::new(Vec::new()));
663        let write_order = Rc::new(RefCell::new(Vec::new()));
664
665        let (client, server) = IoTest::create();
666        let state = Io::from(server)
667            .map_filter(|layer| Counter {
668                layer,
669                idx: 2,
670                in_bytes: in_bytes.clone(),
671                out_bytes: out_bytes.clone(),
672                read_order: read_order.clone(),
673                write_order: write_order.clone(),
674            })
675            .map_filter(|layer| Counter {
676                layer,
677                idx: 1,
678                in_bytes: in_bytes.clone(),
679                out_bytes: out_bytes.clone(),
680                read_order: read_order.clone(),
681                write_order: write_order.clone(),
682            });
683        let state = state.seal();
684
685        client.remote_buffer_cap(1024);
686        client.write(TEXT);
687        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
688        assert_eq!(msg, Bytes::from_static(BIN));
689
690        state
691            .send(Bytes::from_static(b"test"), &BytesCodec)
692            .await
693            .unwrap();
694        let buf = client.read().await.unwrap();
695        assert_eq!(buf, Bytes::from_static(b"test"));
696
697        assert_eq!(in_bytes.get(), BIN.len() * 2);
698        assert_eq!(out_bytes.get(), 16);
699        assert_eq!(state.0.buffer.with_write_dst(|b| b.len()), 0);
700
701        // refs
702        assert_eq!(Rc::strong_count(&in_bytes), 3);
703        drop(state);
704        assert_eq!(Rc::strong_count(&in_bytes), 1);
705        assert_eq!(*read_order.borrow(), &[1, 2][..]);
706        assert_eq!(*write_order.borrow(), &[1, 2, 1, 2, 1, 2][..]);
707    }
708}