#![deny(unsafe_op_in_unsafe_fn)]
use std::{
cell::Cell,
mem,
num::NonZeroUsize,
os::unix::io::{AsFd, BorrowedFd, OwnedFd},
ptr,
sync::{
mpsc::{channel, Sender},
LazyLock, OnceLock, RwLock,
},
thread,
};
use rustix::mm;
use tracing::{debug, instrument, trace};
static DROP_THIS: LazyLock<Sender<InnerPool>> = LazyLock::new(|| {
let (tx, rx) = channel();
thread::Builder::new()
.name("Shm dropping thread".to_owned())
.spawn(move || {
while let Ok(x) = rx.recv() {
profiling::scope!("dropping Pool");
drop(x);
}
})
.unwrap();
tx
});
thread_local!(static SIGBUS_GUARD: Cell<(*const MemMap, bool)> = const { Cell::new((ptr::null_mut(), false)) });
static OLD_SIGBUS_HANDLER: OnceLock<libc::sigaction> = OnceLock::new();
#[derive(Debug)]
pub struct Pool {
inner: Option<InnerPool>,
}
#[derive(Debug)]
struct InnerPool {
map: RwLock<MemMap>,
fd: OwnedFd,
}
unsafe impl Send for InnerPool {}
unsafe impl Sync for InnerPool {}
pub enum ResizeError {
InvalidSize,
MremapFailed,
}
impl InnerPool {
#[instrument(level = "trace", skip_all, name = "wayland_shm")]
pub fn new(fd: OwnedFd, size: NonZeroUsize) -> Result<InnerPool, OwnedFd> {
let memmap = match MemMap::new(fd.as_fd(), size) {
Ok(memmap) => memmap,
Err(_) => {
return Err(fd);
}
};
trace!(fd = ?fd, size = ?size, "Creating new shm pool");
Ok(InnerPool {
map: RwLock::new(memmap),
fd,
})
}
pub fn resize(&self, newsize: NonZeroUsize) -> Result<(), ResizeError> {
let mut guard = self.map.write().unwrap();
let oldsize = guard.size();
if oldsize > usize::from(newsize) {
return Err(ResizeError::InvalidSize);
}
trace!(fd = ?self.fd, oldsize = oldsize, newsize = ?newsize, "Resizing shm pool");
guard.remap(self.fd.as_fd(), newsize).map_err(|()| {
debug!(fd = ?self.fd, oldsize = oldsize, newsize = ?newsize, "SHM pool resize failed");
ResizeError::MremapFailed
})
}
pub fn size(&self) -> usize {
self.map.read().unwrap().size
}
#[instrument(level = "trace", skip_all, name = "wayland_shm")]
pub fn with_data<T, F: FnOnce(*const u8, usize) -> T>(&self, f: F) -> Result<T, ()> {
unsafe { place_sigbus_handler() };
let pool_guard = self.map.read().unwrap();
trace!(fd = ?self.fd, "Buffer access on shm pool");
SIGBUS_GUARD.with(|guard| {
let (p, _) = guard.get();
if !p.is_null() {
panic!("Recursive access to a SHM pool content is not supported.");
}
guard.set((&*pool_guard as *const MemMap, false))
});
let t = f(pool_guard.ptr as *const _, pool_guard.size);
SIGBUS_GUARD.with(|guard| {
let (_, triggered) = guard.get();
guard.set((ptr::null_mut(), false));
if triggered {
debug!(fd = ?self.fd, "SIGBUS caught on access on shm pool");
Err(())
} else {
Ok(t)
}
})
}
#[instrument(level = "trace", skip_all, name = "wayland_shm")]
pub fn with_data_mut<T, F: FnOnce(*mut u8, usize) -> T>(&self, f: F) -> Result<T, ()> {
unsafe { place_sigbus_handler() };
#[allow(clippy::readonly_write_lock)]
let pool_guard = self.map.write().unwrap();
trace!(fd = ?self.fd, "Mutable buffer access on shm pool");
SIGBUS_GUARD.with(|guard| {
let (p, _) = guard.get();
if !p.is_null() {
panic!("Recursive access to a SHM pool content is not supported.");
}
guard.set((&*pool_guard as *const MemMap, false))
});
let t = f(pool_guard.ptr, pool_guard.size);
SIGBUS_GUARD.with(|guard| {
let (_, triggered) = guard.get();
guard.set((ptr::null_mut(), false));
if triggered {
debug!(fd = ?self.fd, "SIGBUS caught on access on shm pool");
Err(())
} else {
Ok(t)
}
})
}
}
impl Pool {
pub fn new(fd: OwnedFd, size: NonZeroUsize) -> Result<Self, OwnedFd> {
InnerPool::new(fd, size).map(|p| Self { inner: Some(p) })
}
pub fn resize(&self, newsize: NonZeroUsize) -> Result<(), ResizeError> {
self.inner.as_ref().unwrap().resize(newsize)
}
pub fn size(&self) -> usize {
self.inner.as_ref().unwrap().size()
}
pub fn with_data<T, F: FnOnce(*const u8, usize) -> T>(&self, f: F) -> Result<T, ()> {
self.inner.as_ref().unwrap().with_data(f)
}
pub fn with_data_mut<T, F: FnOnce(*mut u8, usize) -> T>(&self, f: F) -> Result<T, ()> {
self.inner.as_ref().unwrap().with_data_mut(f)
}
}
impl Drop for Pool {
fn drop(&mut self) {
let _ = DROP_THIS.send(self.inner.take().unwrap());
}
}
#[derive(Debug)]
struct MemMap {
ptr: *mut u8,
size: usize,
}
impl MemMap {
fn new(fd: BorrowedFd<'_>, size: NonZeroUsize) -> Result<MemMap, ()> {
Ok(MemMap {
ptr: unsafe { map(fd, size) }?,
size: size.into(),
})
}
fn remap(&mut self, fd: BorrowedFd<'_>, newsize: NonZeroUsize) -> Result<(), ()> {
if self.ptr.is_null() {
return Err(());
}
let _ = unsafe { unmap(self.ptr, self.size) };
match unsafe { map(fd, newsize) } {
Ok(ptr) => {
self.ptr = ptr;
self.size = usize::from(newsize);
Ok(())
}
Err(()) => {
self.ptr = ptr::null_mut();
self.size = 0;
Err(())
}
}
}
fn size(&self) -> usize {
self.size
}
fn contains(&self, ptr: *mut u8) -> bool {
ptr >= self.ptr && ptr < unsafe { self.ptr.add(self.size) }
}
fn nullify(&self) -> Result<(), ()> {
unsafe { nullify_map(self.ptr, self.size) }
}
}
impl Drop for MemMap {
fn drop(&mut self) {
if !self.ptr.is_null() {
let _ = unsafe { unmap(self.ptr, self.size) };
}
}
}
unsafe fn map(fd: BorrowedFd<'_>, size: NonZeroUsize) -> Result<*mut u8, ()> {
let ret = unsafe {
mm::mmap(
ptr::null_mut(),
size.into(),
mm::ProtFlags::READ | mm::ProtFlags::WRITE,
mm::MapFlags::SHARED,
fd,
0,
)
};
ret.map(|p| p as *mut u8).map_err(|_| ())
}
#[profiling::function]
unsafe fn unmap(ptr: *mut u8, size: usize) -> Result<(), ()> {
let ret = unsafe { mm::munmap(ptr as *mut _, size) };
ret.map_err(|_| ())
}
unsafe fn nullify_map(ptr: *mut u8, size: usize) -> Result<(), ()> {
let ret = unsafe {
mm::mmap_anonymous(
ptr as *mut std::ffi::c_void,
size,
mm::ProtFlags::READ | mm::ProtFlags::WRITE,
mm::MapFlags::PRIVATE | mm::MapFlags::FIXED,
)
};
ret.map(|_| ()).map_err(|_| ())
}
unsafe fn place_sigbus_handler() {
let _ = OLD_SIGBUS_HANDLER.get_or_init(|| {
unsafe {
let mut action: libc::sigaction = mem::zeroed();
action.sa_sigaction = sigbus_handler as _;
action.sa_flags = libc::SA_SIGINFO | libc::SA_NODEFER;
let mut old_action = mem::zeroed();
if libc::sigaction(libc::SIGBUS, &action, &mut old_action) == -1 {
let e = rustix::io::Errno::from_raw_os_error(errno::errno().0);
panic!("sigaction failed for SIGBUS handler: {:?}", e);
}
old_action
}
});
}
unsafe fn reraise_sigbus() {
unsafe {
libc::sigaction(libc::SIGBUS, OLD_SIGBUS_HANDLER.get().unwrap(), ptr::null_mut());
libc::raise(libc::SIGBUS);
}
}
extern "C" fn sigbus_handler(_signum: libc::c_int, info: *mut libc::siginfo_t, _context: *mut libc::c_void) {
let faulty_ptr = unsafe { siginfo_si_addr(info) } as *mut u8;
SIGBUS_GUARD.with(|guard| {
let (memmap, _) = guard.get();
match unsafe { memmap.as_ref() }.map(|m| (m, m.contains(faulty_ptr))) {
Some((m, true)) => {
guard.set((memmap, true));
if m.nullify().is_err() {
unsafe { reraise_sigbus() }
}
}
_ => {
unsafe { reraise_sigbus() }
}
}
});
}
#[cfg(any(target_os = "linux", target_os = "android"))]
unsafe fn siginfo_si_addr(info: *mut libc::siginfo_t) -> *mut libc::c_void {
#[repr(C)]
#[allow(non_camel_case_types)]
struct siginfo_t {
a: [libc::c_int; 3], si_addr: *mut libc::c_void,
}
unsafe { (*(info as *const siginfo_t)).si_addr }
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
unsafe fn siginfo_si_addr(info: *mut libc::siginfo_t) -> *mut libc::c_void {
unsafe { (*info).si_addr as _ }
}