Skip to main content

ntex_io/
ioref.rs

1use std::{any, fmt, hash, io};
2
3use ntex_bytes::BytesMut;
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::{
9    Decoded, Filter, FilterCtx, Flags, IoConfig, IoRef, OnDisconnect, WriteBuf, timer,
10    types,
11};
12
13impl IoRef {
14    #[inline]
15    /// Get tag
16    pub fn tag(&self) -> &'static str {
17        self.0.cfg.tag()
18    }
19
20    #[inline]
21    #[doc(hidden)]
22    /// Get current state flags
23    pub fn flags(&self) -> Flags {
24        self.0.flags.get()
25    }
26
27    #[inline]
28    /// Get current filter
29    pub(crate) fn filter(&self) -> &dyn Filter {
30        self.0.filter()
31    }
32
33    #[inline]
34    /// Get configuration
35    pub fn cfg(&self) -> &IoConfig {
36        &self.0.cfg
37    }
38
39    #[inline]
40    /// Get shared configuration
41    pub fn shared(&self) -> SharedCfg {
42        self.0.cfg.shared()
43    }
44
45    #[inline]
46    /// Check if io stream is closed
47    pub fn is_closed(&self) -> bool {
48        self.0
49            .flags
50            .get()
51            .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
52    }
53
54    #[inline]
55    /// Check if write back-pressure is enabled
56    pub fn is_wr_backpressure(&self) -> bool {
57        self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
58    }
59
60    #[inline]
61    /// Wake dispatcher task
62    pub fn wake(&self) {
63        self.0.dispatch_task.wake();
64    }
65
66    #[inline]
67    /// Gracefully close connection
68    ///
69    /// Initiate io stream shutdown process.
70    pub fn close(&self) {
71        self.0.init_shutdown();
72    }
73
74    #[inline]
75    /// Force close connection
76    ///
77    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
78    /// without any graceful period.
79    pub fn force_close(&self) {
80        log::trace!("{}: Force close io stream object", self.tag());
81        self.0.insert_flags(
82            Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
83        );
84        self.0.read_task.wake();
85        self.0.write_task.wake();
86        self.0.dispatch_task.wake();
87    }
88
89    #[inline]
90    /// Gracefully shutdown io stream
91    pub fn want_shutdown(&self) {
92        if !self
93            .0
94            .flags
95            .get()
96            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
97        {
98            log::trace!(
99                "{}: Initiate io shutdown {:?}",
100                self.tag(),
101                self.0.flags.get()
102            );
103            self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
104            self.0.read_task.wake();
105        }
106    }
107
108    #[inline]
109    /// Query filter specific data
110    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
111        if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
112            types::QueryItem::new(item)
113        } else {
114            types::QueryItem::empty()
115        }
116    }
117
118    #[inline]
119    /// Encode and write item to a buffer and wake up write task
120    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
121    where
122        U: Encoder,
123    {
124        if self.is_closed() {
125            log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
126            Ok(())
127        } else {
128            self.with_write_buf(|buf| {
129                // make sure we've got room
130                self.cfg().write_buf().resize(buf);
131
132                // encode item and wake write task
133                codec.encode(item, buf)
134            })
135            // .with_write_buf() could return io::Error<Result<(), U::Error>>,
136            // in that case mark io as failed
137            .unwrap_or_else(|err| {
138                log::trace!(
139                    "{}: Got io error while encoding, error: {:?}",
140                    self.tag(),
141                    err
142                );
143                self.0.io_stopped(Some(err));
144                Ok(())
145            })
146        }
147    }
148
149    #[inline]
150    /// Attempts to decode a frame from the read buffer
151    pub fn decode<U>(
152        &self,
153        codec: &U,
154    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
155    where
156        U: Decoder,
157    {
158        self.0
159            .buffer
160            .with_read_destination(self, |buf| codec.decode(buf))
161    }
162
163    #[inline]
164    /// Attempts to decode a frame from the read buffer
165    pub fn decode_item<U>(
166        &self,
167        codec: &U,
168    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
169    where
170        U: Decoder,
171    {
172        self.0.buffer.with_read_destination(self, |buf| {
173            let len = buf.len();
174            codec.decode(buf).map(|item| Decoded {
175                item,
176                remains: buf.len(),
177                consumed: len - buf.len(),
178            })
179        })
180    }
181
182    #[inline]
183    /// Write bytes to a buffer and wake up write task
184    pub fn write(&self, src: &[u8]) -> io::Result<()> {
185        self.with_write_buf(|buf| buf.extend_from_slice(src))
186    }
187
188    #[inline]
189    /// Get access to write buffer
190    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
191    where
192        F: FnOnce(&WriteBuf<'_>) -> R,
193    {
194        let ctx = FilterCtx::new(self, &self.0.buffer);
195        let result = ctx.write_buf(f);
196        self.0.filter().process_write_buf(ctx)?;
197        Ok(result)
198    }
199
200    #[inline]
201    /// Get mut access to source write buffer
202    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
203    where
204        F: FnOnce(&mut BytesMut) -> R,
205    {
206        if self.0.flags.get().contains(Flags::IO_STOPPED) {
207            Err(self.0.error_or_disconnected())
208        } else {
209            let result = self.0.buffer.with_write_source(self, f);
210            self.0
211                .filter()
212                .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
213            Ok(result)
214        }
215    }
216
217    #[doc(hidden)]
218    #[inline]
219    /// Get mut access to destination write buffer
220    pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
221    where
222        F: FnOnce(Option<&mut BytesMut>) -> R,
223    {
224        self.0.buffer.with_write_destination(self, f)
225    }
226
227    #[inline]
228    /// Get mut access to source read buffer
229    pub fn with_read_buf<F, R>(&self, f: F) -> R
230    where
231        F: FnOnce(&mut BytesMut) -> R,
232    {
233        self.0.buffer.with_read_destination(self, f)
234    }
235
236    #[inline]
237    /// Wakeup dispatcher
238    pub fn notify_dispatcher(&self) {
239        self.0.dispatch_task.wake();
240        log::trace!("{}: Timer, notify dispatcher", self.tag());
241    }
242
243    #[inline]
244    /// Wakeup dispatcher and send keep-alive error
245    pub fn notify_timeout(&self) {
246        self.0.notify_timeout();
247    }
248
249    #[inline]
250    /// current timer handle
251    pub fn timer_handle(&self) -> timer::TimerHandle {
252        self.0.timeout.get()
253    }
254
255    #[inline]
256    /// Start timer
257    pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
258        let cur_hnd = self.0.timeout.get();
259
260        if timeout.is_zero() {
261            if cur_hnd.is_set() {
262                self.0.timeout.set(timer::TimerHandle::ZERO);
263                timer::unregister(cur_hnd, self);
264            }
265            timer::TimerHandle::ZERO
266        } else if cur_hnd.is_set() {
267            let hnd = timer::update(cur_hnd, timeout, self);
268            if hnd != cur_hnd {
269                log::trace!("{}: Update timer {:?}", self.tag(), timeout);
270                self.0.timeout.set(hnd);
271            }
272            hnd
273        } else {
274            log::trace!("{}: Start timer {:?}", self.tag(), timeout);
275            let hnd = timer::register(timeout, self);
276            self.0.timeout.set(hnd);
277            hnd
278        }
279    }
280
281    #[inline]
282    /// Stop timer
283    pub fn stop_timer(&self) {
284        let hnd = self.0.timeout.get();
285        if hnd.is_set() {
286            log::trace!("{}: Stop timer", self.tag());
287            self.0.timeout.set(timer::TimerHandle::ZERO);
288            timer::unregister(hnd, self);
289        }
290    }
291
292    #[inline]
293    /// Notify when io stream get disconnected
294    pub fn on_disconnect(&self) -> OnDisconnect {
295        OnDisconnect::new(self.0.clone())
296    }
297}
298
299impl Eq for IoRef {}
300
301impl PartialEq for IoRef {
302    #[inline]
303    fn eq(&self, other: &Self) -> bool {
304        self.0.eq(&other.0)
305    }
306}
307
308impl hash::Hash for IoRef {
309    #[inline]
310    fn hash<H: hash::Hasher>(&self, state: &mut H) {
311        self.0.hash(state);
312    }
313}
314
315impl fmt::Debug for IoRef {
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        f.debug_struct("IoRef")
318            .field("state", self.0.as_ref())
319            .finish()
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use std::cell::{Cell, RefCell};
326    use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
327
328    use ntex_bytes::Bytes;
329    use ntex_codec::BytesCodec;
330    use ntex_util::future::lazy;
331    use ntex_util::time::{Millis, sleep};
332
333    use super::*;
334    use crate::{FilterCtx, FilterReadStatus, Io, testing::IoTest};
335
336    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
337    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
338
339    #[ntex::test]
340    async fn utils() {
341        let (client, server) = IoTest::create();
342        client.remote_buffer_cap(1024);
343        client.write(TEXT);
344
345        let state = Io::from(server);
346        assert_eq!(state.get_ref(), state.get_ref());
347
348        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
349        assert_eq!(msg, Bytes::from_static(BIN));
350        assert_eq!(state.get_ref(), state.as_ref().clone());
351        assert!(format!("{state:?}").find("Io {").is_some());
352        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
353
354        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
355        assert!(res.is_pending());
356        client.write(TEXT);
357        sleep(Millis(50)).await;
358        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
359        if let Poll::Ready(msg) = res {
360            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
361        }
362
363        client.read_error(io::Error::other("err"));
364        let msg = state.recv(&BytesCodec).await;
365        assert!(msg.is_err());
366        assert!(state.flags().contains(Flags::IO_STOPPED));
367
368        let (client, server) = IoTest::create();
369        client.remote_buffer_cap(1024);
370        let state = Io::from(server);
371
372        client.read_error(io::Error::other("err"));
373        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
374        if let Poll::Ready(msg) = res {
375            assert!(msg.is_err());
376            assert!(state.flags().contains(Flags::IO_STOPPED));
377        }
378
379        let (client, server) = IoTest::create();
380        client.remote_buffer_cap(1024);
381        let state = Io::from(server);
382        state.write(b"test").unwrap();
383        let buf = client.read().await.unwrap();
384        assert_eq!(buf, Bytes::from_static(b"test"));
385
386        client.write(b"test");
387        state.read_ready().await.unwrap();
388        let buf = state.decode(&BytesCodec).unwrap().unwrap();
389        assert_eq!(buf, Bytes::from_static(b"test"));
390
391        client.write_error(io::Error::other("err"));
392        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
393        assert!(res.is_err());
394        assert!(state.flags().contains(Flags::IO_STOPPED));
395
396        let (client, server) = IoTest::create();
397        client.remote_buffer_cap(1024);
398        let state = Io::from(server);
399        state.force_close();
400        assert!(state.flags().contains(Flags::IO_STOPPED));
401        assert!(state.flags().contains(Flags::IO_STOPPING));
402    }
403
404    #[ntex::test]
405    async fn read_readiness() {
406        let (client, server) = IoTest::create();
407        client.remote_buffer_cap(1024);
408
409        let io = Io::from(server);
410        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
411
412        client.write(TEXT);
413        assert_eq!(io.read_ready().await.unwrap(), Some(()));
414        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
415
416        let item = io.with_read_buf(BytesMut::take);
417        assert_eq!(item, Bytes::from_static(BIN));
418
419        client.write(TEXT);
420        sleep(Millis(50)).await;
421        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
422        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
423    }
424
425    #[ntex::test]
426    #[allow(clippy::unit_cmp)]
427    async fn on_disconnect() {
428        let (client, server) = IoTest::create();
429        let state = Io::from(server);
430        let mut waiter = state.on_disconnect();
431        assert_eq!(
432            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
433            Poll::Pending
434        );
435        let mut waiter2 = waiter.clone();
436        assert_eq!(
437            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
438            Poll::Pending
439        );
440        client.close().await;
441        assert_eq!(waiter.await, ());
442        assert_eq!(waiter2.await, ());
443
444        let mut waiter = state.on_disconnect();
445        assert_eq!(
446            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
447            Poll::Ready(())
448        );
449
450        let (client, server) = IoTest::create();
451        let state = Io::from(server);
452        let mut waiter = state.on_disconnect();
453        assert_eq!(
454            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
455            Poll::Pending
456        );
457        client.read_error(io::Error::other("err"));
458        assert_eq!(waiter.await, ());
459    }
460
461    #[ntex::test]
462    async fn write_to_closed_io() {
463        let (client, server) = IoTest::create();
464        let state = Io::from(server);
465        client.close().await;
466
467        assert!(state.is_closed());
468        assert!(state.write(TEXT.as_bytes()).is_err());
469        assert!(
470            state
471                .with_write_buf(|buf| buf.extend_from_slice(BIN))
472                .is_err()
473        );
474    }
475
476    #[derive(Debug)]
477    struct Counter<F> {
478        layer: F,
479        idx: usize,
480        in_bytes: Rc<Cell<usize>>,
481        out_bytes: Rc<Cell<usize>>,
482        read_order: Rc<RefCell<Vec<usize>>>,
483        write_order: Rc<RefCell<Vec<usize>>>,
484    }
485
486    impl<F: Filter> Filter for Counter<F> {
487        fn process_read_buf(
488            &self,
489            ctx: FilterCtx<'_>,
490            nbytes: usize,
491        ) -> io::Result<FilterReadStatus> {
492            self.read_order.borrow_mut().push(self.idx);
493            self.in_bytes.set(self.in_bytes.get() + nbytes);
494            self.layer.process_read_buf(ctx, nbytes)
495        }
496
497        fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
498            self.write_order.borrow_mut().push(self.idx);
499            self.out_bytes.set(
500                self.out_bytes.get()
501                    + ctx.write_buf(|buf| {
502                        buf.with_src(|b| b.as_ref().map(BytesMut::len).unwrap_or_default())
503                    }),
504            );
505            self.layer.process_write_buf(ctx)
506        }
507
508        crate::forward_ready!(layer);
509        crate::forward_query!(layer);
510        crate::forward_shutdown!(layer);
511    }
512
513    #[ntex::test]
514    async fn filter() {
515        let in_bytes = Rc::new(Cell::new(0));
516        let out_bytes = Rc::new(Cell::new(0));
517        let read_order = Rc::new(RefCell::new(Vec::new()));
518        let write_order = Rc::new(RefCell::new(Vec::new()));
519
520        let (client, server) = IoTest::create();
521        let io = Io::from(server).map_filter(|layer| Counter {
522            layer,
523            idx: 1,
524            in_bytes: in_bytes.clone(),
525            out_bytes: out_bytes.clone(),
526            read_order: read_order.clone(),
527            write_order: write_order.clone(),
528        });
529
530        client.remote_buffer_cap(1024);
531        client.write(TEXT);
532        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
533        assert_eq!(msg, Bytes::from_static(BIN));
534
535        io.send(Bytes::from_static(b"test"), &BytesCodec)
536            .await
537            .unwrap();
538        let buf = client.read().await.unwrap();
539        assert_eq!(buf, Bytes::from_static(b"test"));
540
541        assert_eq!(in_bytes.get(), BIN.len());
542        assert_eq!(out_bytes.get(), 4);
543    }
544
545    #[ntex::test]
546    async fn boxed_filter() {
547        let in_bytes = Rc::new(Cell::new(0));
548        let out_bytes = Rc::new(Cell::new(0));
549        let read_order = Rc::new(RefCell::new(Vec::new()));
550        let write_order = Rc::new(RefCell::new(Vec::new()));
551
552        let (client, server) = IoTest::create();
553        let state = Io::from(server)
554            .map_filter(|layer| Counter {
555                layer,
556                idx: 2,
557                in_bytes: in_bytes.clone(),
558                out_bytes: out_bytes.clone(),
559                read_order: read_order.clone(),
560                write_order: write_order.clone(),
561            })
562            .map_filter(|layer| Counter {
563                layer,
564                idx: 1,
565                in_bytes: in_bytes.clone(),
566                out_bytes: out_bytes.clone(),
567                read_order: read_order.clone(),
568                write_order: write_order.clone(),
569            });
570        let state = state.seal();
571
572        client.remote_buffer_cap(1024);
573        client.write(TEXT);
574        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
575        assert_eq!(msg, Bytes::from_static(BIN));
576
577        state
578            .send(Bytes::from_static(b"test"), &BytesCodec)
579            .await
580            .unwrap();
581        let buf = client.read().await.unwrap();
582        assert_eq!(buf, Bytes::from_static(b"test"));
583
584        assert_eq!(in_bytes.get(), BIN.len() * 2);
585        assert_eq!(out_bytes.get(), 8);
586        assert_eq!(state.with_write_dest_buf(|b| b.map_or(0, |b| b.len())), 0);
587
588        // refs
589        assert_eq!(Rc::strong_count(&in_bytes), 3);
590        drop(state);
591        assert_eq!(Rc::strong_count(&in_bytes), 1);
592        assert_eq!(*read_order.borrow(), &[1, 2][..]);
593        assert_eq!(*write_order.borrow(), &[1, 2][..]);
594    }
595}