ntex_io/
ioref.rs

1use std::{any, fmt, hash, io};
2
3use ntex_bytes::BytesVec;
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::{
9    Decoded, Filter, FilterCtx, Flags, IoConfig, IoRef, OnDisconnect, WriteBuf, timer,
10    types,
11};
12
13impl IoRef {
14    #[inline]
15    /// Get tag
16    pub fn tag(&self) -> &'static str {
17        self.0.cfg.get().tag()
18    }
19
20    #[inline]
21    #[doc(hidden)]
22    /// Get current state flags
23    pub fn flags(&self) -> Flags {
24        self.0.flags.get()
25    }
26
27    #[inline]
28    /// Get current filter
29    pub(crate) fn filter(&self) -> &dyn Filter {
30        self.0.filter()
31    }
32
33    #[inline]
34    /// Get configuration
35    pub fn cfg(&self) -> &IoConfig {
36        self.0.cfg.get()
37    }
38
39    #[inline]
40    /// Get shared configuration
41    pub fn shared(&self) -> SharedCfg {
42        self.0.cfg.get().config.shared()
43    }
44
45    #[inline]
46    /// Check if io stream is closed
47    pub fn is_closed(&self) -> bool {
48        self.0
49            .flags
50            .get()
51            .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
52    }
53
54    #[inline]
55    /// Check if write back-pressure is enabled
56    pub fn is_wr_backpressure(&self) -> bool {
57        self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
58    }
59
60    #[inline]
61    /// Wake dispatcher task
62    pub fn wake(&self) {
63        self.0.dispatch_task.wake();
64    }
65
66    #[inline]
67    /// Gracefully close connection
68    ///
69    /// Initiate io stream shutdown process.
70    pub fn close(&self) {
71        self.0.init_shutdown();
72    }
73
74    #[inline]
75    /// Force close connection
76    ///
77    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
78    /// without any graceful period.
79    pub fn force_close(&self) {
80        log::trace!("{}: Force close io stream object", self.tag());
81        self.0.insert_flags(
82            Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
83        );
84        self.0.read_task.wake();
85        self.0.write_task.wake();
86        self.0.dispatch_task.wake();
87    }
88
89    #[inline]
90    /// Gracefully shutdown io stream
91    pub fn want_shutdown(&self) {
92        if !self
93            .0
94            .flags
95            .get()
96            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
97        {
98            log::trace!(
99                "{}: Initiate io shutdown {:?}",
100                self.tag(),
101                self.0.flags.get()
102            );
103            self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
104            self.0.read_task.wake();
105        }
106    }
107
108    #[inline]
109    /// Query filter specific data
110    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
111        if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
112            types::QueryItem::new(item)
113        } else {
114            types::QueryItem::empty()
115        }
116    }
117
118    #[inline]
119    /// Encode and write item to a buffer and wake up write task
120    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
121    where
122        U: Encoder,
123    {
124        if !self.is_closed() {
125            self.with_write_buf(|buf| {
126                // make sure we've got room
127                self.cfg().write_buf().resize(buf);
128
129                // encode item and wake write task
130                codec.encode_vec(item, buf)
131            })
132            // .with_write_buf() could return io::Error<Result<(), U::Error>>,
133            // in that case mark io as failed
134            .unwrap_or_else(|err| {
135                log::trace!(
136                    "{}: Got io error while encoding, error: {:?}",
137                    self.tag(),
138                    err
139                );
140                self.0.io_stopped(Some(err));
141                Ok(())
142            })
143        } else {
144            log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
145            Ok(())
146        }
147    }
148
149    #[inline]
150    /// Attempts to decode a frame from the read buffer
151    pub fn decode<U>(
152        &self,
153        codec: &U,
154    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
155    where
156        U: Decoder,
157    {
158        self.0
159            .buffer
160            .with_read_destination(self, |buf| codec.decode_vec(buf))
161    }
162
163    #[inline]
164    /// Attempts to decode a frame from the read buffer
165    pub fn decode_item<U>(
166        &self,
167        codec: &U,
168    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
169    where
170        U: Decoder,
171    {
172        self.0.buffer.with_read_destination(self, |buf| {
173            let len = buf.len();
174            codec.decode_vec(buf).map(|item| Decoded {
175                item,
176                remains: buf.len(),
177                consumed: len - buf.len(),
178            })
179        })
180    }
181
182    #[inline]
183    /// Write bytes to a buffer and wake up write task
184    pub fn write(&self, src: &[u8]) -> io::Result<()> {
185        self.with_write_buf(|buf| buf.extend_from_slice(src))
186    }
187
188    #[inline]
189    /// Get access to write buffer
190    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
191    where
192        F: FnOnce(&WriteBuf<'_>) -> R,
193    {
194        let ctx = FilterCtx::new(self, &self.0.buffer);
195        let result = ctx.write_buf(f);
196        self.0.filter().process_write_buf(ctx)?;
197        Ok(result)
198    }
199
200    #[inline]
201    /// Get mut access to source write buffer
202    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
203    where
204        F: FnOnce(&mut BytesVec) -> R,
205    {
206        if self.0.flags.get().contains(Flags::IO_STOPPED) {
207            Err(self.0.error_or_disconnected())
208        } else {
209            let result = self.0.buffer.with_write_source(self, f);
210            self.0
211                .filter()
212                .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
213            Ok(result)
214        }
215    }
216
217    #[doc(hidden)]
218    #[inline]
219    /// Get mut access to destination write buffer
220    pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
221    where
222        F: FnOnce(Option<&mut BytesVec>) -> R,
223    {
224        self.0.buffer.with_write_destination(self, f)
225    }
226
227    #[inline]
228    /// Get mut access to source read buffer
229    pub fn with_read_buf<F, R>(&self, f: F) -> R
230    where
231        F: FnOnce(&mut BytesVec) -> R,
232    {
233        self.0.buffer.with_read_destination(self, f)
234    }
235
236    #[inline]
237    /// Wakeup dispatcher
238    pub fn notify_dispatcher(&self) {
239        self.0.dispatch_task.wake();
240        log::trace!("{}: Timer, notify dispatcher", self.tag());
241    }
242
243    #[inline]
244    /// Wakeup dispatcher and send keep-alive error
245    pub fn notify_timeout(&self) {
246        self.0.notify_timeout()
247    }
248
249    #[inline]
250    /// current timer handle
251    pub fn timer_handle(&self) -> timer::TimerHandle {
252        self.0.timeout.get()
253    }
254
255    #[inline]
256    /// Start timer
257    pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
258        let cur_hnd = self.0.timeout.get();
259
260        if !timeout.is_zero() {
261            if cur_hnd.is_set() {
262                let hnd = timer::update(cur_hnd, timeout, self);
263                if hnd != cur_hnd {
264                    log::trace!("{}: Update timer {:?}", self.tag(), timeout);
265                    self.0.timeout.set(hnd);
266                }
267                hnd
268            } else {
269                log::trace!("{}: Start timer {:?}", self.tag(), timeout);
270                let hnd = timer::register(timeout, self);
271                self.0.timeout.set(hnd);
272                hnd
273            }
274        } else {
275            if cur_hnd.is_set() {
276                self.0.timeout.set(timer::TimerHandle::ZERO);
277                timer::unregister(cur_hnd, self);
278            }
279            timer::TimerHandle::ZERO
280        }
281    }
282
283    #[inline]
284    /// Stop timer
285    pub fn stop_timer(&self) {
286        let hnd = self.0.timeout.get();
287        if hnd.is_set() {
288            log::trace!("{}: Stop timer", self.tag());
289            self.0.timeout.set(timer::TimerHandle::ZERO);
290            timer::unregister(hnd, self)
291        }
292    }
293
294    #[inline]
295    /// Notify when io stream get disconnected
296    pub fn on_disconnect(&self) -> OnDisconnect {
297        OnDisconnect::new(self.0.clone())
298    }
299}
300
301impl Eq for IoRef {}
302
303impl PartialEq for IoRef {
304    #[inline]
305    fn eq(&self, other: &Self) -> bool {
306        self.0.eq(&other.0)
307    }
308}
309
310impl hash::Hash for IoRef {
311    #[inline]
312    fn hash<H: hash::Hasher>(&self, state: &mut H) {
313        self.0.hash(state);
314    }
315}
316
317impl fmt::Debug for IoRef {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        f.debug_struct("IoRef")
320            .field("state", self.0.as_ref())
321            .finish()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use std::cell::{Cell, RefCell};
328    use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
329
330    use ntex_bytes::Bytes;
331    use ntex_codec::BytesCodec;
332    use ntex_util::future::lazy;
333    use ntex_util::time::{Millis, sleep};
334
335    use super::*;
336    use crate::{FilterCtx, FilterReadStatus, Io, testing::IoTest};
337
338    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
339    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
340
341    #[ntex::test]
342    async fn utils() {
343        let (client, server) = IoTest::create();
344        client.remote_buffer_cap(1024);
345        client.write(TEXT);
346
347        let state = Io::from(server);
348        assert_eq!(state.get_ref(), state.get_ref());
349
350        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
351        assert_eq!(msg, Bytes::from_static(BIN));
352        assert_eq!(state.get_ref(), state.as_ref().clone());
353        assert!(format!("{state:?}").find("Io {").is_some());
354        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
355
356        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
357        assert!(res.is_pending());
358        client.write(TEXT);
359        sleep(Millis(50)).await;
360        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
361        if let Poll::Ready(msg) = res {
362            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
363        }
364
365        client.read_error(io::Error::other("err"));
366        let msg = state.recv(&BytesCodec).await;
367        assert!(msg.is_err());
368        assert!(state.flags().contains(Flags::IO_STOPPED));
369
370        let (client, server) = IoTest::create();
371        client.remote_buffer_cap(1024);
372        let state = Io::from(server);
373
374        client.read_error(io::Error::other("err"));
375        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
376        if let Poll::Ready(msg) = res {
377            assert!(msg.is_err());
378            assert!(state.flags().contains(Flags::IO_STOPPED));
379        }
380
381        let (client, server) = IoTest::create();
382        client.remote_buffer_cap(1024);
383        let state = Io::from(server);
384        state.write(b"test").unwrap();
385        let buf = client.read().await.unwrap();
386        assert_eq!(buf, Bytes::from_static(b"test"));
387
388        client.write(b"test");
389        state.read_ready().await.unwrap();
390        let buf = state.decode(&BytesCodec).unwrap().unwrap();
391        assert_eq!(buf, Bytes::from_static(b"test"));
392
393        client.write_error(io::Error::other("err"));
394        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
395        assert!(res.is_err());
396        assert!(state.flags().contains(Flags::IO_STOPPED));
397
398        let (client, server) = IoTest::create();
399        client.remote_buffer_cap(1024);
400        let state = Io::from(server);
401        state.force_close();
402        assert!(state.flags().contains(Flags::IO_STOPPED));
403        assert!(state.flags().contains(Flags::IO_STOPPING));
404    }
405
406    #[ntex::test]
407    async fn read_readiness() {
408        let (client, server) = IoTest::create();
409        client.remote_buffer_cap(1024);
410
411        let io = Io::from(server);
412        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
413
414        client.write(TEXT);
415        assert_eq!(io.read_ready().await.unwrap(), Some(()));
416        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
417
418        let item = io.with_read_buf(|buffer| buffer.split());
419        assert_eq!(item, Bytes::from_static(BIN));
420
421        client.write(TEXT);
422        sleep(Millis(50)).await;
423        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
424        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
425    }
426
427    #[ntex::test]
428    #[allow(clippy::unit_cmp)]
429    async fn on_disconnect() {
430        let (client, server) = IoTest::create();
431        let state = Io::from(server);
432        let mut waiter = state.on_disconnect();
433        assert_eq!(
434            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
435            Poll::Pending
436        );
437        let mut waiter2 = waiter.clone();
438        assert_eq!(
439            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
440            Poll::Pending
441        );
442        client.close().await;
443        assert_eq!(waiter.await, ());
444        assert_eq!(waiter2.await, ());
445
446        let mut waiter = state.on_disconnect();
447        assert_eq!(
448            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
449            Poll::Ready(())
450        );
451
452        let (client, server) = IoTest::create();
453        let state = Io::from(server);
454        let mut waiter = state.on_disconnect();
455        assert_eq!(
456            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
457            Poll::Pending
458        );
459        client.read_error(io::Error::other("err"));
460        assert_eq!(waiter.await, ());
461    }
462
463    #[ntex::test]
464    async fn write_to_closed_io() {
465        let (client, server) = IoTest::create();
466        let state = Io::from(server);
467        client.close().await;
468
469        assert!(state.is_closed());
470        assert!(state.write(TEXT.as_bytes()).is_err());
471        assert!(
472            state
473                .with_write_buf(|buf| buf.extend_from_slice(BIN))
474                .is_err()
475        );
476    }
477
478    #[derive(Debug)]
479    struct Counter<F> {
480        layer: F,
481        idx: usize,
482        in_bytes: Rc<Cell<usize>>,
483        out_bytes: Rc<Cell<usize>>,
484        read_order: Rc<RefCell<Vec<usize>>>,
485        write_order: Rc<RefCell<Vec<usize>>>,
486    }
487
488    impl<F: Filter> Filter for Counter<F> {
489        fn process_read_buf(
490            &self,
491            ctx: FilterCtx<'_>,
492            nbytes: usize,
493        ) -> io::Result<FilterReadStatus> {
494            self.read_order.borrow_mut().push(self.idx);
495            self.in_bytes.set(self.in_bytes.get() + nbytes);
496            self.layer.process_read_buf(ctx, nbytes)
497        }
498
499        fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
500            self.write_order.borrow_mut().push(self.idx);
501            self.out_bytes.set(
502                self.out_bytes.get()
503                    + ctx.write_buf(|buf| {
504                        buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
505                    }),
506            );
507            self.layer.process_write_buf(ctx)
508        }
509
510        crate::forward_ready!(layer);
511        crate::forward_query!(layer);
512        crate::forward_shutdown!(layer);
513    }
514
515    #[ntex::test]
516    async fn filter() {
517        let in_bytes = Rc::new(Cell::new(0));
518        let out_bytes = Rc::new(Cell::new(0));
519        let read_order = Rc::new(RefCell::new(Vec::new()));
520        let write_order = Rc::new(RefCell::new(Vec::new()));
521
522        let (client, server) = IoTest::create();
523        let io = Io::from(server).map_filter(|layer| Counter {
524            layer,
525            idx: 1,
526            in_bytes: in_bytes.clone(),
527            out_bytes: out_bytes.clone(),
528            read_order: read_order.clone(),
529            write_order: write_order.clone(),
530        });
531
532        client.remote_buffer_cap(1024);
533        client.write(TEXT);
534        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
535        assert_eq!(msg, Bytes::from_static(BIN));
536
537        io.send(Bytes::from_static(b"test"), &BytesCodec)
538            .await
539            .unwrap();
540        let buf = client.read().await.unwrap();
541        assert_eq!(buf, Bytes::from_static(b"test"));
542
543        assert_eq!(in_bytes.get(), BIN.len());
544        assert_eq!(out_bytes.get(), 4);
545    }
546
547    #[ntex::test]
548    async fn boxed_filter() {
549        let in_bytes = Rc::new(Cell::new(0));
550        let out_bytes = Rc::new(Cell::new(0));
551        let read_order = Rc::new(RefCell::new(Vec::new()));
552        let write_order = Rc::new(RefCell::new(Vec::new()));
553
554        let (client, server) = IoTest::create();
555        let state = Io::from(server)
556            .map_filter(|layer| Counter {
557                layer,
558                idx: 2,
559                in_bytes: in_bytes.clone(),
560                out_bytes: out_bytes.clone(),
561                read_order: read_order.clone(),
562                write_order: write_order.clone(),
563            })
564            .map_filter(|layer| Counter {
565                layer,
566                idx: 1,
567                in_bytes: in_bytes.clone(),
568                out_bytes: out_bytes.clone(),
569                read_order: read_order.clone(),
570                write_order: write_order.clone(),
571            });
572        let state = state.seal();
573
574        client.remote_buffer_cap(1024);
575        client.write(TEXT);
576        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
577        assert_eq!(msg, Bytes::from_static(BIN));
578
579        state
580            .send(Bytes::from_static(b"test"), &BytesCodec)
581            .await
582            .unwrap();
583        let buf = client.read().await.unwrap();
584        assert_eq!(buf, Bytes::from_static(b"test"));
585
586        assert_eq!(in_bytes.get(), BIN.len() * 2);
587        assert_eq!(out_bytes.get(), 8);
588        assert_eq!(
589            state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
590            0
591        );
592
593        // refs
594        assert_eq!(Rc::strong_count(&in_bytes), 3);
595        drop(state);
596        assert_eq!(Rc::strong_count(&in_bytes), 1);
597        assert_eq!(*read_order.borrow(), &[1, 2][..]);
598        assert_eq!(*write_order.borrow(), &[1, 2][..]);
599    }
600}