ntex_io/
ioref.rs

1use std::{any, fmt, hash, io};
2
3use ntex_bytes::{BytesVec, PoolRef};
4use ntex_codec::{Decoder, Encoder};
5use ntex_util::time::Seconds;
6
7use crate::{timer, types, Decoded, Filter, Flags, IoRef, OnDisconnect, WriteBuf};
8
9impl IoRef {
10    #[inline]
11    #[doc(hidden)]
12    /// Get current state flags
13    pub fn flags(&self) -> Flags {
14        self.0.flags.get()
15    }
16
17    #[inline]
18    /// Get current filter
19    pub(crate) fn filter(&self) -> &dyn Filter {
20        self.0.filter()
21    }
22
23    #[inline]
24    /// Get memory pool
25    pub fn memory_pool(&self) -> PoolRef {
26        self.0.pool.get()
27    }
28
29    #[inline]
30    /// Check if io stream is closed
31    pub fn is_closed(&self) -> bool {
32        self.0
33            .flags
34            .get()
35            .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
36    }
37
38    #[inline]
39    /// Check if write back-pressure is enabled
40    pub fn is_wr_backpressure(&self) -> bool {
41        self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
42    }
43
44    #[inline]
45    /// Wake dispatcher task
46    pub fn wake(&self) {
47        self.0.dispatch_task.wake();
48    }
49
50    #[inline]
51    /// Gracefully close connection
52    ///
53    /// Notify dispatcher and initiate io stream shutdown process.
54    pub fn close(&self) {
55        self.0.insert_flags(Flags::DSP_STOP);
56        self.0.init_shutdown();
57    }
58
59    #[inline]
60    /// Force close connection
61    ///
62    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
63    /// without any graceful period.
64    pub fn force_close(&self) {
65        log::trace!("{}: Force close io stream object", self.tag());
66        self.0.insert_flags(
67            Flags::DSP_STOP
68                | Flags::IO_STOPPED
69                | Flags::IO_STOPPING
70                | Flags::IO_STOPPING_FILTERS,
71        );
72        self.0.read_task.wake();
73        self.0.write_task.wake();
74        self.0.dispatch_task.wake();
75    }
76
77    #[inline]
78    /// Gracefully shutdown io stream
79    pub fn want_shutdown(&self) {
80        if !self
81            .0
82            .flags
83            .get()
84            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
85        {
86            log::trace!(
87                "{}: Initiate io shutdown {:?}",
88                self.tag(),
89                self.0.flags.get()
90            );
91            self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
92            self.0.read_task.wake();
93        }
94    }
95
96    #[inline]
97    /// Query filter specific data
98    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
99        if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
100            types::QueryItem::new(item)
101        } else {
102            types::QueryItem::empty()
103        }
104    }
105
106    #[inline]
107    /// Encode and write item to a buffer and wake up write task
108    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
109    where
110        U: Encoder,
111    {
112        if !self.is_closed() {
113            self.with_write_buf(|buf| {
114                // make sure we've got room
115                self.memory_pool().resize_write_buf(buf);
116
117                // encode item and wake write task
118                codec.encode_vec(item, buf)
119            })
120            // .with_write_buf() could return io::Error<Result<(), U::Error>>,
121            // in that case mark io as failed
122            .unwrap_or_else(|err| {
123                log::trace!(
124                    "{}: Got io error while encoding, error: {:?}",
125                    self.tag(),
126                    err
127                );
128                self.0.io_stopped(Some(err));
129                Ok(())
130            })
131        } else {
132            log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
133            Ok(())
134        }
135    }
136
137    #[inline]
138    /// Attempts to decode a frame from the read buffer
139    pub fn decode<U>(
140        &self,
141        codec: &U,
142    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
143    where
144        U: Decoder,
145    {
146        self.0
147            .buffer
148            .with_read_destination(self, |buf| codec.decode_vec(buf))
149    }
150
151    #[inline]
152    /// Attempts to decode a frame from the read buffer
153    pub fn decode_item<U>(
154        &self,
155        codec: &U,
156    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
157    where
158        U: Decoder,
159    {
160        self.0.buffer.with_read_destination(self, |buf| {
161            let len = buf.len();
162            codec.decode_vec(buf).map(|item| Decoded {
163                item,
164                remains: buf.len(),
165                consumed: len - buf.len(),
166            })
167        })
168    }
169
170    #[inline]
171    /// Write bytes to a buffer and wake up write task
172    pub fn write(&self, src: &[u8]) -> io::Result<()> {
173        self.with_write_buf(|buf| buf.extend_from_slice(src))
174    }
175
176    #[inline]
177    /// Get access to write buffer
178    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
179    where
180        F: FnOnce(&WriteBuf<'_>) -> R,
181    {
182        let result = self.0.buffer.write_buf(self, 0, f);
183        self.0.filter().process_write_buf(self, &self.0.buffer, 0)?;
184        Ok(result)
185    }
186
187    #[inline]
188    /// Get mut access to source write buffer
189    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
190    where
191        F: FnOnce(&mut BytesVec) -> R,
192    {
193        if self.0.flags.get().contains(Flags::IO_STOPPED) {
194            Err(self.0.error_or_disconnected())
195        } else {
196            let result = self.0.buffer.with_write_source(self, f);
197            self.0.filter().process_write_buf(self, &self.0.buffer, 0)?;
198            Ok(result)
199        }
200    }
201
202    #[doc(hidden)]
203    #[inline]
204    /// Get mut access to destination write buffer
205    pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
206    where
207        F: FnOnce(Option<&mut BytesVec>) -> R,
208    {
209        self.0.buffer.with_write_destination(self, f)
210    }
211
212    #[inline]
213    /// Get mut access to source read buffer
214    pub fn with_read_buf<F, R>(&self, f: F) -> R
215    where
216        F: FnOnce(&mut BytesVec) -> R,
217    {
218        self.0.buffer.with_read_destination(self, f)
219    }
220
221    #[inline]
222    /// Wakeup dispatcher
223    pub fn notify_dispatcher(&self) {
224        self.0.dispatch_task.wake();
225        log::trace!("{}: Timer, notify dispatcher", self.tag());
226    }
227
228    #[inline]
229    /// Wakeup dispatcher and send keep-alive error
230    pub fn notify_timeout(&self) {
231        self.0.notify_timeout()
232    }
233
234    #[inline]
235    /// current timer handle
236    pub fn timer_handle(&self) -> timer::TimerHandle {
237        self.0.timeout.get()
238    }
239
240    #[inline]
241    /// Start timer
242    pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
243        let cur_hnd = self.0.timeout.get();
244
245        if !timeout.is_zero() {
246            if cur_hnd.is_set() {
247                let hnd = timer::update(cur_hnd, timeout, self);
248                if hnd != cur_hnd {
249                    log::debug!("{}: Update timer {:?}", self.tag(), timeout);
250                    self.0.timeout.set(hnd);
251                }
252                hnd
253            } else {
254                log::debug!("{}: Start timer {:?}", self.tag(), timeout);
255                let hnd = timer::register(timeout, self);
256                self.0.timeout.set(hnd);
257                hnd
258            }
259        } else {
260            if cur_hnd.is_set() {
261                self.0.timeout.set(timer::TimerHandle::ZERO);
262                timer::unregister(cur_hnd, self);
263            }
264            timer::TimerHandle::ZERO
265        }
266    }
267
268    #[inline]
269    /// Stop timer
270    pub fn stop_timer(&self) {
271        let hnd = self.0.timeout.get();
272        if hnd.is_set() {
273            log::debug!("{}: Stop timer", self.tag());
274            self.0.timeout.set(timer::TimerHandle::ZERO);
275            timer::unregister(hnd, self)
276        }
277    }
278
279    #[inline]
280    /// Get tag
281    pub fn tag(&self) -> &'static str {
282        self.0.tag.get()
283    }
284
285    #[inline]
286    /// Set tag
287    pub fn set_tag(&self, tag: &'static str) {
288        self.0.tag.set(tag)
289    }
290
291    #[inline]
292    /// Notify when io stream get disconnected
293    pub fn on_disconnect(&self) -> OnDisconnect {
294        OnDisconnect::new(self.0.clone())
295    }
296}
297
298impl Eq for IoRef {}
299
300impl PartialEq for IoRef {
301    #[inline]
302    fn eq(&self, other: &Self) -> bool {
303        self.0.eq(&other.0)
304    }
305}
306
307impl hash::Hash for IoRef {
308    #[inline]
309    fn hash<H: hash::Hasher>(&self, state: &mut H) {
310        self.0.hash(state);
311    }
312}
313
314impl fmt::Debug for IoRef {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        f.debug_struct("IoRef")
317            .field("state", self.0.as_ref())
318            .finish()
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use std::cell::{Cell, RefCell};
325    use std::{future::poll_fn, future::Future, pin::Pin, rc::Rc, task::Poll};
326
327    use ntex_bytes::Bytes;
328    use ntex_codec::BytesCodec;
329    use ntex_util::future::lazy;
330    use ntex_util::time::{sleep, Millis};
331
332    use super::*;
333    use crate::{testing::IoTest, FilterLayer, Io, ReadBuf};
334
335    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
336    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
337
338    #[ntex::test]
339    async fn utils() {
340        let (client, server) = IoTest::create();
341        client.remote_buffer_cap(1024);
342        client.write(TEXT);
343
344        let state = Io::new(server);
345        assert_eq!(state.get_ref(), state.get_ref());
346
347        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
348        assert_eq!(msg, Bytes::from_static(BIN));
349        assert_eq!(state.get_ref(), state.as_ref().clone());
350        assert!(format!("{:?}", state).find("Io {").is_some());
351        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
352
353        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
354        assert!(res.is_pending());
355        client.write(TEXT);
356        sleep(Millis(50)).await;
357        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
358        if let Poll::Ready(msg) = res {
359            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
360        }
361
362        client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
363        let msg = state.recv(&BytesCodec).await;
364        assert!(msg.is_err());
365        assert!(state.flags().contains(Flags::IO_STOPPED));
366
367        let (client, server) = IoTest::create();
368        client.remote_buffer_cap(1024);
369        let state = Io::new(server);
370
371        client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
372        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
373        if let Poll::Ready(msg) = res {
374            assert!(msg.is_err());
375            assert!(state.flags().contains(Flags::IO_STOPPED));
376            assert!(state.flags().contains(Flags::DSP_STOP));
377        }
378
379        let (client, server) = IoTest::create();
380        client.remote_buffer_cap(1024);
381        let state = Io::new(server);
382        state.write(b"test").unwrap();
383        let buf = client.read().await.unwrap();
384        assert_eq!(buf, Bytes::from_static(b"test"));
385
386        client.write(b"test");
387        state.read_ready().await.unwrap();
388        let buf = state.decode(&BytesCodec).unwrap().unwrap();
389        assert_eq!(buf, Bytes::from_static(b"test"));
390
391        client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
392        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
393        assert!(res.is_err());
394        assert!(state.flags().contains(Flags::IO_STOPPED));
395
396        let (client, server) = IoTest::create();
397        client.remote_buffer_cap(1024);
398        let state = Io::new(server);
399        state.force_close();
400        assert!(state.flags().contains(Flags::DSP_STOP));
401        assert!(state.flags().contains(Flags::IO_STOPPING));
402    }
403
404    #[ntex::test]
405    async fn read_readiness() {
406        let (client, server) = IoTest::create();
407        client.remote_buffer_cap(1024);
408
409        let io = Io::new(server);
410        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
411
412        client.write(TEXT);
413        assert_eq!(io.read_ready().await.unwrap(), Some(()));
414        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
415
416        let item = io.with_read_buf(|buffer| buffer.split());
417        assert_eq!(item, Bytes::from_static(BIN));
418
419        client.write(TEXT);
420        sleep(Millis(50)).await;
421        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
422        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
423    }
424
425    #[ntex::test]
426    #[allow(clippy::unit_cmp)]
427    async fn on_disconnect() {
428        let (client, server) = IoTest::create();
429        let state = Io::new(server);
430        let mut waiter = state.on_disconnect();
431        assert_eq!(
432            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
433            Poll::Pending
434        );
435        let mut waiter2 = waiter.clone();
436        assert_eq!(
437            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
438            Poll::Pending
439        );
440        client.close().await;
441        assert_eq!(waiter.await, ());
442        assert_eq!(waiter2.await, ());
443
444        let mut waiter = state.on_disconnect();
445        assert_eq!(
446            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
447            Poll::Ready(())
448        );
449
450        let (client, server) = IoTest::create();
451        let state = Io::new(server);
452        let mut waiter = state.on_disconnect();
453        assert_eq!(
454            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
455            Poll::Pending
456        );
457        client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
458        assert_eq!(waiter.await, ());
459    }
460
461    #[ntex::test]
462    async fn write_to_closed_io() {
463        let (client, server) = IoTest::create();
464        let state = Io::new(server);
465        client.close().await;
466
467        assert!(state.is_closed());
468        assert!(state.write(TEXT.as_bytes()).is_err());
469        assert!(state
470            .with_write_buf(|buf| buf.extend_from_slice(BIN))
471            .is_err());
472    }
473
474    #[derive(Debug)]
475    struct Counter {
476        idx: usize,
477        in_bytes: Rc<Cell<usize>>,
478        out_bytes: Rc<Cell<usize>>,
479        read_order: Rc<RefCell<Vec<usize>>>,
480        write_order: Rc<RefCell<Vec<usize>>>,
481    }
482
483    impl FilterLayer for Counter {
484        const BUFFERS: bool = false;
485
486        fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
487            self.read_order.borrow_mut().push(self.idx);
488            self.in_bytes.set(self.in_bytes.get() + buf.nbytes());
489            Ok(buf.nbytes())
490        }
491
492        fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
493            self.write_order.borrow_mut().push(self.idx);
494            self.out_bytes
495                .set(self.out_bytes.get() + buf.with_dst(|b| b.len()));
496            Ok(())
497        }
498    }
499
500    #[ntex::test]
501    async fn filter() {
502        let in_bytes = Rc::new(Cell::new(0));
503        let out_bytes = Rc::new(Cell::new(0));
504        let read_order = Rc::new(RefCell::new(Vec::new()));
505        let write_order = Rc::new(RefCell::new(Vec::new()));
506
507        let (client, server) = IoTest::create();
508        let counter = Counter {
509            idx: 1,
510            in_bytes: in_bytes.clone(),
511            out_bytes: out_bytes.clone(),
512            read_order: read_order.clone(),
513            write_order: write_order.clone(),
514        };
515        let _ = format!("{:?}", counter);
516        let io = Io::new(server).add_filter(counter);
517
518        client.remote_buffer_cap(1024);
519        client.write(TEXT);
520        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
521        assert_eq!(msg, Bytes::from_static(BIN));
522
523        io.send(Bytes::from_static(b"test"), &BytesCodec)
524            .await
525            .unwrap();
526        let buf = client.read().await.unwrap();
527        assert_eq!(buf, Bytes::from_static(b"test"));
528
529        assert_eq!(in_bytes.get(), BIN.len());
530        assert_eq!(out_bytes.get(), 4);
531    }
532
533    #[ntex::test]
534    async fn boxed_filter() {
535        let in_bytes = Rc::new(Cell::new(0));
536        let out_bytes = Rc::new(Cell::new(0));
537        let read_order = Rc::new(RefCell::new(Vec::new()));
538        let write_order = Rc::new(RefCell::new(Vec::new()));
539
540        let (client, server) = IoTest::create();
541        let state = Io::new(server)
542            .add_filter(Counter {
543                idx: 1,
544                in_bytes: in_bytes.clone(),
545                out_bytes: out_bytes.clone(),
546                read_order: read_order.clone(),
547                write_order: write_order.clone(),
548            })
549            .add_filter(Counter {
550                idx: 2,
551                in_bytes: in_bytes.clone(),
552                out_bytes: out_bytes.clone(),
553                read_order: read_order.clone(),
554                write_order: write_order.clone(),
555            });
556        let state = state.seal();
557
558        client.remote_buffer_cap(1024);
559        client.write(TEXT);
560        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
561        assert_eq!(msg, Bytes::from_static(BIN));
562
563        state
564            .send(Bytes::from_static(b"test"), &BytesCodec)
565            .await
566            .unwrap();
567        let buf = client.read().await.unwrap();
568        assert_eq!(buf, Bytes::from_static(b"test"));
569
570        assert_eq!(in_bytes.get(), BIN.len() * 2);
571        assert_eq!(out_bytes.get(), 8);
572        assert_eq!(
573            state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
574            0
575        );
576
577        // refs
578        assert_eq!(Rc::strong_count(&in_bytes), 3);
579        drop(state);
580        assert_eq!(Rc::strong_count(&in_bytes), 1);
581        assert_eq!(*read_order.borrow(), &[1, 2][..]);
582        assert_eq!(*write_order.borrow(), &[2, 1][..]);
583    }
584}