compio_driver/iour/
mod.rs1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsRawFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9 if #[cfg(feature = "io-uring-cqe32")] {
10 use io_uring::cqueue::Entry32 as CEntry;
11 } else {
12 use io_uring::cqueue::Entry as CEntry;
13 }
14}
15cfg_if::cfg_if! {
16 if #[cfg(feature = "io-uring-sqe128")] {
17 use io_uring::squeue::Entry128 as SEntry;
18 } else {
19 use io_uring::squeue::Entry as SEntry;
20 }
21}
22use io_uring::{
23 IoUring,
24 cqueue::more,
25 opcode::{AsyncCancel, PollAdd},
26 types::{Fd, SubmitArgs, Timespec},
27};
28pub(crate) use libc::{sockaddr_storage, socklen_t};
29
30use crate::{AsyncifyPool, Entry, Key, ProactorBuilder, syscall};
31
32pub(crate) mod op;
33
34pub enum OpEntry {
36 Submission(io_uring::squeue::Entry),
38 #[cfg(feature = "io-uring-sqe128")]
39 Submission128(io_uring::squeue::Entry128),
41 Blocking,
43}
44
45impl From<io_uring::squeue::Entry> for OpEntry {
46 fn from(value: io_uring::squeue::Entry) -> Self {
47 Self::Submission(value)
48 }
49}
50
51#[cfg(feature = "io-uring-sqe128")]
52impl From<io_uring::squeue::Entry128> for OpEntry {
53 fn from(value: io_uring::squeue::Entry128) -> Self {
54 Self::Submission128(value)
55 }
56}
57
58pub trait OpCode {
60 fn create_entry(self: Pin<&mut Self>) -> OpEntry;
62
63 fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
66 unreachable!("this operation is asynchronous")
67 }
68
69 unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
77}
78
79pub(crate) struct Driver {
81 inner: IoUring<SEntry, CEntry>,
82 notifier: Notifier,
83 pool: AsyncifyPool,
84 pool_completed: Arc<SegQueue<Entry>>,
85}
86
87impl Driver {
88 const CANCEL: u64 = u64::MAX;
89 const NOTIFY: u64 = u64::MAX - 1;
90
91 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
92 instrument!(compio_log::Level::TRACE, "new", ?builder);
93 trace!("new iour driver");
94 let notifier = Notifier::new()?;
95 let mut io_uring_builder = IoUring::builder();
96 if let Some(sqpoll_idle) = builder.sqpoll_idle {
97 io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
98 }
99 let mut inner = io_uring_builder.build(builder.capacity)?;
100 #[allow(clippy::useless_conversion)]
101 unsafe {
102 inner
103 .submission()
104 .push(
105 &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
106 .multi(true)
107 .build()
108 .user_data(Self::NOTIFY)
109 .into(),
110 )
111 .expect("the squeue sould not be full");
112 }
113 Ok(Self {
114 inner,
115 notifier,
116 pool: builder.create_or_get_thread_pool(),
117 pool_completed: Arc::new(SegQueue::new()),
118 })
119 }
120
121 fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
123 instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
124 let res = {
125 if let Some(duration) = timeout {
127 let timespec = timespec(duration);
128 let args = SubmitArgs::new().timespec(×pec);
129 self.inner.submitter().submit_with_args(1, &args)
130 } else {
131 self.inner.submit_and_wait(1)
132 }
133 };
134 trace!("submit result: {res:?}");
135 match res {
136 Ok(_) => {
137 if self.inner.completion().is_empty() {
138 Err(io::ErrorKind::TimedOut.into())
139 } else {
140 Ok(())
141 }
142 }
143 Err(e) => match e.raw_os_error() {
144 Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
145 Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
146 _ => Err(e),
147 },
148 }
149 }
150
151 fn poll_blocking(&mut self) {
152 if !self.pool_completed.is_empty() {
154 while let Some(entry) = self.pool_completed.pop() {
155 unsafe {
156 entry.notify();
157 }
158 }
159 }
160 }
161
162 fn poll_entries(&mut self) -> bool {
163 self.poll_blocking();
164
165 let mut cqueue = self.inner.completion();
166 cqueue.sync();
167 let has_entry = !cqueue.is_empty();
168 for entry in cqueue {
169 match entry.user_data() {
170 Self::CANCEL => {}
171 Self::NOTIFY => {
172 let flags = entry.flags();
173 debug_assert!(more(flags));
174 self.notifier.clear().expect("cannot clear notifier");
175 }
176 _ => unsafe {
177 create_entry(entry).notify();
178 },
179 }
180 }
181 has_entry
182 }
183
184 pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
185 Key::new(self.as_raw_fd(), op)
186 }
187
188 pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
189 Ok(())
190 }
191
192 pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
193 instrument!(compio_log::Level::TRACE, "cancel", ?op);
194 trace!("cancel RawOp");
195 unsafe {
196 #[allow(clippy::useless_conversion)]
197 if self
198 .inner
199 .submission()
200 .push(
201 &AsyncCancel::new(op.user_data() as _)
202 .build()
203 .user_data(Self::CANCEL)
204 .into(),
205 )
206 .is_err()
207 {
208 warn!("could not push AsyncCancel entry");
209 }
210 }
211 }
212
213 fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
214 loop {
215 let mut squeue = self.inner.submission();
216 match unsafe { squeue.push(&entry) } {
217 Ok(()) => {
218 squeue.sync();
219 break Ok(());
220 }
221 Err(_) => {
222 drop(squeue);
223 self.poll_entries();
224 match self.submit_auto(Some(Duration::ZERO)) {
225 Ok(()) => {}
226 Err(e)
227 if matches!(
228 e.kind(),
229 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
230 ) => {}
231 Err(e) => return Err(e),
232 }
233 }
234 }
235 }
236 }
237
238 pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
239 instrument!(compio_log::Level::TRACE, "push", ?op);
240 let user_data = op.user_data();
241 let op_pin = op.as_op_pin();
242 trace!("push RawOp");
243 match op_pin.create_entry() {
244 OpEntry::Submission(entry) => {
245 #[allow(clippy::useless_conversion)]
246 self.push_raw(entry.user_data(user_data as _).into())?;
247 Poll::Pending
248 }
249 #[cfg(feature = "io-uring-sqe128")]
250 OpEntry::Submission128(entry) => {
251 self.push_raw(entry.user_data(user_data as _))?;
252 Poll::Pending
253 }
254 OpEntry::Blocking => loop {
255 if self.push_blocking(user_data)? {
256 break Poll::Pending;
257 } else {
258 self.poll_blocking();
259 }
260 },
261 }
262 }
263
264 fn push_blocking(&mut self, user_data: usize) -> io::Result<bool> {
265 let handle = self.handle()?;
266 let completed = self.pool_completed.clone();
267 let is_ok = self
268 .pool
269 .dispatch(move || {
270 let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
271 let op_pin = op.as_op_pin();
272 let res = op_pin.call_blocking();
273 completed.push(Entry::new(user_data, res));
274 handle.notify().ok();
275 })
276 .is_ok();
277 Ok(is_ok)
278 }
279
280 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
281 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
282 trace!("start polling");
284
285 if !self.poll_entries() {
286 self.submit_auto(timeout)?;
287 self.poll_entries();
288 }
289
290 Ok(())
291 }
292
293 pub fn handle(&self) -> io::Result<NotifyHandle> {
294 self.notifier.handle()
295 }
296}
297
298impl AsRawFd for Driver {
299 fn as_raw_fd(&self) -> RawFd {
300 self.inner.as_raw_fd()
301 }
302}
303
304fn create_entry(cq_entry: CEntry) -> Entry {
305 let result = cq_entry.result();
306 let result = if result < 0 {
307 let result = if result == -libc::ECANCELED {
308 libc::ETIMEDOUT
309 } else {
310 -result
311 };
312 Err(io::Error::from_raw_os_error(result))
313 } else {
314 Ok(result as _)
315 };
316 let mut entry = Entry::new(cq_entry.user_data() as _, result);
317 entry.set_flags(cq_entry.flags());
318
319 entry
320}
321
322fn timespec(duration: std::time::Duration) -> Timespec {
323 Timespec::new()
324 .sec(duration.as_secs())
325 .nsec(duration.subsec_nanos())
326}
327
328#[derive(Debug)]
329struct Notifier {
330 fd: Arc<OwnedFd>,
331}
332
333impl Notifier {
334 fn new() -> io::Result<Self> {
336 let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
337 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
338 Ok(Self { fd: Arc::new(fd) })
339 }
340
341 pub fn clear(&self) -> io::Result<()> {
342 loop {
343 let mut buffer = [0u64];
344 let res = syscall!(libc::read(
345 self.fd.as_raw_fd(),
346 buffer.as_mut_ptr().cast(),
347 std::mem::size_of::<u64>()
348 ));
349 match res {
350 Ok(len) => {
351 debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
352 break Ok(());
353 }
354 Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
356 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
358 Err(e) => break Err(e),
359 }
360 }
361 }
362
363 pub fn handle(&self) -> io::Result<NotifyHandle> {
364 Ok(NotifyHandle::new(self.fd.clone()))
365 }
366}
367
368impl AsRawFd for Notifier {
369 fn as_raw_fd(&self) -> RawFd {
370 self.fd.as_raw_fd()
371 }
372}
373
374pub struct NotifyHandle {
376 fd: Arc<OwnedFd>,
377}
378
379impl NotifyHandle {
380 pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
381 Self { fd }
382 }
383
384 pub fn notify(&self) -> io::Result<()> {
386 let data = 1u64;
387 syscall!(libc::write(
388 self.fd.as_raw_fd(),
389 &data as *const _ as *const _,
390 std::mem::size_of::<u64>(),
391 ))?;
392 Ok(())
393 }
394}