1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9 if #[cfg(feature = "io-uring-cqe32")] {
10 use io_uring::cqueue::Entry32 as CEntry;
11 } else {
12 use io_uring::cqueue::Entry as CEntry;
13 }
14}
15cfg_if::cfg_if! {
16 if #[cfg(feature = "io-uring-sqe128")] {
17 use io_uring::squeue::Entry128 as SEntry;
18 } else {
19 use io_uring::squeue::Entry as SEntry;
20 }
21}
22use io_uring::{
23 IoUring,
24 cqueue::more,
25 opcode::{AsyncCancel, PollAdd},
26 types::{Fd, SubmitArgs, Timespec},
27};
28pub(crate) use libc::{sockaddr_storage, socklen_t};
29#[cfg(io_uring)]
30use slab::Slab;
31
32use crate::{AsyncifyPool, BufferPool, Entry, Key, ProactorBuilder, syscall};
33
34pub(crate) mod op;
35
36pub enum OpEntry {
38 Submission(io_uring::squeue::Entry),
40 #[cfg(feature = "io-uring-sqe128")]
41 Submission128(io_uring::squeue::Entry128),
43 Blocking,
45}
46
47impl From<io_uring::squeue::Entry> for OpEntry {
48 fn from(value: io_uring::squeue::Entry) -> Self {
49 Self::Submission(value)
50 }
51}
52
53#[cfg(feature = "io-uring-sqe128")]
54impl From<io_uring::squeue::Entry128> for OpEntry {
55 fn from(value: io_uring::squeue::Entry128) -> Self {
56 Self::Submission128(value)
57 }
58}
59
60pub trait OpCode {
62 fn create_entry(self: Pin<&mut Self>) -> OpEntry;
64
65 fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
68 unreachable!("this operation is asynchronous")
69 }
70
71 unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
79}
80
81pub(crate) struct Driver {
83 inner: IoUring<SEntry, CEntry>,
84 notifier: Notifier,
85 pool: AsyncifyPool,
86 pool_completed: Arc<SegQueue<Entry>>,
87 #[cfg(io_uring)]
88 buffer_group_ids: Slab<()>,
89}
90
91impl Driver {
92 const CANCEL: u64 = u64::MAX;
93 const NOTIFY: u64 = u64::MAX - 1;
94
95 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
96 instrument!(compio_log::Level::TRACE, "new", ?builder);
97 trace!("new iour driver");
98 let notifier = Notifier::new()?;
99 let mut io_uring_builder = IoUring::builder();
100 if let Some(sqpoll_idle) = builder.sqpoll_idle {
101 io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
102 }
103 if builder.coop_taskrun {
104 io_uring_builder.setup_coop_taskrun();
105 }
106 if builder.taskrun_flag {
107 io_uring_builder.setup_taskrun_flag();
108 }
109
110 let mut inner = io_uring_builder.build(builder.capacity)?;
111 #[allow(clippy::useless_conversion)]
112 unsafe {
113 inner
114 .submission()
115 .push(
116 &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
117 .multi(true)
118 .build()
119 .user_data(Self::NOTIFY)
120 .into(),
121 )
122 .expect("the squeue sould not be full");
123 }
124 Ok(Self {
125 inner,
126 notifier,
127 pool: builder.create_or_get_thread_pool(),
128 pool_completed: Arc::new(SegQueue::new()),
129 #[cfg(io_uring)]
130 buffer_group_ids: Slab::new(),
131 })
132 }
133
134 fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
136 instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
137
138 let want_sqe = if self.inner.submission().taskrun() {
141 0
142 } else {
143 1
144 };
145
146 let res = {
147 if let Some(duration) = timeout {
149 let timespec = timespec(duration);
150 let args = SubmitArgs::new().timespec(×pec);
151 self.inner.submitter().submit_with_args(want_sqe, &args)
152 } else {
153 self.inner.submit_and_wait(want_sqe)
154 }
155 };
156 trace!("submit result: {res:?}");
157 match res {
158 Ok(_) => {
159 if self.inner.completion().is_empty() {
160 Err(io::ErrorKind::TimedOut.into())
161 } else {
162 Ok(())
163 }
164 }
165 Err(e) => match e.raw_os_error() {
166 Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
167 Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
168 _ => Err(e),
169 },
170 }
171 }
172
173 fn poll_blocking(&mut self) {
174 if !self.pool_completed.is_empty() {
176 while let Some(entry) = self.pool_completed.pop() {
177 unsafe {
178 entry.notify();
179 }
180 }
181 }
182 }
183
184 fn poll_entries(&mut self) -> bool {
185 self.poll_blocking();
186
187 let mut cqueue = self.inner.completion();
188 cqueue.sync();
189 let has_entry = !cqueue.is_empty();
190 for entry in cqueue {
191 match entry.user_data() {
192 Self::CANCEL => {}
193 Self::NOTIFY => {
194 let flags = entry.flags();
195 debug_assert!(more(flags));
196 self.notifier.clear().expect("cannot clear notifier");
197 }
198 _ => unsafe {
199 create_entry(entry).notify();
200 },
201 }
202 }
203 has_entry
204 }
205
206 pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
207 Key::new(self.as_raw_fd(), op)
208 }
209
210 pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
211 Ok(())
212 }
213
214 pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
215 instrument!(compio_log::Level::TRACE, "cancel", ?op);
216 trace!("cancel RawOp");
217 unsafe {
218 #[allow(clippy::useless_conversion)]
219 if self
220 .inner
221 .submission()
222 .push(
223 &AsyncCancel::new(op.user_data() as _)
224 .build()
225 .user_data(Self::CANCEL)
226 .into(),
227 )
228 .is_err()
229 {
230 warn!("could not push AsyncCancel entry");
231 }
232 }
233 }
234
235 fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
236 loop {
237 let mut squeue = self.inner.submission();
238 match unsafe { squeue.push(&entry) } {
239 Ok(()) => {
240 squeue.sync();
241 break Ok(());
242 }
243 Err(_) => {
244 drop(squeue);
245 self.poll_entries();
246 match self.submit_auto(Some(Duration::ZERO)) {
247 Ok(()) => {}
248 Err(e)
249 if matches!(
250 e.kind(),
251 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
252 ) => {}
253 Err(e) => return Err(e),
254 }
255 }
256 }
257 }
258 }
259
260 pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
261 instrument!(compio_log::Level::TRACE, "push", ?op);
262 let user_data = op.user_data();
263 let op_pin = op.as_op_pin();
264 trace!("push RawOp");
265 match op_pin.create_entry() {
266 OpEntry::Submission(entry) => {
267 #[allow(clippy::useless_conversion)]
268 self.push_raw(entry.user_data(user_data as _).into())?;
269 Poll::Pending
270 }
271 #[cfg(feature = "io-uring-sqe128")]
272 OpEntry::Submission128(entry) => {
273 self.push_raw(entry.user_data(user_data as _))?;
274 Poll::Pending
275 }
276 OpEntry::Blocking => loop {
277 if self.push_blocking(user_data) {
278 break Poll::Pending;
279 } else {
280 self.poll_blocking();
281 }
282 },
283 }
284 }
285
286 fn push_blocking(&mut self, user_data: usize) -> bool {
287 let handle = self.handle();
288 let completed = self.pool_completed.clone();
289 self.pool
290 .dispatch(move || {
291 let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
292 let op_pin = op.as_op_pin();
293 let res = op_pin.call_blocking();
294 completed.push(Entry::new(user_data, res));
295 handle.notify().ok();
296 })
297 .is_ok()
298 }
299
300 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
301 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
302 trace!("start polling");
304
305 if !self.poll_entries() {
306 self.submit_auto(timeout)?;
307 self.poll_entries();
308 }
309
310 Ok(())
311 }
312
313 pub fn handle(&self) -> NotifyHandle {
314 self.notifier.handle()
315 }
316
317 #[cfg(io_uring)]
318 pub fn create_buffer_pool(
319 &mut self,
320 buffer_len: u16,
321 buffer_size: usize,
322 ) -> io::Result<BufferPool> {
323 let buffer_group = self.buffer_group_ids.insert(());
324 if buffer_group > u16::MAX as usize {
325 self.buffer_group_ids.remove(buffer_group);
326
327 return Err(io::Error::new(
328 io::ErrorKind::OutOfMemory,
329 "too many buffer pool allocated",
330 ));
331 }
332
333 let buf_ring = io_uring_buf_ring::IoUringBufRing::new(
334 &self.inner,
335 buffer_len,
336 buffer_group as _,
337 buffer_size,
338 )?;
339
340 #[cfg(fusion)]
341 {
342 Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
343 buf_ring,
344 )))
345 }
346 #[cfg(not(fusion))]
347 {
348 Ok(BufferPool::new(buf_ring))
349 }
350 }
351
352 #[cfg(not(io_uring))]
353 pub fn create_buffer_pool(
354 &mut self,
355 buffer_len: u16,
356 buffer_size: usize,
357 ) -> io::Result<BufferPool> {
358 Ok(BufferPool::new(buffer_len, buffer_size))
359 }
360
361 #[cfg(io_uring)]
365 pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
366 #[cfg(fusion)]
367 let buffer_pool = buffer_pool.into_io_uring();
368
369 let buffer_group = buffer_pool.buffer_group();
370 buffer_pool.into_inner().release(&self.inner)?;
371 self.buffer_group_ids.remove(buffer_group as _);
372
373 Ok(())
374 }
375
376 #[cfg(not(io_uring))]
380 pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
381 Ok(())
382 }
383}
384
385impl AsRawFd for Driver {
386 fn as_raw_fd(&self) -> RawFd {
387 self.inner.as_raw_fd()
388 }
389}
390
391fn create_entry(cq_entry: CEntry) -> Entry {
392 let result = cq_entry.result();
393 let result = if result < 0 {
394 let result = if result == -libc::ECANCELED {
395 libc::ETIMEDOUT
396 } else {
397 -result
398 };
399 Err(io::Error::from_raw_os_error(result))
400 } else {
401 Ok(result as _)
402 };
403 let mut entry = Entry::new(cq_entry.user_data() as _, result);
404 entry.set_flags(cq_entry.flags());
405
406 entry
407}
408
409fn timespec(duration: std::time::Duration) -> Timespec {
410 Timespec::new()
411 .sec(duration.as_secs())
412 .nsec(duration.subsec_nanos())
413}
414
415#[derive(Debug)]
416struct Notifier {
417 fd: Arc<OwnedFd>,
418}
419
420impl Notifier {
421 fn new() -> io::Result<Self> {
423 let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
424 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
425 Ok(Self { fd: Arc::new(fd) })
426 }
427
428 pub fn clear(&self) -> io::Result<()> {
429 loop {
430 let mut buffer = [0u64];
431 let res = syscall!(libc::read(
432 self.fd.as_raw_fd(),
433 buffer.as_mut_ptr().cast(),
434 std::mem::size_of::<u64>()
435 ));
436 match res {
437 Ok(len) => {
438 debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
439 break Ok(());
440 }
441 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
443 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
445 Err(e) => break Err(e),
446 }
447 }
448 }
449
450 pub fn handle(&self) -> NotifyHandle {
451 NotifyHandle::new(self.fd.clone())
452 }
453}
454
455impl AsRawFd for Notifier {
456 fn as_raw_fd(&self) -> RawFd {
457 self.fd.as_raw_fd()
458 }
459}
460
461pub struct NotifyHandle {
463 fd: Arc<OwnedFd>,
464}
465
466impl NotifyHandle {
467 pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
468 Self { fd }
469 }
470
471 pub fn notify(&self) -> io::Result<()> {
473 let data = 1u64;
474 syscall!(libc::write(
475 self.fd.as_raw_fd(),
476 &data as *const _ as *const _,
477 std::mem::size_of::<u64>(),
478 ))?;
479 Ok(())
480 }
481}