Skip to main content

spargio_protocols/
lib.rs

1//! Protocol integration companion APIs for spargio runtimes.
2//!
3//! These helpers provide explicit blocking bridges intended for TLS/WS/QUIC
4//! ecosystem integrations that do not natively target spargio executors.
5
6use spargio::{RuntimeError, RuntimeHandle};
7use std::io;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Copy, Default)]
11pub struct BlockingOptions {
12    timeout: Option<Duration>,
13}
14
15impl BlockingOptions {
16    pub fn with_timeout(mut self, timeout: Duration) -> Self {
17        self.timeout = Some(timeout);
18        self
19    }
20
21    pub fn timeout(self) -> Option<Duration> {
22        self.timeout
23    }
24}
25
26pub async fn tls_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
27where
28    T: Send + 'static,
29    F: FnOnce() -> io::Result<T> + Send + 'static,
30{
31    tls_blocking_with_options(handle, BlockingOptions::default(), f).await
32}
33
34pub async fn tls_blocking_with_options<T, F>(
35    handle: &RuntimeHandle,
36    options: BlockingOptions,
37    f: F,
38) -> io::Result<T>
39where
40    T: Send + 'static,
41    F: FnOnce() -> io::Result<T> + Send + 'static,
42{
43    run_blocking(
44        handle,
45        options,
46        f,
47        "tls blocking task canceled",
48        "tls blocking task timed out",
49    )
50    .await
51}
52
53pub async fn ws_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
54where
55    T: Send + 'static,
56    F: FnOnce() -> io::Result<T> + Send + 'static,
57{
58    ws_blocking_with_options(handle, BlockingOptions::default(), f).await
59}
60
61pub async fn ws_blocking_with_options<T, F>(
62    handle: &RuntimeHandle,
63    options: BlockingOptions,
64    f: F,
65) -> io::Result<T>
66where
67    T: Send + 'static,
68    F: FnOnce() -> io::Result<T> + Send + 'static,
69{
70    run_blocking(
71        handle,
72        options,
73        f,
74        "ws blocking task canceled",
75        "ws blocking task timed out",
76    )
77    .await
78}
79
80pub async fn quic_blocking<T, F>(handle: &RuntimeHandle, f: F) -> io::Result<T>
81where
82    T: Send + 'static,
83    F: FnOnce() -> io::Result<T> + Send + 'static,
84{
85    quic_blocking_with_options(handle, BlockingOptions::default(), f).await
86}
87
88pub async fn quic_blocking_with_options<T, F>(
89    handle: &RuntimeHandle,
90    options: BlockingOptions,
91    f: F,
92) -> io::Result<T>
93where
94    T: Send + 'static,
95    F: FnOnce() -> io::Result<T> + Send + 'static,
96{
97    run_blocking(
98        handle,
99        options,
100        f,
101        "quic blocking task canceled",
102        "quic blocking task timed out",
103    )
104    .await
105}
106
107async fn run_blocking<T, F>(
108    handle: &RuntimeHandle,
109    options: BlockingOptions,
110    f: F,
111    canceled_msg: &'static str,
112    timeout_msg: &'static str,
113) -> io::Result<T>
114where
115    T: Send + 'static,
116    F: FnOnce() -> io::Result<T> + Send + 'static,
117{
118    let join = handle
119        .spawn_blocking(f)
120        .map_err(runtime_error_to_io_for_blocking)?;
121    let joined = match options.timeout() {
122        Some(duration) => match spargio::timeout(duration, join).await {
123            Ok(result) => result,
124            Err(_) => return Err(io::Error::new(io::ErrorKind::TimedOut, timeout_msg)),
125        },
126        None => join.await,
127    };
128    joined.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, canceled_msg))?
129}
130
131fn runtime_error_to_io_for_blocking(err: RuntimeError) -> io::Error {
132    match err {
133        RuntimeError::InvalidConfig(msg) => io::Error::new(io::ErrorKind::InvalidInput, msg),
134        RuntimeError::ThreadSpawn(io) => io,
135        RuntimeError::InvalidShard(shard) => {
136            io::Error::new(io::ErrorKind::NotFound, format!("invalid shard {shard}"))
137        }
138        RuntimeError::Closed => io::Error::new(io::ErrorKind::BrokenPipe, "runtime closed"),
139        RuntimeError::Overloaded => io::Error::new(io::ErrorKind::WouldBlock, "runtime overloaded"),
140        RuntimeError::UnsupportedBackend(msg) => io::Error::new(io::ErrorKind::Unsupported, msg),
141        RuntimeError::IoUringInit(io) => io,
142    }
143}
144
145#[cfg(all(feature = "uring-native", target_os = "linux"))]
146pub mod io_compat {
147    use futures::io::{AsyncRead, AsyncWrite};
148    use spargio::net::TcpStream;
149    use std::future::Future;
150    use std::io;
151    use std::pin::Pin;
152    use std::task::{Context, Poll};
153
154    type ReadOp = Pin<Box<dyn Future<Output = io::Result<(usize, Vec<u8>)>> + Send + 'static>>;
155    type WriteOp = Pin<Box<dyn Future<Output = io::Result<usize>> + Send + 'static>>;
156
157    pub struct FuturesTcpStream {
158        inner: TcpStream,
159        read_op: Option<ReadOp>,
160        write_op: Option<WriteOp>,
161    }
162
163    impl std::fmt::Debug for FuturesTcpStream {
164        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165            f.debug_struct("FuturesTcpStream")
166                .field("fd", &self.inner.as_raw_fd())
167                .field("session_shard", &self.inner.session_shard())
168                .finish()
169        }
170    }
171
172    impl FuturesTcpStream {
173        pub fn new(inner: TcpStream) -> Self {
174            Self {
175                inner,
176                read_op: None,
177                write_op: None,
178            }
179        }
180
181        pub fn get_ref(&self) -> &TcpStream {
182            &self.inner
183        }
184
185        pub fn into_inner(self) -> TcpStream {
186            self.inner
187        }
188    }
189
190    impl Unpin for FuturesTcpStream {}
191
192    impl AsyncRead for FuturesTcpStream {
193        fn poll_read(
194            mut self: Pin<&mut Self>,
195            cx: &mut Context<'_>,
196            buf: &mut [u8],
197        ) -> Poll<io::Result<usize>> {
198            if buf.is_empty() {
199                return Poll::Ready(Ok(0));
200            }
201
202            if self.read_op.is_none() {
203                let inner = self.inner.clone();
204                let want = buf.len().max(1);
205                self.read_op = Some(Box::pin(
206                    async move { inner.recv_owned(vec![0u8; want]).await },
207                ));
208            }
209
210            match self
211                .read_op
212                .as_mut()
213                .expect("read op set")
214                .as_mut()
215                .poll(cx)
216            {
217                Poll::Pending => Poll::Pending,
218                Poll::Ready(result) => {
219                    self.read_op = None;
220                    let (got, payload) = result?;
221                    let got = got.min(payload.len()).min(buf.len());
222                    buf[..got].copy_from_slice(&payload[..got]);
223                    Poll::Ready(Ok(got))
224                }
225            }
226        }
227    }
228
229    impl AsyncWrite for FuturesTcpStream {
230        fn poll_write(
231            mut self: Pin<&mut Self>,
232            cx: &mut Context<'_>,
233            buf: &[u8],
234        ) -> Poll<io::Result<usize>> {
235            if buf.is_empty() {
236                return Poll::Ready(Ok(0));
237            }
238
239            if self.write_op.is_none() {
240                let inner = self.inner.clone();
241                let payload = buf.to_vec();
242                let payload_len = payload.len();
243                self.write_op = Some(Box::pin(async move {
244                    let (written, _) = inner.send_owned(payload).await?;
245                    Ok(written.min(payload_len))
246                }));
247            }
248
249            match self
250                .write_op
251                .as_mut()
252                .expect("write op set")
253                .as_mut()
254                .poll(cx)
255            {
256                Poll::Pending => Poll::Pending,
257                Poll::Ready(result) => {
258                    self.write_op = None;
259                    Poll::Ready(result)
260                }
261            }
262        }
263
264        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
265            Poll::Ready(Ok(()))
266        }
267
268        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
269            Poll::Ready(Ok(()))
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use futures::executor::block_on;
278    use std::time::Duration;
279
280    #[test]
281    fn protocol_blocking_helpers_execute_closure() {
282        let rt = spargio::Runtime::builder()
283            .shards(1)
284            .build()
285            .expect("runtime");
286        let handle = rt.handle();
287
288        let tls = block_on(async { tls_blocking(&handle, || Ok::<_, io::Error>(11usize)).await })
289            .expect("tls");
290        let ws = block_on(async { ws_blocking(&handle, || Ok::<_, io::Error>(22usize)).await })
291            .expect("ws");
292        let quic = block_on(async { quic_blocking(&handle, || Ok::<_, io::Error>(33usize)).await })
293            .expect("quic");
294
295        assert_eq!(tls + ws + quic, 66);
296    }
297
298    #[test]
299    fn blocking_timeout_returns_timed_out() {
300        let rt = spargio::Runtime::builder()
301            .shards(1)
302            .build()
303            .expect("runtime");
304        let err = block_on(async {
305            tls_blocking_with_options(
306                &rt.handle(),
307                BlockingOptions::default().with_timeout(Duration::from_millis(5)),
308                || {
309                    std::thread::sleep(Duration::from_millis(30));
310                    Ok::<(), io::Error>(())
311                },
312            )
313            .await
314            .expect_err("timeout")
315        });
316        assert_eq!(err.kind(), io::ErrorKind::TimedOut);
317    }
318}