ntex_io/
tasks.rs

1use std::{cell::Cell, fmt, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BytesVec};
4use ntex_util::time::{Sleep, sleep};
5
6use crate::{FilterCtx, Flags, IoRef, IoTaskStatus, Readiness};
7
8/// Context for io read task
9pub struct IoContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for IoContext {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        f.debug_struct("IoContext").field("io", &self.0).finish()
14    }
15}
16
17impl IoContext {
18    pub(crate) fn new(io: &IoRef) -> Self {
19        Self(io.clone(), Cell::new(None))
20    }
21
22    #[doc(hidden)]
23    #[inline]
24    pub fn id(&self) -> usize {
25        self.0.0.as_ref() as *const _ as usize
26    }
27
28    #[inline]
29    /// Io tag
30    pub fn tag(&self) -> &'static str {
31        self.0.tag()
32    }
33
34    #[inline]
35    #[doc(hidden)]
36    /// Io flags
37    pub fn flags(&self) -> crate::flags::Flags {
38        self.0.flags()
39    }
40
41    #[inline]
42    /// Check readiness for read operations
43    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
44        self.shutdown_filters(cx);
45        self.0.filter().poll_read_ready(cx)
46    }
47
48    #[inline]
49    /// Check readiness for write operations
50    pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
51        self.0.filter().poll_write_ready(cx)
52    }
53
54    #[inline]
55    /// Stop io
56    pub fn stop(&self, e: Option<io::Error>) {
57        self.0.0.io_stopped(e);
58    }
59
60    #[inline]
61    /// Check if Io stopped
62    pub fn is_stopped(&self) -> bool {
63        self.0.flags().is_stopped()
64    }
65
66    /// Wait when io get closed or preparing for close
67    pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
68        let st = &self.0.0;
69        let flags = self.0.flags();
70
71        if flush && !flags.contains(Flags::IO_STOPPED) {
72            if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
73                return Poll::Ready(());
74            }
75            st.insert_flags(Flags::WR_TASK_WAIT);
76            st.write_task.register(cx.waker());
77            Poll::Pending
78        } else if !flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
79            st.write_task.register(cx.waker());
80            Poll::Pending
81        } else {
82            Poll::Ready(())
83        }
84    }
85
86    /// Get read buffer
87    pub fn get_read_buf(&self) -> BytesVec {
88        let inner = &self.0.0;
89
90        if inner.flags.get().is_read_buf_ready() {
91            // read buffer is still not read by dispatcher
92            // we cannot touch it
93            inner.read_buf().get()
94        } else {
95            inner
96                .buffer
97                .get_read_source()
98                .unwrap_or_else(|| inner.read_buf().get())
99        }
100    }
101
102    /// Resize read buffer
103    pub fn resize_read_buf(&self, buf: &mut BytesVec) {
104        self.0.0.read_buf().resize(buf);
105    }
106
107    /// Set read buffer
108    pub fn release_read_buf(
109        &self,
110        nbytes: usize,
111        buf: BytesVec,
112        result: Poll<Result<(), Option<io::Error>>>,
113    ) -> IoTaskStatus {
114        let inner = &self.0.0;
115        let orig_size = inner.buffer.read_destination_size();
116        let hw = self.0.cfg().read_buf().high;
117
118        if let Some(mut first_buf) = inner.buffer.get_read_source() {
119            first_buf.extend_from_slice(&buf);
120            inner.buffer.set_read_source(&self.0, first_buf);
121        } else {
122            inner.buffer.set_read_source(&self.0, buf);
123        }
124
125        let mut full = false;
126
127        // handle buffer changes
128        let st_res = if nbytes > 0 {
129            match self
130                .0
131                .filter()
132                .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
133            {
134                Ok(status) => {
135                    let buffer_size = inner.buffer.read_destination_size();
136                    if buffer_size.saturating_sub(orig_size) > 0 {
137                        // dest buffer has new data, wake up dispatcher
138                        if buffer_size >= hw {
139                            log::trace!(
140                                "{}: Io read buffer is too large {}, enable read back-pressure",
141                                self.tag(),
142                                buffer_size
143                            );
144                            full = true;
145                            inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
146                        } else {
147                            inner.insert_flags(Flags::BUF_R_READY);
148                        }
149                        log::trace!(
150                            "{}: New {} bytes available, wakeup dispatcher",
151                            self.tag(),
152                            buffer_size
153                        );
154                        inner.dispatch_task.wake();
155                    } else {
156                        if buffer_size >= hw {
157                            // read task is paused because of read back-pressure
158                            // but there is no new data in top most read buffer
159                            // so we need to wake up read task to read more data
160                            // otherwise read task would sleep forever
161                            full = true;
162                            inner.read_task.wake();
163                        }
164                        if inner.flags.get().is_waiting_for_read() {
165                            // in case of "notify" we must wake up dispatch task
166                            // if we read any data from source
167                            inner.dispatch_task.wake();
168                        }
169                    }
170
171                    // while reading, filter wrote some data
172                    // in that case filters need to process write buffers
173                    // and potentialy wake write task
174                    if status.need_write {
175                        self.0
176                            .filter()
177                            .process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
178                    } else {
179                        Ok(())
180                    }
181                }
182                Err(err) => Err(err),
183            }
184        } else {
185            Ok(())
186        };
187
188        match result {
189            Poll::Ready(Ok(_)) => {
190                if let Err(e) = st_res {
191                    inner.io_stopped(Some(e));
192                    IoTaskStatus::Stop
193                } else if nbytes == 0 {
194                    inner.io_stopped(None);
195                    IoTaskStatus::Stop
196                } else if full {
197                    IoTaskStatus::Pause
198                } else {
199                    IoTaskStatus::Io
200                }
201            }
202            Poll::Ready(Err(e)) => {
203                inner.io_stopped(e);
204                IoTaskStatus::Stop
205            }
206            Poll::Pending => {
207                if let Err(e) = st_res {
208                    inner.io_stopped(Some(e));
209                    IoTaskStatus::Stop
210                } else if full {
211                    IoTaskStatus::Pause
212                } else {
213                    IoTaskStatus::Io
214                }
215            }
216        }
217    }
218
219    #[inline]
220    /// Get write buffer
221    pub fn get_write_buf(&self) -> Option<BytesVec> {
222        self.0
223            .0
224            .buffer
225            .get_write_destination()
226            .and_then(|buf| if buf.is_empty() { None } else { Some(buf) })
227    }
228
229    /// Set write buffer
230    pub fn release_write_buf(
231        &self,
232        mut buf: BytesVec,
233        result: Poll<io::Result<usize>>,
234    ) -> IoTaskStatus {
235        let result = match result {
236            Poll::Ready(Ok(0)) => {
237                log::trace!("{}: Disconnected during flush", self.tag());
238                Err(io::Error::new(
239                    io::ErrorKind::WriteZero,
240                    "failed to write frame to transport",
241                ))
242            }
243            Poll::Ready(Ok(n)) => {
244                if n == buf.len() {
245                    buf.clear();
246                    Ok(0)
247                } else {
248                    buf.advance(n);
249                    Ok(buf.len())
250                }
251            }
252            Poll::Ready(Err(e)) => Err(e),
253            Poll::Pending => Ok(buf.len()),
254        };
255
256        let inner = &self.0.0;
257
258        // set buffer back
259        let result = match result {
260            Ok(0) => {
261                self.0.cfg().write_buf().release(buf);
262                Ok(inner.buffer.write_destination_size())
263            }
264            Ok(_) => {
265                if let Some(b) = inner.buffer.get_write_destination() {
266                    buf.extend_from_slice(&b);
267                    self.0.cfg().write_buf().release(b);
268                }
269                let l = buf.len();
270                inner.buffer.set_write_destination(buf);
271                Ok(l)
272            }
273            Err(e) => Err(e),
274        };
275
276        match result {
277            Ok(0) => {
278                let mut flags = inner.flags.get();
279
280                // all data has been written
281                flags.insert(Flags::WR_PAUSED);
282
283                if flags.is_task_waiting_for_write() {
284                    flags.task_waiting_for_write_is_done();
285                    inner.write_task.wake();
286                }
287
288                if flags.is_waiting_for_write() {
289                    flags.waiting_for_write_is_done();
290                    inner.dispatch_task.wake();
291                }
292                inner.flags.set(flags);
293                if self.is_stopped() {
294                    IoTaskStatus::Stop
295                } else {
296                    IoTaskStatus::Pause
297                }
298            }
299            Ok(len) => {
300                // if write buffer is smaller than high watermark value, turn off back-pressure
301                if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
302                    && len < inner.write_buf().half
303                {
304                    inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
305                    inner.dispatch_task.wake();
306                }
307                IoTaskStatus::Io
308            }
309            Err(e) => {
310                inner.io_stopped(Some(e));
311                IoTaskStatus::Stop
312            }
313        }
314    }
315
316    fn shutdown_filters(&self, cx: &mut Context<'_>) {
317        let io = &self.0;
318        let st = &self.0.0;
319        let flags = st.flags.get();
320        if flags.contains(Flags::IO_STOPPING_FILTERS)
321            && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
322        {
323            match io.filter().shutdown(FilterCtx::new(io, &st.buffer)) {
324                Ok(Poll::Ready(())) => {
325                    st.dispatch_task.wake();
326                    st.insert_flags(Flags::IO_STOPPING);
327                }
328                Ok(Poll::Pending) => {
329                    // check read buffer, if buffer is not consumed it is unlikely
330                    // that filter will properly complete shutdown
331                    let flags = st.flags.get();
332                    if flags.contains(Flags::RD_PAUSED)
333                        || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
334                    {
335                        st.dispatch_task.wake();
336                        st.insert_flags(Flags::IO_STOPPING);
337                    } else {
338                        // filter shutdown timeout
339                        let timeout = self
340                            .1
341                            .take()
342                            .unwrap_or_else(|| sleep(io.cfg().disconnect_timeout()));
343                        if timeout.poll_elapsed(cx).is_ready() {
344                            st.dispatch_task.wake();
345                            st.insert_flags(Flags::IO_STOPPING);
346                        } else {
347                            self.1.set(Some(timeout));
348                        }
349                    }
350                }
351                Err(err) => {
352                    st.io_stopped(Some(err));
353                }
354            }
355            if let Err(err) = io
356                .filter()
357                .process_write_buf(FilterCtx::new(io, &st.buffer))
358            {
359                st.io_stopped(Some(err));
360            }
361        }
362    }
363}
364
365impl Clone for IoContext {
366    fn clone(&self) -> Self {
367        Self(self.0.clone(), Cell::new(None))
368    }
369}