1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9 if #[cfg(feature = "io-uring-cqe32")] {
10 use io_uring::cqueue::Entry32 as CEntry;
11 } else {
12 use io_uring::cqueue::Entry as CEntry;
13 }
14}
15cfg_if::cfg_if! {
16 if #[cfg(feature = "io-uring-sqe128")] {
17 use io_uring::squeue::Entry128 as SEntry;
18 } else {
19 use io_uring::squeue::Entry as SEntry;
20 }
21}
22use io_uring::{
23 IoUring,
24 cqueue::more,
25 opcode::{AsyncCancel, PollAdd},
26 types::{Fd, SubmitArgs, Timespec},
27};
28use slab::Slab;
29
30use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder, syscall};
31
32pub(crate) mod op;
33
34pub enum OpEntry {
36 Submission(io_uring::squeue::Entry),
38 #[cfg(feature = "io-uring-sqe128")]
39 Submission128(io_uring::squeue::Entry128),
41 Blocking,
43}
44
45impl From<io_uring::squeue::Entry> for OpEntry {
46 fn from(value: io_uring::squeue::Entry) -> Self {
47 Self::Submission(value)
48 }
49}
50
51#[cfg(feature = "io-uring-sqe128")]
52impl From<io_uring::squeue::Entry128> for OpEntry {
53 fn from(value: io_uring::squeue::Entry128) -> Self {
54 Self::Submission128(value)
55 }
56}
57
58pub trait OpCode {
60 fn create_entry(self: Pin<&mut Self>) -> OpEntry;
62
63 fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
66 unreachable!("this operation is asynchronous")
67 }
68
69 unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
77}
78
79pub(crate) struct Driver {
81 inner: IoUring<SEntry, CEntry>,
82 notifier: Notifier,
83 pool: AsyncifyPool,
84 pool_completed: Arc<SegQueue<Entry>>,
85 buffer_group_ids: Slab<()>,
86 need_push_notifier: bool,
87}
88
89impl Driver {
90 const CANCEL: u64 = u64::MAX;
91 const NOTIFY: u64 = u64::MAX - 1;
92
93 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
94 instrument!(compio_log::Level::TRACE, "new", ?builder);
95 trace!("new iour driver");
96 let notifier = Notifier::new()?;
97 let mut io_uring_builder = IoUring::builder();
98 if let Some(sqpoll_idle) = builder.sqpoll_idle {
99 io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
100 }
101 if builder.coop_taskrun {
102 io_uring_builder.setup_coop_taskrun();
103 }
104 if builder.taskrun_flag {
105 io_uring_builder.setup_taskrun_flag();
106 }
107
108 let inner = io_uring_builder.build(builder.capacity)?;
109
110 if let Some(fd) = builder.eventfd {
111 inner.submitter().register_eventfd(fd)?;
112 }
113
114 Ok(Self {
115 inner,
116 notifier,
117 pool: builder.create_or_get_thread_pool(),
118 pool_completed: Arc::new(SegQueue::new()),
119 buffer_group_ids: Slab::new(),
120 need_push_notifier: true,
121 })
122 }
123
124 pub fn driver_type(&self) -> DriverType {
125 DriverType::IoUring
126 }
127
128 fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
130 instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
131
132 let want_sqe = if self.inner.submission().taskrun() {
135 0
136 } else {
137 1
138 };
139
140 let res = {
141 if let Some(duration) = timeout {
143 let timespec = timespec(duration);
144 let args = SubmitArgs::new().timespec(×pec);
145 self.inner.submitter().submit_with_args(want_sqe, &args)
146 } else {
147 self.inner.submit_and_wait(want_sqe)
148 }
149 };
150 trace!("submit result: {res:?}");
151 match res {
152 Ok(_) => {
153 if self.inner.completion().is_empty() {
154 Err(io::ErrorKind::TimedOut.into())
155 } else {
156 Ok(())
157 }
158 }
159 Err(e) => match e.raw_os_error() {
160 Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
161 Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
162 _ => Err(e),
163 },
164 }
165 }
166
167 fn poll_blocking(&mut self) {
168 if !self.pool_completed.is_empty() {
170 while let Some(entry) = self.pool_completed.pop() {
171 unsafe {
172 entry.notify();
173 }
174 }
175 }
176 }
177
178 fn poll_entries(&mut self) -> bool {
179 self.poll_blocking();
180
181 let mut cqueue = self.inner.completion();
182 cqueue.sync();
183 let has_entry = !cqueue.is_empty();
184 for entry in cqueue {
185 match entry.user_data() {
186 Self::CANCEL => {}
187 Self::NOTIFY => {
188 let flags = entry.flags();
189 if !more(flags) {
190 self.need_push_notifier = true;
191 }
192 self.notifier.clear().expect("cannot clear notifier");
193 }
194 _ => unsafe {
195 create_entry(entry).notify();
196 },
197 }
198 }
199 has_entry
200 }
201
202 pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
203 Key::new(self.as_raw_fd(), op)
204 }
205
206 pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
207 Ok(())
208 }
209
210 pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
211 instrument!(compio_log::Level::TRACE, "cancel", ?op);
212 trace!("cancel RawOp");
213 unsafe {
214 #[allow(clippy::useless_conversion)]
215 if self
216 .inner
217 .submission()
218 .push(
219 &AsyncCancel::new(op.user_data() as _)
220 .build()
221 .user_data(Self::CANCEL)
222 .into(),
223 )
224 .is_err()
225 {
226 warn!("could not push AsyncCancel entry");
227 }
228 }
229 }
230
231 fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
232 loop {
233 let mut squeue = self.inner.submission();
234 match unsafe { squeue.push(&entry) } {
235 Ok(()) => {
236 squeue.sync();
237 break Ok(());
238 }
239 Err(_) => {
240 drop(squeue);
241 self.poll_entries();
242 match self.submit_auto(Some(Duration::ZERO)) {
243 Ok(()) => {}
244 Err(e)
245 if matches!(
246 e.kind(),
247 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
248 ) => {}
249 Err(e) => return Err(e),
250 }
251 }
252 }
253 }
254 }
255
256 pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
257 instrument!(compio_log::Level::TRACE, "push", ?op);
258 let user_data = op.user_data();
259 let op_pin = op.as_op_pin();
260 trace!("push RawOp");
261 match op_pin.create_entry() {
262 OpEntry::Submission(entry) => {
263 #[allow(clippy::useless_conversion)]
264 self.push_raw(entry.user_data(user_data as _).into())?;
265 Poll::Pending
266 }
267 #[cfg(feature = "io-uring-sqe128")]
268 OpEntry::Submission128(entry) => {
269 self.push_raw(entry.user_data(user_data as _))?;
270 Poll::Pending
271 }
272 OpEntry::Blocking => loop {
273 if self.push_blocking(user_data) {
274 break Poll::Pending;
275 } else {
276 self.poll_blocking();
277 }
278 },
279 }
280 }
281
282 fn push_blocking(&mut self, user_data: usize) -> bool {
283 let handle = self.handle();
284 let completed = self.pool_completed.clone();
285 self.pool
286 .dispatch(move || {
287 let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
288 let op_pin = op.as_op_pin();
289 let res = op_pin.call_blocking();
290 completed.push(Entry::new(user_data, res));
291 handle.notify().ok();
292 })
293 .is_ok()
294 }
295
296 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
297 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
298 trace!("start polling");
300
301 if self.need_push_notifier {
302 #[allow(clippy::useless_conversion)]
303 self.push_raw(
304 PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
305 .multi(true)
306 .build()
307 .user_data(Self::NOTIFY)
308 .into(),
309 )?;
310 self.need_push_notifier = false;
311 }
312
313 if !self.poll_entries() {
314 self.submit_auto(timeout)?;
315 self.poll_entries();
316 }
317
318 Ok(())
319 }
320
321 pub fn handle(&self) -> NotifyHandle {
322 self.notifier.handle()
323 }
324
325 pub fn create_buffer_pool(
326 &mut self,
327 buffer_len: u16,
328 buffer_size: usize,
329 ) -> io::Result<BufferPool> {
330 let buffer_group = self.buffer_group_ids.insert(());
331 if buffer_group > u16::MAX as usize {
332 self.buffer_group_ids.remove(buffer_group);
333
334 return Err(io::Error::new(
335 io::ErrorKind::OutOfMemory,
336 "too many buffer pool allocated",
337 ));
338 }
339
340 let buf_ring = io_uring_buf_ring::IoUringBufRing::new(
341 &self.inner,
342 buffer_len,
343 buffer_group as _,
344 buffer_size,
345 )?;
346
347 #[cfg(fusion)]
348 {
349 Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
350 buf_ring,
351 )))
352 }
353 #[cfg(not(fusion))]
354 {
355 Ok(BufferPool::new(buf_ring))
356 }
357 }
358
359 pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
363 #[cfg(fusion)]
364 let buffer_pool = buffer_pool.into_io_uring();
365
366 let buffer_group = buffer_pool.buffer_group();
367 buffer_pool.into_inner().release(&self.inner)?;
368 self.buffer_group_ids.remove(buffer_group as _);
369
370 Ok(())
371 }
372}
373
374impl AsRawFd for Driver {
375 fn as_raw_fd(&self) -> RawFd {
376 self.inner.as_raw_fd()
377 }
378}
379
380fn create_entry(cq_entry: CEntry) -> Entry {
381 let result = cq_entry.result();
382 let result = if result < 0 {
383 let result = if result == -libc::ECANCELED {
384 libc::ETIMEDOUT
385 } else {
386 -result
387 };
388 Err(io::Error::from_raw_os_error(result))
389 } else {
390 Ok(result as _)
391 };
392 let mut entry = Entry::new(cq_entry.user_data() as _, result);
393 entry.set_flags(cq_entry.flags());
394
395 entry
396}
397
398fn timespec(duration: std::time::Duration) -> Timespec {
399 Timespec::new()
400 .sec(duration.as_secs())
401 .nsec(duration.subsec_nanos())
402}
403
404#[derive(Debug)]
405struct Notifier {
406 fd: Arc<OwnedFd>,
407}
408
409impl Notifier {
410 fn new() -> io::Result<Self> {
412 let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
413 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
414 Ok(Self { fd: Arc::new(fd) })
415 }
416
417 pub fn clear(&self) -> io::Result<()> {
418 loop {
419 let mut buffer = [0u64];
420 let res = syscall!(libc::read(
421 self.fd.as_raw_fd(),
422 buffer.as_mut_ptr().cast(),
423 std::mem::size_of::<u64>()
424 ));
425 match res {
426 Ok(len) => {
427 debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
428 break Ok(());
429 }
430 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
432 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
434 Err(e) => break Err(e),
435 }
436 }
437 }
438
439 pub fn handle(&self) -> NotifyHandle {
440 NotifyHandle::new(self.fd.clone())
441 }
442}
443
444impl AsRawFd for Notifier {
445 fn as_raw_fd(&self) -> RawFd {
446 self.fd.as_raw_fd()
447 }
448}
449
450pub struct NotifyHandle {
452 fd: Arc<OwnedFd>,
453}
454
455impl NotifyHandle {
456 pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
457 Self { fd }
458 }
459
460 pub fn notify(&self) -> io::Result<()> {
462 let data = 1u64;
463 syscall!(libc::write(
464 self.fd.as_raw_fd(),
465 &data as *const _ as *const _,
466 std::mem::size_of::<u64>(),
467 ))?;
468 Ok(())
469 }
470}