1use spargio::{RuntimeError, RuntimeHandle};
7use std::io;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Copy, Default)]
11pub struct BlockingOptions {
12 timeout: Option<Duration>,
13}
14
15impl BlockingOptions {
16 pub fn with_timeout(mut self, timeout: Duration) -> Self {
17 self.timeout = Some(timeout);
18 self
19 }
20
21 pub fn timeout(self) -> Option<Duration> {
22 self.timeout
23 }
24}
25
26pub async fn tls_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
27where
28 T: Send + 'static,
29 F: FnOnce() -> io::Result<T> + Send + 'static,
30{
31 tls_blocking_with_options(handle, BlockingOptions::default(), f).await
32}
33
34pub async fn tls_blocking_with_options<T, F>(
35 handle: &RuntimeHandle,
36 options: BlockingOptions,
37 f: F,
38) -> io::Result<T>
39where
40 T: Send + 'static,
41 F: FnOnce() -> io::Result<T> + Send + 'static,
42{
43 run_blocking(
44 handle,
45 options,
46 f,
47 "tls blocking task canceled",
48 "tls blocking task timed out",
49 )
50 .await
51}
52
53pub async fn ws_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
54where
55 T: Send + 'static,
56 F: FnOnce() -> io::Result<T> + Send + 'static,
57{
58 ws_blocking_with_options(handle, BlockingOptions::default(), f).await
59}
60
61pub async fn ws_blocking_with_options<T, F>(
62 handle: &RuntimeHandle,
63 options: BlockingOptions,
64 f: F,
65) -> io::Result<T>
66where
67 T: Send + 'static,
68 F: FnOnce() -> io::Result<T> + Send + 'static,
69{
70 run_blocking(
71 handle,
72 options,
73 f,
74 "ws blocking task canceled",
75 "ws blocking task timed out",
76 )
77 .await
78}
79
80pub async fn quic_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
81where
82 T: Send + 'static,
83 F: FnOnce() -> io::Result<T> + Send + 'static,
84{
85 quic_blocking_with_options(handle, BlockingOptions::default(), f).await
86}
87
88pub async fn quic_blocking_with_options<T, F>(
89 handle: &RuntimeHandle,
90 options: BlockingOptions,
91 f: F,
92) -> io::Result<T>
93where
94 T: Send + 'static,
95 F: FnOnce() -> io::Result<T> + Send + 'static,
96{
97 run_blocking(
98 handle,
99 options,
100 f,
101 "quic blocking task canceled",
102 "quic blocking task timed out",
103 )
104 .await
105}
106
107async fn run_blocking<T, F>(
108 handle: &RuntimeHandle,
109 options: BlockingOptions,
110 f: F,
111 canceled_msg: &'static str,
112 timeout_msg: &'static str,
113) -> io::Result<T>
114where
115 T: Send + 'static,
116 F: FnOnce() -> io::Result<T> + Send + 'static,
117{
118 let join = handle
119 .spawn_blocking(f)
120 .map_err(runtime_error_to_io_for_blocking)?;
121 let joined = match options.timeout() {
122 Some(duration) => match spargio::timeout(duration, join).await {
123 Ok(result) => result,
124 Err(_) => return Err(io::Error::new(io::ErrorKind::TimedOut, timeout_msg)),
125 },
126 None => join.await,
127 };
128 joined.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, canceled_msg))?
129}
130
131fn runtime_error_to_io_for_blocking(err: RuntimeError) -> io::Error {
132 match err {
133 RuntimeError::InvalidConfig(msg) => io::Error::new(io::ErrorKind::InvalidInput, msg),
134 RuntimeError::ThreadSpawn(io) => io,
135 RuntimeError::InvalidShard(shard) => {
136 io::Error::new(io::ErrorKind::NotFound, format!("invalid shard {shard}"))
137 }
138 RuntimeError::Closed => io::Error::new(io::ErrorKind::BrokenPipe, "runtime closed"),
139 RuntimeError::Overloaded => io::Error::new(io::ErrorKind::WouldBlock, "runtime overloaded"),
140 RuntimeError::UnsupportedBackend(msg) => io::Error::new(io::ErrorKind::Unsupported, msg),
141 RuntimeError::IoUringInit(io) => io,
142 }
143}
144
145#[cfg(all(feature = "uring-native", target_os = "linux"))]
146pub mod io_compat {
147 use futures::io::{AsyncRead, AsyncWrite};
148 use spargio::net::TcpStream;
149 use std::future::Future;
150 use std::io;
151 use std::pin::Pin;
152 use std::task::{Context, Poll};
153
154 type ReadOp = Pin<Box<dyn Future<Output = io::Result<(usize, Vec<u8>)>> + Send + 'static>>;
155 type WriteOp = Pin<Box<dyn Future<Output = io::Result<usize>> + Send + 'static>>;
156
157 pub struct FuturesTcpStream {
158 inner: TcpStream,
159 read_op: Option<ReadOp>,
160 write_op: Option<WriteOp>,
161 }
162
163 impl std::fmt::Debug for FuturesTcpStream {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("FuturesTcpStream")
166 .field("fd", &self.inner.as_raw_fd())
167 .field("session_shard", &self.inner.session_shard())
168 .finish()
169 }
170 }
171
172 impl FuturesTcpStream {
173 pub fn new(inner: TcpStream) -> Self {
174 Self {
175 inner,
176 read_op: None,
177 write_op: None,
178 }
179 }
180
181 pub fn get_ref(&self) -> &TcpStream {
182 &self.inner
183 }
184
185 pub fn into_inner(self) -> TcpStream {
186 self.inner
187 }
188 }
189
190 impl Unpin for FuturesTcpStream {}
191
192 impl AsyncRead for FuturesTcpStream {
193 fn poll_read(
194 mut self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 buf: &mut [u8],
197 ) -> Poll<io::Result<usize>> {
198 if buf.is_empty() {
199 return Poll::Ready(Ok(0));
200 }
201
202 if self.read_op.is_none() {
203 let inner = self.inner.clone();
204 let want = buf.len().max(1);
205 self.read_op = Some(Box::pin(
206 async move { inner.recv_owned(vec![0u8; want]).await },
207 ));
208 }
209
210 match self
211 .read_op
212 .as_mut()
213 .expect("read op set")
214 .as_mut()
215 .poll(cx)
216 {
217 Poll::Pending => Poll::Pending,
218 Poll::Ready(result) => {
219 self.read_op = None;
220 let (got, payload) = result?;
221 let got = got.min(payload.len()).min(buf.len());
222 buf[..got].copy_from_slice(&payload[..got]);
223 Poll::Ready(Ok(got))
224 }
225 }
226 }
227 }
228
229 impl AsyncWrite for FuturesTcpStream {
230 fn poll_write(
231 mut self: Pin<&mut Self>,
232 cx: &mut Context<'_>,
233 buf: &[u8],
234 ) -> Poll<io::Result<usize>> {
235 if buf.is_empty() {
236 return Poll::Ready(Ok(0));
237 }
238
239 if self.write_op.is_none() {
240 let inner = self.inner.clone();
241 let payload = buf.to_vec();
242 let payload_len = payload.len();
243 self.write_op = Some(Box::pin(async move {
244 let (written, _) = inner.send_owned(payload).await?;
245 Ok(written.min(payload_len))
246 }));
247 }
248
249 match self
250 .write_op
251 .as_mut()
252 .expect("write op set")
253 .as_mut()
254 .poll(cx)
255 {
256 Poll::Pending => Poll::Pending,
257 Poll::Ready(result) => {
258 self.write_op = None;
259 Poll::Ready(result)
260 }
261 }
262 }
263
264 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
265 Poll::Ready(Ok(()))
266 }
267
268 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
269 Poll::Ready(Ok(()))
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use futures::executor::block_on;
278 use std::time::Duration;
279
280 #[test]
281 fn protocol_blocking_helpers_execute_closure() {
282 let rt = spargio::Runtime::builder()
283 .shards(1)
284 .build()
285 .expect("runtime");
286 let handle = rt.handle();
287
288 let tls = block_on(async { tls_blocking(&handle, || Ok::<_, io::Error>(11usize)).await })
289 .expect("tls");
290 let ws = block_on(async { ws_blocking(&handle, || Ok::<_, io::Error>(22usize)).await })
291 .expect("ws");
292 let quic = block_on(async { quic_blocking(&handle, || Ok::<_, io::Error>(33usize)).await })
293 .expect("quic");
294
295 assert_eq!(tls + ws + quic, 66);
296 }
297
298 #[test]
299 fn blocking_timeout_returns_timed_out() {
300 let rt = spargio::Runtime::builder()
301 .shards(1)
302 .build()
303 .expect("runtime");
304 let err = block_on(async {
305 tls_blocking_with_options(
306 &rt.handle(),
307 BlockingOptions::default().with_timeout(Duration::from_millis(5)),
308 || {
309 std::thread::sleep(Duration::from_millis(30));
310 Ok::<(), io::Error>(())
311 },
312 )
313 .await
314 .expect_err("timeout")
315 });
316 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
317 }
318}