ntex_io/
tasks.rs

1use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
2
3use ntex_bytes::{Buf, BufMut, BytesVec};
4use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
5
6use crate::{AsyncRead, AsyncWrite, Flags, IoRef, IoTaskStatus, Readiness};
7
8/// Context for io read task
9pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
10
11impl fmt::Debug for ReadContext {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        f.debug_struct("ReadContext").field("io", &self.0).finish()
14    }
15}
16
17impl ReadContext {
18    pub(crate) fn new(io: &IoRef) -> Self {
19        Self(io.clone(), Cell::new(None))
20    }
21
22    #[doc(hidden)]
23    #[inline]
24    /// Io tag
25    pub fn context(&self) -> IoContext {
26        IoContext::new(&self.0)
27    }
28
29    #[inline]
30    /// Io tag
31    pub fn tag(&self) -> &'static str {
32        self.0.tag()
33    }
34
35    /// Wait when io get closed or preparing for close
36    async fn wait_for_close(&self) {
37        poll_fn(|cx| {
38            let flags = self.0.flags();
39
40            if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
41                Poll::Ready(())
42            } else {
43                self.0 .0.read_task.register(cx.waker());
44                if flags.contains(Flags::IO_STOPPING_FILTERS) {
45                    self.shutdown_filters(cx);
46                }
47                Poll::Pending
48            }
49        })
50        .await
51    }
52
53    /// Handle read io operations
54    pub async fn handle<T>(&self, io: &mut T)
55    where
56        T: AsyncRead,
57    {
58        let inner = &self.0 .0;
59
60        loop {
61            let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
62            if result == Readiness::Terminate {
63                log::trace!("{}: Read task is instructed to shutdown", self.tag());
64                break;
65            }
66
67            let mut buf = if inner.flags.get().is_read_buf_ready() {
68                // read buffer is still not read by dispatcher
69                // we cannot touch it
70                inner.pool.get().get_read_buf()
71            } else {
72                inner
73                    .buffer
74                    .get_read_source()
75                    .unwrap_or_else(|| inner.pool.get().get_read_buf())
76            };
77
78            // make sure we've got room
79            let (hw, lw) = self.0.memory_pool().read_params().unpack();
80            let remaining = buf.remaining_mut();
81            if remaining <= lw {
82                buf.reserve(hw - remaining);
83            }
84            let total = buf.len();
85
86            // call provided callback
87            let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
88                Either::Left(res) => res,
89                Either::Right(_) => {
90                    log::trace!("{}: Read io is closed, stop read task", self.tag());
91                    break;
92                }
93            };
94
95            // handle incoming data
96            let total2 = buf.len();
97            let nbytes = total2.saturating_sub(total);
98            let total = total2;
99
100            if let Some(mut first_buf) = inner.buffer.get_read_source() {
101                first_buf.extend_from_slice(&buf);
102                inner.buffer.set_read_source(&self.0, first_buf);
103            } else {
104                inner.buffer.set_read_source(&self.0, buf);
105            }
106
107            // handle buffer changes
108            if nbytes > 0 {
109                let filter = self.0.filter();
110                let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) {
111                    Ok(status) => {
112                        if status.nbytes > 0 {
113                            // check read back-pressure
114                            if hw < inner.buffer.read_destination_size() {
115                                log::trace!(
116                                "{}: Io read buffer is too large {}, enable read back-pressure",
117                                self.0.tag(),
118                                total
119                            );
120                                inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
121                            } else {
122                                inner.insert_flags(Flags::BUF_R_READY);
123                            }
124                            log::trace!(
125                                "{}: New {} bytes available, wakeup dispatcher",
126                                self.0.tag(),
127                                nbytes
128                            );
129                            // dest buffer has new data, wake up dispatcher
130                            inner.dispatch_task.wake();
131                        } else if inner.flags.get().is_waiting_for_read() {
132                            // in case of "notify" we must wake up dispatch task
133                            // if we read any data from source
134                            inner.dispatch_task.wake();
135                        }
136
137                        // while reading, filter wrote some data
138                        // in that case filters need to process write buffers
139                        // and potentialy wake write task
140                        if status.need_write {
141                            filter.process_write_buf(&self.0, &inner.buffer, 0)
142                        } else {
143                            Ok(())
144                        }
145                    }
146                    Err(err) => Err(err),
147                };
148
149                if let Err(err) = res {
150                    inner.dispatch_task.wake();
151                    inner.io_stopped(Some(err));
152                    inner.insert_flags(Flags::BUF_R_READY);
153                }
154            }
155
156            match result {
157                Ok(0) => {
158                    log::trace!("{}: Tcp stream is disconnected", self.tag());
159                    inner.io_stopped(None);
160                    break;
161                }
162                Ok(_) => {
163                    if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
164                        lazy(|cx| self.shutdown_filters(cx)).await;
165                    }
166                }
167                Err(err) => {
168                    log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
169                    inner.io_stopped(Some(err));
170                    break;
171                }
172            }
173        }
174    }
175
176    fn shutdown_filters(&self, cx: &mut Context<'_>) {
177        let st = &self.0 .0;
178        let filter = self.0.filter();
179
180        match filter.shutdown(&self.0, &st.buffer, 0) {
181            Ok(Poll::Ready(())) => {
182                st.dispatch_task.wake();
183                st.insert_flags(Flags::IO_STOPPING);
184            }
185            Ok(Poll::Pending) => {
186                let flags = st.flags.get();
187
188                // check read buffer, if buffer is not consumed it is unlikely
189                // that filter will properly complete shutdown
190                if flags.contains(Flags::RD_PAUSED)
191                    || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
192                {
193                    st.dispatch_task.wake();
194                    st.insert_flags(Flags::IO_STOPPING);
195                } else {
196                    // filter shutdown timeout
197                    let timeout = self
198                        .1
199                        .take()
200                        .unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
201                    if timeout.poll_elapsed(cx).is_ready() {
202                        st.dispatch_task.wake();
203                        st.insert_flags(Flags::IO_STOPPING);
204                    } else {
205                        self.1.set(Some(timeout));
206                    }
207                }
208            }
209            Err(err) => {
210                st.io_stopped(Some(err));
211            }
212        }
213        if let Err(err) = filter.process_write_buf(&self.0, &st.buffer, 0) {
214            st.io_stopped(Some(err));
215        }
216    }
217}
218
219#[derive(Debug)]
220/// Context for io write task
221pub struct WriteContext(IoRef);
222
223#[derive(Debug)]
224/// Context buf for io write task
225pub struct WriteContextBuf {
226    io: IoRef,
227    buf: Option<BytesVec>,
228}
229
230impl WriteContext {
231    pub(crate) fn new(io: &IoRef) -> Self {
232        Self(io.clone())
233    }
234
235    #[inline]
236    /// Io tag
237    pub fn tag(&self) -> &'static str {
238        self.0.tag()
239    }
240
241    /// Check readiness for write operations
242    async fn ready(&self) -> Readiness {
243        poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
244    }
245
246    /// Indicate that write io task is stopped
247    fn close(&self, err: Option<io::Error>) {
248        self.0 .0.io_stopped(err);
249    }
250
251    /// Check if io is closed
252    async fn when_stopped(&self) {
253        poll_fn(|cx| {
254            if self.0.flags().is_stopped() {
255                Poll::Ready(())
256            } else {
257                self.0 .0.write_task.register(cx.waker());
258                Poll::Pending
259            }
260        })
261        .await
262    }
263
264    /// Handle write io operations
265    pub async fn handle<T>(&self, io: &mut T)
266    where
267        T: AsyncWrite,
268    {
269        let mut buf = WriteContextBuf {
270            io: self.0.clone(),
271            buf: None,
272        };
273
274        loop {
275            match self.ready().await {
276                Readiness::Ready => {
277                    // write io stream
278                    match select(io.write(&mut buf), self.when_stopped()).await {
279                        Either::Left(Ok(_)) => continue,
280                        Either::Left(Err(e)) => self.close(Some(e)),
281                        Either::Right(_) => return,
282                    }
283                }
284                Readiness::Shutdown => {
285                    log::trace!("{}: Write task is instructed to shutdown", self.tag());
286
287                    let fut = async {
288                        // write io stream
289                        io.write(&mut buf).await?;
290                        io.flush().await?;
291                        io.shutdown().await?;
292                        Ok(())
293                    };
294                    match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
295                        Either::Left(_) => self.close(None),
296                        Either::Right(res) => self.close(res.err()),
297                    }
298                }
299                Readiness::Terminate => {
300                    log::trace!("{}: Write task is instructed to terminate", self.tag());
301                    self.close(io.shutdown().await.err());
302                }
303            }
304            return;
305        }
306    }
307}
308
309impl WriteContextBuf {
310    pub fn set(&mut self, mut buf: BytesVec) {
311        if buf.is_empty() {
312            self.io.memory_pool().release_write_buf(buf);
313        } else if let Some(b) = self.buf.take() {
314            buf.extend_from_slice(&b);
315            self.io.memory_pool().release_write_buf(b);
316            self.buf = Some(buf);
317        } else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
318            // write buffer is already set
319            self.buf = Some(b);
320        }
321
322        // if write buffer is smaller than high watermark value, turn off back-pressure
323        let inner = &self.io.0;
324        let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
325            + inner.buffer.write_destination_size();
326        let mut flags = inner.flags.get();
327
328        if len == 0 {
329            if flags.is_waiting_for_write() {
330                flags.waiting_for_write_is_done();
331                inner.dispatch_task.wake();
332            }
333            flags.insert(Flags::WR_PAUSED);
334            inner.flags.set(flags);
335        } else if flags.contains(Flags::BUF_W_BACKPRESSURE)
336            && len < inner.pool.get().write_params_high() << 1
337        {
338            flags.remove(Flags::BUF_W_BACKPRESSURE);
339            inner.flags.set(flags);
340            inner.dispatch_task.wake();
341        }
342    }
343
344    pub fn take(&mut self) -> Option<BytesVec> {
345        if let Some(buf) = self.buf.take() {
346            Some(buf)
347        } else {
348            self.io.0.buffer.get_write_destination()
349        }
350    }
351}
352
353/// Context for io read task
354pub struct IoContext(IoRef);
355
356impl fmt::Debug for IoContext {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        f.debug_struct("IoContext").field("io", &self.0).finish()
359    }
360}
361
362impl IoContext {
363    pub(crate) fn new(io: &IoRef) -> Self {
364        Self(io.clone())
365    }
366
367    #[inline]
368    /// Io tag
369    pub fn tag(&self) -> &'static str {
370        self.0.tag()
371    }
372
373    #[doc(hidden)]
374    /// Io flags
375    pub fn flags(&self) -> crate::flags::Flags {
376        self.0.flags()
377    }
378
379    #[inline]
380    /// Check readiness for read operations
381    pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
382        self.shutdown_filters();
383        self.0.filter().poll_read_ready(cx)
384    }
385
386    #[inline]
387    /// Check readiness for write operations
388    pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<Readiness> {
389        self.0.filter().poll_write_ready(cx)
390    }
391
392    #[inline]
393    /// Get io error
394    pub fn stopped(&self, e: Option<io::Error>) {
395        self.0 .0.io_stopped(e);
396    }
397
398    #[inline]
399    /// Check if Io stopped
400    pub fn is_stopped(&self) -> bool {
401        self.0.flags().is_stopped()
402    }
403
404    /// Wait when io get closed or preparing for close
405    pub async fn shutdown(&self, flush_buf: bool) {
406        let st = &self.0 .0;
407        let mut timeout = None;
408
409        poll_fn(|cx| {
410            let flags = self.0.flags();
411            if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
412                Poll::Ready(())
413            } else {
414                st.write_task.register(cx.waker());
415                if flags.contains(Flags::IO_STOPPING_FILTERS) {
416                    if timeout.is_none() {
417                        timeout = Some(sleep(st.disconnect_timeout.get()));
418                    }
419                    if timeout.as_ref().unwrap().poll_elapsed(cx).is_ready() {
420                        st.dispatch_task.wake();
421                        st.insert_flags(Flags::IO_STOPPING);
422                        return Poll::Ready(());
423                    }
424                }
425                Poll::Pending
426            }
427        })
428        .await;
429
430        if flush_buf && !st.flags.get().contains(Flags::WR_PAUSED) {
431            st.insert_flags(Flags::WR_TASK_WAIT);
432
433            poll_fn(|cx| {
434                let flags = st.flags.get();
435                if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) {
436                    Poll::Ready(())
437                } else {
438                    st.write_task.register(cx.waker());
439                    if timeout.is_none() {
440                        timeout = Some(sleep(st.disconnect_timeout.get()));
441                    }
442                    if timeout.as_ref().unwrap().poll_elapsed(cx).is_ready() {
443                        Poll::Ready(())
444                    } else {
445                        Poll::Pending
446                    }
447                }
448            })
449            .await;
450        }
451    }
452
453    /// Get read buffer
454    pub fn get_read_buf(&self) -> (BytesVec, usize, usize) {
455        let inner = &self.0 .0;
456
457        let buf = if inner.flags.get().is_read_buf_ready() {
458            // read buffer is still not read by dispatcher
459            // we cannot touch it
460            inner.pool.get().get_read_buf()
461        } else {
462            inner
463                .buffer
464                .get_read_source()
465                .unwrap_or_else(|| inner.pool.get().get_read_buf())
466        };
467
468        // make sure we've got room
469        let (hw, lw) = self.0.memory_pool().read_params().unpack();
470        (buf, hw, lw)
471    }
472
473    /// Set read buffer
474    pub fn release_read_buf(
475        &self,
476        nbytes: usize,
477        buf: BytesVec,
478        result: Poll<io::Result<()>>,
479    ) -> IoTaskStatus {
480        let inner = &self.0 .0;
481        let orig_size = inner.buffer.read_destination_size();
482        let hw = self.0.memory_pool().read_params().unpack().0;
483
484        if let Some(mut first_buf) = inner.buffer.get_read_source() {
485            first_buf.extend_from_slice(&buf);
486            inner.buffer.set_read_source(&self.0, first_buf);
487        } else {
488            inner.buffer.set_read_source(&self.0, buf);
489        }
490
491        // handle buffer changes
492        let st_res = if nbytes > 0 {
493            match self
494                .0
495                .filter()
496                .process_read_buf(&self.0, &inner.buffer, 0, nbytes)
497            {
498                Ok(status) => {
499                    let buffer_size = inner.buffer.read_destination_size();
500                    if buffer_size.saturating_sub(orig_size) > 0 {
501                        // dest buffer has new data, wake up dispatcher
502                        if buffer_size >= hw {
503                            log::trace!(
504                                "{}: Io read buffer is too large {}, enable read back-pressure",
505                                self.tag(),
506                                buffer_size
507                            );
508                            inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
509                        } else {
510                            inner.insert_flags(Flags::BUF_R_READY);
511                        }
512                        log::trace!(
513                            "{}: New {} bytes available, wakeup dispatcher",
514                            self.tag(),
515                            buffer_size
516                        );
517                        inner.dispatch_task.wake();
518                    } else {
519                        if buffer_size >= hw {
520                            // read task is paused because of read back-pressure
521                            // but there is no new data in top most read buffer
522                            // so we need to wake up read task to read more data
523                            // otherwise read task would sleep forever
524                            inner.read_task.wake();
525                        }
526                        if inner.flags.get().is_waiting_for_read() {
527                            // in case of "notify" we must wake up dispatch task
528                            // if we read any data from source
529                            inner.dispatch_task.wake();
530                        }
531                    }
532
533                    // while reading, filter wrote some data
534                    // in that case filters need to process write buffers
535                    // and potentialy wake write task
536                    if status.need_write {
537                        self.0.filter().process_write_buf(&self.0, &inner.buffer, 0)
538                    } else {
539                        Ok(())
540                    }
541                }
542                Err(err) => {
543                    inner.insert_flags(Flags::BUF_R_READY);
544                    Err(err)
545                }
546            }
547        } else {
548            Ok(())
549        };
550
551        match result {
552            Poll::Ready(Ok(_)) => {
553                if let Err(e) = st_res {
554                    inner.io_stopped(Some(e));
555                    IoTaskStatus::Pause
556                } else if nbytes == 0 {
557                    inner.io_stopped(None);
558                    IoTaskStatus::Pause
559                } else {
560                    IoTaskStatus::Io
561                }
562            }
563            Poll::Ready(Err(e)) => {
564                inner.io_stopped(Some(e));
565                IoTaskStatus::Pause
566            }
567            Poll::Pending => {
568                if let Err(e) = st_res {
569                    inner.io_stopped(Some(e));
570                    IoTaskStatus::Pause
571                } else {
572                    self.shutdown_filters();
573                    IoTaskStatus::Io
574                }
575            }
576        }
577    }
578
579    /// Get write buffer
580    pub fn get_write_buf(&self) -> Option<BytesVec> {
581        self.0 .0.buffer.get_write_destination().and_then(|buf| {
582            if buf.is_empty() {
583                None
584            } else {
585                Some(buf)
586            }
587        })
588    }
589
590    /// Set write buffer
591    pub fn release_write_buf(
592        &self,
593        mut buf: BytesVec,
594        result: Poll<io::Result<usize>>,
595    ) -> IoTaskStatus {
596        let result = match result {
597            Poll::Ready(Ok(0)) => {
598                log::trace!("{}: Disconnected during flush", self.tag());
599                Err(io::Error::new(
600                    io::ErrorKind::WriteZero,
601                    "failed to write frame to transport",
602                ))
603            }
604            Poll::Ready(Ok(n)) => {
605                if n == buf.len() {
606                    buf.clear();
607                    Ok(0)
608                } else {
609                    buf.advance(n);
610                    Ok(buf.len())
611                }
612            }
613            Poll::Ready(Err(e)) => Err(e),
614            Poll::Pending => Ok(buf.len()),
615        };
616
617        let inner = &self.0 .0;
618
619        // set buffer back
620        let result = match result {
621            Ok(0) => {
622                self.0.memory_pool().release_write_buf(buf);
623                Ok(inner.buffer.write_destination_size())
624            }
625            Ok(_) => {
626                if let Some(b) = inner.buffer.get_write_destination() {
627                    buf.extend_from_slice(&b);
628                    self.0.memory_pool().release_write_buf(b);
629                }
630                let l = buf.len();
631                inner.buffer.set_write_destination(buf);
632                Ok(l)
633            }
634            Err(e) => Err(e),
635        };
636
637        match result {
638            Ok(0) => {
639                let mut flags = inner.flags.get();
640
641                // all data has been written
642                flags.insert(Flags::WR_PAUSED);
643
644                if flags.is_task_waiting_for_write() {
645                    flags.task_waiting_for_write_is_done();
646                    inner.write_task.wake();
647                }
648
649                if flags.is_waiting_for_write() {
650                    flags.waiting_for_write_is_done();
651                    inner.dispatch_task.wake();
652                }
653                inner.flags.set(flags);
654                IoTaskStatus::Pause
655            }
656            Ok(len) => {
657                // if write buffer is smaller than high watermark value, turn off back-pressure
658                if inner.flags.get().contains(Flags::BUF_W_BACKPRESSURE)
659                    && len < inner.pool.get().write_params_high() << 1
660                {
661                    inner.remove_flags(Flags::BUF_W_BACKPRESSURE);
662                    inner.dispatch_task.wake();
663                }
664                IoTaskStatus::Io
665            }
666            Err(e) => {
667                inner.io_stopped(Some(e));
668                IoTaskStatus::Pause
669            }
670        }
671    }
672
673    fn shutdown_filters(&self) {
674        let io = &self.0;
675        let st = &self.0 .0;
676        let flags = st.flags.get();
677        if flags.contains(Flags::IO_STOPPING_FILTERS)
678            && !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
679        {
680            match io.filter().shutdown(io, &st.buffer, 0) {
681                Ok(Poll::Ready(())) => {
682                    st.dispatch_task.wake();
683                    st.insert_flags(Flags::IO_STOPPING);
684                }
685                Ok(Poll::Pending) => {
686                    // check read buffer, if buffer is not consumed it is unlikely
687                    // that filter will properly complete shutdown
688                    let flags = st.flags.get();
689                    if flags.contains(Flags::RD_PAUSED)
690                        || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
691                    {
692                        st.dispatch_task.wake();
693                        st.insert_flags(Flags::IO_STOPPING);
694                    }
695                }
696                Err(err) => {
697                    st.io_stopped(Some(err));
698                }
699            }
700            if let Err(err) = io.filter().process_write_buf(io, &st.buffer, 0) {
701                st.io_stopped(Some(err));
702            }
703        }
704    }
705}
706
707impl Clone for IoContext {
708    fn clone(&self) -> Self {
709        Self(self.0.clone())
710    }
711}