ntex_io/
ioref.rs

1use std::{any, fmt, hash, io};
2
3use ntex_bytes::BytesVec;
4use ntex_codec::{Decoder, Encoder};
5use ntex_service::cfg::SharedCfg;
6use ntex_util::time::Seconds;
7
8use crate::{
9    Decoded, Filter, FilterCtx, Flags, IoConfig, IoRef, OnDisconnect, WriteBuf, timer,
10    types,
11};
12
13impl IoRef {
14    #[inline]
15    /// Get tag
16    pub fn tag(&self) -> &'static str {
17        self.0.cfg.get().tag()
18    }
19
20    #[inline]
21    #[doc(hidden)]
22    /// Get current state flags
23    pub fn flags(&self) -> Flags {
24        self.0.flags.get()
25    }
26
27    #[inline]
28    /// Get current filter
29    pub(crate) fn filter(&self) -> &dyn Filter {
30        self.0.filter()
31    }
32
33    #[inline]
34    /// Get configuration
35    pub fn cfg(&self) -> &IoConfig {
36        self.0.cfg.get()
37    }
38
39    #[inline]
40    /// Get shared configuration
41    pub fn shared(&self) -> SharedCfg {
42        self.0.cfg.get().config.shared()
43    }
44
45    #[inline]
46    /// Check if io stream is closed
47    pub fn is_closed(&self) -> bool {
48        self.0
49            .flags
50            .get()
51            .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
52    }
53
54    #[inline]
55    /// Check if write back-pressure is enabled
56    pub fn is_wr_backpressure(&self) -> bool {
57        self.0.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
58    }
59
60    #[inline]
61    /// Wake dispatcher task
62    pub fn wake(&self) {
63        self.0.dispatch_task.wake();
64    }
65
66    #[inline]
67    /// Gracefully close connection
68    ///
69    /// Notify dispatcher and initiate io stream shutdown process.
70    pub fn close(&self) {
71        self.0.insert_flags(Flags::DSP_STOP);
72        self.0.init_shutdown();
73    }
74
75    #[inline]
76    /// Force close connection
77    ///
78    /// Dispatcher does not wait for uncompleted responses. Io stream get terminated
79    /// without any graceful period.
80    pub fn force_close(&self) {
81        log::trace!("{}: Force close io stream object", self.tag());
82        self.0.insert_flags(
83            Flags::DSP_STOP
84                | Flags::IO_STOPPED
85                | Flags::IO_STOPPING
86                | Flags::IO_STOPPING_FILTERS,
87        );
88        self.0.read_task.wake();
89        self.0.write_task.wake();
90        self.0.dispatch_task.wake();
91    }
92
93    #[inline]
94    /// Gracefully shutdown io stream
95    pub fn want_shutdown(&self) {
96        if !self
97            .0
98            .flags
99            .get()
100            .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
101        {
102            log::trace!(
103                "{}: Initiate io shutdown {:?}",
104                self.tag(),
105                self.0.flags.get()
106            );
107            self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
108            self.0.read_task.wake();
109        }
110    }
111
112    #[inline]
113    /// Query filter specific data
114    pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
115        if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
116            types::QueryItem::new(item)
117        } else {
118            types::QueryItem::empty()
119        }
120    }
121
122    #[inline]
123    /// Encode and write item to a buffer and wake up write task
124    pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
125    where
126        U: Encoder,
127    {
128        if !self.is_closed() {
129            self.with_write_buf(|buf| {
130                // make sure we've got room
131                self.cfg().write_buf().resize(buf);
132
133                // encode item and wake write task
134                codec.encode_vec(item, buf)
135            })
136            // .with_write_buf() could return io::Error<Result<(), U::Error>>,
137            // in that case mark io as failed
138            .unwrap_or_else(|err| {
139                log::trace!(
140                    "{}: Got io error while encoding, error: {:?}",
141                    self.tag(),
142                    err
143                );
144                self.0.io_stopped(Some(err));
145                Ok(())
146            })
147        } else {
148            log::trace!("{}: Io is closed/closing, skip frame encoding", self.tag());
149            Ok(())
150        }
151    }
152
153    #[inline]
154    /// Attempts to decode a frame from the read buffer
155    pub fn decode<U>(
156        &self,
157        codec: &U,
158    ) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
159    where
160        U: Decoder,
161    {
162        self.0
163            .buffer
164            .with_read_destination(self, |buf| codec.decode_vec(buf))
165    }
166
167    #[inline]
168    /// Attempts to decode a frame from the read buffer
169    pub fn decode_item<U>(
170        &self,
171        codec: &U,
172    ) -> Result<Decoded<<U as Decoder>::Item>, <U as Decoder>::Error>
173    where
174        U: Decoder,
175    {
176        self.0.buffer.with_read_destination(self, |buf| {
177            let len = buf.len();
178            codec.decode_vec(buf).map(|item| Decoded {
179                item,
180                remains: buf.len(),
181                consumed: len - buf.len(),
182            })
183        })
184    }
185
186    #[inline]
187    /// Write bytes to a buffer and wake up write task
188    pub fn write(&self, src: &[u8]) -> io::Result<()> {
189        self.with_write_buf(|buf| buf.extend_from_slice(src))
190    }
191
192    #[inline]
193    /// Get access to write buffer
194    pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
195    where
196        F: FnOnce(&WriteBuf<'_>) -> R,
197    {
198        let ctx = FilterCtx::new(self, &self.0.buffer);
199        let result = ctx.write_buf(f);
200        self.0.filter().process_write_buf(ctx)?;
201        Ok(result)
202    }
203
204    #[inline]
205    /// Get mut access to source write buffer
206    pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
207    where
208        F: FnOnce(&mut BytesVec) -> R,
209    {
210        if self.0.flags.get().contains(Flags::IO_STOPPED) {
211            Err(self.0.error_or_disconnected())
212        } else {
213            let result = self.0.buffer.with_write_source(self, f);
214            self.0
215                .filter()
216                .process_write_buf(FilterCtx::new(self, &self.0.buffer))?;
217            Ok(result)
218        }
219    }
220
221    #[doc(hidden)]
222    #[inline]
223    /// Get mut access to destination write buffer
224    pub fn with_write_dest_buf<F, R>(&self, f: F) -> R
225    where
226        F: FnOnce(Option<&mut BytesVec>) -> R,
227    {
228        self.0.buffer.with_write_destination(self, f)
229    }
230
231    #[inline]
232    /// Get mut access to source read buffer
233    pub fn with_read_buf<F, R>(&self, f: F) -> R
234    where
235        F: FnOnce(&mut BytesVec) -> R,
236    {
237        self.0.buffer.with_read_destination(self, f)
238    }
239
240    #[inline]
241    /// Wakeup dispatcher
242    pub fn notify_dispatcher(&self) {
243        self.0.dispatch_task.wake();
244        log::trace!("{}: Timer, notify dispatcher", self.tag());
245    }
246
247    #[inline]
248    /// Wakeup dispatcher and send keep-alive error
249    pub fn notify_timeout(&self) {
250        self.0.notify_timeout()
251    }
252
253    #[inline]
254    /// current timer handle
255    pub fn timer_handle(&self) -> timer::TimerHandle {
256        self.0.timeout.get()
257    }
258
259    #[inline]
260    /// Start timer
261    pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle {
262        let cur_hnd = self.0.timeout.get();
263
264        if !timeout.is_zero() {
265            if cur_hnd.is_set() {
266                let hnd = timer::update(cur_hnd, timeout, self);
267                if hnd != cur_hnd {
268                    log::trace!("{}: Update timer {:?}", self.tag(), timeout);
269                    self.0.timeout.set(hnd);
270                }
271                hnd
272            } else {
273                log::trace!("{}: Start timer {:?}", self.tag(), timeout);
274                let hnd = timer::register(timeout, self);
275                self.0.timeout.set(hnd);
276                hnd
277            }
278        } else {
279            if cur_hnd.is_set() {
280                self.0.timeout.set(timer::TimerHandle::ZERO);
281                timer::unregister(cur_hnd, self);
282            }
283            timer::TimerHandle::ZERO
284        }
285    }
286
287    #[inline]
288    /// Stop timer
289    pub fn stop_timer(&self) {
290        let hnd = self.0.timeout.get();
291        if hnd.is_set() {
292            log::trace!("{}: Stop timer", self.tag());
293            self.0.timeout.set(timer::TimerHandle::ZERO);
294            timer::unregister(hnd, self)
295        }
296    }
297
298    #[inline]
299    /// Notify when io stream get disconnected
300    pub fn on_disconnect(&self) -> OnDisconnect {
301        OnDisconnect::new(self.0.clone())
302    }
303}
304
305impl Eq for IoRef {}
306
307impl PartialEq for IoRef {
308    #[inline]
309    fn eq(&self, other: &Self) -> bool {
310        self.0.eq(&other.0)
311    }
312}
313
314impl hash::Hash for IoRef {
315    #[inline]
316    fn hash<H: hash::Hasher>(&self, state: &mut H) {
317        self.0.hash(state);
318    }
319}
320
321impl fmt::Debug for IoRef {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        f.debug_struct("IoRef")
324            .field("state", self.0.as_ref())
325            .finish()
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use std::cell::{Cell, RefCell};
332    use std::{future::Future, future::poll_fn, pin::Pin, rc::Rc, task::Poll};
333
334    use ntex_bytes::Bytes;
335    use ntex_codec::BytesCodec;
336    use ntex_util::future::lazy;
337    use ntex_util::time::{Millis, sleep};
338
339    use super::*;
340    use crate::{FilterCtx, FilterReadStatus, Io, testing::IoTest};
341
342    const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
343    const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
344
345    #[ntex::test]
346    async fn utils() {
347        let (client, server) = IoTest::create();
348        client.remote_buffer_cap(1024);
349        client.write(TEXT);
350
351        let state = Io::from(server);
352        assert_eq!(state.get_ref(), state.get_ref());
353
354        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
355        assert_eq!(msg, Bytes::from_static(BIN));
356        assert_eq!(state.get_ref(), state.as_ref().clone());
357        assert!(format!("{state:?}").find("Io {").is_some());
358        assert!(format!("{:?}", state.get_ref()).find("IoRef {").is_some());
359
360        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
361        assert!(res.is_pending());
362        client.write(TEXT);
363        sleep(Millis(50)).await;
364        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
365        if let Poll::Ready(msg) = res {
366            assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
367        }
368
369        client.read_error(io::Error::other("err"));
370        let msg = state.recv(&BytesCodec).await;
371        assert!(msg.is_err());
372        assert!(state.flags().contains(Flags::IO_STOPPED));
373
374        let (client, server) = IoTest::create();
375        client.remote_buffer_cap(1024);
376        let state = Io::from(server);
377
378        client.read_error(io::Error::other("err"));
379        let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
380        if let Poll::Ready(msg) = res {
381            assert!(msg.is_err());
382            assert!(state.flags().contains(Flags::IO_STOPPED));
383            assert!(state.flags().contains(Flags::DSP_STOP));
384        }
385
386        let (client, server) = IoTest::create();
387        client.remote_buffer_cap(1024);
388        let state = Io::from(server);
389        state.write(b"test").unwrap();
390        let buf = client.read().await.unwrap();
391        assert_eq!(buf, Bytes::from_static(b"test"));
392
393        client.write(b"test");
394        state.read_ready().await.unwrap();
395        let buf = state.decode(&BytesCodec).unwrap().unwrap();
396        assert_eq!(buf, Bytes::from_static(b"test"));
397
398        client.write_error(io::Error::other("err"));
399        let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
400        assert!(res.is_err());
401        assert!(state.flags().contains(Flags::IO_STOPPED));
402
403        let (client, server) = IoTest::create();
404        client.remote_buffer_cap(1024);
405        let state = Io::from(server);
406        state.force_close();
407        assert!(state.flags().contains(Flags::DSP_STOP));
408        assert!(state.flags().contains(Flags::IO_STOPPING));
409    }
410
411    #[ntex::test]
412    async fn read_readiness() {
413        let (client, server) = IoTest::create();
414        client.remote_buffer_cap(1024);
415
416        let io = Io::from(server);
417        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
418
419        client.write(TEXT);
420        assert_eq!(io.read_ready().await.unwrap(), Some(()));
421        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
422
423        let item = io.with_read_buf(|buffer| buffer.split());
424        assert_eq!(item, Bytes::from_static(BIN));
425
426        client.write(TEXT);
427        sleep(Millis(50)).await;
428        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_ready());
429        assert!(lazy(|cx| io.poll_read_ready(cx)).await.is_pending());
430    }
431
432    #[ntex::test]
433    #[allow(clippy::unit_cmp)]
434    async fn on_disconnect() {
435        let (client, server) = IoTest::create();
436        let state = Io::from(server);
437        let mut waiter = state.on_disconnect();
438        assert_eq!(
439            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
440            Poll::Pending
441        );
442        let mut waiter2 = waiter.clone();
443        assert_eq!(
444            lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
445            Poll::Pending
446        );
447        client.close().await;
448        assert_eq!(waiter.await, ());
449        assert_eq!(waiter2.await, ());
450
451        let mut waiter = state.on_disconnect();
452        assert_eq!(
453            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
454            Poll::Ready(())
455        );
456
457        let (client, server) = IoTest::create();
458        let state = Io::from(server);
459        let mut waiter = state.on_disconnect();
460        assert_eq!(
461            lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
462            Poll::Pending
463        );
464        client.read_error(io::Error::other("err"));
465        assert_eq!(waiter.await, ());
466    }
467
468    #[ntex::test]
469    async fn write_to_closed_io() {
470        let (client, server) = IoTest::create();
471        let state = Io::from(server);
472        client.close().await;
473
474        assert!(state.is_closed());
475        assert!(state.write(TEXT.as_bytes()).is_err());
476        assert!(
477            state
478                .with_write_buf(|buf| buf.extend_from_slice(BIN))
479                .is_err()
480        );
481    }
482
483    #[derive(Debug)]
484    struct Counter<F> {
485        layer: F,
486        idx: usize,
487        in_bytes: Rc<Cell<usize>>,
488        out_bytes: Rc<Cell<usize>>,
489        read_order: Rc<RefCell<Vec<usize>>>,
490        write_order: Rc<RefCell<Vec<usize>>>,
491    }
492
493    impl<F: Filter> Filter for Counter<F> {
494        fn process_read_buf(
495            &self,
496            ctx: FilterCtx<'_>,
497            nbytes: usize,
498        ) -> io::Result<FilterReadStatus> {
499            self.read_order.borrow_mut().push(self.idx);
500            self.in_bytes.set(self.in_bytes.get() + nbytes);
501            self.layer.process_read_buf(ctx, nbytes)
502        }
503
504        fn process_write_buf(&self, ctx: FilterCtx<'_>) -> io::Result<()> {
505            self.write_order.borrow_mut().push(self.idx);
506            self.out_bytes.set(
507                self.out_bytes.get()
508                    + ctx.write_buf(|buf| {
509                        buf.with_src(|b| b.as_ref().map(|b| b.len()).unwrap_or_default())
510                    }),
511            );
512            self.layer.process_write_buf(ctx)
513        }
514
515        crate::forward_ready!(layer);
516        crate::forward_query!(layer);
517        crate::forward_shutdown!(layer);
518    }
519
520    #[ntex::test]
521    async fn filter() {
522        let in_bytes = Rc::new(Cell::new(0));
523        let out_bytes = Rc::new(Cell::new(0));
524        let read_order = Rc::new(RefCell::new(Vec::new()));
525        let write_order = Rc::new(RefCell::new(Vec::new()));
526
527        let (client, server) = IoTest::create();
528        let io = Io::from(server).map_filter(|layer| Counter {
529            layer,
530            idx: 1,
531            in_bytes: in_bytes.clone(),
532            out_bytes: out_bytes.clone(),
533            read_order: read_order.clone(),
534            write_order: write_order.clone(),
535        });
536
537        client.remote_buffer_cap(1024);
538        client.write(TEXT);
539        let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
540        assert_eq!(msg, Bytes::from_static(BIN));
541
542        io.send(Bytes::from_static(b"test"), &BytesCodec)
543            .await
544            .unwrap();
545        let buf = client.read().await.unwrap();
546        assert_eq!(buf, Bytes::from_static(b"test"));
547
548        assert_eq!(in_bytes.get(), BIN.len());
549        assert_eq!(out_bytes.get(), 4);
550    }
551
552    #[ntex::test]
553    async fn boxed_filter() {
554        let in_bytes = Rc::new(Cell::new(0));
555        let out_bytes = Rc::new(Cell::new(0));
556        let read_order = Rc::new(RefCell::new(Vec::new()));
557        let write_order = Rc::new(RefCell::new(Vec::new()));
558
559        let (client, server) = IoTest::create();
560        let state = Io::from(server)
561            .map_filter(|layer| Counter {
562                layer,
563                idx: 2,
564                in_bytes: in_bytes.clone(),
565                out_bytes: out_bytes.clone(),
566                read_order: read_order.clone(),
567                write_order: write_order.clone(),
568            })
569            .map_filter(|layer| Counter {
570                layer,
571                idx: 1,
572                in_bytes: in_bytes.clone(),
573                out_bytes: out_bytes.clone(),
574                read_order: read_order.clone(),
575                write_order: write_order.clone(),
576            });
577        let state = state.seal();
578
579        client.remote_buffer_cap(1024);
580        client.write(TEXT);
581        let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
582        assert_eq!(msg, Bytes::from_static(BIN));
583
584        state
585            .send(Bytes::from_static(b"test"), &BytesCodec)
586            .await
587            .unwrap();
588        let buf = client.read().await.unwrap();
589        assert_eq!(buf, Bytes::from_static(b"test"));
590
591        assert_eq!(in_bytes.get(), BIN.len() * 2);
592        assert_eq!(out_bytes.get(), 8);
593        assert_eq!(
594            state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)),
595            0
596        );
597
598        // refs
599        assert_eq!(Rc::strong_count(&in_bytes), 3);
600        drop(state);
601        assert_eq!(Rc::strong_count(&in_bytes), 1);
602        assert_eq!(*read_order.borrow(), &[1, 2][..]);
603        assert_eq!(*write_order.borrow(), &[1, 2][..]);
604    }
605}