1#[cfg(all(target_os = "linux", feature = "io_uring"))]
13mod inner {
14 use std::ffi::CString;
15 use std::io;
16 use std::net::SocketAddr;
17 use std::sync::atomic::{AtomicBool, Ordering};
18 use tokio::sync::{mpsc, oneshot};
19
20 static URING_ENABLED: AtomicBool = AtomicBool::new(false);
21 static URING_POOL: std::sync::OnceLock<mpsc::UnboundedSender<IoTask>> =
22 std::sync::OnceLock::new();
23
24 enum IoTask {
25 WriteFile {
26 path: String,
27 data: Vec<u8>,
28 tx: oneshot::Sender<io::Result<()>>,
29 },
30 ReadFile {
31 path: String,
32 tx: oneshot::Sender<io::Result<Vec<u8>>>,
33 },
34 TcpConnect {
35 addr: SocketAddr,
36 tx: oneshot::Sender<io::Result<std::net::TcpStream>>,
37 },
38 }
39
40 fn probe_io_uring() -> Option<io_uring::IoUring> {
41 match io_uring::IoUring::builder().build(64) {
42 Ok(ring) => {
43 tracing::info!("chromey: io_uring probe succeeded");
44 Some(ring)
45 }
46 Err(e) => {
47 tracing::info!("chromey: io_uring unavailable ({}), using tokio::fs", e);
48 None
49 }
50 }
51 }
52
53 fn submit_and_reap(ring: &mut io_uring::IoUring) -> io::Result<i32> {
54 ring.submit_and_wait(1)?;
55 let cqe = ring
56 .completion()
57 .next()
58 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "io_uring: no CQE after wait"))?;
59 Ok(cqe.result())
60 }
61
62 fn uring_close(ring: &mut io_uring::IoUring, fd: i32) -> io::Result<()> {
63 let close_e = io_uring::opcode::Close::new(io_uring::types::Fd(fd))
64 .build()
65 .user_data(0xC105E);
66 unsafe {
67 ring.submission()
68 .push(&close_e)
69 .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on close"))?;
70 }
71 let res = submit_and_reap(ring)?;
72 if res < 0 {
73 return Err(io::Error::from_raw_os_error(-res));
74 }
75 Ok(())
76 }
77
78 fn uring_write_file(ring: &mut io_uring::IoUring, path: &str, data: &[u8]) -> io::Result<()> {
79 let c_path =
80 CString::new(path).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
81
82 let open_e =
83 io_uring::opcode::OpenAt::new(io_uring::types::Fd(libc::AT_FDCWD), c_path.as_ptr())
84 .flags(libc::O_WRONLY | libc::O_CREAT | libc::O_TRUNC)
85 .mode(0o644)
86 .build()
87 .user_data(0x0BE4);
88 unsafe {
89 ring.submission()
90 .push(&open_e)
91 .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on open"))?;
92 }
93 let fd = submit_and_reap(ring)?;
94 if fd < 0 {
95 return Err(io::Error::from_raw_os_error(-fd));
96 }
97
98 let write_result = uring_write_all(ring, fd, data);
99 let close_result = uring_close(ring, fd);
100 write_result?;
101 close_result
102 }
103
104 fn uring_write_all(ring: &mut io_uring::IoUring, fd: i32, data: &[u8]) -> io::Result<()> {
105 if data.is_empty() {
106 return Ok(());
107 }
108 let mut offset: u64 = 0;
109 while (offset as usize) < data.len() {
110 let remaining = &data[offset as usize..];
111 let chunk_len = remaining.len().min(u32::MAX as usize) as u32;
112 let write_e = io_uring::opcode::Write::new(
113 io_uring::types::Fd(fd),
114 remaining.as_ptr(),
115 chunk_len,
116 )
117 .offset(offset)
118 .build()
119 .user_data(0x1417E);
120 unsafe {
121 ring.submission().push(&write_e).map_err(|_| {
122 io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on write")
123 })?;
124 }
125 let written = submit_and_reap(ring)?;
126 if written < 0 {
127 return Err(io::Error::from_raw_os_error(-written));
128 }
129 if written == 0 {
130 return Err(io::Error::new(
131 io::ErrorKind::WriteZero,
132 "io_uring: write returned 0",
133 ));
134 }
135 offset += written as u64;
136 }
137 Ok(())
138 }
139
140 fn uring_read_file(ring: &mut io_uring::IoUring, path: &str) -> io::Result<Vec<u8>> {
141 let meta = std::fs::metadata(path)?;
142 let len = meta.len() as usize;
143 if len == 0 {
144 return Ok(Vec::new());
145 }
146
147 let c_path =
148 CString::new(path).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
149 let open_e =
150 io_uring::opcode::OpenAt::new(io_uring::types::Fd(libc::AT_FDCWD), c_path.as_ptr())
151 .flags(libc::O_RDONLY)
152 .build()
153 .user_data(0x0BE4);
154 unsafe {
155 ring.submission()
156 .push(&open_e)
157 .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on open"))?;
158 }
159 let fd = submit_and_reap(ring)?;
160 if fd < 0 {
161 return Err(io::Error::from_raw_os_error(-fd));
162 }
163
164 let mut buf = vec![0u8; len];
165 let read_result = uring_read_exact(ring, fd, &mut buf);
166 let close_result = uring_close(ring, fd);
167 read_result?;
168 close_result?;
169 Ok(buf)
170 }
171
172 fn uring_read_exact(ring: &mut io_uring::IoUring, fd: i32, buf: &mut [u8]) -> io::Result<()> {
173 let mut offset: u64 = 0;
174 while (offset as usize) < buf.len() {
175 let remaining = &mut buf[offset as usize..];
176 let chunk_len = remaining.len().min(u32::MAX as usize) as u32;
177 let read_e = io_uring::opcode::Read::new(
178 io_uring::types::Fd(fd),
179 remaining.as_mut_ptr(),
180 chunk_len,
181 )
182 .offset(offset)
183 .build()
184 .user_data(0x4EAD);
185 unsafe {
186 ring.submission().push(&read_e).map_err(|_| {
187 io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on read")
188 })?;
189 }
190 let n = submit_and_reap(ring)?;
191 if n < 0 {
192 return Err(io::Error::from_raw_os_error(-n));
193 }
194 if n == 0 {
195 return Err(io::Error::new(
196 io::ErrorKind::UnexpectedEof,
197 "io_uring: read returned 0",
198 ));
199 }
200 offset += n as u64;
201 }
202 Ok(())
203 }
204
205 fn uring_tcp_connect(
206 ring: &mut io_uring::IoUring,
207 addr: SocketAddr,
208 ) -> io::Result<std::net::TcpStream> {
209 use std::os::unix::io::FromRawFd;
210
211 let domain = match addr {
212 SocketAddr::V4(_) => libc::AF_INET,
213 SocketAddr::V6(_) => libc::AF_INET6,
214 };
215
216 let socket_e = io_uring::opcode::Socket::new(
217 domain,
218 libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC,
219 0,
220 )
221 .build()
222 .user_data(0x50CE7);
223 unsafe {
224 ring.submission()
225 .push(&socket_e)
226 .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on socket"))?;
227 }
228 let fd = submit_and_reap(ring)?;
229 if fd < 0 {
230 return Err(io::Error::from_raw_os_error(-fd));
231 }
232
233 let (sa_ptr, sa_len) = match addr {
234 SocketAddr::V4(v4) => {
235 let sa = libc::sockaddr_in {
236 sin_family: libc::AF_INET as libc::sa_family_t,
237 sin_port: v4.port().to_be(),
238 sin_addr: libc::in_addr {
239 s_addr: u32::from_ne_bytes(v4.ip().octets()),
240 },
241 sin_zero: [0; 8],
242 };
243 let ptr = &sa as *const libc::sockaddr_in as *const libc::sockaddr;
244 (ptr, std::mem::size_of::<libc::sockaddr_in>() as u32)
245 }
246 SocketAddr::V6(v6) => {
247 let sa = libc::sockaddr_in6 {
248 sin6_family: libc::AF_INET6 as libc::sa_family_t,
249 sin6_port: v6.port().to_be(),
250 sin6_flowinfo: v6.flowinfo(),
251 sin6_addr: libc::in6_addr {
252 s6_addr: v6.ip().octets(),
253 },
254 sin6_scope_id: v6.scope_id(),
255 };
256 let ptr = &sa as *const libc::sockaddr_in6 as *const libc::sockaddr;
257 (ptr, std::mem::size_of::<libc::sockaddr_in6>() as u32)
258 }
259 };
260
261 let connect_e = io_uring::opcode::Connect::new(io_uring::types::Fd(fd), sa_ptr, sa_len)
262 .build()
263 .user_data(0xC044);
264 unsafe {
265 ring.submission().push(&connect_e).map_err(|_| {
266 libc::close(fd);
267 io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on connect")
268 })?;
269 }
270
271 let res = submit_and_reap(ring)?;
272 if res < 0 && res != -libc::EINPROGRESS {
273 let _ = uring_close(ring, fd);
274 return Err(io::Error::from_raw_os_error(-res));
275 }
276
277 let stream = unsafe { std::net::TcpStream::from_raw_fd(fd) };
278 Ok(stream)
279 }
280
281 fn worker_loop(mut rx: mpsc::UnboundedReceiver<IoTask>, mut ring: io_uring::IoUring) {
282 while let Some(task) = rx.blocking_recv() {
283 match task {
284 IoTask::WriteFile { path, data, tx } => {
285 let _ = tx.send(uring_write_file(&mut ring, &path, &data));
286 }
287 IoTask::ReadFile { path, tx } => {
288 let _ = tx.send(uring_read_file(&mut ring, &path));
289 }
290 IoTask::TcpConnect { addr, tx } => {
291 let _ = tx.send(uring_tcp_connect(&mut ring, addr));
292 }
293 }
294 }
295 drop(ring);
296 }
297
298 pub fn init() -> bool {
301 if URING_ENABLED.load(Ordering::Acquire) {
302 return true;
303 }
304 let ring = match probe_io_uring() {
305 Some(r) => r,
306 None => return false,
307 };
308 let (tx, rx) = mpsc::unbounded_channel();
309 let builder = std::thread::Builder::new().name("chromey-uring-worker".into());
310 match builder.spawn(move || worker_loop(rx, ring)) {
311 Ok(_) => {
312 if URING_POOL.set(tx).is_ok() {
313 URING_ENABLED.store(true, Ordering::Release);
314 }
315 }
316 Err(e) => {
317 tracing::warn!("Failed to spawn chromey io_uring worker: {}", e);
318 return false;
319 }
320 }
321 URING_ENABLED.load(Ordering::Acquire)
322 }
323
324 async fn await_worker<T>(
328 sender: &mpsc::UnboundedSender<IoTask>,
329 task: IoTask,
330 rx: oneshot::Receiver<io::Result<T>>,
331 ) -> io::Result<T> {
332 if sender.send(task).is_err() {
333 return Err(io::Error::new(
334 io::ErrorKind::BrokenPipe,
335 "chromey io_uring worker channel closed",
336 ));
337 }
338 rx.await.unwrap_or_else(|_| {
339 Err(io::Error::new(
340 io::ErrorKind::BrokenPipe,
341 "chromey io_uring worker dropped the response",
342 ))
343 })
344 }
345
346 pub async fn write_file(path: String, data: Vec<u8>) -> io::Result<()> {
347 if URING_ENABLED.load(Ordering::Acquire) {
348 if let Some(sender) = URING_POOL.get() {
349 let (tx, rx) = oneshot::channel();
350 return await_worker(sender, IoTask::WriteFile { path, data, tx }, rx).await;
351 }
352 }
353 tokio::fs::write(path, data).await
354 }
355
356 pub async fn read_file(path: String) -> io::Result<Vec<u8>> {
357 if URING_ENABLED.load(Ordering::Acquire) {
358 if let Some(sender) = URING_POOL.get() {
359 let (tx, rx) = oneshot::channel();
360 return await_worker(sender, IoTask::ReadFile { path, tx }, rx).await;
361 }
362 }
363 tokio::fs::read(path).await
364 }
365
366 pub async fn tcp_connect(addr: SocketAddr) -> io::Result<std::net::TcpStream> {
367 if URING_ENABLED.load(Ordering::Acquire) {
368 if let Some(sender) = URING_POOL.get() {
369 let (tx, rx) = oneshot::channel();
370 return await_worker(sender, IoTask::TcpConnect { addr, tx }, rx).await;
371 }
372 }
373 tokio::task::spawn_blocking(move || std::net::TcpStream::connect(addr))
374 .await
375 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
376 }
377
378 pub fn is_enabled() -> bool {
379 URING_ENABLED.load(Ordering::Acquire)
380 }
381}
382
383#[cfg(not(all(target_os = "linux", feature = "io_uring")))]
386mod inner {
387 use std::io;
388 use std::net::SocketAddr;
389
390 pub fn init() -> bool {
391 false
392 }
393
394 pub async fn write_file(path: String, data: Vec<u8>) -> io::Result<()> {
395 tokio::fs::write(&path, &data).await
396 }
397
398 pub async fn read_file(path: String) -> io::Result<Vec<u8>> {
399 tokio::fs::read(&path).await
400 }
401
402 pub async fn tcp_connect(addr: SocketAddr) -> io::Result<std::net::TcpStream> {
403 tokio::task::spawn_blocking(move || std::net::TcpStream::connect(addr))
404 .await
405 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
406 }
407
408 pub fn is_enabled() -> bool {
409 false
410 }
411}
412
413pub use inner::init;
416pub use inner::is_enabled;
417pub use inner::read_file;
418pub use inner::tcp_connect;
419pub use inner::write_file;
420
421#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[tokio::test]
428 async fn test_write_read_roundtrip() {
429 let path = std::env::temp_dir()
430 .join("chromey_uring_test_roundtrip")
431 .display()
432 .to_string();
433 let payload = b"chromey uring test".to_vec();
434
435 write_file(path.clone(), payload.clone()).await.unwrap();
436 let read_back = read_file(path.clone()).await.unwrap();
437 assert_eq!(read_back, payload);
438
439 let _ = tokio::fs::remove_file(&path).await;
440 }
441
442 #[tokio::test]
443 async fn test_tcp_connect_loopback() {
444 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
445 let addr = listener.local_addr().unwrap();
446
447 let accept = tokio::spawn(async move { listener.accept().await });
448 let connect = tokio::spawn(async move { tcp_connect(addr).await });
449
450 let (a, c) = tokio::join!(accept, connect);
451 assert!(a.unwrap().is_ok());
452 assert!(c.unwrap().is_ok());
453 }
454
455 #[tokio::test]
456 async fn test_tcp_connect_refused() {
457 let addr: std::net::SocketAddr = "127.0.0.1:1".parse().unwrap();
458 assert!(tcp_connect(addr).await.is_err());
459 }
460
461 #[tokio::test]
462 async fn test_init_idempotent() {
463 let r1 = init();
464 let r2 = init();
465 assert_eq!(r1, r2);
466 }
467}