mondayio_compat/
safe_wrapper.rs

1use std::{cell::UnsafeCell, io};
2
3use monoio::{
4    buf::IoBufMut,
5    io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt, Split},
6    BufResult,
7};
8
9use crate::{box_future::MaybeArmedBoxFuture, buf::Buf};
10
11/// A wrapper for stream with ownership that impl AsyncReadRent and AsyncWriteRent.
12/// The Wrapper will impl tokio AsyncRead and AsyncWrite.
13/// Mainly used for compatible.
14pub struct StreamWrapper<T> {
15    stream: UnsafeCell<T>,
16    read_buf: Option<Buf>,
17    write_buf: Option<Buf>,
18
19    read_fut: MaybeArmedBoxFuture<BufResult<usize, Buf>>,
20    write_fut: MaybeArmedBoxFuture<BufResult<usize, Buf>>,
21    flush_fut: MaybeArmedBoxFuture<io::Result<()>>,
22    shutdown_fut: MaybeArmedBoxFuture<io::Result<()>>,
23}
24
25unsafe impl<T: Split> Split for StreamWrapper<T> {}
26
27impl<T> StreamWrapper<T> {
28    /// Consume self and get inner T.
29    pub fn into_inner(self) -> T {
30        self.stream.into_inner()
31    }
32
33    /// Creates a new `TcpStreamCompat` from a monoio `TcpStream` or `UnixStream`.
34    pub fn new_with_buffer_size(stream: T, read_buffer: usize, write_buffer: usize) -> Self {
35        let r_buf = Buf::new(read_buffer);
36        let w_buf = Buf::new(write_buffer);
37
38        Self {
39            stream: UnsafeCell::new(stream),
40            read_buf: Some(r_buf),
41            write_buf: Some(w_buf),
42            read_fut: Default::default(),
43            write_fut: Default::default(),
44            flush_fut: Default::default(),
45            shutdown_fut: Default::default(),
46        }
47    }
48
49    /// Creates a new `TcpStreamCompat` from a monoio `TcpStream` or `UnixStream`.
50    pub fn new(stream: T) -> Self {
51        const DEFAULT_READ_BUFFER: usize = 8 * 1024;
52        const DEFAULT_WRITE_BUFFER: usize = 8 * 1024;
53        Self::new_with_buffer_size(stream, DEFAULT_READ_BUFFER, DEFAULT_WRITE_BUFFER)
54    }
55}
56
57impl<T: AsyncReadRent + Unpin + 'static> tokio::io::AsyncRead for StreamWrapper<T> {
58    fn poll_read(
59        self: std::pin::Pin<&mut Self>,
60        cx: &mut std::task::Context<'_>,
61        buf: &mut tokio::io::ReadBuf<'_>,
62    ) -> std::task::Poll<std::io::Result<()>> {
63        let this = self.get_mut();
64
65        loop {
66            // if the future not armed, this means maybe buffer has data.
67            if !this.read_fut.armed() {
68                // if there is some data left in our buf, copy it and return.
69                let read_buf_mut = unsafe { this.read_buf.as_mut().unwrap_unchecked() };
70                if !read_buf_mut.is_empty() {
71                    // copy directly from inner buf to buf
72                    let our_buf = read_buf_mut.buf_to_read(buf.remaining());
73                    let our_buf_len = our_buf.len();
74                    buf.put_slice(our_buf);
75                    unsafe { read_buf_mut.advance_offset(our_buf_len) };
76                    return std::task::Poll::Ready(Ok(()));
77                }
78
79                // there is no data in buffer. we will construct the future
80                let buf = unsafe { this.read_buf.take().unwrap_unchecked() };
81                // we must leak the stream
82                let stream = unsafe { &mut *this.stream.get() };
83                this.read_fut.arm_future(AsyncReadRent::read(stream, buf));
84            }
85
86            // the future slot is armed now. we will poll it.
87            let (ret, buf) = match this.read_fut.poll(cx) {
88                std::task::Poll::Ready(out) => out,
89                std::task::Poll::Pending => {
90                    return std::task::Poll::Pending;
91                }
92            };
93            this.read_buf = Some(buf);
94            if ret? == 0 {
95                // on eof, return directly; otherwise goto next loop.
96                return std::task::Poll::Ready(Ok(()));
97            }
98        }
99    }
100}
101
102impl<T: AsyncWriteRent + Unpin + 'static> tokio::io::AsyncWrite for StreamWrapper<T> {
103    fn poll_write(
104        self: std::pin::Pin<&mut Self>,
105        cx: &mut std::task::Context<'_>,
106        buf: &[u8],
107    ) -> std::task::Poll<Result<usize, std::io::Error>> {
108        if buf.is_empty() {
109            return std::task::Poll::Ready(Ok(0));
110        }
111        let this = self.get_mut();
112
113        // if there is some future armed, we must poll it until ready.
114        // if it returns error, we will return it;
115        // if it returns ok, we ignore it.
116        if this.write_fut.armed() {
117            let (ret, mut owned_buf) = match this.write_fut.poll(cx) {
118                std::task::Poll::Ready(r) => r,
119                std::task::Poll::Pending => {
120                    return std::task::Poll::Pending;
121                }
122            };
123            // clear the buffer
124            unsafe { owned_buf.set_init(0) };
125            this.write_buf = Some(owned_buf);
126            if ret.is_err() {
127                return std::task::Poll::Ready(ret);
128            }
129        }
130
131        // now we should arm it again.
132        // we will copy the data and return Ready.
133        // Though return Ready does not mean really ready, but this helps preventing
134        // poll_write different data.
135
136        // # Safety
137        // We always make sure the write_buf is Some.
138        let mut owned_buf = unsafe { this.write_buf.take().unwrap_unchecked() };
139        let owned_buf_mut = owned_buf.buf_to_write();
140        let len = buf.len().min(owned_buf_mut.len());
141        // # Safety
142        // We can make sure the buf and buf_mut_slice have len size data.
143        unsafe { std::ptr::copy_nonoverlapping(buf.as_ptr(), owned_buf_mut.as_mut_ptr(), len) };
144        unsafe { owned_buf.set_init(len) };
145
146        // we must leak the stream
147        let stream = unsafe { &mut *this.stream.get() };
148        this.write_fut
149            .arm_future(AsyncWriteRentExt::write_all(stream, owned_buf));
150        match this.write_fut.poll(cx) {
151            std::task::Poll::Ready((ret, mut buf)) => {
152                unsafe { buf.set_init(0) };
153                this.write_buf = Some(buf);
154                if ret.is_err() {
155                    return std::task::Poll::Ready(ret);
156                }
157            }
158            std::task::Poll::Pending => (),
159        }
160        // if there is no error, no matter it is sending or sent, we will
161        // return Ready.
162        std::task::Poll::Ready(Ok(len))
163    }
164
165    fn poll_flush(
166        self: std::pin::Pin<&mut Self>,
167        cx: &mut std::task::Context<'_>,
168    ) -> std::task::Poll<Result<(), std::io::Error>> {
169        let this = self.get_mut();
170
171        if this.write_fut.armed() {
172            match this.write_fut.poll(cx) {
173                std::task::Poll::Ready((ret, mut buf)) => {
174                    unsafe { buf.set_init(0) };
175                    this.write_buf = Some(buf);
176                    if let Err(e) = ret {
177                        return std::task::Poll::Ready(Err(e));
178                    }
179                }
180                std::task::Poll::Pending => return std::task::Poll::Pending,
181            }
182        }
183
184        if !this.flush_fut.armed() {
185            let stream = unsafe { &mut *this.stream.get() };
186            this.flush_fut.arm_future(stream.flush());
187        }
188        this.flush_fut.poll(cx)
189    }
190
191    fn poll_shutdown(
192        self: std::pin::Pin<&mut Self>,
193        cx: &mut std::task::Context<'_>,
194    ) -> std::task::Poll<Result<(), std::io::Error>> {
195        let this = self.get_mut();
196
197        if this.write_fut.armed() {
198            match this.write_fut.poll(cx) {
199                std::task::Poll::Ready((ret, mut buf)) => {
200                    unsafe { buf.set_init(0) };
201                    this.write_buf = Some(buf);
202                    if let Err(e) = ret {
203                        return std::task::Poll::Ready(Err(e));
204                    }
205                }
206                std::task::Poll::Pending => return std::task::Poll::Pending,
207            }
208        }
209
210        if !this.shutdown_fut.armed() {
211            let stream = unsafe { &mut *this.stream.get() };
212            this.shutdown_fut.arm_future(stream.shutdown());
213        }
214        this.shutdown_fut.poll(cx)
215    }
216}