extern crate interprocess_traits;
extern crate libc;
extern crate memfd;
extern crate thiserror;
#[cfg(test)]
extern crate sendfd;
use std::{
io, mem,
ops::Deref,
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
ptr,
};
use interprocess_traits::ProcSync;
use libc::c_void;
use memfd::MemfdOptions;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Could not create an in-memory file for shared memory")]
CreateMemfd(#[source] memfd::Error),
#[error("Failed to set the length of the shared memory file")]
Truncate(#[source] io::Error),
#[error("Failed to retrieve the length of the shared memory file")]
GetMetadata(#[source] io::Error),
#[error(
"Failed to truncate the in-memory file to {expected} bytes, the file is {actual}B long"
)]
Length { expected: usize, actual: usize },
#[error("Failed to map shared memory")]
Mmap(#[source] io::Error),
#[error("Failed to duplicate the file descriptor")]
Dup(#[source] io::Error),
}
struct MmapRegion<T> {
ptr: *mut T,
size: usize,
}
impl<T> MmapRegion<T> {
unsafe fn new(size: usize, fd: RawFd) -> Result<MmapRegion<T>, Error> {
let ptr = libc::mmap(
ptr::null_mut(),
size,
libc::PROT_READ | libc::PROT_WRITE, libc::MAP_SHARED_VALIDATE | libc::MAP_POPULATE,
fd,
0, );
if ptr == libc::MAP_FAILED {
return Err(Error::Mmap(std::io::Error::last_os_error()));
}
Ok(MmapRegion {
ptr: ptr as *mut T,
size,
})
}
}
impl<T> Drop for MmapRegion<T> {
fn drop(&mut self) {
unsafe {
libc::munmap(self.ptr as *mut c_void, self.size);
}
}
}
pub struct Shared<T> {
fd: RawFd,
region: MmapRegion<T>,
}
unsafe impl<T: Sync> Send for Shared<T> {}
unsafe impl<T: Sync> Sync for Shared<T> {}
unsafe fn create_shared<T>(size: usize) -> Result<Shared<T>, Error> {
let page_size = libc::sysconf(libc::_SC_PAGE_SIZE) as usize;
let requested_align = mem::align_of::<T>();
if requested_align > page_size {
panic!(
"Page size {}B is too low for requested alignment {}",
page_size, requested_align
);
}
let memfd = MemfdOptions::new()
.allow_sealing(true)
.close_on_exec(true)
.create("caring")
.map_err(Error::CreateMemfd)?;
let file = memfd.into_file();
file.set_len(size as u64).map_err(Error::Truncate)?;
let actual_size = file.metadata().map_err(Error::GetMetadata)?.len() as usize;
if actual_size != size {
return Err(Error::Length {
expected: size,
actual: actual_size,
});
}
let seals = libc::F_SEAL_SHRINK | libc::F_SEAL_GROW | libc::F_SEAL_SEAL;
let rc = libc::fcntl(file.as_raw_fd(), libc::F_ADD_SEALS, seals);
assert_eq!(rc, 0, "sealing failed on a memfd");
let fd = file.into_raw_fd();
let region = MmapRegion::new(size, fd)?;
Ok(Shared { fd, region })
}
impl<T> Shared<T> {
pub fn new(val: T) -> Result<Shared<T>, Error> {
unsafe {
let res = create_shared::<T>(mem::size_of::<T>())?;
ptr::write_volatile(res.region.ptr, val);
Ok(res)
}
}
}
impl Shared<c_void> {
pub fn new_sized(size: usize) -> Result<Shared<c_void>, Error> {
unsafe {
create_shared(size)
}
}
}
impl<T> Shared<T> {
unsafe fn from_raw_fd_impl(fd: RawFd) -> Result<Shared<T>, Error> {
let mut statbuf = mem::zeroed::<libc::stat>();
if libc::fstat(fd, &mut statbuf) != 0 {
return Err(Error::GetMetadata(io::Error::last_os_error()));
}
assert_eq!(statbuf.st_mode & libc::S_IFMT, libc::S_IFREG);
let size = statbuf.st_size as usize;
let region = MmapRegion::new(size, fd)?;
Ok(Shared { fd, region })
}
pub fn try_clone(data: &Shared<T>) -> Result<Shared<T>, Error> {
unsafe {
let fd = libc::dup(data.as_raw_fd());
if fd == -1 {
return Err(Error::Dup(std::io::Error::last_os_error()));
}
Self::from_raw_fd_impl(fd)
}
}
pub fn as_mut_ptr(data: &Shared<T>) -> *mut T {
data.region.ptr
}
pub fn size(data: &Shared<T>) -> usize {
data.region.size
}
}
impl<T: ProcSync> Shared<T> {
pub unsafe fn from_raw_fd(fd: RawFd) -> Result<Shared<T>, Error> {
Self::from_raw_fd_impl(fd)
}
}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
unsafe {
libc::close(self.fd);
}
}
}
impl<T> Deref for Shared<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.region.ptr }
}
}
impl<T> AsRawFd for Shared<T> {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl<T> IntoRawFd for Shared<T> {
fn into_raw_fd(mut self) -> RawFd {
let res = self.fd;
unsafe {
ptr::drop_in_place(&mut self.region);
mem::forget(self);
}
res
}
}
#[cfg(test)]
mod tests {
use crate::*;
use std::{
os::unix::net::UnixDatagram,
process,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
thread,
};
use sendfd::{RecvWithFd, SendWithFd};
macro_rules! test_write_and_read {
($zone:expr, $size:expr) => {{
let zone = $zone;
let size = $size;
let ptr_mut = &*zone as *const _ as *mut u8;
for i in 0..size {
ptr::write_volatile(ptr_mut.add(i), i as u8);
}
let ptr = &*zone as *const _ as *const u8;
for i in 0..size {
assert_eq!(ptr::read_volatile(ptr.add(i)), i as u8);
}
}};
}
#[test]
fn new_sized_allocates_properly() {
const SIZE: usize = 10 * 4 * 1024 + 1; let zone = Shared::new_sized(SIZE).unwrap();
assert_eq!(Shared::size(&zone), SIZE);
unsafe { test_write_and_read!(zone, SIZE) };
}
#[test]
fn new_allocates_properly() {
const SIZE: usize = 4 * 1024 - 1; let zone = Shared::new([0u8; SIZE]).unwrap();
assert_eq!(Shared::size(&zone), SIZE);
unsafe { test_write_and_read!(zone, SIZE) };
}
macro_rules! test_sync_across_threads {
($zone_name:ident, $base_name:ident; $build_zone:stmt, $clone_zone:expr; $v:ident, $incr:expr, $decr:expr) => {{
const $base_name: usize = 42;
const INCR: usize = 9876500;
const DECR: usize = 9012300;
$build_zone
let zone1 = $clone_zone;
let zone2 = $clone_zone;
let incr = thread::spawn(move || {
let $v = &*zone1;
for _ in 0..INCR {
$incr;
}
});
let decr = thread::spawn(move || {
let $v = &*zone2;
for _ in 0..DECR {
$decr;
}
});
incr.join().unwrap();
decr.join().unwrap();
assert_eq!($zone_name.load(Ordering::SeqCst), BASE + INCR - DECR);
}};
}
macro_rules! test_sync_across_threads_arc {
($v:ident, $incr:expr, $decr:expr) => {
test_sync_across_threads!(
zone, BASE;
let zone = Arc::new(Shared::new(AtomicUsize::new(BASE)).unwrap()),
zone.clone();
$v, $incr, $decr
)
}
}
#[test]
#[should_panic]
fn syncs_across_threads_test_can_fail() {
test_sync_across_threads_arc!(
v,
v.store(
v.load(Ordering::SeqCst).overflowing_add(1).0,
Ordering::SeqCst
),
v.store(
v.load(Ordering::SeqCst).overflowing_sub(1).0,
Ordering::SeqCst
)
);
}
#[test]
fn syncs_across_threads() {
test_sync_across_threads_arc!(
v,
v.fetch_add(1, Ordering::SeqCst),
v.fetch_sub(1, Ordering::SeqCst)
);
}
macro_rules! test_sync_across_threads_different_shared {
($v:ident, $incr:expr, $decr:expr) => {{
test_sync_across_threads!(
zone, BASE;
let zone = Shared::new(AtomicUsize::new(BASE)).unwrap(),
Shared::try_clone(&zone).unwrap();
$v, $incr, $decr
)
}};
}
#[test]
#[should_panic]
fn syncs_across_threads_different_shared_can_fail() {
test_sync_across_threads_different_shared!(
v,
v.store(
v.load(Ordering::SeqCst).overflowing_add(1).0,
Ordering::SeqCst
),
v.store(
v.load(Ordering::SeqCst).overflowing_sub(1).0,
Ordering::SeqCst
)
);
}
#[test]
fn syncs_across_threads_different_shared() {
test_sync_across_threads_different_shared!(
v,
v.fetch_add(1, Ordering::SeqCst),
v.fetch_sub(1, Ordering::SeqCst)
);
}
macro_rules! test_sync_across_processes_with_fork {
($v:ident, $incr:expr, $decr:expr) => {{
const BASE: usize = 1337;
const INCR: usize = 8901200;
const DECR: usize = 8765400;
let zone = Shared::new((AtomicUsize::new(BASE), AtomicBool::new(false))).unwrap();
let ($v, child_complete) = &*zone;
let child = || {
for _ in 0..INCR {
$incr;
}
child_complete.store(true, Ordering::SeqCst);
};
let parent = || {
for _ in 0..DECR {
$decr;
}
while !child_complete.load(Ordering::SeqCst) {
thread::yield_now();
}
assert_eq!(zone.0.load(Ordering::SeqCst), BASE + INCR - DECR);
};
unsafe {
let pid = libc::fork();
if pid == 0 {
child();
process::exit(0);
} else {
parent();
libc::waitpid(pid, ptr::null_mut(), 0); }
}
}};
}
#[test]
#[should_panic]
fn syncs_across_processes_with_fork_test_can_fail() {
test_sync_across_processes_with_fork!(
v,
v.store(
v.load(Ordering::SeqCst).overflowing_add(1).0,
Ordering::SeqCst
),
v.store(
v.load(Ordering::SeqCst).overflowing_sub(1).0,
Ordering::SeqCst
)
);
}
#[test]
fn syncs_across_processes_with_fork() {
test_sync_across_processes_with_fork!(
v,
v.fetch_add(1, Ordering::SeqCst),
v.fetch_sub(1, Ordering::SeqCst)
);
}
macro_rules! test_sync_across_processes_after_socket_send {
($v:ident, $incr:expr, $decr:expr) => {{
const BASE: usize = 10;
const INCR: usize = 9000000;
const DECR: usize = 8000000;
let (send, receive) = UnixDatagram::pair().unwrap();
let child = || {
let zone = Shared::new((AtomicUsize::new(BASE), AtomicBool::new(false))).unwrap();
send.send_with_fd(&[], &[zone.as_raw_fd()])
.expect("send should succeed");
let ($v, child_complete) = &*zone;
for _ in 0..INCR {
$incr;
}
child_complete.store(true, Ordering::SeqCst);
};
let parent = || {
let mut fd = [0; 1];
receive
.recv_with_fd(&mut [], &mut fd)
.expect("recv should succeed");
let zone: Shared<(AtomicUsize, AtomicBool)> =
unsafe { Shared::from_raw_fd(fd[0]).unwrap() };
let ($v, child_complete) = &*zone;
for _ in 0..DECR {
$decr;
}
while !child_complete.load(Ordering::SeqCst) {
thread::yield_now();
}
assert_eq!($v.load(Ordering::SeqCst), BASE + INCR - DECR);
};
unsafe {
let pid = libc::fork();
if pid == 0 {
child();
process::exit(0);
} else {
parent();
libc::waitpid(pid, ptr::null_mut(), 0); }
}
}};
}
#[test]
#[should_panic]
fn syncs_across_processes_after_socket_send_test_can_fail() {
test_sync_across_processes_after_socket_send!(
v,
v.store(
v.load(Ordering::SeqCst).overflowing_add(1).0,
Ordering::SeqCst
),
v.store(
v.load(Ordering::SeqCst).overflowing_sub(1).0,
Ordering::SeqCst
)
);
}
#[test]
fn syncs_across_processes_after_socket_send() {
test_sync_across_processes_after_socket_send!(
v,
v.fetch_add(1, Ordering::SeqCst),
v.fetch_sub(1, Ordering::SeqCst)
);
}
}