ntex_io/
ioref.rs

1use std::{any, fmt, hash, io};
2
3use ntex_bytes::{BytesVec, PoolRef};
4use ntex_codec::{Decoder, Encoder};
5use ntex_util::time::Seconds;
6
7use crate::{
8    timer, types, Decoded, Filter, FilterCtx, Flags, IoRef, OnDisconnect, WriteBuf,
9};
10
11impl IoRef {
12    #[inline]
13    #[doc(hidden)]
14    /// Get current state flags
15    pub fn flags(&self) -> Flags {
16        self.0.flags.get()
17    }
18
19    #[inline]
20    /// Get current filter
21    pub(crate) fn filter(&self) -> &dyn Filter {
22        self.0.filter()
23    }
24
25    #[inline]
26    /// Get memory pool
27    pub fn memory_pool(&self) -> PoolRef {
28        self.0.pool.get()
29    }
30
31    #[inline]
32    /// Check if io stream is closed
33    pub fn is_closed(&self) -> bool {
34        self.0
35            .flags
36            .get()
37            .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
38    }
39
40    #[inline]
41    /// Check if write back-pressure is enabled
42    pub fn is_wr_backpressure(&self) -> bool {
43        self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
44    }
45
46    #[inline]
47    /// Wake dispatcher task
48    pub fn wake(&self) {
49        self.0.dispatch_task.wake();
50    }
51
52    #[inline]
53    /// Gracefully close connection
54    ///
55    /// Notify dispatcher and initiate io stream shutdown process.
56    pub fn close(&self) {
57        self.0.insert_flags(Flags::DSP_STOP);
58        self.0.init_shutdown();
59    }
60
61    #[inline]
62    /// Force close connection
63    ///
64    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
65    /// without any graceful period.
66    pub fn force_close(&self) {
67        log::trace!("{}: Force close io stream object", self.tag());
68        self.0.insert_flags(
69            Flags::DSP_STOP
70                | Flags::IO_STOPPED
71                | Flags::IO_STOPPING
72                | Flags::IO_STOPPING_FILTERS,
73        );
74        self.0.read_task.wake();
75        self.0.write_task.wake();
76        self.0.dispatch_task.wake();
77    }
78
79    #[inline]
80    /// Gracefully shutdown io stream
81    pub fn want_shutdown(&self) {
82        if !self
83            .0
84            .flags
85            .get()
86            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
87        {
88            log::trace!(
89                "{}: Initiate io shutdown {:?}",
90                self.tag(),
91                self.0.flags.get()
92            );
93            self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
94            self.0.read_task.wake();
95        }
96    }
97
98    #[inline]
99    /// Query filter specific data
100    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
101        if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
102            types::QueryItem::new(item)
103        } else {
104            types::QueryItem::empty()
105        }
106    }
107
108    #[inline]
109    /// Encode and write item to a buffer and wake up write task
110    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
111    where
112        U: Encoder,
113    {
114        if !self.is_closed() {
115            self.with_write_buf(|buf| {
116                // make sure we've got room
117                self.memory_pool().resize_write_buf(buf);
118
119                // encode item and wake write task
120                codec.encode_vec(item, buf)
121            })
122            // .with_write_buf() could return io::Error<Result<(), U::Error>>,
123            // in that case mark io as failed
124            .unwrap_or_else(|err| {
125                log::trace!(
126                    "{}: Got io error while encoding, error: {:?}",
127                    self.tag(),
128                    err
129                );
130                self.0.io_stopped(Some(err));
131                Ok(())
132            })
133        } else {
134            log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
135            Ok(())
136        }
137    }
138
139    #[inline]
140    /// Attempts to decode a frame from the read buffer
141    pub fn decode<U>(
142        &self,
143        codec: &U,
144    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
145    where
146        U: Decoder,
147    {
148        self.0
149            .buffer
150            .with_read_destination(self, |buf| codec.decode_vec(buf))
151    }
152
153    #[inline]
154    /// Attempts to decode a frame from the read buffer
155    pub fn decode_item<U>(
156        &self,
157        codec: &U,
158    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
159    where
160        U: Decoder,
161    {
162        self.0.buffer.with_read_destination(self, |buf| {
163            let len = buf.len();
164            codec.decode_vec(buf).map(|item| Decoded {
165                item,
166                remains: buf.len(),
167                consumed: len - buf.len(),
168            })
169        })
170    }
171
172    #[inline]
173    /// Write bytes to a buffer and wake up write task
174    pub fn write(&self, src: &[u8]) -> io::Result<()> {
175        self.with_write_buf(|buf| buf.extend_from_slice(src))
176    }
177
178    #[inline]
179    /// Get access to write buffer
180    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
181    where
182        F: FnOnce(&WriteBuf<'_>) -> R,
183    {
184        let ctx = FilterCtx::new(self, &self.0.buffer);
185        let result = ctx.write_buf(f);
186        self.0.filter().process_write_buf(ctx)?;
187        Ok(result)
188    }
189
190    #[inline]
191    /// Get mut access to source write buffer
192    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
193    where
194        F: FnOnce(&mut BytesVec) -> R,
195    {
196        if self.0.flags.get().contains(Flags::IO_STOPPED) {
197            Err(self.0.error_or_disconnected())
198        } else {
199            let result = self.0.buffer.with_write_source(self, f);
200            self.0
201                .filter()
202                .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
203            Ok(result)
204        }
205    }
206
207    #[doc(hidden)]
208    #[inline]
209    /// Get mut access to destination write buffer
210    pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
211    where
212        F: FnOnce(Option<&mut BytesVec>) -> R,
213    {
214        self.0.buffer.with_write_destination(self, f)
215    }
216
217    #[inline]
218    /// Get mut access to source read buffer
219    pub fn with_read_buf<F, R>(&self, f: F) -> R
220    where
221        F: FnOnce(&mut BytesVec) -> R,
222    {
223        self.0.buffer.with_read_destination(self, f)
224    }
225
226    #[inline]
227    /// Wakeup dispatcher
228    pub fn notify_dispatcher(&self) {
229        self.0.dispatch_task.wake();
230        log::trace!("{}: Timer, notify dispatcher", self.tag());
231    }
232
233    #[inline]
234    /// Wakeup dispatcher and send keep-alive error
235    pub fn notify_timeout(&self) {
236        self.0.notify_timeout()
237    }
238
239    #[doc(hidden)]
240    #[inline]
241    /// Keep-alive error
242    pub fn is_keepalive(&self) -> bool {
243        self.0.flags.get().is_keepalive()
244    }
245
246    #[inline]
247    /// Wakeup dispatcher and send keep-alive error
248    pub fn notify_keepalive(&self) {
249        self.0.insert_flags(Flags::DSP_KEEPALIVE);
250        self.0.notify_timeout()
251    }
252
253    #[inline]
254    /// current timer handle
255    pub fn timer_handle(&self) -> timer::TimerHandle {
256        self.0.timeout.get()
257    }
258
259    #[inline]
260    /// Start timer
261    pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
262        let cur_hnd = self.0.timeout.get();
263
264        if !timeout.is_zero() {
265            if cur_hnd.is_set() {
266                let hnd = timer::update(cur_hnd, timeout, self);
267                if hnd != cur_hnd {
268                    log::trace!("{}: Update timer {:?}", self.tag(), timeout);
269                    self.0.timeout.set(hnd);
270                }
271                hnd
272            } else {
273                log::trace!("{}: Start timer {:?}", self.tag(), timeout);
274                let hnd = timer::register(timeout, self);
275                self.0.timeout.set(hnd);
276                hnd
277            }
278        } else {
279            if cur_hnd.is_set() {
280                self.0.timeout.set(timer::TimerHandle::ZERO);
281                timer::unregister(cur_hnd, self);
282            }
283            timer::TimerHandle::ZERO
284        }
285    }
286
287    #[inline]
288    /// Stop timer
289    pub fn stop_timer(&self) {
290        let hnd = self.0.timeout.get();
291        if hnd.is_set() {
292            log::trace!("{}: Stop timer", self.tag());
293            self.0.timeout.set(timer::TimerHandle::ZERO);
294            timer::unregister(hnd, self)
295        }
296    }
297
298    #[inline]
299    /// Get tag
300    pub fn tag(&self) -> &'static str {
301        self.0.tag.get()
302    }
303
304    #[inline]
305    /// Set tag
306    pub fn set_tag(&self, tag: &'static str) {
307        self.0.tag.set(tag)
308    }
309
310    #[inline]
311    /// Notify when io stream get disconnected
312    pub fn on_disconnect(&self) -> OnDisconnect {
313        OnDisconnect::new(self.0.clone())
314    }
315}
316
317impl Eq for IoRef {}
318
319impl PartialEq for IoRef {
320    #[inline]
321    fn eq(&self, other: &Self) -> bool {
322        self.0.eq(&other.0)
323    }
324}
325
326impl hash::Hash for IoRef {
327    #[inline]
328    fn hash<H: hash::Hasher>(&self, state: &mut H) {
329        self.0.hash(state);
330    }
331}
332
333impl fmt::Debug for IoRef {
334    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335        f.debug_struct("IoRef")
336            .field("state", self.0.as_ref())
337            .finish()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use std::cell::{Cell, RefCell};
344    use std::{future::poll_fn, future::Future, pin::Pin, rc::Rc, task::Poll};
345
346    use ntex_bytes::Bytes;
347    use ntex_codec::BytesCodec;
348    use ntex_util::future::lazy;
349    use ntex_util::time::{sleep, Millis};
350
351    use super::*;
352    use crate::{testing::IoTest, FilterCtx, FilterReadStatus, Io};
353
354    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
355    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
356
357    #[ntex::test]
358    async fn utils() {
359        let (client, server) = IoTest::create();
360        client.remote_buffer_cap(1024);
361        client.write(TEXT);
362
363        let state = Io::new(server);
364        assert_eq!(state.get_ref(), state.get_ref());
365
366        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
367        assert_eq!(msg, Bytes::from_static(BIN));
368        assert_eq!(state.get_ref(), state.as_ref().clone());
369        assert!(format!("{state:?}").find("Io {").is_some());
370        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
371
372        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
373        assert!(res.is_pending());
374        client.write(TEXT);
375        sleep(Millis(50)).await;
376        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
377        if let Poll::Ready(msg) = res {
378            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
379        }
380
381        client.read_error(io::Error::other("err"));
382        let msg = state.recv(&BytesCodec).await;
383        assert!(msg.is_err());
384        assert!(state.flags().contains(Flags::IO_STOPPED));
385
386        let (client, server) = IoTest::create();
387        client.remote_buffer_cap(1024);
388        let state = Io::new(server);
389
390        client.read_error(io::Error::other("err"));
391        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
392        if let Poll::Ready(msg) = res {
393            assert!(msg.is_err());
394            assert!(state.flags().contains(Flags::IO_STOPPED));
395            assert!(state.flags().contains(Flags::DSP_STOP));
396        }
397
398        let (client, server) = IoTest::create();
399        client.remote_buffer_cap(1024);
400        let state = Io::new(server);
401        state.write(b"test").unwrap();
402        let buf = client.read().await.unwrap();
403        assert_eq!(buf, Bytes::from_static(b"test"));
404
405        client.write(b"test");
406        state.read_ready().await.unwrap();
407        let buf = state.decode(&BytesCodec).unwrap().unwrap();
408        assert_eq!(buf, Bytes::from_static(b"test"));
409
410        client.write_error(io::Error::other("err"));
411        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
412        assert!(res.is_err());
413        assert!(state.flags().contains(Flags::IO_STOPPED));
414
415        let (client, server) = IoTest::create();
416        client.remote_buffer_cap(1024);
417        let state = Io::new(server);
418        state.force_close();
419        assert!(state.flags().contains(Flags::DSP_STOP));
420        assert!(state.flags().contains(Flags::IO_STOPPING));
421    }
422
423    #[ntex::test]
424    async fn read_readiness() {
425        let (client, server) = IoTest::create();
426        client.remote_buffer_cap(1024);
427
428        let io = Io::new(server);
429        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
430
431        client.write(TEXT);
432        assert_eq!(io.read_ready().await.unwrap(), Some(()));
433        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
434
435        let item = io.with_read_buf(|buffer| buffer.split());
436        assert_eq!(item, Bytes::from_static(BIN));
437
438        client.write(TEXT);
439        sleep(Millis(50)).await;
440        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
441        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
442    }
443
444    #[ntex::test]
445    #[allow(clippy::unit_cmp)]
446    async fn on_disconnect() {
447        let (client, server) = IoTest::create();
448        let state = Io::new(server);
449        let mut waiter = state.on_disconnect();
450        assert_eq!(
451            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
452            Poll::Pending
453        );
454        let mut waiter2 = waiter.clone();
455        assert_eq!(
456            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
457            Poll::Pending
458        );
459        client.close().await;
460        assert_eq!(waiter.await, ());
461        assert_eq!(waiter2.await, ());
462
463        let mut waiter = state.on_disconnect();
464        assert_eq!(
465            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
466            Poll::Ready(())
467        );
468
469        let (client, server) = IoTest::create();
470        let state = Io::new(server);
471        let mut waiter = state.on_disconnect();
472        assert_eq!(
473            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
474            Poll::Pending
475        );
476        client.read_error(io::Error::other("err"));
477        assert_eq!(waiter.await, ());
478    }
479
480    #[ntex::test]
481    async fn write_to_closed_io() {
482        let (client, server) = IoTest::create();
483        let state = Io::new(server);
484        client.close().await;
485
486        assert!(state.is_closed());
487        assert!(state.write(TEXT.as_bytes()).is_err());
488        assert!(state
489            .with_write_buf(|buf| buf.extend_from_slice(BIN))
490            .is_err());
491    }
492
493    #[derive(Debug)]
494    struct Counter<F> {
495        layer: F,
496        idx: usize,
497        in_bytes: Rc<Cell<usize>>,
498        out_bytes: Rc<Cell<usize>>,
499        read_order: Rc<RefCell<Vec<usize>>>,
500        write_order: Rc<RefCell<Vec<usize>>>,
501    }
502
503    impl<F: Filter> Filter for Counter<F> {
504        fn process_read_buf(
505            &self,
506            ctx: FilterCtx<'_>,
507            nbytes: usize,
508        ) -> io::Result<FilterReadStatus> {
509            self.read_order.borrow_mut().push(self.idx);
510            self.in_bytes.set(self.in_bytes.get() + nbytes);
511            self.layer.process_read_buf(ctx, nbytes)
512        }
513
514        fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
515            self.write_order.borrow_mut().push(self.idx);
516            self.out_bytes.set(
517                self.out_bytes.get()
518                    + ctx.write_buf(|buf| {
519                        buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
520                    }),
521            );
522            self.layer.process_write_buf(ctx)
523        }
524
525        crate::forward_ready!(layer);
526        crate::forward_query!(layer);
527        crate::forward_shutdown!(layer);
528    }
529
530    #[ntex::test]
531    async fn filter() {
532        let in_bytes = Rc::new(Cell::new(0));
533        let out_bytes = Rc::new(Cell::new(0));
534        let read_order = Rc::new(RefCell::new(Vec::new()));
535        let write_order = Rc::new(RefCell::new(Vec::new()));
536
537        let (client, server) = IoTest::create();
538        let io = Io::new(server).map_filter(|layer| Counter {
539            layer,
540            idx: 1,
541            in_bytes: in_bytes.clone(),
542            out_bytes: out_bytes.clone(),
543            read_order: read_order.clone(),
544            write_order: write_order.clone(),
545        });
546
547        client.remote_buffer_cap(1024);
548        client.write(TEXT);
549        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
550        assert_eq!(msg, Bytes::from_static(BIN));
551
552        io.send(Bytes::from_static(b"test"), &BytesCodec)
553            .await
554            .unwrap();
555        let buf = client.read().await.unwrap();
556        assert_eq!(buf, Bytes::from_static(b"test"));
557
558        assert_eq!(in_bytes.get(), BIN.len());
559        assert_eq!(out_bytes.get(), 4);
560    }
561
562    #[ntex::test]
563    async fn boxed_filter() {
564        let in_bytes = Rc::new(Cell::new(0));
565        let out_bytes = Rc::new(Cell::new(0));
566        let read_order = Rc::new(RefCell::new(Vec::new()));
567        let write_order = Rc::new(RefCell::new(Vec::new()));
568
569        let (client, server) = IoTest::create();
570        let state = Io::new(server)
571            .map_filter(|layer| Counter {
572                layer,
573                idx: 2,
574                in_bytes: in_bytes.clone(),
575                out_bytes: out_bytes.clone(),
576                read_order: read_order.clone(),
577                write_order: write_order.clone(),
578            })
579            .map_filter(|layer| Counter {
580                layer,
581                idx: 1,
582                in_bytes: in_bytes.clone(),
583                out_bytes: out_bytes.clone(),
584                read_order: read_order.clone(),
585                write_order: write_order.clone(),
586            });
587        let state = state.seal();
588
589        client.remote_buffer_cap(1024);
590        client.write(TEXT);
591        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
592        assert_eq!(msg, Bytes::from_static(BIN));
593
594        state
595            .send(Bytes::from_static(b"test"), &BytesCodec)
596            .await
597            .unwrap();
598        let buf = client.read().await.unwrap();
599        assert_eq!(buf, Bytes::from_static(b"test"));
600
601        assert_eq!(in_bytes.get(), BIN.len() * 2);
602        assert_eq!(out_bytes.get(), 8);
603        assert_eq!(
604            state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
605            0
606        );
607
608        // refs
609        assert_eq!(Rc::strong_count(&in_bytes), 3);
610        drop(state);
611        assert_eq!(Rc::strong_count(&in_bytes), 1);
612        assert_eq!(*read_order.borrow(), &[1, 2][..]);
613        assert_eq!(*write_order.borrow(), &[1, 2][..]);
614    }
615}