Skip to main content

ntex_io/
ioref.rs

1use std::{any, fmt, hash, io, ptr};
2
3use ntex_bytes::{BytePage, BytePages, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::ops::{Id, Iops, TimerHandle};
9use crate::{Decoded, Filter, FilterBuf, Flags, IoConfig, IoContext, IoRef, types};
10
11impl IoRef {
12    #[inline]
13    /// Gets the ID.
14    pub fn id(&self) -> Id {
15        self.0.id()
16    }
17
18    #[inline]
19    /// Gets the I/O tag.
20    pub fn tag(&self) -> &'static str {
21        self.0.tag()
22    }
23
24    #[doc(hidden)]
25    /// Gets the current state flags.
26    pub fn flags(&self) -> Flags {
27        self.0.flags.clone()
28    }
29
30    #[inline]
31    /// Gets the current filter.
32    pub(crate) fn filter(&self) -> &dyn Filter {
33        self.0.filter()
34    }
35
36    #[inline]
37    /// Gets the configuration.
38    pub fn cfg(&self) -> &IoConfig {
39        &self.0.cfg
40    }
41
42    #[inline]
43    /// Gets the shared configuration.
44    pub fn shared(&self) -> SharedCfg {
45        self.0.cfg.shared()
46    }
47
48    #[inline]
49    /// Checks whether the I/O stream is closed.
50    pub fn is_closed(&self) -> bool {
51        self.0.flags.is_closed()
52    }
53
54    #[inline]
55    /// Checks whether write back-pressure is enabled.
56    pub fn is_wr_backpressure(&self) -> bool {
57        self.0.flags.is_wr_backpressure()
58    }
59
60    /// Gracefully closes the connection.
61    ///
62    /// Initiates the I/O stream shutdown process.
63    pub fn close(&self) {
64        self.0.start_shutdown();
65    }
66
67    /// Force-closes the connection.
68    ///
69    /// The dispatcher does not wait for incomplete responses. The I/O stream is
70    /// terminated without any graceful period.
71    pub fn terminate(&self) {
72        log::trace!("{}: Terminate io stream object", self.tag());
73        self.0.terminate_connection(None);
74    }
75
76    #[doc(hidden)]
77    #[deprecated(since = "3.10.0", note = "use IoRef::terminate() instead")]
78    /// Force close connection
79    ///
80    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
81    /// without any graceful period.
82    pub fn force_close(&self) {
83        self.terminate();
84    }
85
86    #[doc(hidden)]
87    #[deprecated(since = "3.11.0", note = "use IoRef::close() instead")]
88    /// Gracefully shuts down the I/O stream.
89    pub fn wants_shutdown(&self) {
90        self.0.start_shutdown();
91    }
92
93    /// Queries filter-specific data.
94    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
95        types::QueryItem::new(self.filter().query(any::TypeId::of::<T>()))
96    }
97
98    #[inline]
99    /// Encodes the item into the write buffer.
100    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
101    where
102        U: Encoder,
103    {
104        self.with_write_buf(|buf| codec.encodev(item, buf))
105            .unwrap_or_else(|_| Ok(()))
106    }
107
108    #[inline]
109    /// Encodes the slice into the write buffer.
110    pub fn encode_slice(&self, src: &[u8]) -> io::Result<()> {
111        self.with_write_buf(|buf| buf.extend_from_slice(src))
112    }
113
114    #[inline]
115    /// Writes bytes to the write buffer.
116    pub fn encode_bytes<B>(&self, src: B) -> io::Result<()>
117    where
118        BytePage: From<B>,
119    {
120        self.with_write_buf(|buf| buf.append(src))
121    }
122
123    /// Attempts to decode a frame from the read buffer.
124    pub fn decode<U>(
125        &self,
126        codec: &U,
127    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
128    where
129        U: Decoder,
130    {
131        self.0.buffer.with_read_dst(self, |buf| {
132            let res = codec.decode(buf);
133            self.0.flags.unset_read_ready();
134            self.update_read_destination(buf);
135            res
136        })
137    }
138
139    /// Attempts to decode a frame from the read buffer.
140    pub fn decode_item<U>(
141        &self,
142        codec: &U,
143    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
144    where
145        U: Decoder,
146    {
147        self.0.buffer.with_read_dst(self, |buf| {
148            let len = buf.len();
149            let res = codec.decode(buf).map(|item| Decoded {
150                item,
151                remains: buf.len(),
152                consumed: len - buf.len(),
153            });
154            self.0.flags.unset_read_ready();
155            self.update_read_destination(buf);
156            res
157        })
158    }
159
160    /// Sends the write buffer to the I/O layer.
161    ///
162    /// Requires the underlying runtime to implement `.write()`;
163    /// otherwise, no action is taken.
164    pub fn send_buf(&self) -> io::Result<()> {
165        // try send bytes
166        self.consolidate_write_state(true);
167
168        if self.0.flags.is_stopping_any()
169            && let Some(err) = self.0.error.take()
170        {
171            Err(err)
172        } else {
173            Ok(())
174        }
175    }
176
177    pub(crate) fn ops_send_buf(&self) {
178        let st = &self.0;
179        #[cfg(feature = "trace")]
180        log::trace!(
181            "{}: ops-send == buf:{} flags:{:?}",
182            st.tag(),
183            st.buffer.write_buf_size(),
184            st.flags
185        );
186
187        if st.flags.is_wr_send_scheduled() {
188            st.flags.unset_wr_send_scheduled();
189
190            if st.flags.is_write_paused() {
191                // call `Handle::write()`.
192                // if write task is not paused, io write is pending
193                // need to wake write task for io completeion
194                if self.call_write() == WakeWriteTask::Yes {
195                    st.wake_write_task();
196                    st.flags.unset_write_paused();
197                }
198            } else {
199                st.wake_write_task();
200            }
201        }
202    }
203
204    /// Get access to filter buffer
205    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
206    where
207        F: FnOnce(&mut FilterBuf<'_>) -> R,
208    {
209        let result = self.0.buffer.with_filter(self, |ctx| ctx.with_buffer(f));
210        self.consolidate_write_state(false);
211        Ok(result)
212    }
213
214    /// Get mut access to read buffer
215    pub fn with_read_buf<F, R>(&self, f: F) -> R
216    where
217        F: FnOnce(&mut BytesMut) -> R,
218    {
219        self.0.buffer.with_read_dst(self, |buf| {
220            let res = f(buf);
221            self.update_read_destination(buf);
222            res
223        })
224    }
225
226    /// Get mut access to source write buffer
227    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
228    where
229        F: FnOnce(&mut BytePages) -> R,
230    {
231        let st = &self.0;
232
233        if st.flags.is_stopping_any() {
234            if st.flags.is_closed() {
235                Err(st.error_or_disconnected())
236            } else {
237                Err(io::Error::other("I/O stream is closing"))
238            }
239        } else {
240            let result = st.buffer.with_write_src(f);
241            self.consolidate_write_state(false);
242            Ok(result)
243        }
244    }
245
246    pub(crate) fn consolidate_write_state(&self, force: bool) {
247        let st = &self.0;
248
249        // wake write task if needsed
250        let size = st.buffer.write_buf_size();
251
252        #[cfg(feature = "trace")]
253        log::trace!("{}: write-upd == buf:{size} flags:{:?}", st.tag(), st.flags);
254
255        if size > 0 && st.flags.is_write_paused() {
256            // The app encodes data in response to incoming data,
257            // continuing to fill the write buffer until all data
258            // has been processed. Only then can the runtime wake
259            // the write task to send the buffered data.
260            //
261            // By that time, the buffer may have accumulated a large
262            // amount of data, causing it to be sent in large bursts,
263            // which introduces latency. To prevent this behavior and
264            // flatten data delivery to the peer, IoRef can initiate
265            // out-of-order writes based on a configured threshold.
266            if st.flags.is_direct_wr_enabled()
267                && (force || size >= st.cfg.write_buf_threshold())
268            {
269                // Send data in-place
270                if self.call_write() == WakeWriteTask::Yes {
271                    #[cfg(feature = "trace")]
272                    log::trace!(
273                        "{}: write-upd == schedule(more):{} flags:{:?}",
274                        st.tag(),
275                        st.buffer.write_buf_size(),
276                        st.flags
277                    );
278                    if !st.flags.is_wr_send_scheduled() {
279                        // More data needs to be sent
280                        st.flags.set_wr_send_scheduled();
281                        Iops::schedule_write(st.id());
282                    }
283                } else {
284                    st.flags.unset_wr_send_scheduled();
285                }
286            } else if !st.flags.is_wr_send_scheduled() {
287                #[cfg(feature = "trace")]
288                log::trace!("{}: write-upd == schedule(too small)", st.tag());
289                st.flags.set_wr_send_scheduled();
290                Iops::schedule_write(st.id());
291            }
292        }
293        // Enable backpressure
294        if !st.flags.is_wr_backpressure() && st.is_wr_backpressure_needed(size) {
295            st.flags.set_wr_backpressure();
296            st.wake_dispatch_task();
297        }
298    }
299
300    fn update_read_destination(&self, buf: &mut BytesMut) {
301        let st = &self.0;
302
303        #[cfg(feature = "trace")]
304        log::trace!(
305            "{}: read-upd == buf:{} flags:{:?}",
306            st.tag(),
307            buf.len(),
308            st.flags
309        );
310
311        if st.flags.is_rd_backpressure() {
312            // back-pressure is still eanbled
313            if st.is_rd_backpressure_needed(buf.len()) {
314                return;
315            }
316            st.flags.unset_all_read_flags();
317        } else {
318            st.flags.unset_read_ready();
319        }
320
321        if st.flags.is_read_paused() {
322            st.wake_read_task();
323            st.flags.unset_read_paused();
324        }
325    }
326
327    /// Make sure buffer has enough free space
328    pub fn resize_read_buf(&self, buf: &mut BytesMut) {
329        self.0.cfg.read_buf().resize(buf);
330    }
331
332    #[doc(hidden)]
333    #[deprecated(since = "3.10.3", note = "Use .notify_disapatcher()")]
334    /// Wakeup dispatcher
335    pub fn wake(&self) {
336        self.notify_dispatcher();
337    }
338
339    /// Wakeup dispatcher
340    pub fn notify_dispatcher(&self) {
341        log::trace!("{}: Timer, notify dispatcher", self.tag());
342        self.0.wake_dispatch_task();
343    }
344
345    /// Wakeup dispatcher and send keep-alive error
346    pub fn notify_timeout(&self) {
347        self.0.notify_timeout();
348    }
349
350    /// Current timer handle
351    pub fn timer_handle(&self) -> TimerHandle {
352        self.0.timeout.get()
353    }
354
355    /// Start timer
356    pub fn start_timer(&self, timeout: Seconds) -> TimerHandle {
357        let cur_hnd = self.0.timeout.get();
358
359        if timeout.is_zero() {
360            if cur_hnd.is_set() {
361                self.0.timeout.set(TimerHandle::ZERO);
362                cur_hnd.unregister(self);
363            }
364            TimerHandle::ZERO
365        } else if cur_hnd.is_set() {
366            let hnd = cur_hnd.update(timeout, self);
367            if hnd != cur_hnd {
368                log::trace!("{}: Update timer {:?}", self.tag(), timeout);
369                self.0.timeout.set(hnd);
370            }
371            hnd
372        } else {
373            log::trace!("{}: Start timer {:?}", self.tag(), timeout);
374            let hnd = TimerHandle::register(timeout, self);
375            self.0.timeout.set(hnd);
376            hnd
377        }
378    }
379
380    /// Stop timer
381    pub fn stop_timer(&self) {
382        let hnd = self.0.timeout.get();
383        if hnd.is_set() {
384            log::trace!("{}: Stop timer", self.tag());
385            self.0.timeout.set(TimerHandle::ZERO);
386            hnd.unregister(self);
387        }
388    }
389
390    /// Notify when io stream get disconnected
391    pub fn on_disconnect(&self) -> crate::OnDisconnect {
392        crate::OnDisconnect::new(self.0.clone())
393    }
394
395    /// Call handle write method, returns true if
396    /// `write-paused` is still set
397    fn call_write(&self) -> WakeWriteTask {
398        if let Some(hnd) = self.0.handle.take() {
399            self.0.flags.unset_write_paused();
400            #[cfg(feature = "trace")]
401            log::trace!(
402                "{}: call-write ({}), flags:{:?}",
403                self.tag(),
404                self.0.buffer.write_buf_size(),
405                self.0.flags
406            );
407            let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
408            hnd.write(ctx);
409            self.0.handle.set(Some(hnd));
410        }
411        if self.0.flags.is_write_paused() {
412            WakeWriteTask::No
413        } else {
414            WakeWriteTask::Yes
415        }
416    }
417
418    pub(crate) fn call_notify(&self) {
419        if let Some(hnd) = self.0.handle.take() {
420            let ctx = unsafe { &*(ptr::from_ref(self).cast::<IoContext>()) };
421            hnd.notify(ctx);
422            self.0.handle.set(Some(hnd));
423        }
424    }
425}
426
427#[derive(Copy, Clone, PartialEq, Eq, Debug)]
428enum WakeWriteTask {
429    Yes,
430    No,
431}
432
433impl Eq for IoRef {}
434
435impl PartialEq for IoRef {
436    #[inline]
437    fn eq(&self, other: &Self) -> bool {
438        self.0.eq(&other.0)
439    }
440}
441
442impl hash::Hash for IoRef {
443    #[inline]
444    fn hash<H: hash::Hasher>(&self, state: &mut H) {
445        self.0.hash(state);
446    }
447}
448
449impl fmt::Debug for IoRef {
450    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451        f.debug_struct("IoRef")
452            .field("state", self.0.as_ref())
453            .finish()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use std::cell::{Cell, RefCell};
460    use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
461
462    use ntex_bytes::Bytes;
463    use ntex_codec::BytesCodec;
464    use ntex_util::{future::lazy, time::Millis, time::sleep};
465
466    use super::*;
467    use crate::{FilterCtx, Io, testing::IoTest};
468
469    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
470    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
471
472    #[ntex::test]
473    async fn utils() {
474        let (client, server) = IoTest::create();
475        client.remote_buffer_cap(1024);
476        client.write(TEXT);
477
478        let state = Io::from(server);
479        assert_eq!(state.get_ref(), state.get_ref());
480
481        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
482        assert_eq!(msg, Bytes::from_static(BIN));
483        assert_eq!(state.get_ref(), state.as_ref().clone());
484        assert!(format!("{state:?}").find("Io {").is_some());
485        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
486
487        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
488        assert!(res.is_pending());
489        client.write(TEXT);
490        sleep(Millis(50)).await;
491        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
492        if let Poll::Ready(msg) = res {
493            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
494        }
495
496        client.read_error(io::Error::other("err"));
497        let msg = state.recv(&BytesCodec).await;
498        assert!(msg.is_err());
499        assert!(state.flags().is_terminated());
500
501        let (client, server) = IoTest::create();
502        client.remote_buffer_cap(1024);
503        let state = Io::from(server);
504
505        client.read_error(io::Error::other("err"));
506        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
507        if let Poll::Ready(msg) = res {
508            assert!(msg.is_err());
509            assert!(state.flags().is_terminated());
510        }
511
512        let (client, server) = IoTest::create();
513        client.remote_buffer_cap(1024);
514        let state = Io::from(server);
515        state.encode_slice(b"test").unwrap();
516        let buf = client.read().await.unwrap();
517        assert_eq!(buf, Bytes::from_static(b"test"));
518
519        client.write(b"test");
520        state.read_ready().await.unwrap();
521        let buf = state.decode(&BytesCodec).unwrap().unwrap();
522        assert_eq!(buf, Bytes::from_static(b"test"));
523
524        client.write_error(io::Error::other("err"));
525        state
526            .send(Bytes::from_static(b"test"), &BytesCodec)
527            .await
528            .unwrap();
529        assert!(state.flags().is_terminated());
530
531        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
532        assert!(res.is_err());
533
534        let (client, server) = IoTest::create();
535        client.remote_buffer_cap(1024);
536        let state = Io::from(server);
537        state.terminate();
538        assert!(state.flags().is_stopping());
539        assert!(state.flags().is_terminated());
540    }
541
542    #[ntex::test]
543    #[allow(clippy::unit_cmp)]
544    async fn on_disconnect() {
545        let (client, server) = IoTest::create();
546        let state = Io::from(server);
547        let mut waiter = state.on_disconnect();
548        assert_eq!(
549            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
550            Poll::Pending
551        );
552        let mut waiter2 = waiter.clone();
553        assert_eq!(
554            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
555            Poll::Pending
556        );
557        client.close().await;
558        assert_eq!(waiter.await, ());
559        assert_eq!(waiter2.await, ());
560
561        let mut waiter = state.on_disconnect();
562        assert_eq!(
563            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
564            Poll::Ready(())
565        );
566
567        let (client, server) = IoTest::create();
568        let state = Io::from(server);
569        let mut waiter = state.on_disconnect();
570        assert_eq!(
571            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
572            Poll::Pending
573        );
574        client.read_error(io::Error::other("err"));
575        assert_eq!(waiter.await, ());
576    }
577
578    #[ntex::test]
579    async fn write_to_closed_io() {
580        let (client, server) = IoTest::create();
581        let state = Io::from(server);
582        client.close().await;
583
584        assert!(state.is_closed());
585        assert!(state.encode_slice(TEXT.as_bytes()).is_err());
586        assert!(state.encode_bytes(Bytes::from_static(BIN)).is_err());
587        assert!(
588            state
589                .with_write_buf(|buf| buf.extend_from_slice(BIN))
590                .is_err()
591        );
592    }
593
594    #[derive(Debug)]
595    struct Counter<F> {
596        layer: F,
597        idx: usize,
598        in_bytes: Rc<Cell<usize>>,
599        out_bytes: Rc<Cell<usize>>,
600        read_order: Rc<RefCell<Vec<usize>>>,
601        write_order: Rc<RefCell<Vec<usize>>>,
602    }
603
604    impl<F: Filter> Filter for Counter<F> {
605        fn process_read_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
606            self.read_order.borrow_mut().push(self.idx);
607            let result = self.layer.process_read_buf(ctx);
608            self.in_bytes
609                .set(self.in_bytes.get() + ctx.new_read_bytes());
610            result
611        }
612
613        fn process_write_buf(&self, ctx: &mut FilterCtx<'_>) -> io::Result<()> {
614            self.write_order.borrow_mut().push(self.idx);
615            ctx.with_buffer(|buf| {
616                buf.with_write_buffers(|src, _| {
617                    self.out_bytes.set(self.out_bytes.get() + src.len());
618                });
619            });
620            self.layer.process_write_buf(ctx)
621        }
622
623        crate::forward_ready!(layer);
624        crate::forward_query!(layer);
625        crate::forward_shutdown!(layer);
626    }
627
628    #[ntex::test]
629    async fn filter() {
630        let in_bytes = Rc::new(Cell::new(0));
631        let out_bytes = Rc::new(Cell::new(0));
632        let read_order = Rc::new(RefCell::new(Vec::new()));
633        let write_order = Rc::new(RefCell::new(Vec::new()));
634
635        let (client, server) = IoTest::create();
636        let io = Io::from(server)
637            .map_filter(|layer| Counter {
638                layer,
639                idx: 1,
640                in_bytes: in_bytes.clone(),
641                out_bytes: out_bytes.clone(),
642                read_order: read_order.clone(),
643                write_order: write_order.clone(),
644            })
645            .seal();
646
647        client.remote_buffer_cap(1024);
648        client.write(TEXT);
649        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
650        assert_eq!(msg, Bytes::from_static(BIN));
651
652        io.send(Bytes::from_static(b"test"), &BytesCodec)
653            .await
654            .unwrap();
655        let buf = client.read().await.unwrap();
656        assert_eq!(buf, Bytes::from_static(b"test"));
657
658        client.write(TEXT);
659        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
660        assert_eq!(msg, Bytes::from_static(BIN));
661
662        assert_eq!(in_bytes.get(), BIN.len() * 2);
663        assert_eq!(out_bytes.get(), 8);
664    }
665
666    #[ntex::test]
667    async fn boxed_filter() {
668        let in_bytes = Rc::new(Cell::new(0));
669        let out_bytes = Rc::new(Cell::new(0));
670        let read_order = Rc::new(RefCell::new(Vec::new()));
671        let write_order = Rc::new(RefCell::new(Vec::new()));
672
673        let (client, server) = IoTest::create();
674        let state = Io::from(server)
675            .map_filter(|layer| Counter {
676                layer,
677                idx: 2,
678                in_bytes: in_bytes.clone(),
679                out_bytes: out_bytes.clone(),
680                read_order: read_order.clone(),
681                write_order: write_order.clone(),
682            })
683            .map_filter(|layer| Counter {
684                layer,
685                idx: 1,
686                in_bytes: in_bytes.clone(),
687                out_bytes: out_bytes.clone(),
688                read_order: read_order.clone(),
689                write_order: write_order.clone(),
690            });
691        let state = state.seal();
692
693        client.remote_buffer_cap(1024);
694        client.write(TEXT);
695        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
696        assert_eq!(msg, Bytes::from_static(BIN));
697
698        state
699            .send(Bytes::from_static(b"test"), &BytesCodec)
700            .await
701            .unwrap();
702        let buf = client.read().await.unwrap();
703        assert_eq!(buf, Bytes::from_static(b"test"));
704
705        assert_eq!(in_bytes.get(), BIN.len() * 2);
706        assert_eq!(out_bytes.get(), 16);
707        assert_eq!(state.0.buffer.with_write_dst(|b| b.len()), 0);
708
709        // refs
710        assert_eq!(Rc::strong_count(&in_bytes), 3);
711        drop(state);
712        assert_eq!(Rc::strong_count(&in_bytes), 1);
713        assert_eq!(*read_order.borrow(), &[1, 2][..]);
714        assert_eq!(*write_order.borrow(), &[1, 2, 1, 2, 1, 2][..]);
715    }
716}