Skip to main content

chromiumoxide/
uring_fs.rs

1//! Async file I/O with optional io_uring acceleration for chromey.
2//!
3//! On Linux with the `io_uring` feature, file operations are dispatched to a
4//! dedicated worker thread that drives a raw io_uring ring. On all other
5//! platforms, operations fall back to `tokio::fs`.
6//!
7//! The worker also supports TCP connect via io_uring Socket + Connect opcodes,
8//! used for the initial CDP WebSocket connection.
9
10// ── io_uring implementation ─────────────────────────────────────────────────
11
12#[cfg(all(target_os = "linux", feature = "io_uring"))]
13mod inner {
14    use std::ffi::CString;
15    use std::io;
16    use std::net::SocketAddr;
17    use std::sync::atomic::{AtomicBool, Ordering};
18    use tokio::sync::{mpsc, oneshot};
19
20    static URING_ENABLED: AtomicBool = AtomicBool::new(false);
21    static URING_POOL: std::sync::OnceLock<mpsc::UnboundedSender<IoTask>> =
22        std::sync::OnceLock::new();
23
24    enum IoTask {
25        WriteFile {
26            path: String,
27            data: Vec<u8>,
28            tx: oneshot::Sender<io::Result<()>>,
29        },
30        ReadFile {
31            path: String,
32            tx: oneshot::Sender<io::Result<Vec<u8>>>,
33        },
34        TcpConnect {
35            addr: SocketAddr,
36            tx: oneshot::Sender<io::Result<std::net::TcpStream>>,
37        },
38    }
39
40    fn probe_io_uring() -> Option<io_uring::IoUring> {
41        match io_uring::IoUring::builder().build(64) {
42            Ok(ring) => {
43                tracing::info!("chromey: io_uring probe succeeded");
44                Some(ring)
45            }
46            Err(e) => {
47                tracing::info!("chromey: io_uring unavailable ({}), using tokio::fs", e);
48                None
49            }
50        }
51    }
52
53    fn submit_and_reap(ring: &mut io_uring::IoUring) -> io::Result<i32> {
54        ring.submit_and_wait(1)?;
55        let cqe = ring
56            .completion()
57            .next()
58            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "io_uring: no CQE after wait"))?;
59        Ok(cqe.result())
60    }
61
62    fn uring_close(ring: &mut io_uring::IoUring, fd: i32) -> io::Result<()> {
63        let close_e = io_uring::opcode::Close::new(io_uring::types::Fd(fd))
64            .build()
65            .user_data(0xC105E);
66        unsafe {
67            ring.submission()
68                .push(&close_e)
69                .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on close"))?;
70        }
71        let res = submit_and_reap(ring)?;
72        if res < 0 {
73            return Err(io::Error::from_raw_os_error(-res));
74        }
75        Ok(())
76    }
77
78    fn uring_write_file(ring: &mut io_uring::IoUring, path: &str, data: &[u8]) -> io::Result<()> {
79        let c_path =
80            CString::new(path).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
81
82        let open_e =
83            io_uring::opcode::OpenAt::new(io_uring::types::Fd(libc::AT_FDCWD), c_path.as_ptr())
84                .flags(libc::O_WRONLY | libc::O_CREAT | libc::O_TRUNC)
85                .mode(0o644)
86                .build()
87                .user_data(0x0BE4);
88        unsafe {
89            ring.submission()
90                .push(&open_e)
91                .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on open"))?;
92        }
93        let fd = submit_and_reap(ring)?;
94        if fd < 0 {
95            return Err(io::Error::from_raw_os_error(-fd));
96        }
97
98        let write_result = uring_write_all(ring, fd, data);
99        let close_result = uring_close(ring, fd);
100        write_result?;
101        close_result
102    }
103
104    fn uring_write_all(ring: &mut io_uring::IoUring, fd: i32, data: &[u8]) -> io::Result<()> {
105        if data.is_empty() {
106            return Ok(());
107        }
108        let mut offset: u64 = 0;
109        while (offset as usize) < data.len() {
110            let remaining = &data[offset as usize..];
111            let chunk_len = remaining.len().min(u32::MAX as usize) as u32;
112            let write_e = io_uring::opcode::Write::new(
113                io_uring::types::Fd(fd),
114                remaining.as_ptr(),
115                chunk_len,
116            )
117            .offset(offset)
118            .build()
119            .user_data(0x1417E);
120            unsafe {
121                ring.submission().push(&write_e).map_err(|_| {
122                    io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on write")
123                })?;
124            }
125            let written = submit_and_reap(ring)?;
126            if written < 0 {
127                return Err(io::Error::from_raw_os_error(-written));
128            }
129            if written == 0 {
130                return Err(io::Error::new(
131                    io::ErrorKind::WriteZero,
132                    "io_uring: write returned 0",
133                ));
134            }
135            offset += written as u64;
136        }
137        Ok(())
138    }
139
140    fn uring_read_file(ring: &mut io_uring::IoUring, path: &str) -> io::Result<Vec<u8>> {
141        let meta = std::fs::metadata(path)?;
142        let len = meta.len() as usize;
143        if len == 0 {
144            return Ok(Vec::new());
145        }
146
147        let c_path =
148            CString::new(path).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
149        let open_e =
150            io_uring::opcode::OpenAt::new(io_uring::types::Fd(libc::AT_FDCWD), c_path.as_ptr())
151                .flags(libc::O_RDONLY)
152                .build()
153                .user_data(0x0BE4);
154        unsafe {
155            ring.submission()
156                .push(&open_e)
157                .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on open"))?;
158        }
159        let fd = submit_and_reap(ring)?;
160        if fd < 0 {
161            return Err(io::Error::from_raw_os_error(-fd));
162        }
163
164        let mut buf = vec![0u8; len];
165        let read_result = uring_read_exact(ring, fd, &mut buf);
166        let close_result = uring_close(ring, fd);
167        read_result?;
168        close_result?;
169        Ok(buf)
170    }
171
172    fn uring_read_exact(ring: &mut io_uring::IoUring, fd: i32, buf: &mut [u8]) -> io::Result<()> {
173        let mut offset: u64 = 0;
174        while (offset as usize) < buf.len() {
175            let remaining = &mut buf[offset as usize..];
176            let chunk_len = remaining.len().min(u32::MAX as usize) as u32;
177            let read_e = io_uring::opcode::Read::new(
178                io_uring::types::Fd(fd),
179                remaining.as_mut_ptr(),
180                chunk_len,
181            )
182            .offset(offset)
183            .build()
184            .user_data(0x4EAD);
185            unsafe {
186                ring.submission().push(&read_e).map_err(|_| {
187                    io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on read")
188                })?;
189            }
190            let n = submit_and_reap(ring)?;
191            if n < 0 {
192                return Err(io::Error::from_raw_os_error(-n));
193            }
194            if n == 0 {
195                return Err(io::Error::new(
196                    io::ErrorKind::UnexpectedEof,
197                    "io_uring: read returned 0",
198                ));
199            }
200            offset += n as u64;
201        }
202        Ok(())
203    }
204
205    fn uring_tcp_connect(
206        ring: &mut io_uring::IoUring,
207        addr: SocketAddr,
208    ) -> io::Result<std::net::TcpStream> {
209        use std::os::unix::io::FromRawFd;
210
211        let domain = match addr {
212            SocketAddr::V4(_) => libc::AF_INET,
213            SocketAddr::V6(_) => libc::AF_INET6,
214        };
215
216        let socket_e = io_uring::opcode::Socket::new(
217            domain,
218            libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC,
219            0,
220        )
221        .build()
222        .user_data(0x50CE7);
223        unsafe {
224            ring.submission()
225                .push(&socket_e)
226                .map_err(|_| io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on socket"))?;
227        }
228        let fd = submit_and_reap(ring)?;
229        if fd < 0 {
230            return Err(io::Error::from_raw_os_error(-fd));
231        }
232
233        let (sa_ptr, sa_len) = match addr {
234            SocketAddr::V4(v4) => {
235                let sa = libc::sockaddr_in {
236                    sin_family: libc::AF_INET as libc::sa_family_t,
237                    sin_port: v4.port().to_be(),
238                    sin_addr: libc::in_addr {
239                        s_addr: u32::from_ne_bytes(v4.ip().octets()),
240                    },
241                    sin_zero: [0; 8],
242                };
243                let ptr = &sa as *const libc::sockaddr_in as *const libc::sockaddr;
244                (ptr, std::mem::size_of::<libc::sockaddr_in>() as u32)
245            }
246            SocketAddr::V6(v6) => {
247                let sa = libc::sockaddr_in6 {
248                    sin6_family: libc::AF_INET6 as libc::sa_family_t,
249                    sin6_port: v6.port().to_be(),
250                    sin6_flowinfo: v6.flowinfo(),
251                    sin6_addr: libc::in6_addr {
252                        s6_addr: v6.ip().octets(),
253                    },
254                    sin6_scope_id: v6.scope_id(),
255                };
256                let ptr = &sa as *const libc::sockaddr_in6 as *const libc::sockaddr;
257                (ptr, std::mem::size_of::<libc::sockaddr_in6>() as u32)
258            }
259        };
260
261        let connect_e = io_uring::opcode::Connect::new(io_uring::types::Fd(fd), sa_ptr, sa_len)
262            .build()
263            .user_data(0xC044);
264        unsafe {
265            ring.submission().push(&connect_e).map_err(|_| {
266                libc::close(fd);
267                io::Error::new(io::ErrorKind::Other, "io_uring: SQ full on connect")
268            })?;
269        }
270
271        let res = submit_and_reap(ring)?;
272        if res < 0 && res != -libc::EINPROGRESS {
273            let _ = uring_close(ring, fd);
274            return Err(io::Error::from_raw_os_error(-res));
275        }
276
277        let stream = unsafe { std::net::TcpStream::from_raw_fd(fd) };
278        Ok(stream)
279    }
280
281    fn worker_loop(mut rx: mpsc::UnboundedReceiver<IoTask>, mut ring: io_uring::IoUring) {
282        while let Some(task) = rx.blocking_recv() {
283            match task {
284                IoTask::WriteFile { path, data, tx } => {
285                    let _ = tx.send(uring_write_file(&mut ring, &path, &data));
286                }
287                IoTask::ReadFile { path, tx } => {
288                    let _ = tx.send(uring_read_file(&mut ring, &path));
289                }
290                IoTask::TcpConnect { addr, tx } => {
291                    let _ = tx.send(uring_tcp_connect(&mut ring, addr));
292                }
293            }
294        }
295        drop(ring);
296    }
297
298    // ── Public API ──────────────────────────────────────────────────────────
299
300    pub fn init() -> bool {
301        if URING_ENABLED.load(Ordering::Acquire) {
302            return true;
303        }
304        let ring = match probe_io_uring() {
305            Some(r) => r,
306            None => return false,
307        };
308        let (tx, rx) = mpsc::unbounded_channel();
309        let builder = std::thread::Builder::new().name("chromey-uring-worker".into());
310        match builder.spawn(move || worker_loop(rx, ring)) {
311            Ok(_) => {
312                if URING_POOL.set(tx).is_ok() {
313                    URING_ENABLED.store(true, Ordering::Release);
314                }
315            }
316            Err(e) => {
317                tracing::warn!("Failed to spawn chromey io_uring worker: {}", e);
318                return false;
319            }
320        }
321        URING_ENABLED.load(Ordering::Acquire)
322    }
323
324    /// Send a pre-built `IoTask` to the worker and await its result.
325    /// Caller must only call this after checking `URING_ENABLED` and
326    /// `URING_POOL.get()`.
327    async fn await_worker<T>(
328        sender: &mpsc::UnboundedSender<IoTask>,
329        task: IoTask,
330        rx: oneshot::Receiver<io::Result<T>>,
331    ) -> io::Result<T> {
332        if sender.send(task).is_err() {
333            return Err(io::Error::new(
334                io::ErrorKind::BrokenPipe,
335                "chromey io_uring worker channel closed",
336            ));
337        }
338        rx.await.unwrap_or_else(|_| {
339            Err(io::Error::new(
340                io::ErrorKind::BrokenPipe,
341                "chromey io_uring worker dropped the response",
342            ))
343        })
344    }
345
346    pub async fn write_file(path: String, data: Vec<u8>) -> io::Result<()> {
347        if URING_ENABLED.load(Ordering::Acquire) {
348            if let Some(sender) = URING_POOL.get() {
349                let (tx, rx) = oneshot::channel();
350                return await_worker(sender, IoTask::WriteFile { path, data, tx }, rx).await;
351            }
352        }
353        tokio::fs::write(path, data).await
354    }
355
356    pub async fn read_file(path: String) -> io::Result<Vec<u8>> {
357        if URING_ENABLED.load(Ordering::Acquire) {
358            if let Some(sender) = URING_POOL.get() {
359                let (tx, rx) = oneshot::channel();
360                return await_worker(sender, IoTask::ReadFile { path, tx }, rx).await;
361            }
362        }
363        tokio::fs::read(path).await
364    }
365
366    pub async fn tcp_connect(addr: SocketAddr) -> io::Result<std::net::TcpStream> {
367        if URING_ENABLED.load(Ordering::Acquire) {
368            if let Some(sender) = URING_POOL.get() {
369                let (tx, rx) = oneshot::channel();
370                return await_worker(sender, IoTask::TcpConnect { addr, tx }, rx).await;
371            }
372        }
373        tokio::task::spawn_blocking(move || std::net::TcpStream::connect(addr))
374            .await
375            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
376    }
377
378    pub fn is_enabled() -> bool {
379        URING_ENABLED.load(Ordering::Acquire)
380    }
381}
382
383// ── Fallback (non-Linux / no io_uring feature) ──────────────────────────────
384
385#[cfg(not(all(target_os = "linux", feature = "io_uring")))]
386mod inner {
387    use std::io;
388    use std::net::SocketAddr;
389
390    pub fn init() -> bool {
391        false
392    }
393
394    pub async fn write_file(path: String, data: Vec<u8>) -> io::Result<()> {
395        tokio::fs::write(&path, &data).await
396    }
397
398    pub async fn read_file(path: String) -> io::Result<Vec<u8>> {
399        tokio::fs::read(&path).await
400    }
401
402    pub async fn tcp_connect(addr: SocketAddr) -> io::Result<std::net::TcpStream> {
403        tokio::task::spawn_blocking(move || std::net::TcpStream::connect(addr))
404            .await
405            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
406    }
407
408    pub fn is_enabled() -> bool {
409        false
410    }
411}
412
413// ── Re-exports ──────────────────────────────────────────────────────────────
414
415pub use inner::init;
416pub use inner::is_enabled;
417pub use inner::read_file;
418pub use inner::tcp_connect;
419pub use inner::write_file;
420
421// ── Tests ───────────────────────────────────────────────────────────────────
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[tokio::test]
428    async fn test_write_read_roundtrip() {
429        let path = std::env::temp_dir()
430            .join("chromey_uring_test_roundtrip")
431            .display()
432            .to_string();
433        let payload = b"chromey uring test".to_vec();
434
435        write_file(path.clone(), payload.clone()).await.unwrap();
436        let read_back = read_file(path.clone()).await.unwrap();
437        assert_eq!(read_back, payload);
438
439        let _ = tokio::fs::remove_file(&path).await;
440    }
441
442    #[tokio::test]
443    async fn test_tcp_connect_loopback() {
444        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
445        let addr = listener.local_addr().unwrap();
446
447        let accept = tokio::spawn(async move { listener.accept().await });
448        let connect = tokio::spawn(async move { tcp_connect(addr).await });
449
450        let (a, c) = tokio::join!(accept, connect);
451        assert!(a.unwrap().is_ok());
452        assert!(c.unwrap().is_ok());
453    }
454
455    #[tokio::test]
456    async fn test_tcp_connect_refused() {
457        let addr: std::net::SocketAddr = "127.0.0.1:1".parse().unwrap();
458        assert!(tcp_connect(addr).await.is_err());
459    }
460
461    #[tokio::test]
462    async fn test_init_idempotent() {
463        let r1 = init();
464        let r2 = init();
465        assert_eq!(r1, r2);
466    }
467}