Skip to main content

ntex_io/
ctx.rs

1use std::{fmt, io, task::Context, task::Poll};
2
3use ntex_bytes::{BytePages, BytesMut};
4use ntex_util::time::sleep;
5
6use crate::{Flags, Id, IoRef, IoTaskStatus, Readiness, io::IoState};
7
8/// Context for io read task
9pub struct IoContext(IoRef);
10
11impl fmt::Debug for IoContext {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        f.debug_struct("IoContext").field("io", &self.0).finish()
14    }
15}
16
17impl IoContext {
18    pub(crate) fn new(io: IoRef) -> Self {
19        Self(io)
20    }
21
22    fn st(&self) -> &IoState {
23        &self.0.0
24    }
25
26    #[doc(hidden)]
27    #[inline]
28    pub fn id(&self) -> Id {
29        self.0.id()
30    }
31
32    #[inline]
33    /// Gets the I/O tag.
34    pub fn tag(&self) -> &'static str {
35        self.0.tag()
36    }
37
38    #[doc(hidden)]
39    /// Gets the flags.
40    pub fn flags(&self) -> Flags {
41        self.0.flags()
42    }
43
44    #[inline]
45    /// Checks readiness for read operations.
46    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
47        self.shutdown_filters(cx);
48        self.0.filter().poll_read_ready(cx)
49    }
50
51    #[inline]
52    /// Checks readiness for write operations.
53    pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
54        self.0.filter().poll_write_ready(cx)
55    }
56
57    /// Stops the I/O stream.
58    pub fn stop(&self, e: Option<io::Error>) {
59        self.st().terminate_connection(e);
60    }
61
62    /// Checks if the I/O stream is stopped.
63    pub fn is_stopped(&self) -> bool {
64        self.st().flags.is_closed()
65    }
66
67    /// Gets the read buffer.
68    pub fn get_read_buf(&self) -> BytesMut {
69        let st = self.st();
70
71        if st.flags.is_read_ready() {
72            // The dispatcher has not consumed the read buffer yet,
73            // so we must not modify it.
74            st.get_read_buf()
75        } else if let Some(mut buf) = st.buffer.get_read_buf() {
76            self.0.resize_read_buf(&mut buf);
77            buf
78        } else {
79            st.get_read_buf()
80        }
81    }
82
83    /// Resizes the read buffer.
84    pub fn resize_read_buf(&self, buf: &mut BytesMut) {
85        self.0.resize_read_buf(buf);
86    }
87
88    /// Updates the read status.
89    ///
90    /// Returns `Ok(Some(buf))` containing the read buffer.
91    /// `Ok(None)` indicates that the connection has been disconnected.
92    pub fn update_read_status(
93        &self,
94        buf: BytesMut,
95        status: io::Result<usize>,
96    ) -> IoTaskStatus {
97        let st = self.st();
98        let orig = st.buffer.read_dst_size();
99
100        #[cfg(feature = "trace")]
101        log::trace!(
102            "{}: read-status == {status:?} orig:{orig:?} flags:{:?}",
103            st.tag(),
104            st.flags
105        );
106
107        // release read buffer
108        st.buffer.set_read_buf(buf, self.0.cfg());
109
110        // process read buf
111        let result = status.and_then(|nbytes| {
112            if nbytes == 0 {
113                return Ok(());
114            }
115            st.buffer.process_read_buf(&self.0, nbytes).map(|status| {
116                let size = st.buffer.read_dst_size();
117
118                // The destination read buffer has new data, wake up the dispatcher
119                if size > orig {
120                    if st.is_rd_backpressure_needed(size) {
121                        log::trace!("{}: Read buf({size}), enable back-pressure", st.tag());
122                        st.flags.set_read_ready_and_backpressure();
123                    } else {
124                        st.flags.set_read_ready();
125                    }
126                    #[cfg(feature = "trace")]
127                    log::trace!("{}: New {size} bytes available", st.tag());
128                    st.wake_dispatch_task();
129                }
130
131                if st.flags.is_read_notify() {
132                    // If the "notify" flag is set, we must wake the
133                    // dispatcher task whenever data is read from the source.
134                    st.wake_dispatch_task();
135                    st.flags.set_read_notifed();
136                }
137
138                // Check if the filter wrote data during buffer processing
139                if status.wants_write {
140                    if let Err(err) = st.buffer.process_write_buf_force(&self.0) {
141                        st.terminate_connection(Some(err));
142                    } else {
143                        self.0.consolidate_write_state(false);
144                    }
145                }
146
147                // Check whether the filter notifies about readiness changes
148                if status.notify {
149                    self.0.call_notify();
150                }
151            })
152        });
153
154        if let Err(err) = result {
155            st.terminate_connection(Some(err));
156            IoTaskStatus::Stop
157        } else if st.flags.is_closed() {
158            IoTaskStatus::Stop
159        } else if st.flags.is_read_paused_or_backpressure() {
160            IoTaskStatus::Pause
161        } else {
162            IoTaskStatus::Io
163        }
164    }
165
166    /// Gets the write buffer.
167    pub fn with_write_buf<F, R>(&self, f: F) -> R
168    where
169        F: FnOnce(&mut BytePages) -> R,
170    {
171        // Write buffer processing may be delayed
172        if let Err(e) = self.st().buffer.process_write_buf(&self.0) {
173            self.st().terminate_connection(Some(e));
174        }
175
176        self.st().buffer.with_write_dst(|buffer| f(buffer))
177    }
178
179    /// Updates the write status.
180    ///
181    /// `Ok(true)` indicates that one or more bytes were successfully written
182    /// to the I/O stream.
183    pub fn update_write_status(&self, status: io::Result<bool>) -> IoTaskStatus {
184        let st = &self.st();
185
186        #[cfg(feature = "trace")]
187        log::trace!(
188            "{}: write-status == {status:?} buf:{} flags:{:?}",
189            st.tag(),
190            st.buffer.write_buf_size(),
191            st.flags
192        );
193
194        match status {
195            Ok(written) => {
196                let len = st.buffer.write_buf_size();
197                // Full flush is active
198                if st.flags.is_write_flush() {
199                    // The write buffer must be fully written
200                    if len == 0 {
201                        st.wake_dispatch_task();
202                    }
203                } else if st.flags.is_wr_backpressure()
204                    && st.should_disable_wr_backpressure(len)
205                {
206                    // Write backpressure is active and write buffer is below threshold
207                    st.wake_dispatch_task();
208                }
209
210                // Write notify is enabled
211                if written && st.flags.is_write_notify() {
212                    st.flags.unset_write_notify();
213                    st.wake_read_task();
214                    st.wake_write_task();
215                }
216
217                if st.flags.is_closed() {
218                    IoTaskStatus::Stop
219                } else if len == 0 {
220                    // All data has been written, pause the write task.
221                    st.flags.set_write_paused();
222                    if st.flags.is_stopping_filters() {
223                        st.wake_read_task();
224                    }
225                    IoTaskStatus::Pause
226                } else {
227                    st.flags.unset_write_paused();
228                    IoTaskStatus::Io
229                }
230            }
231            Err(err) => {
232                st.terminate_connection(Some(err));
233                IoTaskStatus::Stop
234            }
235        }
236    }
237
238    /// Waits for the I/O stream to close or begin closing.
239    pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
240        let st = self.st();
241        if flush && !st.flags.is_stopping() {
242            if st.flags.is_write_paused() {
243                return Poll::Ready(());
244            }
245            st.flags.set_write_notify();
246            st.read_task.register(cx.waker());
247            st.write_task.register(cx.waker());
248            Poll::Pending
249        } else if !st.flags.is_closed() {
250            st.read_task.register(cx.waker());
251            st.write_task.register(cx.waker());
252            Poll::Pending
253        } else {
254            Poll::Ready(())
255        }
256    }
257
258    fn shutdown_filters(&self, cx: &mut Context<'_>) {
259        let st = &self.st();
260        if !st.flags.is_shutting_down_filters() {
261            return;
262        }
263
264        // process filter shutdown
265        let ready = match st.buffer.process_shutdown(&self.0) {
266            Ok(Poll::Ready(())) => true,
267            Ok(Poll::Pending) => false,
268            Err(err) => {
269                st.terminate_connection(Some(err));
270                return;
271            }
272        };
273        self.0.consolidate_write_state(true);
274
275        #[cfg(feature = "trace")]
276        log::trace!(
277            "{}: shutdown filters, done:{ready:?} wr-buf:{:?}, flags:{:?}",
278            st.tag(),
279            st.buffer.write_buf_size(),
280            st.flags,
281        );
282
283        // filters are shutdown and write task is paused
284        if ready && st.flags.is_write_paused() && !st.flags.is_wr_send_scheduled() {
285            st.filters_stopped();
286        } else if st.flags.is_read_paused() || st.flags.is_read_ready_and_backpressure() {
287            // if read buffer is not consumed it is unlikely
288            // that filter will properly complete shutdown
289            st.filters_stopped();
290        } else {
291            // filter shutdown timeout
292            let timeout = st
293                .shutdown_timeout
294                .take()
295                .unwrap_or_else(|| sleep(st.cfg.disconnect_timeout()));
296            if timeout.poll_elapsed(cx).is_ready() {
297                st.filters_stopped();
298            } else {
299                st.shutdown_timeout.set(Some(timeout));
300            }
301        }
302    }
303
304    /// Notifies read tasks.
305    pub fn notify(&self) {
306        self.0.0.wake_read_task();
307    }
308}
309
310impl Clone for IoContext {
311    fn clone(&self) -> Self {
312        Self(self.0.clone())
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::{Io, testing::IoTest};
320
321    #[ntex::test]
322    async fn ctx_basics() {
323        let (_, server) = IoTest::create();
324
325        let state = Io::from(server);
326        let ctx = IoContext::new(state.get_ref());
327        let _ = ctx.flags();
328        assert!(ctx.id() != Id::default());
329        assert!(format!("{ctx:?}").contains("IoContext"));
330    }
331}