ntex_tokio/
io.rs

1use std::task::{Context, Poll, ready};
2use std::{any, cell::RefCell, cmp, future::poll_fn, io, mem, pin::Pin, rc::Rc, rc::Weak};
3
4use ntex_bytes::{BufMut, BytesVec};
5use ntex_io::{
6    Filter, Handle, Io, IoBoxed, IoContext, IoStream, IoTaskStatus, Readiness, types,
7};
8use ntex_util::time::Millis;
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio::net::TcpStream;
11
12impl IoStream for crate::TcpStream {
13    fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
14        let io = Rc::new(RefCell::new(self.0));
15        tokio::task::spawn_local(run(io.clone(), ctx));
16        Some(Box::new(HandleWrapper(io)))
17    }
18}
19
20#[cfg(unix)]
21impl IoStream for crate::UnixStream {
22    fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
23        let io = Rc::new(RefCell::new(self.0));
24        tokio::task::spawn_local(run(io.clone(), ctx));
25        None
26    }
27}
28
29struct HandleWrapper(Rc<RefCell<TcpStream>>);
30
31impl Handle for HandleWrapper {
32    fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
33        if id == any::TypeId::of::<types::PeerAddr>() {
34            if let Ok(addr) = self.0.borrow().peer_addr() {
35                return Some(Box::new(types::PeerAddr(addr)));
36            }
37        } else if id == any::TypeId::of::<SocketOptions>() {
38            return Some(Box::new(SocketOptions(Rc::downgrade(&self.0))));
39        }
40        None
41    }
42}
43
44#[derive(Copy, Clone, Debug, PartialEq, Eq)]
45enum Status {
46    Shutdown,
47    Terminate,
48}
49
50async fn run<T>(io: Rc<RefCell<T>>, ctx: IoContext)
51where
52    T: AsyncRead + AsyncWrite + Unpin,
53{
54    let st = poll_fn(|cx| turn(&mut *io.borrow_mut(), &ctx, cx)).await;
55
56    log::trace!("{}: Shuting down io {:?}", ctx.tag(), ctx.is_stopped());
57    if !ctx.is_stopped() {
58        let flush = st == Status::Shutdown;
59        let _ = poll_fn(|cx| {
60            if write(&mut *io.borrow_mut(), &ctx, cx) == Poll::Ready(Status::Terminate) {
61                Poll::Ready(())
62            } else {
63                ctx.shutdown(flush, cx)
64            }
65        })
66        .await;
67    }
68
69    let _ = poll_fn(|cx| Pin::new(&mut *io.borrow_mut()).poll_shutdown(cx)).await;
70
71    log::trace!("{}: Shutdown complete", ctx.tag());
72    if !ctx.is_stopped() {
73        ctx.stop(None);
74    }
75}
76
77fn turn<T>(io: &mut T, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status>
78where
79    T: AsyncRead + AsyncWrite + Unpin,
80{
81    let read = match ctx.poll_read_ready(cx) {
82        Poll::Ready(Readiness::Ready) => read(io, ctx, cx),
83        Poll::Ready(Readiness::Shutdown) | Poll::Ready(Readiness::Terminate) => {
84            Poll::Ready(())
85        }
86        Poll::Pending => Poll::Pending,
87    };
88
89    let write = match ctx.poll_write_ready(cx) {
90        Poll::Ready(Readiness::Ready) => write(io, ctx, cx),
91        Poll::Ready(Readiness::Shutdown) => Poll::Ready(Status::Shutdown),
92        Poll::Ready(Readiness::Terminate) => Poll::Ready(Status::Terminate),
93        Poll::Pending => Poll::Pending,
94    };
95
96    if read.is_pending() && write.is_pending() {
97        Poll::Pending
98    } else if write.is_ready() {
99        write
100    } else {
101        Poll::Ready(Status::Terminate)
102    }
103}
104
105fn write<T>(io: &mut T, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status>
106where
107    T: AsyncRead + AsyncWrite + Unpin,
108{
109    if let Some(mut buf) = ctx.get_write_buf() {
110        let result = write_io(io, &mut buf, cx);
111        if ctx.release_write_buf(buf, result) == IoTaskStatus::Stop {
112            Poll::Ready(Status::Terminate)
113        } else {
114            Poll::Pending
115        }
116    } else {
117        Poll::Pending
118    }
119}
120
121fn read<T: AsyncRead + Unpin>(
122    io: &mut T,
123    ctx: &IoContext,
124    cx: &mut Context<'_>,
125) -> Poll<()> {
126    let mut buf = ctx.get_read_buf();
127
128    // read data from socket
129    let mut n = 0;
130    loop {
131        ctx.resize_read_buf(&mut buf);
132        let result = match read_buf(Pin::new(&mut *io), cx, &mut buf) {
133            Poll::Pending => {
134                if n > 0 {
135                    Poll::Ready(Ok(()))
136                } else {
137                    Poll::Pending
138                }
139            }
140            Poll::Ready(Ok(0)) => Poll::Ready(Err(None)),
141            Poll::Ready(Ok(size)) => {
142                n += size;
143                continue;
144            }
145            Poll::Ready(Err(err)) => Poll::Ready(Err(Some(err))),
146        };
147
148        return if matches!(ctx.release_read_buf(n, buf, result), IoTaskStatus::Stop) {
149            Poll::Ready(())
150        } else {
151            Poll::Pending
152        };
153    }
154}
155
156fn read_buf<T: AsyncRead>(
157    io: Pin<&mut T>,
158    cx: &mut Context<'_>,
159    buf: &mut BytesVec,
160) -> Poll<io::Result<usize>> {
161    let n = {
162        let dst =
163            unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) };
164        let mut buf = ReadBuf::uninit(dst);
165        let ptr = buf.filled().as_ptr();
166        if io.poll_read(cx, &mut buf)?.is_pending() {
167            return Poll::Pending;
168        }
169
170        // Ensure the pointer does not change from under us
171        assert_eq!(ptr, buf.filled().as_ptr());
172        buf.filled().len()
173    };
174
175    // Safety: This is guaranteed to be the number of initialized (and read)
176    // bytes due to the invariants provided by `ReadBuf::filled`.
177    unsafe {
178        buf.advance_mut(n);
179    }
180
181    Poll::Ready(Ok(n))
182}
183
184/// Flush write buffer to underlying I/O stream.
185fn write_io<T: AsyncRead + AsyncWrite + Unpin>(
186    io: &mut T,
187    buf: &mut BytesVec,
188    cx: &mut Context<'_>,
189) -> Poll<io::Result<usize>> {
190    let len = buf.len();
191
192    if len != 0 {
193        // log::trace!("Flushing framed transport: {len:?}");
194
195        let mut written = 0;
196        while let Poll::Ready(n) = Pin::new(&mut *io).poll_write(cx, &buf[written..])? {
197            if n == 0 {
198                return Poll::Ready(Err(io::Error::new(
199                    io::ErrorKind::WriteZero,
200                    "failed to write frame to transport",
201                )));
202            } else {
203                written += n;
204                if written == len {
205                    break;
206                }
207            }
208        }
209        // log::trace!("flushed {written} bytes");
210
211        // flush
212        if written > 0 {
213            let _ = Pin::new(&mut *io).poll_flush(cx)?;
214            Poll::Ready(Ok(written))
215        } else {
216            Poll::Pending
217        }
218    } else {
219        Poll::Pending
220    }
221}
222
223pub struct TokioIoBoxed(IoBoxed);
224
225impl std::ops::Deref for TokioIoBoxed {
226    type Target = IoBoxed;
227
228    #[inline]
229    fn deref(&self) -> &Self::Target {
230        &self.0
231    }
232}
233
234impl From<IoBoxed> for TokioIoBoxed {
235    fn from(io: IoBoxed) -> TokioIoBoxed {
236        TokioIoBoxed(io)
237    }
238}
239
240impl<F: Filter> From<Io<F>> for TokioIoBoxed {
241    fn from(io: Io<F>) -> TokioIoBoxed {
242        TokioIoBoxed(IoBoxed::from(io))
243    }
244}
245
246impl AsyncRead for TokioIoBoxed {
247    fn poll_read(
248        self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &mut ReadBuf<'_>,
251    ) -> Poll<io::Result<()>> {
252        let len = self.0.with_read_buf(|src| {
253            let len = cmp::min(src.len(), buf.remaining());
254            buf.put_slice(&src.split_to(len));
255            len
256        });
257
258        if len == 0 {
259            match ready!(self.0.poll_read_ready(cx)) {
260                Ok(Some(())) => Poll::Pending,
261                Err(e) => Poll::Ready(Err(e)),
262                Ok(None) => Poll::Ready(Ok(())),
263            }
264        } else {
265            Poll::Ready(Ok(()))
266        }
267    }
268}
269
270impl AsyncWrite for TokioIoBoxed {
271    fn poll_write(
272        self: Pin<&mut Self>,
273        _: &mut Context<'_>,
274        buf: &[u8],
275    ) -> Poll<io::Result<usize>> {
276        Poll::Ready(self.0.write(buf).map(|_| buf.len()))
277    }
278
279    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280        self.as_ref().0.poll_flush(cx, false)
281    }
282
283    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
284        self.as_ref().0.poll_shutdown(cx)
285    }
286}
287
288/// Query TCP Io connections for a handle to set socket options
289pub struct SocketOptions(Weak<RefCell<TcpStream>>);
290
291impl SocketOptions {
292    pub fn set_linger(&self, dur: Option<Millis>) -> io::Result<()> {
293        self.try_self()
294            .and_then(|s| s.borrow().set_linger(dur.map(|d| d.into())))
295    }
296
297    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
298        self.try_self().and_then(|s| s.borrow().set_ttl(ttl))
299    }
300
301    fn try_self(&self) -> io::Result<Rc<RefCell<TcpStream>>> {
302        self.0
303            .upgrade()
304            .ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "socket is gone"))
305    }
306}