1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{
5 io,
6 os::fd::FromRawFd,
7 pin::Pin,
8 sync::Arc,
9 task::{Poll, Wake, Waker},
10 time::Duration,
11};
12
13use compio_log::{instrument, trace, warn};
14use crossbeam_queue::SegQueue;
15cfg_if::cfg_if! {
16 if #[cfg(feature = "io-uring-cqe32")] {
17 use io_uring::cqueue::Entry32 as CEntry;
18 } else {
19 use io_uring::cqueue::Entry as CEntry;
20 }
21}
22cfg_if::cfg_if! {
23 if #[cfg(feature = "io-uring-sqe128")] {
24 use io_uring::squeue::Entry128 as SEntry;
25 } else {
26 use io_uring::squeue::Entry as SEntry;
27 }
28}
29use io_uring::{
30 IoUring,
31 cqueue::more,
32 opcode::{AsyncCancel, PollAdd},
33 types::{Fd, SubmitArgs, Timespec},
34};
35use slab::Slab;
36
37use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder, syscall};
38
39pub(crate) mod op;
40
41pub(crate) fn is_op_supported(code: u8) -> bool {
42 #[cfg(feature = "once_cell_try")]
43 use std::sync::OnceLock;
44
45 #[cfg(not(feature = "once_cell_try"))]
46 use once_cell::sync::OnceCell as OnceLock;
47
48 static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
49
50 PROBE
51 .get_or_try_init(|| {
52 let mut probe = io_uring::Probe::new();
53
54 io_uring::IoUring::new(2)?
55 .submitter()
56 .register_probe(&mut probe)?;
57
58 std::io::Result::Ok(probe)
59 })
60 .map(|probe| probe.is_supported(code))
61 .unwrap_or_default()
62}
63
64pub enum OpEntry {
66 Submission(io_uring::squeue::Entry),
68 #[cfg(feature = "io-uring-sqe128")]
69 Submission128(io_uring::squeue::Entry128),
71 Blocking,
73}
74
75impl From<io_uring::squeue::Entry> for OpEntry {
76 fn from(value: io_uring::squeue::Entry) -> Self {
77 Self::Submission(value)
78 }
79}
80
81#[cfg(feature = "io-uring-sqe128")]
82impl From<io_uring::squeue::Entry128> for OpEntry {
83 fn from(value: io_uring::squeue::Entry128) -> Self {
84 Self::Submission128(value)
85 }
86}
87
88pub trait OpCode {
90 fn create_entry(self: Pin<&mut Self>) -> OpEntry;
92
93 fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
96 unreachable!("this operation is asynchronous")
97 }
98
99 unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
107}
108
109pub(crate) struct Driver {
111 inner: IoUring<SEntry, CEntry>,
112 notifier: Notifier,
113 pool: AsyncifyPool,
114 pool_completed: Arc<SegQueue<Entry>>,
115 buffer_group_ids: Slab<()>,
116 need_push_notifier: bool,
117}
118
119impl Driver {
120 const CANCEL: u64 = u64::MAX;
121 const NOTIFY: u64 = u64::MAX - 1;
122
123 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
124 instrument!(compio_log::Level::TRACE, "new", ?builder);
125 trace!("new iour driver");
126 let notifier = Notifier::new()?;
127 let mut io_uring_builder = IoUring::builder();
128 if let Some(sqpoll_idle) = builder.sqpoll_idle {
129 io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
130 }
131 if builder.coop_taskrun {
132 io_uring_builder.setup_coop_taskrun();
133 }
134 if builder.taskrun_flag {
135 io_uring_builder.setup_taskrun_flag();
136 }
137
138 let inner = io_uring_builder.build(builder.capacity)?;
139
140 let submitter = inner.submitter();
141
142 if let Some(fd) = builder.eventfd {
143 submitter.register_eventfd(fd)?;
144 }
145
146 Ok(Self {
147 inner,
148 notifier,
149 pool: builder.create_or_get_thread_pool(),
150 pool_completed: Arc::new(SegQueue::new()),
151 buffer_group_ids: Slab::new(),
152 need_push_notifier: true,
153 })
154 }
155
156 pub fn driver_type(&self) -> DriverType {
157 DriverType::IoUring
158 }
159
160 fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
162 instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
163
164 let want_sqe = if self.inner.submission().taskrun() {
167 0
168 } else {
169 1
170 };
171
172 let res = {
173 if let Some(duration) = timeout {
175 let timespec = timespec(duration);
176 let args = SubmitArgs::new().timespec(×pec);
177 self.inner.submitter().submit_with_args(want_sqe, &args)
178 } else {
179 self.inner.submit_and_wait(want_sqe)
180 }
181 };
182 trace!("submit result: {res:?}");
183 match res {
184 Ok(_) => {
185 if self.inner.completion().is_empty() {
186 Err(io::ErrorKind::TimedOut.into())
187 } else {
188 Ok(())
189 }
190 }
191 Err(e) => match e.raw_os_error() {
192 Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
193 Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
194 _ => Err(e),
195 },
196 }
197 }
198
199 fn poll_blocking(&mut self) {
200 if !self.pool_completed.is_empty() {
202 while let Some(entry) = self.pool_completed.pop() {
203 unsafe {
204 entry.notify();
205 }
206 }
207 }
208 }
209
210 fn poll_entries(&mut self) -> bool {
211 self.poll_blocking();
212
213 let mut cqueue = self.inner.completion();
214 cqueue.sync();
215 let has_entry = !cqueue.is_empty();
216 for entry in cqueue {
217 match entry.user_data() {
218 Self::CANCEL => {}
219 Self::NOTIFY => {
220 let flags = entry.flags();
221 if !more(flags) {
222 self.need_push_notifier = true;
223 }
224 self.notifier.clear().expect("cannot clear notifier");
225 }
226 _ => unsafe {
227 create_entry(entry).notify();
228 },
229 }
230 }
231 has_entry
232 }
233
234 pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
235 Key::new(self.as_raw_fd(), op)
236 }
237
238 pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
239 Ok(())
240 }
241
242 pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
243 instrument!(compio_log::Level::TRACE, "cancel", ?op);
244 trace!("cancel RawOp");
245 unsafe {
246 #[allow(clippy::useless_conversion)]
247 if self
248 .inner
249 .submission()
250 .push(
251 &AsyncCancel::new(op.user_data() as _)
252 .build()
253 .user_data(Self::CANCEL)
254 .into(),
255 )
256 .is_err()
257 {
258 warn!("could not push AsyncCancel entry");
259 }
260 }
261 }
262
263 fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
264 loop {
265 let mut squeue = self.inner.submission();
266 match unsafe { squeue.push(&entry) } {
267 Ok(()) => {
268 squeue.sync();
269 break Ok(());
270 }
271 Err(_) => {
272 drop(squeue);
273 self.poll_entries();
274 match self.submit_auto(Some(Duration::ZERO)) {
275 Ok(()) => {}
276 Err(e)
277 if matches!(
278 e.kind(),
279 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
280 ) => {}
281 Err(e) => return Err(e),
282 }
283 }
284 }
285 }
286 }
287
288 pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
289 instrument!(compio_log::Level::TRACE, "push", ?op);
290 let user_data = op.user_data();
291 let op_pin = op.as_op_pin();
292 trace!("push RawOp");
293 match op_pin.create_entry() {
294 OpEntry::Submission(entry) => {
295 #[allow(clippy::useless_conversion)]
296 self.push_raw(entry.user_data(user_data as _).into())?;
297 Poll::Pending
298 }
299 #[cfg(feature = "io-uring-sqe128")]
300 OpEntry::Submission128(entry) => {
301 self.push_raw(entry.user_data(user_data as _))?;
302 Poll::Pending
303 }
304 OpEntry::Blocking => loop {
305 if self.push_blocking(user_data) {
306 break Poll::Pending;
307 } else {
308 self.poll_blocking();
309 }
310 },
311 }
312 }
313
314 fn push_blocking(&mut self, user_data: usize) -> bool {
315 let waker = self.waker();
316 let completed = self.pool_completed.clone();
317 self.pool
318 .dispatch(move || {
319 let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
320 let op_pin = op.as_op_pin();
321 let res = op_pin.call_blocking();
322 completed.push(Entry::new(user_data, res));
323 waker.wake();
324 })
325 .is_ok()
326 }
327
328 pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
329 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
330 trace!("start polling");
332
333 if self.need_push_notifier {
334 #[allow(clippy::useless_conversion)]
335 self.push_raw(
336 PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
337 .multi(true)
338 .build()
339 .user_data(Self::NOTIFY)
340 .into(),
341 )?;
342 self.need_push_notifier = false;
343 }
344
345 if !self.poll_entries() {
346 self.submit_auto(timeout)?;
347 self.poll_entries();
348 }
349
350 Ok(())
351 }
352
353 pub fn waker(&self) -> Waker {
354 self.notifier.waker()
355 }
356
357 pub fn create_buffer_pool(
358 &mut self,
359 buffer_len: u16,
360 buffer_size: usize,
361 ) -> io::Result<BufferPool> {
362 let buffer_group = self.buffer_group_ids.insert(());
363 if buffer_group > u16::MAX as usize {
364 self.buffer_group_ids.remove(buffer_group);
365
366 return Err(io::Error::new(
367 io::ErrorKind::OutOfMemory,
368 "too many buffer pool allocated",
369 ));
370 }
371
372 let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
373 &self.inner,
374 buffer_len,
375 buffer_group as _,
376 buffer_size,
377 0,
378 )?;
379
380 #[cfg(fusion)]
381 {
382 Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
383 buf_ring,
384 )))
385 }
386 #[cfg(not(fusion))]
387 {
388 Ok(BufferPool::new(buf_ring))
389 }
390 }
391
392 pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
396 #[cfg(fusion)]
397 let buffer_pool = buffer_pool.into_io_uring();
398
399 let buffer_group = buffer_pool.buffer_group();
400 unsafe { buffer_pool.into_inner().release(&self.inner)? };
401 self.buffer_group_ids.remove(buffer_group as _);
402
403 Ok(())
404 }
405}
406
407impl AsRawFd for Driver {
408 fn as_raw_fd(&self) -> RawFd {
409 self.inner.as_raw_fd()
410 }
411}
412
413fn create_entry(cq_entry: CEntry) -> Entry {
414 let result = cq_entry.result();
415 let result = if result < 0 {
416 let result = if result == -libc::ECANCELED {
417 libc::ETIMEDOUT
418 } else {
419 -result
420 };
421 Err(io::Error::from_raw_os_error(result))
422 } else {
423 Ok(result as _)
424 };
425 let mut entry = Entry::new(cq_entry.user_data() as _, result);
426 entry.set_flags(cq_entry.flags());
427
428 entry
429}
430
431fn timespec(duration: std::time::Duration) -> Timespec {
432 Timespec::new()
433 .sec(duration.as_secs())
434 .nsec(duration.subsec_nanos())
435}
436
437#[derive(Debug)]
438struct Notifier {
439 notify: Arc<Notify>,
440}
441
442impl Notifier {
443 fn new() -> io::Result<Self> {
445 let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
446 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
447 Ok(Self {
448 notify: Arc::new(Notify::new(fd)),
449 })
450 }
451
452 pub fn clear(&self) -> io::Result<()> {
453 loop {
454 let mut buffer = [0u64];
455 let res = syscall!(libc::read(
456 self.as_raw_fd(),
457 buffer.as_mut_ptr().cast(),
458 std::mem::size_of::<u64>()
459 ));
460 match res {
461 Ok(len) => {
462 debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
463 break Ok(());
464 }
465 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
467 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
469 Err(e) => break Err(e),
470 }
471 }
472 }
473
474 pub fn waker(&self) -> Waker {
475 Waker::from(self.notify.clone())
476 }
477}
478
479impl AsRawFd for Notifier {
480 fn as_raw_fd(&self) -> RawFd {
481 self.notify.fd.as_raw_fd()
482 }
483}
484
485#[derive(Debug)]
487struct Notify {
488 fd: OwnedFd,
489}
490
491impl Notify {
492 pub(crate) fn new(fd: OwnedFd) -> Self {
493 Self { fd }
494 }
495
496 pub fn notify(&self) -> io::Result<()> {
498 let data = 1u64;
499 syscall!(libc::write(
500 self.fd.as_raw_fd(),
501 &data as *const _ as *const _,
502 std::mem::size_of::<u64>(),
503 ))?;
504 Ok(())
505 }
506}
507
508impl Wake for Notify {
509 fn wake(self: Arc<Self>) {
510 self.wake_by_ref();
511 }
512
513 fn wake_by_ref(self: &Arc<Self>) {
514 self.notify().ok();
515 }
516}