ntex_io/
tasks.rs

1use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BufMut, BytesVec};
4use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
5
6use crate::{AsyncRead, AsyncWrite, FilterCtx, Flags, IoRef, IoTaskStatus, Readiness};
7
8/// Context for io read task
9pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for ReadContext {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        f.debug_struct("ReadContext").field("io", &self.0).finish()
14    }
15}
16
17impl ReadContext {
18    pub(crate) fn new(io: &IoRef) -> Self {
19        Self(io.clone(), Cell::new(None))
20    }
21
22    #[doc(hidden)]
23    #[inline]
24    /// Io tag
25    pub fn context(&self) -> IoContext {
26        IoContext::new(&self.0)
27    }
28
29    #[inline]
30    /// Io tag
31    pub fn tag(&self) -> &'static str {
32        self.0.tag()
33    }
34
35    /// Wait when io get closed or preparing for close
36    async fn wait_for_close(&self) {
37        poll_fn(|cx| {
38            let flags = self.0.flags();
39
40            if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
41                Poll::Ready(())
42            } else {
43                self.0 .0.read_task.register(cx.waker());
44                if flags.contains(Flags::IO_STOPPING_FILTERS) {
45                    self.shutdown_filters(cx);
46                }
47                Poll::Pending
48            }
49        })
50        .await
51    }
52
53    /// Handle read io operations
54    pub async fn handle<T>(&self, io: &mut T)
55    where
56        T: AsyncRead,
57    {
58        let inner = &self.0 .0;
59
60        loop {
61            let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
62            if result == Readiness::Terminate {
63                log::trace!("{}: Read task is instructed to shutdown", self.tag());
64                break;
65            }
66
67            let mut buf = if inner.flags.get().is_read_buf_ready() {
68                // read buffer is still not read by dispatcher
69                // we cannot touch it
70                inner.pool.get().get_read_buf()
71            } else {
72                inner
73                    .buffer
74                    .get_read_source()
75                    .unwrap_or_else(|| inner.pool.get().get_read_buf())
76            };
77
78            // make sure we've got room
79            let (hw, lw) = self.0.memory_pool().read_params().unpack();
80            let remaining = buf.remaining_mut();
81            if remaining <= lw {
82                buf.reserve(hw - remaining);
83            }
84            let total = buf.len();
85
86            // call provided callback
87            let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
88                Either::Left(res) => res,
89                Either::Right(_) => {
90                    log::trace!("{}: Read io is closed, stop read task", self.tag());
91                    break;
92                }
93            };
94
95            // handle incoming data
96            let total2 = buf.len();
97            let nbytes = total2.saturating_sub(total);
98            let total = total2;
99
100            if let Some(mut first_buf) = inner.buffer.get_read_source() {
101                first_buf.extend_from_slice(&buf);
102                inner.buffer.set_read_source(&self.0, first_buf);
103            } else {
104                inner.buffer.set_read_source(&self.0, buf);
105            }
106
107            // handle buffer changes
108            if nbytes > 0 {
109                let filter = self.0.filter();
110                let res = match filter
111                    .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
112                {
113                    Ok(status) => {
114                        if status.nbytes > 0 {
115                            // check read back-pressure
116                            if hw < inner.buffer.read_destination_size() {
117                                log::trace!(
118                                "{}: Io read buffer is too large {}, enable read back-pressure",
119                                self.0.tag(),
120                                total
121                            );
122                                inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
123                            } else {
124                                inner.insert_flags(Flags::BUF_R_READY);
125                            }
126                            log::trace!(
127                                "{}: New {} bytes available, wakeup dispatcher",
128                                self.0.tag(),
129                                nbytes
130                            );
131                            // dest buffer has new data, wake up dispatcher
132                            inner.dispatch_task.wake();
133                        } else if inner.flags.get().is_waiting_for_read() {
134                            // in case of "notify" we must wake up dispatch task
135                            // if we read any data from source
136                            inner.dispatch_task.wake();
137                        }
138
139                        // while reading, filter wrote some data
140                        // in that case filters need to process write buffers
141                        // and potentialy wake write task
142                        if status.need_write {
143                            filter.process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
144                        } else {
145                            Ok(())
146                        }
147                    }
148                    Err(err) => Err(err),
149                };
150
151                if let Err(err) = res {
152                    inner.dispatch_task.wake();
153                    inner.io_stopped(Some(err));
154                    inner.insert_flags(Flags::BUF_R_READY);
155                }
156            }
157
158            match result {
159                Ok(0) => {
160                    log::trace!("{}: Tcp stream is disconnected", self.tag());
161                    inner.io_stopped(None);
162                    break;
163                }
164                Ok(_) => {
165                    if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
166                        lazy(|cx| self.shutdown_filters(cx)).await;
167                    }
168                }
169                Err(err) => {
170                    log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
171                    inner.io_stopped(Some(err));
172                    break;
173                }
174            }
175        }
176    }
177
178    fn shutdown_filters(&self, cx: &mut Context<'_>) {
179        let st = &self.0 .0;
180        let filter = self.0.filter();
181
182        match filter.shutdown(FilterCtx::new(&self.0, &st.buffer)) {
183            Ok(Poll::Ready(())) => {
184                st.dispatch_task.wake();
185                st.insert_flags(Flags::IO_STOPPING);
186            }
187            Ok(Poll::Pending) => {
188                let flags = st.flags.get();
189
190                // check read buffer, if buffer is not consumed it is unlikely
191                // that filter will properly complete shutdown
192                if flags.contains(Flags::RD_PAUSED)
193                    || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
194                {
195                    st.dispatch_task.wake();
196                    st.insert_flags(Flags::IO_STOPPING);
197                } else {
198                    // filter shutdown timeout
199                    let timeout = self
200                        .1
201                        .take()
202                        .unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
203                    if timeout.poll_elapsed(cx).is_ready() {
204                        st.dispatch_task.wake();
205                        st.insert_flags(Flags::IO_STOPPING);
206                    } else {
207                        self.1.set(Some(timeout));
208                    }
209                }
210            }
211            Err(err) => {
212                st.io_stopped(Some(err));
213            }
214        }
215        if let Err(err) = filter.process_write_buf(FilterCtx::new(&self.0, &st.buffer)) {
216            st.io_stopped(Some(err));
217        }
218    }
219}
220
221#[derive(Debug)]
222/// Context for io write task
223pub struct WriteContext(IoRef);
224
225#[derive(Debug)]
226/// Context buf for io write task
227pub struct WriteContextBuf {
228    io: IoRef,
229    buf: Option<BytesVec>,
230}
231
232impl WriteContext {
233    pub(crate) fn new(io: &IoRef) -> Self {
234        Self(io.clone())
235    }
236
237    #[inline]
238    /// Io tag
239    pub fn tag(&self) -> &'static str {
240        self.0.tag()
241    }
242
243    /// Check readiness for write operations
244    async fn ready(&self) -> Readiness {
245        poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
246    }
247
248    /// Indicate that write io task is stopped
249    fn close(&self, err: Option<io::Error>) {
250        self.0 .0.io_stopped(err);
251    }
252
253    /// Check if io is closed
254    async fn when_stopped(&self) {
255        poll_fn(|cx| {
256            if self.0.flags().is_stopped() {
257                Poll::Ready(())
258            } else {
259                self.0 .0.write_task.register(cx.waker());
260                Poll::Pending
261            }
262        })
263        .await
264    }
265
266    /// Handle write io operations
267    pub async fn handle<T>(&self, io: &mut T)
268    where
269        T: AsyncWrite,
270    {
271        let mut buf = WriteContextBuf {
272            io: self.0.clone(),
273            buf: None,
274        };
275
276        loop {
277            match self.ready().await {
278                Readiness::Ready => {
279                    // write io stream
280                    match select(io.write(&mut buf), self.when_stopped()).await {
281                        Either::Left(Ok(_)) => continue,
282                        Either::Left(Err(e)) => self.close(Some(e)),
283                        Either::Right(_) => return,
284                    }
285                }
286                Readiness::Shutdown => {
287                    log::trace!("{}: Write task is instructed to shutdown", self.tag());
288
289                    let fut = async {
290                        // write io stream
291                        io.write(&mut buf).await?;
292                        io.flush().await?;
293                        io.shutdown().await?;
294                        Ok(())
295                    };
296                    match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
297                        Either::Left(_) => self.close(None),
298                        Either::Right(res) => self.close(res.err()),
299                    }
300                }
301                Readiness::Terminate => {
302                    log::trace!("{}: Write task is instructed to terminate", self.tag());
303                    self.close(io.shutdown().await.err());
304                }
305            }
306            return;
307        }
308    }
309}
310
311impl WriteContextBuf {
312    pub fn set(&mut self, mut buf: BytesVec) {
313        if buf.is_empty() {
314            self.io.memory_pool().release_write_buf(buf);
315        } else if let Some(b) = self.buf.take() {
316            buf.extend_from_slice(&b);
317            self.io.memory_pool().release_write_buf(b);
318            self.buf = Some(buf);
319        } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
320            // write buffer is already set
321            self.buf = Some(b);
322        }
323
324        // if write buffer is smaller than high watermark value, turn off back-pressure
325        let inner = &self.io.0;
326        let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
327            + inner.buffer.write_destination_size();
328        let mut flags = inner.flags.get();
329
330        if len == 0 {
331            if flags.is_waiting_for_write() {
332                flags.waiting_for_write_is_done();
333                inner.dispatch_task.wake();
334            }
335            flags.insert(Flags::WR_PAUSED);
336            inner.flags.set(flags);
337        } else if flags.contains(Flags::BUF_W_BACKPRESSURE)
338            && len < inner.pool.get().write_params_high() << 1
339        {
340            flags.remove(Flags::BUF_W_BACKPRESSURE);
341            inner.flags.set(flags);
342            inner.dispatch_task.wake();
343        }
344    }
345
346    pub fn take(&mut self) -> Option<BytesVec> {
347        if let Some(buf) = self.buf.take() {
348            Some(buf)
349        } else {
350            self.io.0.buffer.get_write_destination()
351        }
352    }
353}
354
355/// Context for io read task
356pub struct IoContext(IoRef);
357
358impl fmt::Debug for IoContext {
359    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360        f.debug_struct("IoContext").field("io", &self.0).finish()
361    }
362}
363
364impl IoContext {
365    pub(crate) fn new(io: &IoRef) -> Self {
366        Self(io.clone())
367    }
368
369    #[doc(hidden)]
370    #[inline]
371    pub fn id(&self) -> usize {
372        self.0 .0.as_ref() as *const _ as usize
373    }
374
375    #[inline]
376    /// Io tag
377    pub fn tag(&self) -> &'static str {
378        self.0.tag()
379    }
380
381    #[doc(hidden)]
382    /// Io flags
383    pub fn flags(&self) -> crate::flags::Flags {
384        self.0.flags()
385    }
386
387    #[inline]
388    /// Check readiness for read operations
389    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
390        self.shutdown_filters();
391        self.0.filter().poll_read_ready(cx)
392    }
393
394    #[inline]
395    /// Check readiness for write operations
396    pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
397        self.0.filter().poll_write_ready(cx)
398    }
399
400    #[inline]
401    /// Stop io
402    pub fn stop(&self, e: Option<io::Error>) {
403        self.0 .0.io_stopped(e);
404    }
405
406    #[inline]
407    /// Initiate gracefully shutdown
408    pub fn init_shutdown(&self) {
409        self.0 .0.init_shutdown();
410    }
411
412    #[inline]
413    /// Check if Io stopped
414    pub fn is_stopped(&self) -> bool {
415        self.0.flags().is_stopped()
416    }
417
418    /// Wait when io get closed or preparing for close
419    pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
420        let st = &self.0 .0;
421
422        let flags = self.0.flags();
423        if !flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
424            st.write_task.register(cx.waker());
425            return Poll::Pending;
426        }
427
428        if flush && !flags.contains(Flags::IO_STOPPED) {
429            if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
430                return Poll::Ready(());
431            }
432            st.insert_flags(Flags::WR_TASK_WAIT);
433            st.write_task.register(cx.waker());
434            Poll::Pending
435        } else {
436            Poll::Ready(())
437        }
438    }
439
440    /// Get read buffer
441    pub fn get_read_buf(&self) -> BytesVec {
442        let inner = &self.0 .0;
443
444        if inner.flags.get().is_read_buf_ready() {
445            // read buffer is still not read by dispatcher
446            // we cannot touch it
447            inner.pool.get().get_read_buf()
448        } else {
449            inner
450                .buffer
451                .get_read_source()
452                .unwrap_or_else(|| inner.pool.get().get_read_buf())
453        }
454    }
455
456    /// Set read buffer
457    pub fn release_read_buf(
458        &self,
459        nbytes: usize,
460        buf: BytesVec,
461        result: Poll<io::Result<()>>,
462    ) -> IoTaskStatus {
463        let inner = &self.0 .0;
464        let orig_size = inner.buffer.read_destination_size();
465        let hw = self.0.memory_pool().read_params().unpack().0;
466
467        if let Some(mut first_buf) = inner.buffer.get_read_source() {
468            first_buf.extend_from_slice(&buf);
469            inner.buffer.set_read_source(&self.0, first_buf);
470        } else {
471            inner.buffer.set_read_source(&self.0, buf);
472        }
473
474        let mut full = false;
475
476        // handle buffer changes
477        let st_res = if nbytes > 0 {
478            match self
479                .0
480                .filter()
481                .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
482            {
483                Ok(status) => {
484                    let buffer_size = inner.buffer.read_destination_size();
485                    if buffer_size.saturating_sub(orig_size) > 0 {
486                        // dest buffer has new data, wake up dispatcher
487                        if buffer_size >= hw {
488                            log::trace!(
489                                "{}: Io read buffer is too large {}, enable read back-pressure",
490                                self.tag(),
491                                buffer_size
492                            );
493                            full = true;
494                            inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
495                        } else {
496                            inner.insert_flags(Flags::BUF_R_READY);
497                        }
498                        log::trace!(
499                            "{}: New {} bytes available, wakeup dispatcher",
500                            self.tag(),
501                            buffer_size
502                        );
503                        inner.dispatch_task.wake();
504                    } else {
505                        if buffer_size >= hw {
506                            // read task is paused because of read back-pressure
507                            // but there is no new data in top most read buffer
508                            // so we need to wake up read task to read more data
509                            // otherwise read task would sleep forever
510                            full = true;
511                            inner.read_task.wake();
512                        }
513                        if inner.flags.get().is_waiting_for_read() {
514                            // in case of "notify" we must wake up dispatch task
515                            // if we read any data from source
516                            inner.dispatch_task.wake();
517                        }
518                    }
519
520                    // while reading, filter wrote some data
521                    // in that case filters need to process write buffers
522                    // and potentialy wake write task
523                    if status.need_write {
524                        self.0
525                            .filter()
526                            .process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
527                    } else {
528                        Ok(())
529                    }
530                }
531                Err(err) => Err(err),
532            }
533        } else {
534            Ok(())
535        };
536
537        match result {
538            Poll::Ready(Ok(_)) => {
539                if let Err(e) = st_res {
540                    inner.io_stopped(Some(e));
541                    IoTaskStatus::Pause
542                } else if nbytes == 0 {
543                    inner.io_stopped(None);
544                    IoTaskStatus::Pause
545                } else {
546                    self.shutdown_filters();
547                    if full {
548                        IoTaskStatus::Pause
549                    } else {
550                        IoTaskStatus::Io
551                    }
552                }
553            }
554            Poll::Ready(Err(e)) => {
555                inner.io_stopped(Some(e));
556                IoTaskStatus::Pause
557            }
558            Poll::Pending => {
559                if let Err(e) = st_res {
560                    inner.io_stopped(Some(e));
561                    IoTaskStatus::Pause
562                } else {
563                    self.shutdown_filters();
564                    if full {
565                        IoTaskStatus::Pause
566                    } else {
567                        IoTaskStatus::Io
568                    }
569                }
570            }
571        }
572    }
573
574    /// Get write buffer
575    pub fn get_write_buf(&self) -> Option<BytesVec> {
576        self.0 .0.buffer.get_write_destination().and_then(|buf| {
577            if buf.is_empty() {
578                None
579            } else {
580                Some(buf)
581            }
582        })
583    }
584
585    /// Set write buffer
586    pub fn release_write_buf(
587        &self,
588        mut buf: BytesVec,
589        result: Poll<io::Result<usize>>,
590    ) -> IoTaskStatus {
591        let result = match result {
592            Poll::Ready(Ok(0)) => {
593                log::trace!("{}: Disconnected during flush", self.tag());
594                Err(io::Error::new(
595                    io::ErrorKind::WriteZero,
596                    "failed to write frame to transport",
597                ))
598            }
599            Poll::Ready(Ok(n)) => {
600                if n == buf.len() {
601                    buf.clear();
602                    Ok(0)
603                } else {
604                    buf.advance(n);
605                    Ok(buf.len())
606                }
607            }
608            Poll::Ready(Err(e)) => Err(e),
609            Poll::Pending => Ok(buf.len()),
610        };
611
612        let inner = &self.0 .0;
613
614        // set buffer back
615        let result = match result {
616            Ok(0) => {
617                self.0.memory_pool().release_write_buf(buf);
618                Ok(inner.buffer.write_destination_size())
619            }
620            Ok(_) => {
621                if let Some(b) = inner.buffer.get_write_destination() {
622                    buf.extend_from_slice(&b);
623                    self.0.memory_pool().release_write_buf(b);
624                }
625                let l = buf.len();
626                inner.buffer.set_write_destination(buf);
627                Ok(l)
628            }
629            Err(e) => Err(e),
630        };
631
632        match result {
633            Ok(0) => {
634                let mut flags = inner.flags.get();
635
636                // all data has been written
637                flags.insert(Flags::WR_PAUSED);
638
639                if flags.is_task_waiting_for_write() {
640                    flags.task_waiting_for_write_is_done();
641                    inner.write_task.wake();
642                }
643
644                if flags.is_waiting_for_write() {
645                    flags.waiting_for_write_is_done();
646                    inner.dispatch_task.wake();
647                }
648                inner.flags.set(flags);
649                IoTaskStatus::Pause
650            }
651            Ok(len) => {
652                // if write buffer is smaller than high watermark value, turn off back-pressure
653                if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
654                    && len < inner.pool.get().write_params_high() << 1
655                {
656                    inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
657                    inner.dispatch_task.wake();
658                }
659                if self.is_stopped() {
660                    IoTaskStatus::Pause
661                } else {
662                    IoTaskStatus::Io
663                }
664            }
665            Err(e) => {
666                inner.io_stopped(Some(e));
667                IoTaskStatus::Pause
668            }
669        }
670    }
671
672    fn shutdown_filters(&self) {
673        let io = &self.0;
674        let st = &self.0 .0;
675        let flags = st.flags.get();
676        if flags.contains(Flags::IO_STOPPING_FILTERS)
677            && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
678        {
679            match io.filter().shutdown(FilterCtx::new(io, &st.buffer)) {
680                Ok(Poll::Ready(())) => {
681                    st.write_task.wake();
682                    st.dispatch_task.wake();
683                    st.insert_flags(Flags::IO_STOPPING);
684                }
685                Ok(Poll::Pending) => {
686                    // check read buffer, if buffer is not consumed it is unlikely
687                    // that filter will properly complete shutdown
688                    let flags = st.flags.get();
689                    if flags.contains(Flags::RD_PAUSED)
690                        || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
691                    {
692                        st.write_task.wake();
693                        st.dispatch_task.wake();
694                        st.insert_flags(Flags::IO_STOPPING);
695                    }
696                }
697                Err(err) => {
698                    st.io_stopped(Some(err));
699                }
700            }
701            if let Err(err) = io
702                .filter()
703                .process_write_buf(FilterCtx::new(io, &st.buffer))
704            {
705                st.io_stopped(Some(err));
706            }
707        }
708    }
709}
710
711impl Clone for IoContext {
712    fn clone(&self) -> Self {
713        Self(self.0.clone())
714    }
715}