1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9 if #[cfg(feature = "io-uring-cqe32")] {
10 use io_uring::cqueue::Entry32 as CEntry;
11 } else {
12 use io_uring::cqueue::Entry as CEntry;
13 }
14}
15cfg_if::cfg_if! {
16 if #[cfg(feature = "io-uring-sqe128")] {
17 use io_uring::squeue::Entry128 as SEntry;
18 } else {
19 use io_uring::squeue::Entry as SEntry;
20 }
21}
22use io_uring::{
23 IoUring,
24 cqueue::more,
25 opcode::{AsyncCancel, PollAdd},
26 types::{Fd, SubmitArgs, Timespec},
27};
28pub(crate) use libc::{sockaddr_storage, socklen_t};
29#[cfg(io_uring)]
30use slab::Slab;
31
32use crate::{AsyncifyPool, BufferPool, Entry, Key, ProactorBuilder, syscall};
33
34pub(crate) mod op;
35
36pub enum OpEntry {
38 Submission(io_uring::squeue::Entry),
40 #[cfg(feature = "io-uring-sqe128")]
41 Submission128(io_uring::squeue::Entry128),
43 Blocking,
45}
46
47impl From<io_uring::squeue::Entry> for OpEntry {
48 fn from(value: io_uring::squeue::Entry) -> Self {
49 Self::Submission(value)
50 }
51}
52
53#[cfg(feature = "io-uring-sqe128")]
54impl From<io_uring::squeue::Entry128> for OpEntry {
55 fn from(value: io_uring::squeue::Entry128) -> Self {
56 Self::Submission128(value)
57 }
58}
59
60pub trait OpCode {
62 fn create_entry(self: Pin<&mut Self>) -> OpEntry;
64
65 fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
68 unreachable!("this operation is asynchronous")
69 }
70
71 unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
79}
80
81pub(crate) struct Driver {
83 inner: IoUring<SEntry, CEntry>,
84 notifier: Notifier,
85 pool: AsyncifyPool,
86 pool_completed: Arc<SegQueue<Entry>>,
87 #[cfg(io_uring)]
88 buffer_group_ids: Slab<()>,
89}
90
91impl Driver {
92 const CANCEL: u64 = u64::MAX;
93 const NOTIFY: u64 = u64::MAX - 1;
94
95 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
96 instrument!(compio_log::Level::TRACE, "new", ?builder);
97 trace!("new iour driver");
98 let notifier = Notifier::new()?;
99 let mut io_uring_builder = IoUring::builder();
100 if let Some(sqpoll_idle) = builder.sqpoll_idle {
101 io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
102 }
103 if builder.coop_taskrun {
104 io_uring_builder.setup_coop_taskrun();
105 }
106 if builder.taskrun_flag {
107 io_uring_builder.setup_taskrun_flag();
108 }
109
110 let mut inner = io_uring_builder.build(builder.capacity)?;
111
112 if let Some(fd) = builder.eventfd {
113 inner.submitter().register_eventfd(fd)?;
114 }
115
116 #[allow(clippy::useless_conversion)]
117 unsafe {
118 inner
119 .submission()
120 .push(
121 &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
122 .multi(true)
123 .build()
124 .user_data(Self::NOTIFY)
125 .into(),
126 )
127 .expect("the squeue sould not be full");
128 }
129 Ok(Self {
130 inner,
131 notifier,
132 pool: builder.create_or_get_thread_pool(),
133 pool_completed: Arc::new(SegQueue::new()),
134 #[cfg(io_uring)]
135 buffer_group_ids: Slab::new(),
136 })
137 }
138
139 fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
141 instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
142
143 let want_sqe = if self.inner.submission().taskrun() {
146 0
147 } else {
148 1
149 };
150
151 let res = {
152 if let Some(duration) = timeout {
154 let timespec = timespec(duration);
155 let args = SubmitArgs::new().timespec(×pec);
156 self.inner.submitter().submit_with_args(want_sqe, &args)
157 } else {
158 self.inner.submit_and_wait(want_sqe)
159 }
160 };
161 trace!("submit result: {res:?}");
162 match res {
163 Ok(_) => {
164 if self.inner.completion().is_empty() {
165 Err(io::ErrorKind::TimedOut.into())
166 } else {
167 Ok(())
168 }
169 }
170 Err(e) => match e.raw_os_error() {
171 Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
172 Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
173 _ => Err(e),
174 },
175 }
176 }
177
178 fn poll_blocking(&mut self) {
179 if !self.pool_completed.is_empty() {
181 while let Some(entry) = self.pool_completed.pop() {
182 unsafe {
183 entry.notify();
184 }
185 }
186 }
187 }
188
189 fn poll_entries(&mut self) -> bool {
190 self.poll_blocking();
191
192 let mut cqueue = self.inner.completion();
193 cqueue.sync();
194 let has_entry = !cqueue.is_empty();
195 for entry in cqueue {
196 match entry.user_data() {
197 Self::CANCEL => {}
198 Self::NOTIFY => {
199 let flags = entry.flags();
200 debug_assert!(more(flags));
201 self.notifier.clear().expect("cannot clear notifier");
202 }
203 _ => unsafe {
204 create_entry(entry).notify();
205 },
206 }
207 }
208 has_entry
209 }
210
211 pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
212 Key::new(self.as_raw_fd(), op)
213 }
214
215 pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
216 Ok(())
217 }
218
219 pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
220 instrument!(compio_log::Level::TRACE, "cancel", ?op);
221 trace!("cancel RawOp");
222 unsafe {
223 #[allow(clippy::useless_conversion)]
224 if self
225 .inner
226 .submission()
227 .push(
228 &AsyncCancel::new(op.user_data() as _)
229 .build()
230 .user_data(Self::CANCEL)
231 .into(),
232 )
233 .is_err()
234 {
235 warn!("could not push AsyncCancel entry");
236 }
237 }
238 }
239
240 fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
241 loop {
242 let mut squeue = self.inner.submission();
243 match unsafe { squeue.push(&entry) } {
244 Ok(()) => {
245 squeue.sync();
246 break Ok(());
247 }
248 Err(_) => {
249 drop(squeue);
250 self.poll_entries();
251 match self.submit_auto(Some(Duration::ZERO)) {
252 Ok(()) => {}
253 Err(e)
254 if matches!(
255 e.kind(),
256 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
257 ) => {}
258 Err(e) => return Err(e),
259 }
260 }
261 }
262 }
263 }
264
265 pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
266 instrument!(compio_log::Level::TRACE, "push", ?op);
267 let user_data = op.user_data();
268 let op_pin = op.as_op_pin();
269 trace!("push RawOp");
270 match op_pin.create_entry() {
271 OpEntry::Submission(entry) => {
272 #[allow(clippy::useless_conversion)]
273 self.push_raw(entry.user_data(user_data as _).into())?;
274 Poll::Pending
275 }
276 #[cfg(feature = "io-uring-sqe128")]
277 OpEntry::Submission128(entry) => {
278 self.push_raw(entry.user_data(user_data as _))?;
279 Poll::Pending
280 }
281 OpEntry::Blocking => loop {
282 if self.push_blocking(user_data) {
283 break Poll::Pending;
284 } else {
285 self.poll_blocking();
286 }
287 },
288 }
289 }
290
291 fn push_blocking(&mut self, user_data: usize) -> bool {
292 let handle = self.handle();
293 let completed = self.pool_completed.clone();
294 self.pool
295 .dispatch(move || {
296 let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
297 let op_pin = op.as_op_pin();
298 let res = op_pin.call_blocking();
299 completed.push(Entry::new(user_data, res));
300 handle.notify().ok();
301 })
302 .is_ok()
303 }
304
305 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
306 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
307 trace!("start polling");
309
310 if !self.poll_entries() {
311 self.submit_auto(timeout)?;
312 self.poll_entries();
313 }
314
315 Ok(())
316 }
317
318 pub fn handle(&self) -> NotifyHandle {
319 self.notifier.handle()
320 }
321
322 #[cfg(io_uring)]
323 pub fn create_buffer_pool(
324 &mut self,
325 buffer_len: u16,
326 buffer_size: usize,
327 ) -> io::Result<BufferPool> {
328 let buffer_group = self.buffer_group_ids.insert(());
329 if buffer_group > u16::MAX as usize {
330 self.buffer_group_ids.remove(buffer_group);
331
332 return Err(io::Error::new(
333 io::ErrorKind::OutOfMemory,
334 "too many buffer pool allocated",
335 ));
336 }
337
338 let buf_ring = io_uring_buf_ring::IoUringBufRing::new(
339 &self.inner,
340 buffer_len,
341 buffer_group as _,
342 buffer_size,
343 )?;
344
345 #[cfg(fusion)]
346 {
347 Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
348 buf_ring,
349 )))
350 }
351 #[cfg(not(fusion))]
352 {
353 Ok(BufferPool::new(buf_ring))
354 }
355 }
356
357 #[cfg(not(io_uring))]
358 pub fn create_buffer_pool(
359 &mut self,
360 buffer_len: u16,
361 buffer_size: usize,
362 ) -> io::Result<BufferPool> {
363 Ok(BufferPool::new(buffer_len, buffer_size))
364 }
365
366 #[cfg(io_uring)]
370 pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
371 #[cfg(fusion)]
372 let buffer_pool = buffer_pool.into_io_uring();
373
374 let buffer_group = buffer_pool.buffer_group();
375 buffer_pool.into_inner().release(&self.inner)?;
376 self.buffer_group_ids.remove(buffer_group as _);
377
378 Ok(())
379 }
380
381 #[cfg(not(io_uring))]
385 pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
386 Ok(())
387 }
388}
389
390impl AsRawFd for Driver {
391 fn as_raw_fd(&self) -> RawFd {
392 self.inner.as_raw_fd()
393 }
394}
395
396fn create_entry(cq_entry: CEntry) -> Entry {
397 let result = cq_entry.result();
398 let result = if result < 0 {
399 let result = if result == -libc::ECANCELED {
400 libc::ETIMEDOUT
401 } else {
402 -result
403 };
404 Err(io::Error::from_raw_os_error(result))
405 } else {
406 Ok(result as _)
407 };
408 let mut entry = Entry::new(cq_entry.user_data() as _, result);
409 entry.set_flags(cq_entry.flags());
410
411 entry
412}
413
414fn timespec(duration: std::time::Duration) -> Timespec {
415 Timespec::new()
416 .sec(duration.as_secs())
417 .nsec(duration.subsec_nanos())
418}
419
420#[derive(Debug)]
421struct Notifier {
422 fd: Arc<OwnedFd>,
423}
424
425impl Notifier {
426 fn new() -> io::Result<Self> {
428 let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
429 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
430 Ok(Self { fd: Arc::new(fd) })
431 }
432
433 pub fn clear(&self) -> io::Result<()> {
434 loop {
435 let mut buffer = [0u64];
436 let res = syscall!(libc::read(
437 self.fd.as_raw_fd(),
438 buffer.as_mut_ptr().cast(),
439 std::mem::size_of::<u64>()
440 ));
441 match res {
442 Ok(len) => {
443 debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
444 break Ok(());
445 }
446 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
448 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
450 Err(e) => break Err(e),
451 }
452 }
453 }
454
455 pub fn handle(&self) -> NotifyHandle {
456 NotifyHandle::new(self.fd.clone())
457 }
458}
459
460impl AsRawFd for Notifier {
461 fn as_raw_fd(&self) -> RawFd {
462 self.fd.as_raw_fd()
463 }
464}
465
466pub struct NotifyHandle {
468 fd: Arc<OwnedFd>,
469}
470
471impl NotifyHandle {
472 pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
473 Self { fd }
474 }
475
476 pub fn notify(&self) -> io::Result<()> {
478 let data = 1u64;
479 syscall!(libc::write(
480 self.fd.as_raw_fd(),
481 &data as *const _ as *const _,
482 std::mem::size_of::<u64>(),
483 ))?;
484 Ok(())
485 }
486}