mondayio_compat/
safe_wrapper.rs1use std::{cell::UnsafeCell, io};
2
3use monoio::{
4 buf::IoBufMut,
5 io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt, Split},
6 BufResult,
7};
8
9use crate::{box_future::MaybeArmedBoxFuture, buf::Buf};
10
11pub struct StreamWrapper<T> {
15 stream: UnsafeCell<T>,
16 read_buf: Option<Buf>,
17 write_buf: Option<Buf>,
18
19 read_fut: MaybeArmedBoxFuture<BufResult<usize, Buf>>,
20 write_fut: MaybeArmedBoxFuture<BufResult<usize, Buf>>,
21 flush_fut: MaybeArmedBoxFuture<io::Result<()>>,
22 shutdown_fut: MaybeArmedBoxFuture<io::Result<()>>,
23}
24
25unsafe impl<T: Split> Split for StreamWrapper<T> {}
26
27impl<T> StreamWrapper<T> {
28 pub fn into_inner(self) -> T {
30 self.stream.into_inner()
31 }
32
33 pub fn new_with_buffer_size(stream: T, read_buffer: usize, write_buffer: usize) -> Self {
35 let r_buf = Buf::new(read_buffer);
36 let w_buf = Buf::new(write_buffer);
37
38 Self {
39 stream: UnsafeCell::new(stream),
40 read_buf: Some(r_buf),
41 write_buf: Some(w_buf),
42 read_fut: Default::default(),
43 write_fut: Default::default(),
44 flush_fut: Default::default(),
45 shutdown_fut: Default::default(),
46 }
47 }
48
49 pub fn new(stream: T) -> Self {
51 const DEFAULT_READ_BUFFER: usize = 8 * 1024;
52 const DEFAULT_WRITE_BUFFER: usize = 8 * 1024;
53 Self::new_with_buffer_size(stream, DEFAULT_READ_BUFFER, DEFAULT_WRITE_BUFFER)
54 }
55}
56
57impl<T: AsyncReadRent + Unpin + 'static> tokio::io::AsyncRead for StreamWrapper<T> {
58 fn poll_read(
59 self: std::pin::Pin<&mut Self>,
60 cx: &mut std::task::Context<'_>,
61 buf: &mut tokio::io::ReadBuf<'_>,
62 ) -> std::task::Poll<std::io::Result<()>> {
63 let this = self.get_mut();
64
65 loop {
66 if !this.read_fut.armed() {
68 let read_buf_mut = unsafe { this.read_buf.as_mut().unwrap_unchecked() };
70 if !read_buf_mut.is_empty() {
71 let our_buf = read_buf_mut.buf_to_read(buf.remaining());
73 let our_buf_len = our_buf.len();
74 buf.put_slice(our_buf);
75 unsafe { read_buf_mut.advance_offset(our_buf_len) };
76 return std::task::Poll::Ready(Ok(()));
77 }
78
79 let buf = unsafe { this.read_buf.take().unwrap_unchecked() };
81 let stream = unsafe { &mut *this.stream.get() };
83 this.read_fut.arm_future(AsyncReadRent::read(stream, buf));
84 }
85
86 let (ret, buf) = match this.read_fut.poll(cx) {
88 std::task::Poll::Ready(out) => out,
89 std::task::Poll::Pending => {
90 return std::task::Poll::Pending;
91 }
92 };
93 this.read_buf = Some(buf);
94 if ret? == 0 {
95 return std::task::Poll::Ready(Ok(()));
97 }
98 }
99 }
100}
101
102impl<T: AsyncWriteRent + Unpin + 'static> tokio::io::AsyncWrite for StreamWrapper<T> {
103 fn poll_write(
104 self: std::pin::Pin<&mut Self>,
105 cx: &mut std::task::Context<'_>,
106 buf: &[u8],
107 ) -> std::task::Poll<Result<usize, std::io::Error>> {
108 if buf.is_empty() {
109 return std::task::Poll::Ready(Ok(0));
110 }
111 let this = self.get_mut();
112
113 if this.write_fut.armed() {
117 let (ret, mut owned_buf) = match this.write_fut.poll(cx) {
118 std::task::Poll::Ready(r) => r,
119 std::task::Poll::Pending => {
120 return std::task::Poll::Pending;
121 }
122 };
123 unsafe { owned_buf.set_init(0) };
125 this.write_buf = Some(owned_buf);
126 if ret.is_err() {
127 return std::task::Poll::Ready(ret);
128 }
129 }
130
131 let mut owned_buf = unsafe { this.write_buf.take().unwrap_unchecked() };
139 let owned_buf_mut = owned_buf.buf_to_write();
140 let len = buf.len().min(owned_buf_mut.len());
141 unsafe { std::ptr::copy_nonoverlapping(buf.as_ptr(), owned_buf_mut.as_mut_ptr(), len) };
144 unsafe { owned_buf.set_init(len) };
145
146 let stream = unsafe { &mut *this.stream.get() };
148 this.write_fut
149 .arm_future(AsyncWriteRentExt::write_all(stream, owned_buf));
150 match this.write_fut.poll(cx) {
151 std::task::Poll::Ready((ret, mut buf)) => {
152 unsafe { buf.set_init(0) };
153 this.write_buf = Some(buf);
154 if ret.is_err() {
155 return std::task::Poll::Ready(ret);
156 }
157 }
158 std::task::Poll::Pending => (),
159 }
160 std::task::Poll::Ready(Ok(len))
163 }
164
165 fn poll_flush(
166 self: std::pin::Pin<&mut Self>,
167 cx: &mut std::task::Context<'_>,
168 ) -> std::task::Poll<Result<(), std::io::Error>> {
169 let this = self.get_mut();
170
171 if this.write_fut.armed() {
172 match this.write_fut.poll(cx) {
173 std::task::Poll::Ready((ret, mut buf)) => {
174 unsafe { buf.set_init(0) };
175 this.write_buf = Some(buf);
176 if let Err(e) = ret {
177 return std::task::Poll::Ready(Err(e));
178 }
179 }
180 std::task::Poll::Pending => return std::task::Poll::Pending,
181 }
182 }
183
184 if !this.flush_fut.armed() {
185 let stream = unsafe { &mut *this.stream.get() };
186 this.flush_fut.arm_future(stream.flush());
187 }
188 this.flush_fut.poll(cx)
189 }
190
191 fn poll_shutdown(
192 self: std::pin::Pin<&mut Self>,
193 cx: &mut std::task::Context<'_>,
194 ) -> std::task::Poll<Result<(), std::io::Error>> {
195 let this = self.get_mut();
196
197 if this.write_fut.armed() {
198 match this.write_fut.poll(cx) {
199 std::task::Poll::Ready((ret, mut buf)) => {
200 unsafe { buf.set_init(0) };
201 this.write_buf = Some(buf);
202 if let Err(e) = ret {
203 return std::task::Poll::Ready(Err(e));
204 }
205 }
206 std::task::Poll::Pending => return std::task::Poll::Pending,
207 }
208 }
209
210 if !this.shutdown_fut.armed() {
211 let stream = unsafe { &mut *this.stream.get() };
212 this.shutdown_fut.arm_future(stream.shutdown());
213 }
214 this.shutdown_fut.poll(cx)
215 }
216}