Skip to main content

ntex_net/uring/
driver.rs

1use std::cell::{Cell, UnsafeCell};
2use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
3use std::{cmp, collections::VecDeque, fmt, io, mem, net, ptr, rc::Rc, sync::Arc};
4
5#[cfg(unix)]
6use std::os::unix::net::UnixStream as OsUnixStream;
7
8use ntex_io::Io;
9use ntex_io_uring::cqueue::{self, Entry as CEntry, more};
10use ntex_io_uring::opcode::{AsyncCancel, PollAdd};
11use ntex_io_uring::squeue::{Entry as SEntry, SubmissionQueue};
12use ntex_io_uring::{IoUring, Probe, types::Fd};
13use ntex_rt::{DriverType, Notify, PollResult, Runtime, syscall};
14use ntex_service::cfg::SharedCfg;
15use socket2::{Protocol, SockAddr, Socket, Type};
16
17use super::{TcpStream, UnixStream, stream::StreamOps};
18use crate::channel::Receiver;
19
20pub trait Handler {
21    /// Operation is completed
22    fn completed(&mut self, id: usize, flags: u32, result: io::Result<usize>);
23
24    /// Operation is canceled
25    fn canceled(&mut self, id: usize);
26
27    /// Driver turn is completed
28    fn tick(&mut self);
29
30    /// Cleanup before drop
31    fn cleanup(&mut self);
32}
33
34pub struct DriverApi {
35    batch: u64,
36    inner: Rc<DriverInner>,
37}
38
39impl DriverApi {
40    #[inline]
41    /// Check if kernel ver 6.1 or greater
42    pub fn is_new(&self) -> bool {
43        self.inner.flags.get().contains(Flags::NEW)
44    }
45
46    fn submit_inner<F>(&self, f: F)
47    where
48        F: FnOnce(&mut SEntry),
49    {
50        unsafe {
51            let changes = &mut *self.inner.changes.get();
52            let sq = self.inner.ring.submission();
53            if !changes.is_empty() || sq.is_full() {
54                changes.push_back(mem::MaybeUninit::uninit());
55                let entry = changes.back_mut().unwrap();
56                ptr::write_bytes(entry.as_mut_ptr(), 0, 1);
57                f(entry.assume_init_mut());
58            } else {
59                sq.push_inline(f).expect("Queue size is checked");
60            }
61        }
62    }
63
64    #[inline]
65    /// Submit request to the driver.
66    pub fn submit(&self, id: u32, entry: SEntry) {
67        self.submit_inner(|en| {
68            *en = entry;
69            en.set_user_data(u64::from(id) | self.batch);
70        });
71    }
72
73    #[inline]
74    /// Submit request to the driver.
75    pub fn submit_inline<F>(&self, id: u32, f: F)
76    where
77        F: FnOnce(&mut SEntry),
78    {
79        self.submit_inner(|en| {
80            f(en);
81            en.set_user_data(u64::from(id) | self.batch);
82        });
83    }
84
85    #[inline]
86    /// Attempt to cancel an already issued request.
87    pub fn cancel(&self, id: u32) {
88        self.submit_inner(|en| {
89            *en = AsyncCancel::new(u64::from(id) | self.batch)
90                .build()
91                .user_data(Driver::CANCEL);
92        });
93    }
94
95    /// Get whether a specific io-uring opcode is supported.
96    pub fn is_supported(&self, opcode: u8) -> bool {
97        self.inner.probe.is_supported(opcode)
98    }
99}
100
101/// Low-level driver of io-uring.
102pub struct Driver {
103    fd: RawFd,
104    hid: Cell<u64>,
105    notifier: Notifier,
106    #[allow(clippy::box_collection)]
107    handlers: Cell<Option<Box<Vec<HandlerItem>>>>,
108    inner: Rc<DriverInner>,
109}
110
111struct HandlerItem {
112    hnd: Box<dyn Handler>,
113    modified: bool,
114}
115
116impl HandlerItem {
117    fn tick(&mut self) {
118        if self.modified {
119            self.modified = false;
120            self.hnd.tick();
121        }
122    }
123}
124
125bitflags::bitflags! {
126    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
127    struct Flags: u8 {
128        const NEW      = 0b0000_0001;
129        const NOTIFIER = 0b0000_0010;
130    }
131}
132
133struct DriverInner {
134    probe: Probe,
135    flags: Cell<Flags>,
136    ring: IoUring<SEntry, CEntry>,
137    changes: UnsafeCell<VecDeque<mem::MaybeUninit<SEntry>>>,
138}
139
140impl Driver {
141    const NOTIFY: u64 = u64::MAX;
142    const CANCEL: u64 = u64::MAX - 1;
143    const BATCH: u64 = 48;
144    const BATCH_MASK: u64 = 0xFFFF_0000_0000_0000;
145    const DATA_MASK: u64 = 0x0000_FFFF_FFFF_FFFF;
146
147    /// Create io-uring driver
148    pub fn new(capacity: u32) -> io::Result<Self> {
149        // Create ring
150        let (new, ring) = if let Ok(ring) = IoUring::builder()
151            .setup_coop_taskrun()
152            .setup_single_issuer()
153            .setup_defer_taskrun()
154            .build(capacity)
155        {
156            log::info!(
157                "New io-uring driver with single-issuer, coop-taskrun, defer-taskrun"
158            );
159            (true, ring)
160        } else if let Ok(ring) = IoUring::builder().setup_single_issuer().build(capacity) {
161            log::info!("New io-uring driver with single-issuer");
162            (true, ring)
163        } else {
164            let ring = IoUring::builder().build(capacity)?;
165            log::info!("New io-uring driver");
166            (false, ring)
167        };
168
169        let mut probe = Probe::new();
170        ring.submitter().register_probe(&mut probe)?;
171
172        // Remote notifier
173        let notifier = Notifier::new()?;
174        unsafe {
175            let sq = ring.submission();
176            sq.push(
177                &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
178                    .multi(true)
179                    .build()
180                    .user_data(Self::NOTIFY),
181            )
182            .expect("the squeue sould not be full");
183            sq.sync();
184        }
185
186        let fd = ring.as_raw_fd();
187        let inner = Rc::new(DriverInner {
188            ring,
189            probe,
190            flags: Cell::new(if new { Flags::NEW } else { Flags::empty() }),
191            changes: UnsafeCell::new(VecDeque::with_capacity(32)),
192        });
193
194        Ok(Self {
195            fd,
196            inner,
197            notifier,
198            hid: Cell::new(0),
199            handlers: Cell::new(Some(Box::new(Vec::new()))),
200        })
201    }
202
203    /// Driver type
204    pub const fn tp(&self) -> DriverType {
205        DriverType::IoUring
206    }
207
208    /// Register updates handler
209    pub fn register<F>(&self, f: F)
210    where
211        F: FnOnce(DriverApi) -> Box<dyn Handler>,
212    {
213        let id = self.hid.get();
214        let mut handlers = self.handlers.take().unwrap_or_default();
215        handlers.push(HandlerItem {
216            hnd: f(DriverApi {
217                batch: id << Self::BATCH,
218                inner: self.inner.clone(),
219            }),
220            modified: false,
221        });
222        self.handlers.set(Some(handlers));
223        self.hid.set(id + 1);
224    }
225
226    fn apply_changes(&self, sq: SubmissionQueue<'_, SEntry>) -> bool {
227        unsafe {
228            let changes = &mut *self.inner.changes.get();
229            if changes.is_empty() {
230                false
231            } else {
232                let num = cmp::min(changes.len(), sq.capacity() - sq.len());
233                let (s1, s2) = changes.as_slices();
234                let s1_num = cmp::min(s1.len(), num);
235                if s1_num > 0 {
236                    // safety: "changes" contains only initialized entries
237                    sq.push_multiple(
238                        ((&raw const s1[0..s1_num]) as *const [SEntry])
239                            .as_ref()
240                            .unwrap(),
241                    )
242                    .unwrap();
243                } else if !s2.is_empty() {
244                    let s2_num = cmp::min(s2.len(), num - s1_num);
245                    if s2_num > 0 {
246                        sq.push_multiple(
247                            ((&raw const s2[0..s2_num]) as *const [SEntry])
248                                .as_ref()
249                                .unwrap(),
250                        )
251                        .unwrap();
252                    }
253                }
254                changes.drain(0..num);
255
256                !changes.is_empty()
257            }
258        }
259    }
260
261    /// Handle ring completions, forward changes to specific handler
262    fn poll_completions(
263        &self,
264        cq: &mut cqueue::CompletionQueue<'_, CEntry>,
265        sq: SubmissionQueue<'_, SEntry>,
266    ) {
267        cq.sync();
268
269        if !cqueue::CompletionQueue::<'_, _>::is_empty(cq) {
270            let mut handlers = self.handlers.take().unwrap();
271            for entry in cq {
272                let user_data = entry.user_data();
273                match user_data {
274                    Self::CANCEL => {}
275                    Self::NOTIFY => {
276                        let flags = entry.flags();
277                        self.notifier.clear().expect("cannot clear notifier");
278
279                        // re-submit notifier fd
280                        if !more(flags) {
281                            unsafe {
282                                sq.push(
283                                    &PollAdd::new(
284                                        Fd(self.notifier.as_raw_fd()),
285                                        libc::POLLIN as _,
286                                    )
287                                    .multi(true)
288                                    .build()
289                                    .user_data(Self::NOTIFY),
290                                )
291                            }
292                            .expect("the squeue sould not be full");
293                        }
294                    }
295                    _ => {
296                        let batch =
297                            ((user_data & Self::BATCH_MASK) >> Self::BATCH) as usize;
298                        let user_data = (user_data & Self::DATA_MASK) as usize;
299
300                        let result = entry.result();
301                        if result == -libc::ECANCELED {
302                            handlers[batch].modified = true;
303                            handlers[batch].hnd.canceled(user_data);
304                        } else {
305                            let result = if result < 0 {
306                                Err(io::Error::from_raw_os_error(-result))
307                            } else {
308                                #[allow(clippy::cast_sign_loss)]
309                                Ok(result as _)
310                            };
311                            handlers[batch].modified = true;
312                            handlers[batch]
313                                .hnd
314                                .completed(user_data, entry.flags(), result);
315                        }
316                    }
317                }
318            }
319            for h in handlers.iter_mut() {
320                h.tick();
321            }
322            self.handlers.set(Some(handlers));
323        }
324    }
325}
326
327impl AsRawFd for Driver {
328    fn as_raw_fd(&self) -> RawFd {
329        self.fd
330    }
331}
332
333impl crate::Reactor for Driver {
334    fn tcp_connect(&self, addr: net::SocketAddr, cfg: SharedCfg) -> Receiver<Io> {
335        let addr = SockAddr::from(addr);
336        let result = Socket::new(addr.domain(), Type::STREAM, Some(Protocol::TCP))
337            .and_then(crate::helpers::prep_socket)
338            .map(move |sock| (addr, sock));
339
340        match result {
341            Err(err) => Receiver::new(Err(err)),
342            Ok((addr, sock)) => {
343                super::connect::ConnectOps::get(self).connect(sock, addr, cfg)
344            }
345        }
346    }
347
348    fn unix_connect(&self, addr: std::path::PathBuf, cfg: SharedCfg) -> Receiver<Io> {
349        let result = SockAddr::unix(addr).and_then(|addr| {
350            Socket::new(addr.domain(), Type::STREAM, None)
351                .and_then(crate::helpers::prep_socket)
352                .map(move |sock| (addr, sock))
353        });
354
355        match result {
356            Err(err) => Receiver::new(Err(err)),
357            Ok((addr, sock)) => {
358                super::connect::ConnectOps::get(self).connect(sock, addr, cfg)
359            }
360        }
361    }
362
363    fn from_tcp_stream(&self, stream: net::TcpStream, cfg: SharedCfg) -> io::Result<Io> {
364        stream.set_nodelay(true)?;
365
366        Ok(Io::new(
367            TcpStream(
368                crate::helpers::prep_socket(Socket::from(stream))?,
369                StreamOps::get(self),
370            ),
371            cfg,
372        ))
373    }
374
375    #[cfg(unix)]
376    fn from_unix_stream(&self, stream: OsUnixStream, cfg: SharedCfg) -> io::Result<Io> {
377        Ok(Io::new(
378            UnixStream(
379                crate::helpers::prep_socket(Socket::from(stream))?,
380                StreamOps::get(self),
381            ),
382            cfg,
383        ))
384    }
385}
386
387impl ntex_rt::Driver for Driver {
388    /// Poll the driver and handle completed operations.
389    fn run(&self, rt: &Runtime) -> io::Result<()> {
390        let ring = &self.inner.ring;
391        let sq = ring.submission();
392        let mut cq = unsafe { ring.completion_shared() };
393        let submitter = ring.submitter();
394        loop {
395            self.poll_completions(&mut cq, sq);
396
397            let more_tasks = match rt.poll() {
398                PollResult::Pending => false,
399                PollResult::PollAgain => true,
400                PollResult::Ready => return Ok(()),
401            };
402            let more_changes = self.apply_changes(sq);
403
404            // squeue has to sync after we apply all changes
405            // otherwise ring wont see any change in submit call
406            sq.sync();
407
408            let result = if more_changes || more_tasks {
409                submitter.submit()
410            } else {
411                submitter.submit_and_wait(1)
412            };
413
414            if let Err(e) = result {
415                match e.raw_os_error() {
416                    Some(libc::ETIME | libc::EBUSY | libc::EAGAIN | libc::EINTR) => {
417                        log::info!("Ring submit interrupted, {e:?}");
418                    }
419                    _ => return Err(e),
420                }
421            }
422        }
423    }
424
425    /// Get notification handle
426    fn handle(&self) -> Box<dyn Notify> {
427        Box::new(self.notifier.handle())
428    }
429
430    fn clear(&self) {
431        for mut h in self.handlers.take().unwrap().into_iter() {
432            h.hnd.cleanup();
433        }
434    }
435}
436
437#[derive(Debug)]
438pub(crate) struct Notifier {
439    fd: Arc<OwnedFd>,
440}
441
442impl Notifier {
443    /// Create a new notifier.
444    pub(crate) fn new() -> io::Result<Self> {
445        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
446        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
447        Ok(Self { fd: Arc::new(fd) })
448    }
449
450    pub(crate) fn clear(&self) -> io::Result<()> {
451        loop {
452            let mut buffer = [0u64];
453            let res = syscall!(libc::read(
454                self.fd.as_raw_fd(),
455                buffer.as_mut_ptr().cast(),
456                mem::size_of::<u64>()
457            ));
458            #[allow(clippy::cast_possible_wrap)]
459            match res {
460                Ok(len) => {
461                    debug_assert_eq!(len, mem::size_of::<u64>() as isize);
462                    break Ok(());
463                }
464                // Clear the next time
465                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
466                // Just like read_exact
467                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
468                Err(e) => break Err(e),
469            }
470        }
471    }
472
473    pub(crate) fn handle(&self) -> NotifyHandle {
474        NotifyHandle::new(self.fd.clone())
475    }
476}
477
478impl AsRawFd for Notifier {
479    fn as_raw_fd(&self) -> RawFd {
480        self.fd.as_raw_fd()
481    }
482}
483
484#[derive(Clone, Debug)]
485/// A notify handle to the driver.
486pub(crate) struct NotifyHandle {
487    fd: Arc<OwnedFd>,
488}
489
490impl NotifyHandle {
491    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
492        Self { fd }
493    }
494}
495
496impl Notify for NotifyHandle {
497    /// Notify the driver.
498    fn notify(&self) -> io::Result<()> {
499        let data = 1u64;
500        syscall!(libc::write(
501            self.fd.as_raw_fd(),
502            (&raw const data).cast(),
503            std::mem::size_of::<u64>(),
504        ))?;
505        Ok(())
506    }
507}
508
509impl fmt::Debug for Driver {
510    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511        f.debug_struct("Driver")
512            .field("fd", &self.fd)
513            .field("hid", &self.hid)
514            .field("nodifier", &self.notifier)
515            .finish()
516    }
517}
518
519impl fmt::Debug for DriverApi {
520    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
521        f.debug_struct("DriverApi")
522            .field("batch", &self.batch)
523            .finish()
524    }
525}