ntex_net/uring/
driver.rs

1use std::cell::{Cell, UnsafeCell};
2use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
3use std::{cmp, collections::VecDeque, fmt, io, mem, net, ptr, rc::Rc, sync::Arc};
4
5#[cfg(unix)]
6use std::os::unix::net::UnixStream as OsUnixStream;
7
8use io_uring::cqueue::{self, Entry as CEntry, more};
9use io_uring::opcode::{AsyncCancel, PollAdd};
10use io_uring::squeue::{Entry as SEntry, SubmissionQueue};
11use io_uring::{IoUring, Probe, types::Fd};
12use ntex_io::Io;
13use ntex_rt::{DriverType, Notify, PollResult, Runtime, syscall};
14use ntex_service::cfg::SharedCfg;
15use socket2::{Protocol, SockAddr, Socket, Type};
16
17use super::{TcpStream, UnixStream, stream::StreamOps};
18use crate::channel::Receiver;
19
20pub trait Handler {
21    /// Operation is completed
22    fn completed(&mut self, id: usize, flags: u32, result: io::Result<usize>);
23
24    /// Operation is canceled
25    fn canceled(&mut self, id: usize);
26
27    /// Driver turn is completed
28    fn tick(&mut self);
29
30    /// Cleanup before drop
31    fn cleanup(&mut self);
32}
33
34pub struct DriverApi {
35    batch: u64,
36    inner: Rc<DriverInner>,
37}
38
39impl DriverApi {
40    #[inline]
41    /// Check if kernel ver 6.1 or greater
42    pub fn is_new(&self) -> bool {
43        self.inner.flags.get().contains(Flags::NEW)
44    }
45
46    fn submit_inner<F>(&self, f: F)
47    where
48        F: FnOnce(&mut SEntry),
49    {
50        unsafe {
51            let changes = &mut *self.inner.changes.get();
52            let sq = self.inner.ring.submission();
53            if !changes.is_empty() || sq.is_full() {
54                changes.push_back(mem::MaybeUninit::uninit());
55                let entry = changes.back_mut().unwrap();
56                ptr::write_bytes(entry.as_mut_ptr(), 0, 1);
57                f(entry.assume_init_mut());
58            } else {
59                sq.push_inline(f).expect("Queue size is checked");
60            }
61        }
62    }
63
64    #[inline]
65    /// Submit request to the driver.
66    pub fn submit(&self, id: u32, entry: SEntry) {
67        self.submit_inner(|en| {
68            *en = entry;
69            en.set_user_data(id as u64 | self.batch);
70        });
71    }
72
73    #[inline]
74    /// Submit request to the driver.
75    pub fn submit_inline<F>(&self, id: u32, f: F)
76    where
77        F: FnOnce(&mut SEntry),
78    {
79        self.submit_inner(|en| {
80            f(en);
81            en.set_user_data(id as u64 | self.batch);
82        });
83    }
84
85    #[inline]
86    /// Attempt to cancel an already issued request.
87    pub fn cancel(&self, id: u32) {
88        self.submit_inner(|en| {
89            *en = AsyncCancel::new(id as u64 | self.batch)
90                .build()
91                .user_data(Driver::CANCEL);
92        });
93    }
94
95    /// Get whether a specific io-uring opcode is supported.
96    pub fn is_supported(&self, opcode: u8) -> bool {
97        self.inner.probe.is_supported(opcode)
98    }
99}
100
101/// Low-level driver of io-uring.
102pub struct Driver {
103    fd: RawFd,
104    hid: Cell<u64>,
105    notifier: Notifier,
106    #[allow(clippy::box_collection)]
107    handlers: Cell<Option<Box<Vec<HandlerItem>>>>,
108    inner: Rc<DriverInner>,
109}
110
111struct HandlerItem {
112    hnd: Box<dyn Handler>,
113    modified: bool,
114}
115
116impl HandlerItem {
117    fn tick(&mut self) {
118        if self.modified {
119            self.modified = false;
120            self.hnd.tick();
121        }
122    }
123}
124
125bitflags::bitflags! {
126    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
127    struct Flags: u8 {
128        const NEW      = 0b0000_0001;
129        const NOTIFIER = 0b0000_0010;
130    }
131}
132
133struct DriverInner {
134    probe: Probe,
135    flags: Cell<Flags>,
136    ring: IoUring<SEntry, CEntry>,
137    changes: UnsafeCell<VecDeque<mem::MaybeUninit<SEntry>>>,
138}
139
140impl Driver {
141    const NOTIFY: u64 = u64::MAX;
142    const CANCEL: u64 = u64::MAX - 1;
143    const BATCH: u64 = 48;
144    const BATCH_MASK: u64 = 0xFFFF_0000_0000_0000;
145    const DATA_MASK: u64 = 0x0000_FFFF_FFFF_FFFF;
146
147    /// Create io-uring driver
148    pub fn new(capacity: u32) -> io::Result<Self> {
149        // Create ring
150        let (new, ring) = if let Ok(ring) = IoUring::builder()
151            .setup_coop_taskrun()
152            .setup_single_issuer()
153            .setup_defer_taskrun()
154            .build(capacity)
155        {
156            log::info!(
157                "New io-uring driver with single-issuer, coop-taskrun, defer-taskrun"
158            );
159            (true, ring)
160        } else if let Ok(ring) = IoUring::builder().setup_single_issuer().build(capacity) {
161            log::info!("New io-uring driver with single-issuer");
162            (true, ring)
163        } else {
164            let ring = IoUring::builder().build(capacity)?;
165            log::info!("New io-uring driver");
166            (false, ring)
167        };
168
169        let mut probe = Probe::new();
170        ring.submitter().register_probe(&mut probe)?;
171
172        // Remote notifier
173        let notifier = Notifier::new()?;
174        unsafe {
175            let sq = ring.submission();
176            sq.push(
177                &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
178                    .multi(true)
179                    .build()
180                    .user_data(Self::NOTIFY),
181            )
182            .expect("the squeue sould not be full");
183            sq.sync();
184        }
185
186        let fd = ring.as_raw_fd();
187        let inner = Rc::new(DriverInner {
188            ring,
189            probe,
190            flags: Cell::new(if new { Flags::NEW } else { Flags::empty() }),
191            changes: UnsafeCell::new(VecDeque::with_capacity(32)),
192        });
193
194        Ok(Self {
195            fd,
196            inner,
197            notifier,
198            hid: Cell::new(0),
199            handlers: Cell::new(Some(Box::new(Vec::new()))),
200        })
201    }
202
203    /// Driver type
204    pub const fn tp(&self) -> DriverType {
205        DriverType::IoUring
206    }
207
208    /// Register updates handler
209    pub fn register<F>(&self, f: F)
210    where
211        F: FnOnce(DriverApi) -> Box<dyn Handler>,
212    {
213        let id = self.hid.get();
214        let mut handlers = self.handlers.take().unwrap_or_default();
215        handlers.push(HandlerItem {
216            hnd: f(DriverApi {
217                batch: id << Self::BATCH,
218                inner: self.inner.clone(),
219            }),
220            modified: false,
221        });
222        self.handlers.set(Some(handlers));
223        self.hid.set(id + 1);
224    }
225
226    fn apply_changes(&self, sq: SubmissionQueue<'_, SEntry>) -> bool {
227        unsafe {
228            let changes = &mut *self.inner.changes.get();
229            if changes.is_empty() {
230                false
231            } else {
232                let num = cmp::min(changes.len(), sq.capacity() - sq.len());
233                let (s1, s2) = changes.as_slices();
234                let s1_num = cmp::min(s1.len(), num);
235                if s1_num > 0 {
236                    // safety: "changes" contains only initialized entries
237                    sq.push_multiple(mem::transmute::<
238                        &[mem::MaybeUninit<SEntry>],
239                        &[SEntry],
240                    >(&s1[0..s1_num]))
241                        .unwrap();
242                } else if !s2.is_empty() {
243                    let s2_num = cmp::min(s2.len(), num - s1_num);
244                    if s2_num > 0 {
245                        sq.push_multiple(mem::transmute::<
246                            &[mem::MaybeUninit<SEntry>],
247                            &[SEntry],
248                        >(&s2[0..s2_num]))
249                            .unwrap();
250                    }
251                }
252                changes.drain(0..num);
253
254                !changes.is_empty()
255            }
256        }
257    }
258
259    /// Handle ring completions, forward changes to specific handler
260    fn poll_completions(
261        &self,
262        cq: &mut cqueue::CompletionQueue<'_, CEntry>,
263        sq: &SubmissionQueue<'_, SEntry>,
264    ) {
265        cq.sync();
266
267        if !cqueue::CompletionQueue::<'_, _>::is_empty(cq) {
268            let mut handlers = self.handlers.take().unwrap();
269            for entry in cq {
270                let user_data = entry.user_data();
271                match user_data {
272                    Self::CANCEL => {}
273                    Self::NOTIFY => {
274                        let flags = entry.flags();
275                        self.notifier.clear().expect("cannot clear notifier");
276
277                        // re-submit notifier fd
278                        if !more(flags) {
279                            unsafe {
280                                sq.push(
281                                    &PollAdd::new(
282                                        Fd(self.notifier.as_raw_fd()),
283                                        libc::POLLIN as _,
284                                    )
285                                    .multi(true)
286                                    .build()
287                                    .user_data(Self::NOTIFY),
288                                )
289                            }
290                            .expect("the squeue sould not be full");
291                        }
292                    }
293                    _ => {
294                        let batch =
295                            ((user_data & Self::BATCH_MASK) >> Self::BATCH) as usize;
296                        let user_data = (user_data & Self::DATA_MASK) as usize;
297
298                        let result = entry.result();
299                        if result == -libc::ECANCELED {
300                            handlers[batch].modified = true;
301                            handlers[batch].hnd.canceled(user_data);
302                        } else {
303                            let result = if result < 0 {
304                                Err(io::Error::from_raw_os_error(-result))
305                            } else {
306                                Ok(result as _)
307                            };
308                            handlers[batch].modified = true;
309                            handlers[batch]
310                                .hnd
311                                .completed(user_data, entry.flags(), result);
312                        }
313                    }
314                }
315            }
316            for h in handlers.iter_mut() {
317                h.tick();
318            }
319            self.handlers.set(Some(handlers));
320        }
321    }
322}
323
324impl AsRawFd for Driver {
325    fn as_raw_fd(&self) -> RawFd {
326        self.fd
327    }
328}
329
330impl crate::Reactor for Driver {
331    fn tcp_connect(&self, addr: net::SocketAddr, cfg: SharedCfg) -> Receiver<Io> {
332        let addr = SockAddr::from(addr);
333        let result = Socket::new(addr.domain(), Type::STREAM, Some(Protocol::TCP))
334            .and_then(crate::helpers::prep_socket)
335            .map(move |sock| (addr, sock));
336
337        match result {
338            Err(err) => Receiver::new(Err(err)),
339            Ok((addr, sock)) => {
340                super::connect::ConnectOps::get(self).connect(sock, addr, cfg)
341            }
342        }
343    }
344
345    fn unix_connect(&self, addr: std::path::PathBuf, cfg: SharedCfg) -> Receiver<Io> {
346        let result = SockAddr::unix(addr).and_then(|addr| {
347            Socket::new(addr.domain(), Type::STREAM, None)
348                .and_then(crate::helpers::prep_socket)
349                .map(move |sock| (addr, sock))
350        });
351
352        match result {
353            Err(err) => Receiver::new(Err(err)),
354            Ok((addr, sock)) => {
355                super::connect::ConnectOps::get(self).connect(sock, addr, cfg)
356            }
357        }
358    }
359
360    fn from_tcp_stream(&self, stream: net::TcpStream, cfg: SharedCfg) -> io::Result<Io> {
361        stream.set_nodelay(true)?;
362
363        Ok(Io::new(
364            TcpStream(
365                crate::helpers::prep_socket(Socket::from(stream))?,
366                StreamOps::get(self),
367            ),
368            cfg,
369        ))
370    }
371
372    #[cfg(unix)]
373    fn from_unix_stream(&self, stream: OsUnixStream, cfg: SharedCfg) -> io::Result<Io> {
374        Ok(Io::new(
375            UnixStream(
376                crate::helpers::prep_socket(Socket::from(stream))?,
377                StreamOps::get(self),
378            ),
379            cfg,
380        ))
381    }
382}
383
384impl ntex_rt::Driver for Driver {
385    /// Poll the driver and handle completed operations.
386    fn run(&self, rt: &Runtime) -> io::Result<()> {
387        let ring = &self.inner.ring;
388        let sq = ring.submission();
389        let mut cq = unsafe { ring.completion_shared() };
390        let submitter = ring.submitter();
391        loop {
392            self.poll_completions(&mut cq, &sq);
393
394            let more_tasks = match rt.poll() {
395                PollResult::Pending => false,
396                PollResult::PollAgain => true,
397                PollResult::Ready => return Ok(()),
398            };
399            let more_changes = self.apply_changes(sq);
400
401            // squeue has to sync after we apply all changes
402            // otherwise ring wont see any change in submit call
403            sq.sync();
404
405            let result = if more_changes || more_tasks {
406                submitter.submit()
407            } else {
408                submitter.submit_and_wait(1)
409            };
410
411            if let Err(e) = result {
412                match e.raw_os_error() {
413                    Some(libc::ETIME) | Some(libc::EBUSY) | Some(libc::EAGAIN)
414                    | Some(libc::EINTR) => {
415                        log::info!("Ring submit interrupted, {:?}", e);
416                    }
417                    _ => return Err(e),
418                }
419            }
420        }
421    }
422
423    /// Get notification handle
424    fn handle(&self) -> Box<dyn Notify> {
425        Box::new(self.notifier.handle())
426    }
427
428    fn clear(&self) {
429        for mut h in self.handlers.take().unwrap().into_iter() {
430            h.hnd.cleanup()
431        }
432    }
433}
434
435#[derive(Debug)]
436pub(crate) struct Notifier {
437    fd: Arc<OwnedFd>,
438}
439
440impl Notifier {
441    /// Create a new notifier.
442    pub(crate) fn new() -> io::Result<Self> {
443        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
444        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
445        Ok(Self { fd: Arc::new(fd) })
446    }
447
448    pub(crate) fn clear(&self) -> io::Result<()> {
449        loop {
450            let mut buffer = [0u64];
451            let res = syscall!(libc::read(
452                self.fd.as_raw_fd(),
453                buffer.as_mut_ptr().cast(),
454                mem::size_of::<u64>()
455            ));
456            match res {
457                Ok(len) => {
458                    debug_assert_eq!(len, mem::size_of::<u64>() as isize);
459                    break Ok(());
460                }
461                // Clear the next time
462                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
463                // Just like read_exact
464                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
465                Err(e) => break Err(e),
466            }
467        }
468    }
469
470    pub(crate) fn handle(&self) -> NotifyHandle {
471        NotifyHandle::new(self.fd.clone())
472    }
473}
474
475impl AsRawFd for Notifier {
476    fn as_raw_fd(&self) -> RawFd {
477        self.fd.as_raw_fd()
478    }
479}
480
481#[derive(Clone, Debug)]
482/// A notify handle to the driver.
483pub(crate) struct NotifyHandle {
484    fd: Arc<OwnedFd>,
485}
486
487impl NotifyHandle {
488    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
489        Self { fd }
490    }
491}
492
493impl Notify for NotifyHandle {
494    /// Notify the driver.
495    fn notify(&self) -> io::Result<()> {
496        let data = 1u64;
497        syscall!(libc::write(
498            self.fd.as_raw_fd(),
499            &data as *const _ as *const _,
500            std::mem::size_of::<u64>(),
501        ))?;
502        Ok(())
503    }
504}
505
506impl fmt::Debug for Driver {
507    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
508        f.debug_struct("Driver")
509            .field("fd", &self.fd)
510            .field("hid", &self.hid)
511            .field("nodifier", &self.notifier)
512            .finish()
513    }
514}
515
516impl fmt::Debug for DriverApi {
517    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
518        f.debug_struct("DriverApi")
519            .field("batch", &self.batch)
520            .finish()
521    }
522}