1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9 if #[cfg(feature = "io-uring-cqe32")] {
10 use io_uring::cqueue::Entry32 as CEntry;
11 } else {
12 use io_uring::cqueue::Entry as CEntry;
13 }
14}
15cfg_if::cfg_if! {
16 if #[cfg(feature = "io-uring-sqe128")] {
17 use io_uring::squeue::Entry128 as SEntry;
18 } else {
19 use io_uring::squeue::Entry as SEntry;
20 }
21}
22use io_uring::{
23 IoUring,
24 cqueue::more,
25 opcode::{AsyncCancel, PollAdd},
26 types::{Fd, SubmitArgs, Timespec},
27};
28use slab::Slab;
29
30use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder, syscall};
31
32pub(crate) mod op;
33
34pub(crate) fn is_op_supported(code: u8) -> bool {
35 #[cfg(feature = "once_cell_try")]
36 use std::sync::OnceLock;
37
38 #[cfg(not(feature = "once_cell_try"))]
39 use once_cell::sync::OnceCell as OnceLock;
40
41 static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
42
43 PROBE
44 .get_or_try_init(|| {
45 let mut probe = io_uring::Probe::new();
46
47 io_uring::IoUring::new(2)?
48 .submitter()
49 .register_probe(&mut probe)?;
50
51 std::io::Result::Ok(probe)
52 })
53 .map(|probe| probe.is_supported(code))
54 .unwrap_or_default()
55}
56
57pub enum OpEntry {
59 Submission(io_uring::squeue::Entry),
61 #[cfg(feature = "io-uring-sqe128")]
62 Submission128(io_uring::squeue::Entry128),
64 Blocking,
66}
67
68impl From<io_uring::squeue::Entry> for OpEntry {
69 fn from(value: io_uring::squeue::Entry) -> Self {
70 Self::Submission(value)
71 }
72}
73
74#[cfg(feature = "io-uring-sqe128")]
75impl From<io_uring::squeue::Entry128> for OpEntry {
76 fn from(value: io_uring::squeue::Entry128) -> Self {
77 Self::Submission128(value)
78 }
79}
80
81pub trait OpCode {
83 fn create_entry(self: Pin<&mut Self>) -> OpEntry;
85
86 fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
89 unreachable!("this operation is asynchronous")
90 }
91
92 unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
100}
101
102pub(crate) struct Driver {
104 inner: IoUring<SEntry, CEntry>,
105 notifier: Notifier,
106 pool: AsyncifyPool,
107 pool_completed: Arc<SegQueue<Entry>>,
108 buffer_group_ids: Slab<()>,
109 need_push_notifier: bool,
110}
111
112impl Driver {
113 const CANCEL: u64 = u64::MAX;
114 const NOTIFY: u64 = u64::MAX - 1;
115
116 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
117 instrument!(compio_log::Level::TRACE, "new", ?builder);
118 trace!("new iour driver");
119 let notifier = Notifier::new()?;
120 let mut io_uring_builder = IoUring::builder();
121 if let Some(sqpoll_idle) = builder.sqpoll_idle {
122 io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
123 }
124 if builder.coop_taskrun {
125 io_uring_builder.setup_coop_taskrun();
126 }
127 if builder.taskrun_flag {
128 io_uring_builder.setup_taskrun_flag();
129 }
130
131 let inner = io_uring_builder.build(builder.capacity)?;
132
133 let submitter = inner.submitter();
134
135 if let Some(fd) = builder.eventfd {
136 submitter.register_eventfd(fd)?;
137 }
138
139 Ok(Self {
140 inner,
141 notifier,
142 pool: builder.create_or_get_thread_pool(),
143 pool_completed: Arc::new(SegQueue::new()),
144 buffer_group_ids: Slab::new(),
145 need_push_notifier: true,
146 })
147 }
148
149 pub fn driver_type(&self) -> DriverType {
150 DriverType::IoUring
151 }
152
153 fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
155 instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
156
157 let want_sqe = if self.inner.submission().taskrun() {
160 0
161 } else {
162 1
163 };
164
165 let res = {
166 if let Some(duration) = timeout {
168 let timespec = timespec(duration);
169 let args = SubmitArgs::new().timespec(×pec);
170 self.inner.submitter().submit_with_args(want_sqe, &args)
171 } else {
172 self.inner.submit_and_wait(want_sqe)
173 }
174 };
175 trace!("submit result: {res:?}");
176 match res {
177 Ok(_) => {
178 if self.inner.completion().is_empty() {
179 Err(io::ErrorKind::TimedOut.into())
180 } else {
181 Ok(())
182 }
183 }
184 Err(e) => match e.raw_os_error() {
185 Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
186 Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
187 _ => Err(e),
188 },
189 }
190 }
191
192 fn poll_blocking(&mut self) {
193 if !self.pool_completed.is_empty() {
195 while let Some(entry) = self.pool_completed.pop() {
196 unsafe {
197 entry.notify();
198 }
199 }
200 }
201 }
202
203 fn poll_entries(&mut self) -> bool {
204 self.poll_blocking();
205
206 let mut cqueue = self.inner.completion();
207 cqueue.sync();
208 let has_entry = !cqueue.is_empty();
209 for entry in cqueue {
210 match entry.user_data() {
211 Self::CANCEL => {}
212 Self::NOTIFY => {
213 let flags = entry.flags();
214 if !more(flags) {
215 self.need_push_notifier = true;
216 }
217 self.notifier.clear().expect("cannot clear notifier");
218 }
219 _ => unsafe {
220 create_entry(entry).notify();
221 },
222 }
223 }
224 has_entry
225 }
226
227 pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
228 Key::new(self.as_raw_fd(), op)
229 }
230
231 pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
232 Ok(())
233 }
234
235 pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
236 instrument!(compio_log::Level::TRACE, "cancel", ?op);
237 trace!("cancel RawOp");
238 unsafe {
239 #[allow(clippy::useless_conversion)]
240 if self
241 .inner
242 .submission()
243 .push(
244 &AsyncCancel::new(op.user_data() as _)
245 .build()
246 .user_data(Self::CANCEL)
247 .into(),
248 )
249 .is_err()
250 {
251 warn!("could not push AsyncCancel entry");
252 }
253 }
254 }
255
256 fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
257 loop {
258 let mut squeue = self.inner.submission();
259 match unsafe { squeue.push(&entry) } {
260 Ok(()) => {
261 squeue.sync();
262 break Ok(());
263 }
264 Err(_) => {
265 drop(squeue);
266 self.poll_entries();
267 match self.submit_auto(Some(Duration::ZERO)) {
268 Ok(()) => {}
269 Err(e)
270 if matches!(
271 e.kind(),
272 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
273 ) => {}
274 Err(e) => return Err(e),
275 }
276 }
277 }
278 }
279 }
280
281 pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
282 instrument!(compio_log::Level::TRACE, "push", ?op);
283 let user_data = op.user_data();
284 let op_pin = op.as_op_pin();
285 trace!("push RawOp");
286 match op_pin.create_entry() {
287 OpEntry::Submission(entry) => {
288 #[allow(clippy::useless_conversion)]
289 self.push_raw(entry.user_data(user_data as _).into())?;
290 Poll::Pending
291 }
292 #[cfg(feature = "io-uring-sqe128")]
293 OpEntry::Submission128(entry) => {
294 self.push_raw(entry.user_data(user_data as _))?;
295 Poll::Pending
296 }
297 OpEntry::Blocking => loop {
298 if self.push_blocking(user_data) {
299 break Poll::Pending;
300 } else {
301 self.poll_blocking();
302 }
303 },
304 }
305 }
306
307 fn push_blocking(&mut self, user_data: usize) -> bool {
308 let handle = self.handle();
309 let completed = self.pool_completed.clone();
310 self.pool
311 .dispatch(move || {
312 let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
313 let op_pin = op.as_op_pin();
314 let res = op_pin.call_blocking();
315 completed.push(Entry::new(user_data, res));
316 handle.notify().ok();
317 })
318 .is_ok()
319 }
320
321 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
322 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
323 trace!("start polling");
325
326 if self.need_push_notifier {
327 #[allow(clippy::useless_conversion)]
328 self.push_raw(
329 PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
330 .multi(true)
331 .build()
332 .user_data(Self::NOTIFY)
333 .into(),
334 )?;
335 self.need_push_notifier = false;
336 }
337
338 if !self.poll_entries() {
339 self.submit_auto(timeout)?;
340 self.poll_entries();
341 }
342
343 Ok(())
344 }
345
346 pub fn handle(&self) -> NotifyHandle {
347 self.notifier.handle()
348 }
349
350 pub fn create_buffer_pool(
351 &mut self,
352 buffer_len: u16,
353 buffer_size: usize,
354 ) -> io::Result<BufferPool> {
355 let buffer_group = self.buffer_group_ids.insert(());
356 if buffer_group > u16::MAX as usize {
357 self.buffer_group_ids.remove(buffer_group);
358
359 return Err(io::Error::new(
360 io::ErrorKind::OutOfMemory,
361 "too many buffer pool allocated",
362 ));
363 }
364
365 let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
366 &self.inner,
367 buffer_len,
368 buffer_group as _,
369 buffer_size,
370 0,
371 )?;
372
373 #[cfg(fusion)]
374 {
375 Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
376 buf_ring,
377 )))
378 }
379 #[cfg(not(fusion))]
380 {
381 Ok(BufferPool::new(buf_ring))
382 }
383 }
384
385 pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
389 #[cfg(fusion)]
390 let buffer_pool = buffer_pool.into_io_uring();
391
392 let buffer_group = buffer_pool.buffer_group();
393 buffer_pool.into_inner().release(&self.inner)?;
394 self.buffer_group_ids.remove(buffer_group as _);
395
396 Ok(())
397 }
398}
399
400impl AsRawFd for Driver {
401 fn as_raw_fd(&self) -> RawFd {
402 self.inner.as_raw_fd()
403 }
404}
405
406fn create_entry(cq_entry: CEntry) -> Entry {
407 let result = cq_entry.result();
408 let result = if result < 0 {
409 let result = if result == -libc::ECANCELED {
410 libc::ETIMEDOUT
411 } else {
412 -result
413 };
414 Err(io::Error::from_raw_os_error(result))
415 } else {
416 Ok(result as _)
417 };
418 let mut entry = Entry::new(cq_entry.user_data() as _, result);
419 entry.set_flags(cq_entry.flags());
420
421 entry
422}
423
424fn timespec(duration: std::time::Duration) -> Timespec {
425 Timespec::new()
426 .sec(duration.as_secs())
427 .nsec(duration.subsec_nanos())
428}
429
430#[derive(Debug)]
431struct Notifier {
432 fd: Arc<OwnedFd>,
433}
434
435impl Notifier {
436 fn new() -> io::Result<Self> {
438 let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
439 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
440 Ok(Self { fd: Arc::new(fd) })
441 }
442
443 pub fn clear(&self) -> io::Result<()> {
444 loop {
445 let mut buffer = [0u64];
446 let res = syscall!(libc::read(
447 self.fd.as_raw_fd(),
448 buffer.as_mut_ptr().cast(),
449 std::mem::size_of::<u64>()
450 ));
451 match res {
452 Ok(len) => {
453 debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
454 break Ok(());
455 }
456 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
458 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
460 Err(e) => break Err(e),
461 }
462 }
463 }
464
465 pub fn handle(&self) -> NotifyHandle {
466 NotifyHandle::new(self.fd.clone())
467 }
468}
469
470impl AsRawFd for Notifier {
471 fn as_raw_fd(&self) -> RawFd {
472 self.fd.as_raw_fd()
473 }
474}
475
476pub struct NotifyHandle {
478 fd: Arc<OwnedFd>,
479}
480
481impl NotifyHandle {
482 pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
483 Self { fd }
484 }
485
486 pub fn notify(&self) -> io::Result<()> {
488 let data = 1u64;
489 syscall!(libc::write(
490 self.fd.as_raw_fd(),
491 &data as *const _ as *const _,
492 std::mem::size_of::<u64>(),
493 ))?;
494 Ok(())
495 }
496}