1use std::task::{Context, Poll, ready};
2use std::{any, cell::RefCell, cmp, future::poll_fn, io, mem, pin::Pin, rc::Rc, rc::Weak};
3
4use ntex_bytes::{BufMut, BytesVec};
5use ntex_io::{
6 Filter, Handle, Io, IoBoxed, IoContext, IoStream, IoTaskStatus, Readiness, types,
7};
8use ntex_util::time::Millis;
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio::net::TcpStream;
11
12impl IoStream for crate::TcpStream {
13 fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
14 let io = Rc::new(RefCell::new(self.0));
15 tokio::task::spawn_local(run(io.clone(), ctx));
16 Some(Box::new(HandleWrapper(io)))
17 }
18}
19
20#[cfg(unix)]
21impl IoStream for crate::UnixStream {
22 fn start(self, ctx: IoContext) -> Option<Box<dyn Handle>> {
23 let io = Rc::new(RefCell::new(self.0));
24 tokio::task::spawn_local(run(io.clone(), ctx));
25 None
26 }
27}
28
29struct HandleWrapper(Rc<RefCell<TcpStream>>);
30
31impl Handle for HandleWrapper {
32 fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
33 if id == any::TypeId::of::<types::PeerAddr>() {
34 if let Ok(addr) = self.0.borrow().peer_addr() {
35 return Some(Box::new(types::PeerAddr(addr)));
36 }
37 } else if id == any::TypeId::of::<SocketOptions>() {
38 return Some(Box::new(SocketOptions(Rc::downgrade(&self.0))));
39 }
40 None
41 }
42}
43
44#[derive(Copy, Clone, Debug, PartialEq, Eq)]
45enum Status {
46 Shutdown,
47 Terminate,
48}
49
50async fn run<T>(io: Rc<RefCell<T>>, ctx: IoContext)
51where
52 T: AsyncRead + AsyncWrite + Unpin,
53{
54 let st = poll_fn(|cx| turn(&mut *io.borrow_mut(), &ctx, cx)).await;
55
56 log::trace!("{}: Shuting down io {:?}", ctx.tag(), ctx.is_stopped());
57 if !ctx.is_stopped() {
58 let flush = st == Status::Shutdown;
59 let _ = poll_fn(|cx| {
60 if write(&mut *io.borrow_mut(), &ctx, cx) == Poll::Ready(Status::Terminate) {
61 Poll::Ready(())
62 } else {
63 ctx.shutdown(flush, cx)
64 }
65 })
66 .await;
67 }
68
69 let _ = poll_fn(|cx| Pin::new(&mut *io.borrow_mut()).poll_shutdown(cx)).await;
70
71 log::trace!("{}: Shutdown complete", ctx.tag());
72 if !ctx.is_stopped() {
73 ctx.stop(None);
74 }
75}
76
77fn turn<T>(io: &mut T, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status>
78where
79 T: AsyncRead + AsyncWrite + Unpin,
80{
81 let read = match ctx.poll_read_ready(cx) {
82 Poll::Ready(Readiness::Ready) => read(io, ctx, cx),
83 Poll::Ready(Readiness::Shutdown) | Poll::Ready(Readiness::Terminate) => {
84 Poll::Ready(())
85 }
86 Poll::Pending => Poll::Pending,
87 };
88
89 let write = match ctx.poll_write_ready(cx) {
90 Poll::Ready(Readiness::Ready) => write(io, ctx, cx),
91 Poll::Ready(Readiness::Shutdown) => Poll::Ready(Status::Shutdown),
92 Poll::Ready(Readiness::Terminate) => Poll::Ready(Status::Terminate),
93 Poll::Pending => Poll::Pending,
94 };
95
96 if read.is_pending() && write.is_pending() {
97 Poll::Pending
98 } else if write.is_ready() {
99 write
100 } else {
101 Poll::Ready(Status::Terminate)
102 }
103}
104
105fn write<T>(io: &mut T, ctx: &IoContext, cx: &mut Context<'_>) -> Poll<Status>
106where
107 T: AsyncRead + AsyncWrite + Unpin,
108{
109 if let Some(mut buf) = ctx.get_write_buf() {
110 let result = write_io(io, &mut buf, cx);
111 if ctx.release_write_buf(buf, result) == IoTaskStatus::Stop {
112 Poll::Ready(Status::Terminate)
113 } else {
114 Poll::Pending
115 }
116 } else {
117 Poll::Pending
118 }
119}
120
121fn read<T: AsyncRead + Unpin>(
122 io: &mut T,
123 ctx: &IoContext,
124 cx: &mut Context<'_>,
125) -> Poll<()> {
126 let mut buf = ctx.get_read_buf();
127
128 let mut n = 0;
130 loop {
131 ctx.resize_read_buf(&mut buf);
132 let result = match read_buf(Pin::new(&mut *io), cx, &mut buf) {
133 Poll::Pending => {
134 if n > 0 {
135 Poll::Ready(Ok(()))
136 } else {
137 Poll::Pending
138 }
139 }
140 Poll::Ready(Ok(0)) => Poll::Ready(Err(None)),
141 Poll::Ready(Ok(size)) => {
142 n += size;
143 continue;
144 }
145 Poll::Ready(Err(err)) => Poll::Ready(Err(Some(err))),
146 };
147
148 return if matches!(ctx.release_read_buf(n, buf, result), IoTaskStatus::Stop) {
149 Poll::Ready(())
150 } else {
151 Poll::Pending
152 };
153 }
154}
155
156fn read_buf<T: AsyncRead>(
157 io: Pin<&mut T>,
158 cx: &mut Context<'_>,
159 buf: &mut BytesVec,
160) -> Poll<io::Result<usize>> {
161 let n = {
162 let dst =
163 unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) };
164 let mut buf = ReadBuf::uninit(dst);
165 let ptr = buf.filled().as_ptr();
166 if io.poll_read(cx, &mut buf)?.is_pending() {
167 return Poll::Pending;
168 }
169
170 assert_eq!(ptr, buf.filled().as_ptr());
172 buf.filled().len()
173 };
174
175 unsafe {
178 buf.advance_mut(n);
179 }
180
181 Poll::Ready(Ok(n))
182}
183
184fn write_io<T: AsyncRead + AsyncWrite + Unpin>(
186 io: &mut T,
187 buf: &mut BytesVec,
188 cx: &mut Context<'_>,
189) -> Poll<io::Result<usize>> {
190 let len = buf.len();
191
192 if len != 0 {
193 let mut written = 0;
196 while let Poll::Ready(n) = Pin::new(&mut *io).poll_write(cx, &buf[written..])? {
197 if n == 0 {
198 return Poll::Ready(Err(io::Error::new(
199 io::ErrorKind::WriteZero,
200 "failed to write frame to transport",
201 )));
202 } else {
203 written += n;
204 if written == len {
205 break;
206 }
207 }
208 }
209 if written > 0 {
213 let _ = Pin::new(&mut *io).poll_flush(cx)?;
214 Poll::Ready(Ok(written))
215 } else {
216 Poll::Pending
217 }
218 } else {
219 Poll::Pending
220 }
221}
222
223pub struct TokioIoBoxed(IoBoxed);
224
225impl std::ops::Deref for TokioIoBoxed {
226 type Target = IoBoxed;
227
228 #[inline]
229 fn deref(&self) -> &Self::Target {
230 &self.0
231 }
232}
233
234impl From<IoBoxed> for TokioIoBoxed {
235 fn from(io: IoBoxed) -> TokioIoBoxed {
236 TokioIoBoxed(io)
237 }
238}
239
240impl<F: Filter> From<Io<F>> for TokioIoBoxed {
241 fn from(io: Io<F>) -> TokioIoBoxed {
242 TokioIoBoxed(IoBoxed::from(io))
243 }
244}
245
246impl AsyncRead for TokioIoBoxed {
247 fn poll_read(
248 self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &mut ReadBuf<'_>,
251 ) -> Poll<io::Result<()>> {
252 let len = self.0.with_read_buf(|src| {
253 let len = cmp::min(src.len(), buf.remaining());
254 buf.put_slice(&src.split_to(len));
255 len
256 });
257
258 if len == 0 {
259 match ready!(self.0.poll_read_ready(cx)) {
260 Ok(Some(())) => Poll::Pending,
261 Err(e) => Poll::Ready(Err(e)),
262 Ok(None) => Poll::Ready(Ok(())),
263 }
264 } else {
265 Poll::Ready(Ok(()))
266 }
267 }
268}
269
270impl AsyncWrite for TokioIoBoxed {
271 fn poll_write(
272 self: Pin<&mut Self>,
273 _: &mut Context<'_>,
274 buf: &[u8],
275 ) -> Poll<io::Result<usize>> {
276 Poll::Ready(self.0.write(buf).map(|_| buf.len()))
277 }
278
279 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280 self.as_ref().0.poll_flush(cx, false)
281 }
282
283 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
284 self.as_ref().0.poll_shutdown(cx)
285 }
286}
287
288pub struct SocketOptions(Weak<RefCell<TcpStream>>);
290
291impl SocketOptions {
292 pub fn set_linger(&self, dur: Option<Millis>) -> io::Result<()> {
293 self.try_self()
294 .and_then(|s| s.borrow().set_linger(dur.map(|d| d.into())))
295 }
296
297 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
298 self.try_self().and_then(|s| s.borrow().set_ttl(ttl))
299 }
300
301 fn try_self(&self) -> io::Result<Rc<RefCell<TcpStream>>> {
302 self.0
303 .upgrade()
304 .ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "socket is gone"))
305 }
306}