ntex_io/
tasks.rs

1use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BufMut, BytesVec};
4use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
5
6use crate::{AsyncRead, AsyncWrite, FilterCtx, Flags, IoRef, IoTaskStatus, Readiness};
7
8/// Context for io read task
9pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for ReadContext {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        f.debug_struct("ReadContext").field("io", &self.0).finish()
14    }
15}
16
17impl ReadContext {
18    pub(crate) fn new(io: &IoRef) -> Self {
19        Self(io.clone(), Cell::new(None))
20    }
21
22    #[doc(hidden)]
23    #[inline]
24    /// Io tag
25    pub fn context(&self) -> IoContext {
26        IoContext::new(&self.0)
27    }
28
29    #[inline]
30    /// Io tag
31    pub fn tag(&self) -> &'static str {
32        self.0.tag()
33    }
34
35    /// Wait when io get closed or preparing for close
36    async fn wait_for_close(&self) {
37        poll_fn(|cx| {
38            let flags = self.0.flags();
39
40            if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
41                Poll::Ready(())
42            } else {
43                self.0 .0.read_task.register(cx.waker());
44                if flags.contains(Flags::IO_STOPPING_FILTERS) {
45                    self.shutdown_filters(cx);
46                }
47                Poll::Pending
48            }
49        })
50        .await
51    }
52
53    /// Handle read io operations
54    pub async fn handle<T>(&self, io: &mut T)
55    where
56        T: AsyncRead,
57    {
58        let inner = &self.0 .0;
59
60        loop {
61            let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
62            if result == Readiness::Terminate {
63                log::trace!("{}: Read task is instructed to shutdown", self.tag());
64                break;
65            }
66
67            let mut buf = if inner.flags.get().is_read_buf_ready() {
68                // read buffer is still not read by dispatcher
69                // we cannot touch it
70                inner.pool.get().get_read_buf()
71            } else {
72                inner
73                    .buffer
74                    .get_read_source()
75                    .unwrap_or_else(|| inner.pool.get().get_read_buf())
76            };
77
78            // make sure we've got room
79            let (hw, lw) = self.0.memory_pool().read_params().unpack();
80            let remaining = buf.remaining_mut();
81            if remaining <= lw {
82                buf.reserve(hw - remaining);
83            }
84            let total = buf.len();
85
86            // call provided callback
87            let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
88                Either::Left(res) => res,
89                Either::Right(_) => {
90                    log::trace!("{}: Read io is closed, stop read task", self.tag());
91                    break;
92                }
93            };
94
95            // handle incoming data
96            let total2 = buf.len();
97            let nbytes = total2.saturating_sub(total);
98            let total = total2;
99
100            if let Some(mut first_buf) = inner.buffer.get_read_source() {
101                first_buf.extend_from_slice(&buf);
102                inner.buffer.set_read_source(&self.0, first_buf);
103            } else {
104                inner.buffer.set_read_source(&self.0, buf);
105            }
106
107            // handle buffer changes
108            if nbytes > 0 {
109                let filter = self.0.filter();
110                let res = match filter
111                    .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
112                {
113                    Ok(status) => {
114                        if status.nbytes > 0 {
115                            // check read back-pressure
116                            if hw < inner.buffer.read_destination_size() {
117                                log::trace!(
118                                "{}: Io read buffer is too large {}, enable read back-pressure",
119                                self.0.tag(),
120                                total
121                            );
122                                inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
123                            } else {
124                                inner.insert_flags(Flags::BUF_R_READY);
125                            }
126                            log::trace!(
127                                "{}: New {} bytes available, wakeup dispatcher",
128                                self.0.tag(),
129                                nbytes
130                            );
131                            // dest buffer has new data, wake up dispatcher
132                            inner.dispatch_task.wake();
133                        } else if inner.flags.get().is_waiting_for_read() {
134                            // in case of "notify" we must wake up dispatch task
135                            // if we read any data from source
136                            inner.dispatch_task.wake();
137                        }
138
139                        // while reading, filter wrote some data
140                        // in that case filters need to process write buffers
141                        // and potentialy wake write task
142                        if status.need_write {
143                            filter.process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
144                        } else {
145                            Ok(())
146                        }
147                    }
148                    Err(err) => Err(err),
149                };
150
151                if let Err(err) = res {
152                    inner.dispatch_task.wake();
153                    inner.io_stopped(Some(err));
154                    inner.insert_flags(Flags::BUF_R_READY);
155                }
156            }
157
158            match result {
159                Ok(0) => {
160                    log::trace!("{}: Tcp stream is disconnected", self.tag());
161                    inner.io_stopped(None);
162                    break;
163                }
164                Ok(_) => {
165                    if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
166                        lazy(|cx| self.shutdown_filters(cx)).await;
167                    }
168                }
169                Err(err) => {
170                    log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
171                    inner.io_stopped(Some(err));
172                    break;
173                }
174            }
175        }
176    }
177
178    fn shutdown_filters(&self, cx: &mut Context<'_>) {
179        let st = &self.0 .0;
180        let filter = self.0.filter();
181
182        match filter.shutdown(FilterCtx::new(&self.0, &st.buffer)) {
183            Ok(Poll::Ready(())) => {
184                st.dispatch_task.wake();
185                st.insert_flags(Flags::IO_STOPPING);
186            }
187            Ok(Poll::Pending) => {
188                let flags = st.flags.get();
189
190                // check read buffer, if buffer is not consumed it is unlikely
191                // that filter will properly complete shutdown
192                if flags.contains(Flags::RD_PAUSED)
193                    || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
194                {
195                    st.dispatch_task.wake();
196                    st.insert_flags(Flags::IO_STOPPING);
197                } else {
198                    // filter shutdown timeout
199                    let timeout = self
200                        .1
201                        .take()
202                        .unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
203                    if timeout.poll_elapsed(cx).is_ready() {
204                        st.dispatch_task.wake();
205                        st.insert_flags(Flags::IO_STOPPING);
206                    } else {
207                        self.1.set(Some(timeout));
208                    }
209                }
210            }
211            Err(err) => {
212                st.io_stopped(Some(err));
213            }
214        }
215        if let Err(err) = filter.process_write_buf(FilterCtx::new(&self.0, &st.buffer)) {
216            st.io_stopped(Some(err));
217        }
218    }
219}
220
221#[derive(Debug)]
222/// Context for io write task
223pub struct WriteContext(IoRef);
224
225#[derive(Debug)]
226/// Context buf for io write task
227pub struct WriteContextBuf {
228    io: IoRef,
229    buf: Option<BytesVec>,
230}
231
232impl WriteContext {
233    pub(crate) fn new(io: &IoRef) -> Self {
234        Self(io.clone())
235    }
236
237    #[inline]
238    /// Io tag
239    pub fn tag(&self) -> &'static str {
240        self.0.tag()
241    }
242
243    /// Check readiness for write operations
244    async fn ready(&self) -> Readiness {
245        poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
246    }
247
248    /// Indicate that write io task is stopped
249    fn close(&self, err: Option<io::Error>) {
250        self.0 .0.io_stopped(err);
251    }
252
253    /// Check if io is closed
254    async fn when_stopped(&self) {
255        poll_fn(|cx| {
256            if self.0.flags().is_stopped() {
257                Poll::Ready(())
258            } else {
259                self.0 .0.write_task.register(cx.waker());
260                Poll::Pending
261            }
262        })
263        .await
264    }
265
266    /// Handle write io operations
267    pub async fn handle<T>(&self, io: &mut T)
268    where
269        T: AsyncWrite,
270    {
271        let mut buf = WriteContextBuf {
272            io: self.0.clone(),
273            buf: None,
274        };
275
276        loop {
277            match self.ready().await {
278                Readiness::Ready => {
279                    // write io stream
280                    match select(io.write(&mut buf), self.when_stopped()).await {
281                        Either::Left(Ok(_)) => continue,
282                        Either::Left(Err(e)) => self.close(Some(e)),
283                        Either::Right(_) => return,
284                    }
285                }
286                Readiness::Shutdown => {
287                    log::trace!("{}: Write task is instructed to shutdown", self.tag());
288
289                    let fut = async {
290                        // write io stream
291                        io.write(&mut buf).await?;
292                        io.flush().await?;
293                        io.shutdown().await?;
294                        Ok(())
295                    };
296                    match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
297                        Either::Left(_) => self.close(None),
298                        Either::Right(res) => self.close(res.err()),
299                    }
300                }
301                Readiness::Terminate => {
302                    log::trace!("{}: Write task is instructed to terminate", self.tag());
303                    self.close(io.shutdown().await.err());
304                }
305            }
306            return;
307        }
308    }
309}
310
311impl WriteContextBuf {
312    pub fn set(&mut self, mut buf: BytesVec) {
313        if buf.is_empty() {
314            self.io.memory_pool().release_write_buf(buf);
315        } else if let Some(b) = self.buf.take() {
316            buf.extend_from_slice(&b);
317            self.io.memory_pool().release_write_buf(b);
318            self.buf = Some(buf);
319        } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
320            // write buffer is already set
321            self.buf = Some(b);
322        }
323
324        // if write buffer is smaller than high watermark value, turn off back-pressure
325        let inner = &self.io.0;
326        let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
327            + inner.buffer.write_destination_size();
328        let mut flags = inner.flags.get();
329
330        if len == 0 {
331            if flags.is_waiting_for_write() {
332                flags.waiting_for_write_is_done();
333                inner.dispatch_task.wake();
334            }
335            flags.insert(Flags::WR_PAUSED);
336            inner.flags.set(flags);
337        } else if flags.contains(Flags::BUF_W_BACKPRESSURE)
338            && len < inner.pool.get().write_params_high() << 1
339        {
340            flags.remove(Flags::BUF_W_BACKPRESSURE);
341            inner.flags.set(flags);
342            inner.dispatch_task.wake();
343        }
344    }
345
346    pub fn take(&mut self) -> Option<BytesVec> {
347        if let Some(buf) = self.buf.take() {
348            Some(buf)
349        } else {
350            self.io.0.buffer.get_write_destination()
351        }
352    }
353}
354
355/// Context for io read task
356pub struct IoContext(IoRef);
357
358impl fmt::Debug for IoContext {
359    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360        f.debug_struct("IoContext").field("io", &self.0).finish()
361    }
362}
363
364impl IoContext {
365    pub(crate) fn new(io: &IoRef) -> Self {
366        Self(io.clone())
367    }
368
369    #[doc(hidden)]
370    #[inline]
371    pub fn id(&self) -> usize {
372        self.0 .0.as_ref() as *const _ as usize
373    }
374
375    #[inline]
376    /// Io tag
377    pub fn tag(&self) -> &'static str {
378        self.0.tag()
379    }
380
381    #[doc(hidden)]
382    /// Io flags
383    pub fn flags(&self) -> crate::flags::Flags {
384        self.0.flags()
385    }
386
387    #[inline]
388    /// Check readiness for read operations
389    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
390        self.shutdown_filters();
391        self.0.filter().poll_read_ready(cx)
392    }
393
394    #[inline]
395    /// Check readiness for write operations
396    pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
397        self.0.filter().poll_write_ready(cx)
398    }
399
400    #[inline]
401    /// Stop io
402    pub fn stop(&self, e: Option<io::Error>) {
403        self.0 .0.io_stopped(e);
404    }
405
406    #[deprecated(since = "2.14.1")]
407    #[doc(hidden)]
408    #[inline]
409    /// Initiate gracefully shutdown
410    pub fn init_shutdown(&self) {
411        self.0 .0.init_shutdown();
412    }
413
414    #[inline]
415    /// Check if Io stopped
416    pub fn is_stopped(&self) -> bool {
417        self.0.flags().is_stopped()
418    }
419
420    /// Wait when io get closed or preparing for close
421    pub fn shutdown(&self, flush: bool, cx: &mut Context<'_>) -> Poll<()> {
422        let st = &self.0 .0;
423
424        let flags = self.0.flags();
425        if !flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
426            st.write_task.register(cx.waker());
427            return Poll::Pending;
428        }
429
430        if flush && !flags.contains(Flags::IO_STOPPED) {
431            if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
432                return Poll::Ready(());
433            }
434            st.insert_flags(Flags::WR_TASK_WAIT);
435            st.write_task.register(cx.waker());
436            Poll::Pending
437        } else {
438            Poll::Ready(())
439        }
440    }
441
442    /// Get read buffer
443    pub fn get_read_buf(&self) -> BytesVec {
444        let inner = &self.0 .0;
445
446        if inner.flags.get().is_read_buf_ready() {
447            // read buffer is still not read by dispatcher
448            // we cannot touch it
449            inner.pool.get().get_read_buf()
450        } else {
451            inner
452                .buffer
453                .get_read_source()
454                .unwrap_or_else(|| inner.pool.get().get_read_buf())
455        }
456    }
457
458    /// Set read buffer
459    pub fn release_read_buf(
460        &self,
461        nbytes: usize,
462        buf: BytesVec,
463        result: Poll<io::Result<()>>,
464    ) -> IoTaskStatus {
465        let inner = &self.0 .0;
466        let orig_size = inner.buffer.read_destination_size();
467        let hw = self.0.memory_pool().read_params().unpack().0;
468
469        if let Some(mut first_buf) = inner.buffer.get_read_source() {
470            first_buf.extend_from_slice(&buf);
471            inner.buffer.set_read_source(&self.0, first_buf);
472        } else {
473            inner.buffer.set_read_source(&self.0, buf);
474        }
475
476        let mut full = false;
477
478        // handle buffer changes
479        let st_res = if nbytes > 0 {
480            match self
481                .0
482                .filter()
483                .process_read_buf(FilterCtx::new(&self.0, &inner.buffer), nbytes)
484            {
485                Ok(status) => {
486                    let buffer_size = inner.buffer.read_destination_size();
487                    if buffer_size.saturating_sub(orig_size) > 0 {
488                        // dest buffer has new data, wake up dispatcher
489                        if buffer_size >= hw {
490                            log::trace!(
491                                "{}: Io read buffer is too large {}, enable read back-pressure",
492                                self.tag(),
493                                buffer_size
494                            );
495                            full = true;
496                            inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
497                        } else {
498                            inner.insert_flags(Flags::BUF_R_READY);
499                        }
500                        log::trace!(
501                            "{}: New {} bytes available, wakeup dispatcher",
502                            self.tag(),
503                            buffer_size
504                        );
505                        inner.dispatch_task.wake();
506                    } else {
507                        if buffer_size >= hw {
508                            // read task is paused because of read back-pressure
509                            // but there is no new data in top most read buffer
510                            // so we need to wake up read task to read more data
511                            // otherwise read task would sleep forever
512                            full = true;
513                            inner.read_task.wake();
514                        }
515                        if inner.flags.get().is_waiting_for_read() {
516                            // in case of "notify" we must wake up dispatch task
517                            // if we read any data from source
518                            inner.dispatch_task.wake();
519                        }
520                    }
521
522                    // while reading, filter wrote some data
523                    // in that case filters need to process write buffers
524                    // and potentialy wake write task
525                    if status.need_write {
526                        self.0
527                            .filter()
528                            .process_write_buf(FilterCtx::new(&self.0, &inner.buffer))
529                    } else {
530                        Ok(())
531                    }
532                }
533                Err(err) => Err(err),
534            }
535        } else {
536            Ok(())
537        };
538
539        match result {
540            Poll::Ready(Ok(_)) => {
541                if let Err(e) = st_res {
542                    inner.io_stopped(Some(e));
543                    IoTaskStatus::Pause
544                } else if nbytes == 0 {
545                    inner.io_stopped(None);
546                    IoTaskStatus::Pause
547                } else {
548                    self.shutdown_filters();
549                    if full {
550                        IoTaskStatus::Pause
551                    } else {
552                        IoTaskStatus::Io
553                    }
554                }
555            }
556            Poll::Ready(Err(e)) => {
557                inner.io_stopped(Some(e));
558                IoTaskStatus::Pause
559            }
560            Poll::Pending => {
561                if let Err(e) = st_res {
562                    inner.io_stopped(Some(e));
563                    IoTaskStatus::Pause
564                } else {
565                    self.shutdown_filters();
566                    if full {
567                        IoTaskStatus::Pause
568                    } else {
569                        IoTaskStatus::Io
570                    }
571                }
572            }
573        }
574    }
575
576    /// Get write buffer
577    pub fn get_write_buf(&self) -> Option<BytesVec> {
578        self.0 .0.buffer.get_write_destination().and_then(|buf| {
579            if buf.is_empty() {
580                None
581            } else {
582                Some(buf)
583            }
584        })
585    }
586
587    /// Set write buffer
588    pub fn release_write_buf(
589        &self,
590        mut buf: BytesVec,
591        result: Poll<io::Result<usize>>,
592    ) -> IoTaskStatus {
593        let result = match result {
594            Poll::Ready(Ok(0)) => {
595                log::trace!("{}: Disconnected during flush", self.tag());
596                Err(io::Error::new(
597                    io::ErrorKind::WriteZero,
598                    "failed to write frame to transport",
599                ))
600            }
601            Poll::Ready(Ok(n)) => {
602                if n == buf.len() {
603                    buf.clear();
604                    Ok(0)
605                } else {
606                    buf.advance(n);
607                    Ok(buf.len())
608                }
609            }
610            Poll::Ready(Err(e)) => Err(e),
611            Poll::Pending => Ok(buf.len()),
612        };
613
614        let inner = &self.0 .0;
615
616        // set buffer back
617        let result = match result {
618            Ok(0) => {
619                self.0.memory_pool().release_write_buf(buf);
620                Ok(inner.buffer.write_destination_size())
621            }
622            Ok(_) => {
623                if let Some(b) = inner.buffer.get_write_destination() {
624                    buf.extend_from_slice(&b);
625                    self.0.memory_pool().release_write_buf(b);
626                }
627                let l = buf.len();
628                inner.buffer.set_write_destination(buf);
629                Ok(l)
630            }
631            Err(e) => Err(e),
632        };
633
634        match result {
635            Ok(0) => {
636                let mut flags = inner.flags.get();
637
638                // all data has been written
639                flags.insert(Flags::WR_PAUSED);
640
641                if flags.is_task_waiting_for_write() {
642                    flags.task_waiting_for_write_is_done();
643                    inner.write_task.wake();
644                }
645
646                if flags.is_waiting_for_write() {
647                    flags.waiting_for_write_is_done();
648                    inner.dispatch_task.wake();
649                }
650                inner.flags.set(flags);
651                IoTaskStatus::Pause
652            }
653            Ok(len) => {
654                // if write buffer is smaller than high watermark value, turn off back-pressure
655                if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
656                    && len < inner.pool.get().write_params_high() << 1
657                {
658                    inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
659                    inner.dispatch_task.wake();
660                }
661                if self.is_stopped() {
662                    IoTaskStatus::Pause
663                } else {
664                    IoTaskStatus::Io
665                }
666            }
667            Err(e) => {
668                inner.io_stopped(Some(e));
669                IoTaskStatus::Pause
670            }
671        }
672    }
673
674    fn shutdown_filters(&self) {
675        let io = &self.0;
676        let st = &self.0 .0;
677        let flags = st.flags.get();
678        if flags.contains(Flags::IO_STOPPING_FILTERS)
679            && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
680        {
681            match io.filter().shutdown(FilterCtx::new(io, &st.buffer)) {
682                Ok(Poll::Ready(())) => {
683                    st.write_task.wake();
684                    st.dispatch_task.wake();
685                    st.insert_flags(Flags::IO_STOPPING);
686                }
687                Ok(Poll::Pending) => {
688                    // check read buffer, if buffer is not consumed it is unlikely
689                    // that filter will properly complete shutdown
690                    let flags = st.flags.get();
691                    if flags.contains(Flags::RD_PAUSED)
692                        || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
693                    {
694                        st.write_task.wake();
695                        st.dispatch_task.wake();
696                        st.insert_flags(Flags::IO_STOPPING);
697                    }
698                }
699                Err(err) => {
700                    st.io_stopped(Some(err));
701                }
702            }
703            if let Err(err) = io
704                .filter()
705                .process_write_buf(FilterCtx::new(io, &st.buffer))
706            {
707                st.io_stopped(Some(err));
708            }
709        }
710    }
711}
712
713impl Clone for IoContext {
714    fn clone(&self) -> Self {
715        Self(self.0.clone())
716    }
717}