ureeves_userfaultfd/
lib.rs

1//! A Linux mechanism for handling page faults in user space.
2//!
3//! The main way to interact with this library is to create a `Uffd` object with a `UffdBuilder`,
4//! then use the methods of `Uffd` from a worker thread.
5//!
6//! See [`userfaultfd(2)`](http://man7.org/linux/man-pages/man2/userfaultfd.2.html) and
7//! [`ioctl_userfaultfd(2)`](http://man7.org/linux/man-pages/man2/ioctl_userfaultfd.2.html) for more
8//! details.
9
10mod builder;
11mod error;
12mod event;
13mod raw;
14
15pub use crate::builder::{FeatureFlags, UffdBuilder};
16pub use crate::error::{Error, Result};
17pub use crate::event::{Event, FaultKind, ReadWrite};
18
19use bitflags::bitflags;
20use libc::{self, c_void};
21use nix::errno::Errno;
22use nix::unistd::read;
23use std::mem;
24use std::os::fd::{AsFd, BorrowedFd};
25use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
26
27/// Represents an opaque buffer where userfaultfd events are stored.
28///
29/// This is used in conjunction with [`Uffd::read_events`].
30pub struct EventBuffer(Vec<raw::uffd_msg>);
31
32impl EventBuffer {
33    /// Creates a new buffer for `size` number of events.
34    ///
35    /// [`Uffd::read_events`] will read up to this many events at a time.
36    pub fn new(size: usize) -> Self {
37        Self(vec![unsafe { mem::zeroed() }; size])
38    }
39}
40
41/// The userfaultfd object.
42///
43/// The userspace representation of the object is a file descriptor, so this type implements
44/// `AsRawFd`, `FromRawFd`, and `IntoRawFd`. These methods should be used with caution, but can be
45/// essential for using functions like `poll` on a worker thread.
46#[derive(Debug)]
47pub struct Uffd {
48    fd: RawFd,
49}
50
51impl Drop for Uffd {
52    fn drop(&mut self) {
53        unsafe { libc::close(self.fd) };
54    }
55}
56
57impl AsFd for Uffd {
58    fn as_fd(&self) -> BorrowedFd<'_> {
59        unsafe { BorrowedFd::borrow_raw(self.as_raw_fd()) }
60    }
61}
62
63impl AsRawFd for Uffd {
64    fn as_raw_fd(&self) -> RawFd {
65        self.fd
66    }
67}
68
69impl IntoRawFd for Uffd {
70    fn into_raw_fd(self) -> RawFd {
71        self.fd
72    }
73}
74
75impl FromRawFd for Uffd {
76    unsafe fn from_raw_fd(fd: RawFd) -> Self {
77        Uffd { fd }
78    }
79}
80
81bitflags! {
82    /// The registration mode used when registering an address range with `Uffd`.
83    #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
84    pub struct RegisterMode: u64 {
85        /// Registers the range for missing page faults.
86        const MISSING = raw::UFFDIO_REGISTER_MODE_MISSING;
87        /// Registers the range for write faults.
88        #[cfg(feature = "linux5_7")]
89        const WRITE_PROTECT = raw::UFFDIO_REGISTER_MODE_WP;
90    }
91}
92
93impl Uffd {
94    /// Register a memory address range with the userfaultfd object, and returns the `IoctlFlags`
95    /// that are available for the selected range.
96    ///
97    /// This method only registers the given range for missing page faults.
98    pub fn register(&self, start: *mut c_void, len: usize) -> Result<IoctlFlags> {
99        self.register_with_mode(start, len, RegisterMode::MISSING)
100    }
101
102    /// Register a memory address range with the userfaultfd object for the given mode and
103    /// returns the `IoctlFlags` that are available for the selected range.
104    pub fn register_with_mode(
105        &self,
106        start: *mut c_void,
107        len: usize,
108        mode: RegisterMode,
109    ) -> Result<IoctlFlags> {
110        let mut register = raw::uffdio_register {
111            range: raw::uffdio_range {
112                start: start as u64,
113                len: len as u64,
114            },
115            mode: mode.bits(),
116            ioctls: 0,
117        };
118        unsafe {
119            raw::register(self.as_raw_fd(), &mut register as *mut raw::uffdio_register)?;
120        }
121        Ok(IoctlFlags::from_bits_retain(register.ioctls))
122    }
123
124    /// Unregister a memory address range from the userfaultfd object.
125    pub fn unregister(&self, start: *mut c_void, len: usize) -> Result<()> {
126        let mut range = raw::uffdio_range {
127            start: start as u64,
128            len: len as u64,
129        };
130        unsafe {
131            raw::unregister(self.as_raw_fd(), &mut range as *mut raw::uffdio_range)?;
132        }
133        Ok(())
134    }
135
136    /// Atomically copy a continuous memory chunk into the userfaultfd-registered range, and return
137    /// the number of bytes that were successfully copied.
138    ///
139    /// If `wake` is `true`, wake up the thread waiting for page fault resolution on the memory
140    /// range.
141    pub unsafe fn copy(
142        &self,
143        src: *const c_void,
144        dst: *mut c_void,
145        len: usize,
146        wake: bool,
147    ) -> Result<usize> {
148        let mut copy = raw::uffdio_copy {
149            src: src as u64,
150            dst: dst as u64,
151            len: len as u64,
152            mode: if wake {
153                0
154            } else {
155                raw::UFFDIO_COPY_MODE_DONTWAKE
156            },
157            copy: 0,
158        };
159
160        let _ =
161            raw::copy(self.as_raw_fd(), &mut copy as *mut raw::uffdio_copy).map_err(|errno| {
162                match errno {
163                    Errno::EAGAIN => Error::PartiallyCopied(copy.copy as usize),
164                    _ => Error::CopyFailed(errno),
165                }
166            })?;
167        if copy.copy < 0 {
168            // shouldn't ever get here, as errno should be caught above
169            Err(Error::CopyFailed(Errno::from_i32(-copy.copy as i32)))
170        } else {
171            Ok(copy.copy as usize)
172        }
173    }
174
175    /// Zero out a memory address range registered with userfaultfd, and return the number of bytes
176    /// that were successfully zeroed.
177    ///
178    /// If `wake` is `true`, wake up the thread waiting for page fault resolution on the memory
179    /// address range.
180    pub unsafe fn zeropage(&self, start: *mut c_void, len: usize, wake: bool) -> Result<usize> {
181        let mut zeropage = raw::uffdio_zeropage {
182            range: raw::uffdio_range {
183                start: start as u64,
184                len: len as u64,
185            },
186            mode: if wake {
187                0
188            } else {
189                raw::UFFDIO_ZEROPAGE_MODE_DONTWAKE
190            },
191            zeropage: 0,
192        };
193
194        let _ = raw::zeropage(self.as_raw_fd(), &mut zeropage as &mut raw::uffdio_zeropage)
195            .map_err(Error::ZeropageFailed)?;
196        if zeropage.zeropage < 0 {
197            // shouldn't ever get here, as errno should be caught above
198            Err(Error::ZeropageFailed(Errno::from_i32(
199                -zeropage.zeropage as i32,
200            )))
201        } else {
202            Ok(zeropage.zeropage as usize)
203        }
204    }
205
206    /// Wake up the thread waiting for page fault resolution on the specified memory address range.
207    pub fn wake(&self, start: *mut c_void, len: usize) -> Result<()> {
208        let mut range = raw::uffdio_range {
209            start: start as u64,
210            len: len as u64,
211        };
212        unsafe {
213            raw::wake(self.as_raw_fd(), &mut range as *mut raw::uffdio_range)?;
214        }
215        Ok(())
216    }
217
218    /// Makes a range write-protected.
219    #[cfg(feature = "linux5_7")]
220    pub fn write_protect(&self, start: *mut c_void, len: usize) -> Result<()> {
221        let mut ioctl = raw::uffdio_writeprotect {
222            range: raw::uffdio_range {
223                start: start as u64,
224                len: len as u64,
225            },
226            mode: raw::UFFDIO_WRITEPROTECT_MODE_WP,
227        };
228
229        unsafe {
230            raw::write_protect(
231                self.as_raw_fd(),
232                &mut ioctl as *mut raw::uffdio_writeprotect,
233            )?;
234        }
235
236        Ok(())
237    }
238
239    /// Removes the write-protection for a range.
240    ///
241    /// If `wake` is `true`, wake up the thread waiting for page fault resolution on the memory
242    /// address range.
243    #[cfg(feature = "linux5_7")]
244    pub fn remove_write_protection(
245        &self,
246        start: *mut c_void,
247        len: usize,
248        wake: bool,
249    ) -> Result<()> {
250        let mut ioctl = raw::uffdio_writeprotect {
251            range: raw::uffdio_range {
252                start: start as u64,
253                len: len as u64,
254            },
255            mode: if wake {
256                0
257            } else {
258                raw::UFFDIO_WRITEPROTECT_MODE_DONTWAKE
259            },
260        };
261
262        unsafe {
263            raw::write_protect(
264                self.as_raw_fd(),
265                &mut ioctl as *mut raw::uffdio_writeprotect,
266            )?;
267        }
268
269        Ok(())
270    }
271
272    /// Read an `Event` from the userfaultfd object.
273    ///
274    /// If the `Uffd` object was created with `non_blocking` set to `false`, this will block until
275    /// an event is successfully read (returning `Some(event)`, or an error is returned.
276    ///
277    /// If `non_blocking` was `true`, this will immediately return `None` if no event is ready to
278    /// read.
279    ///
280    /// Note that while this method doesn't require a mutable reference to the `Uffd` object, it
281    /// does consume bytes (thread-safely) from the underlying file descriptor.
282    ///
283    /// # Examples
284    ///
285    /// ```rust
286    /// # use ureeves_userfaultfd::{Uffd, Result};
287    /// fn read_event(uffd: &Uffd) -> Result<()> {
288    ///     // Read a single event
289    ///     match uffd.read_event()? {
290    ///         Some(e) => {
291    ///             // Do something with the event
292    ///         },
293    ///         None => {
294    ///             // This was a non-blocking read and the descriptor was not ready for read
295    ///         },
296    ///     }
297    ///     Ok(())
298    /// }
299    /// ```
300    pub fn read_event(&self) -> Result<Option<Event>> {
301        let mut buf = [unsafe { std::mem::zeroed() }; 1];
302        let mut iter = self.read(&mut buf)?;
303        let event = iter.next().transpose()?;
304        assert!(iter.next().is_none());
305        Ok(event)
306    }
307
308    /// Read multiple events from the userfaultfd object using the given event buffer.
309    ///
310    /// If the `Uffd` object was created with `non_blocking` set to `false`, this will block until
311    /// an event is successfully read or an error is returned.
312    ///
313    /// If `non_blocking` was `true`, this will immediately return an empty iterator if the file
314    /// descriptor is not ready for reading.
315    ///
316    /// # Examples
317    ///
318    /// ```rust
319    /// # use ureeves_userfaultfd::{Uffd, EventBuffer};
320    /// fn read_events(uffd: &Uffd) -> ureeves_userfaultfd::Result<()> {
321    ///     // Read up to 100 events at a time
322    ///     let mut buf = EventBuffer::new(100);
323    ///     for event in uffd.read_events(&mut buf)? {
324    ///         let event = event?;
325    ///         // Do something with the event...
326    ///     }
327    ///     Ok(())
328    /// }
329    /// ```
330    pub fn read_events<'a>(
331        &self,
332        buf: &'a mut EventBuffer,
333    ) -> Result<impl Iterator<Item = Result<Event>> + 'a> {
334        self.read(&mut buf.0)
335    }
336
337    fn read<'a>(
338        &self,
339        msgs: &'a mut [raw::uffd_msg],
340    ) -> Result<impl Iterator<Item = Result<Event>> + 'a> {
341        const MSG_SIZE: usize = std::mem::size_of::<raw::uffd_msg>();
342
343        let buf = unsafe {
344            std::slice::from_raw_parts_mut(msgs.as_mut_ptr() as _, msgs.len() * MSG_SIZE)
345        };
346
347        let count = match read(self.as_raw_fd(), buf) {
348            Err(e) if e == Errno::EAGAIN => 0,
349            Err(e) => return Err(Error::SystemError(e)),
350            Ok(0) => return Err(Error::ReadEof),
351            Ok(bytes_read) => {
352                let remainder = bytes_read % MSG_SIZE;
353                if remainder != 0 {
354                    return Err(Error::IncompleteMsg {
355                        read: remainder,
356                        expected: MSG_SIZE,
357                    });
358                }
359
360                bytes_read / MSG_SIZE
361            }
362        };
363
364        Ok(msgs.iter().take(count).map(|msg| Event::from_uffd_msg(msg)))
365    }
366}
367
368bitflags! {
369    /// Used with `UffdBuilder` and `Uffd::register()` to determine which operations are available.
370    #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
371    pub struct IoctlFlags: u64 {
372        const REGISTER = 1 << raw::_UFFDIO_REGISTER;
373        const UNREGISTER = 1 << raw::_UFFDIO_UNREGISTER;
374        const WAKE = 1 << raw::_UFFDIO_WAKE;
375        const COPY = 1 << raw::_UFFDIO_COPY;
376        const ZEROPAGE = 1 << raw::_UFFDIO_ZEROPAGE;
377        #[cfg(feature = "linux5_7")]
378        const WRITE_PROTECT = 1 << raw::_UFFDIO_WRITEPROTECT;
379        const API = 1 << raw::_UFFDIO_API;
380
381        /// Unknown ioctls flags are allowed to be robust to future kernel changes.
382        const _ = !0;
383    }
384}
385
386#[cfg(test)]
387mod test {
388    use super::*;
389    use std::ptr;
390    use std::thread;
391
392    #[test]
393    fn test_read_event() -> Result<()> {
394        const PAGE_SIZE: usize = 4096;
395
396        unsafe {
397            let uffd = UffdBuilder::new().close_on_exec(true).create()?;
398
399            let mapping = libc::mmap(
400                ptr::null_mut(),
401                PAGE_SIZE,
402                libc::PROT_READ | libc::PROT_WRITE,
403                libc::MAP_PRIVATE | libc::MAP_ANON,
404                -1,
405                0,
406            );
407
408            assert!(!mapping.is_null());
409
410            uffd.register(mapping, PAGE_SIZE)?;
411
412            let ptr = mapping as usize;
413            let thread = thread::spawn(move || {
414                let ptr = ptr as *mut u8;
415                *ptr = 1;
416            });
417
418            match uffd.read_event()? {
419                Some(Event::Pagefault {
420                    rw: ReadWrite::Write,
421                    addr,
422                    ..
423                }) => {
424                    assert_eq!(addr, mapping);
425                    uffd.zeropage(addr, PAGE_SIZE, true)?;
426                }
427                _ => panic!("unexpected event"),
428            }
429
430            thread.join().expect("failed to join thread");
431
432            uffd.unregister(mapping, PAGE_SIZE)?;
433
434            assert_eq!(libc::munmap(mapping, PAGE_SIZE), 0);
435        }
436
437        Ok(())
438    }
439
440    #[test]
441    fn test_nonblocking_read_event() -> Result<()> {
442        const PAGE_SIZE: usize = 4096;
443
444        unsafe {
445            let uffd = UffdBuilder::new()
446                .close_on_exec(true)
447                .non_blocking(true)
448                .create()?;
449
450            let mapping = libc::mmap(
451                ptr::null_mut(),
452                PAGE_SIZE,
453                libc::PROT_READ | libc::PROT_WRITE,
454                libc::MAP_PRIVATE | libc::MAP_ANON,
455                -1,
456                0,
457            );
458
459            assert!(!mapping.is_null());
460
461            uffd.register(mapping, PAGE_SIZE)?;
462
463            assert!(uffd.read_event()?.is_none());
464
465            let ptr = mapping as usize;
466            let thread = thread::spawn(move || {
467                let ptr = ptr as *mut u8;
468                *ptr = 1;
469            });
470
471            loop {
472                match uffd.read_event()? {
473                    Some(Event::Pagefault {
474                        rw: ReadWrite::Write,
475                        addr,
476                        ..
477                    }) => {
478                        assert_eq!(addr, mapping);
479                        uffd.zeropage(addr, PAGE_SIZE, true)?;
480                        break;
481                    }
482                    Some(_) => panic!("unexpected event"),
483                    None => thread::sleep(std::time::Duration::from_millis(50)),
484                }
485            }
486
487            thread.join().expect("failed to join thread");
488
489            uffd.unregister(mapping, PAGE_SIZE)?;
490
491            assert_eq!(libc::munmap(mapping, PAGE_SIZE), 0);
492        }
493
494        Ok(())
495    }
496
497    #[test]
498    fn test_read_events() -> Result<()> {
499        unsafe {
500            const MAX_THREADS: usize = 5;
501            const PAGE_SIZE: usize = 4096;
502            const MEM_SIZE: usize = PAGE_SIZE * MAX_THREADS;
503
504            let uffd = UffdBuilder::new().close_on_exec(true).create()?;
505
506            let mapping = libc::mmap(
507                ptr::null_mut(),
508                MEM_SIZE,
509                libc::PROT_READ | libc::PROT_WRITE,
510                libc::MAP_PRIVATE | libc::MAP_ANON,
511                -1,
512                0,
513            );
514
515            assert!(!mapping.is_null());
516
517            uffd.register(mapping, MEM_SIZE)?;
518
519            // As accessing the memory will suspend each thread with a page fault event,
520            // there is no way to signal that the operations the test thread is waiting on to
521            // complete have been performed.
522            //
523            // Therefore, this is inherently racy. The best we can do is simply sleep-wait for
524            // all threads to have signaled that the operation is *about to be performed*.
525            let mut seen = [false; MAX_THREADS];
526            let mut threads = Vec::new();
527            for i in 0..MAX_THREADS {
528                let seen = &mut seen[i] as *mut _ as usize;
529                let ptr = (mapping as *mut u8).add(PAGE_SIZE * i) as usize;
530                threads.push(thread::spawn(move || {
531                    let seen = seen as *mut bool;
532                    let ptr = ptr as *mut u8;
533                    *seen = true;
534                    *ptr = 1;
535                }));
536            }
537
538            loop {
539                // Sleep even if all threads have "signaled", just in case any
540                // thread is preempted prior to faulting the memory access.
541                // Still, there's no guarantee that the call to `read_events` below will
542                // read all the events at once, but this should be "good enough".
543                let done = seen.iter().all(|b| *b);
544                thread::sleep(std::time::Duration::from_millis(50));
545                if done {
546                    break;
547                }
548            }
549
550            // Read all the events at once
551            let mut buf = EventBuffer::new(MAX_THREADS);
552            let mut iter = uffd.read_events(&mut buf)?;
553
554            let mut seen = [false; MAX_THREADS];
555            for _ in 0..MAX_THREADS {
556                match iter
557                    .next()
558                    .transpose()?
559                    .expect("failed to read all events; potential race condition was hit")
560                {
561                    Event::Pagefault {
562                        rw: ReadWrite::Write,
563                        addr,
564                        ..
565                    } => {
566                        let index = (addr as usize - mapping as usize) / PAGE_SIZE;
567                        assert_eq!(seen[index], false);
568                        seen[index] = true;
569                        uffd.zeropage(addr, PAGE_SIZE, true)?;
570                    }
571                    _ => panic!("unexpected event"),
572                }
573            }
574
575            assert!(seen.iter().all(|b| *b));
576
577            for thread in threads {
578                thread.join().expect("failed to join thread");
579            }
580
581            uffd.unregister(mapping, MEM_SIZE)?;
582
583            assert_eq!(libc::munmap(mapping, MEM_SIZE), 0);
584        }
585
586        Ok(())
587    }
588
589    #[cfg(feature = "linux5_7")]
590    #[test]
591    fn test_write_protect() -> Result<()> {
592        const PAGE_SIZE: usize = 4096;
593
594        unsafe {
595            let uffd = UffdBuilder::new()
596                .require_features(FeatureFlags::PAGEFAULT_FLAG_WP)
597                .close_on_exec(true)
598                .create()?;
599
600            let mapping = libc::mmap(
601                ptr::null_mut(),
602                PAGE_SIZE,
603                libc::PROT_READ | libc::PROT_WRITE,
604                libc::MAP_PRIVATE | libc::MAP_ANON,
605                -1,
606                0,
607            );
608
609            assert!(!mapping.is_null());
610
611            // This test uses both missing and write-protect modes for a reason.
612            // The `uffdio_writeprotect` ioctl can only be used on a range *after*
613            // the missing fault is handled, it seems. This means we either need to
614            // read/write the page *before* we protect it or handle the missing
615            // page fault by changing the protection level *after* we zero the page.
616            assert!(uffd
617                .register_with_mode(
618                    mapping,
619                    PAGE_SIZE,
620                    RegisterMode::MISSING | RegisterMode::WRITE_PROTECT
621                )?
622                .contains(IoctlFlags::WRITE_PROTECT));
623
624            let ptr = mapping as usize;
625            let thread = thread::spawn(move || {
626                let ptr = ptr as *mut u8;
627                *ptr = 1;
628                *ptr = 2;
629            });
630
631            loop {
632                match uffd.read_event()? {
633                    Some(Event::Pagefault {
634                        kind,
635                        rw: ReadWrite::Write,
636                        addr,
637                        ..
638                    }) => match kind {
639                        FaultKind::WriteProtected => {
640                            assert_eq!(addr, mapping);
641                            assert_eq!(*(addr as *const u8), 0);
642                            // Remove the protection and wake the page
643                            uffd.remove_write_protection(mapping, PAGE_SIZE, true)?;
644                            break;
645                        }
646                        FaultKind::Missing => {
647                            assert_eq!(addr, mapping);
648                            uffd.zeropage(mapping, PAGE_SIZE, false)?;
649
650                            // Technically, we already know it was a write that triggered
651                            // the missing page fault, so there's little point in immediately
652                            // write-protecting the page to cause another fault; in the real
653                            // world, a missing fault with `rw` being `ReadWrite::Write` would
654                            // be enough to mark the page as "dirty". For this test, however,
655                            // we do it this way to ensure a write-protected fault is read.
656                            assert_eq!(*(addr as *const u8), 0);
657                            uffd.write_protect(mapping, PAGE_SIZE)?;
658                            uffd.wake(mapping, PAGE_SIZE)?;
659                        }
660                    },
661                    _ => panic!("unexpected event"),
662                }
663            }
664
665            thread.join().expect("failed to join thread");
666
667            assert_eq!(*(mapping as *const u8), 2);
668
669            uffd.unregister(mapping, PAGE_SIZE)?;
670
671            assert_eq!(libc::munmap(mapping, PAGE_SIZE), 0);
672        }
673
674        Ok(())
675    }
676}