use std::{
convert::TryFrom,
ffi::CString,
io,
mem::ManuallyDrop,
num::NonZeroU8,
os::{
raw::{c_int, c_short, c_ushort},
unix::{ffi::OsStrExt, io::RawFd},
},
path::Path,
};
pub use libc;
pub use nix::sys::stat::Mode;
const GETVAL: c_int = 12;
const GETALL: c_int = 13;
const SETVAL: c_int = 16;
const IPC_UNDO: c_int = 0x1000;
#[derive(Clone, Copy, Debug)]
pub struct Key {
k: libc::key_t,
}
impl Key {
pub fn new(path: &Path, id: NonZeroU8) -> io::Result<Key> {
let path = CString::new(path.as_os_str().as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"null byte in path passed to Key::new",
)
})?;
let k = unsafe { libc::ftok(path.as_ptr(), id.get() as c_int) };
if k == -1 {
Err(io::Error::last_os_error())
} else {
Ok(Key { k })
}
}
pub fn new_fd(fd: RawFd, id: NonZeroU8) -> io::Result<Key> {
Key::new(Path::new(&format!("/proc/self/fd/{}", fd)), id)
}
pub fn private() -> Key {
Key {
k: libc::IPC_PRIVATE,
}
}
}
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct SemOp(libc::sembuf);
impl SemOp {
pub fn wait(self, wait: bool) -> SemOp {
SemOp(libc::sembuf {
sem_num: self.0.sem_num,
sem_op: self.0.sem_op,
sem_flg: if wait {
self.0.sem_flg & !(libc::IPC_NOWAIT as c_short)
} else {
self.0.sem_flg | (libc::IPC_NOWAIT as c_short)
},
})
}
pub fn undo(self, undo: bool) -> SemOp {
SemOp(libc::sembuf {
sem_num: self.0.sem_num,
sem_op: self.0.sem_op,
sem_flg: if undo {
self.0.sem_flg | (IPC_UNDO as c_short)
} else {
self.0.sem_flg & !(IPC_UNDO as c_short)
},
})
}
}
#[derive(Debug)]
pub struct Semaphore {
id: c_int,
nsems: c_int,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Exclusive {
Yes,
No,
}
impl Semaphore {
fn adjust_refcount(&self, by: c_short) -> io::Result<()> {
unsafe {
self.op_unchecked(&[SemOp(libc::sembuf {
sem_num: self.nsems as c_ushort,
sem_op: by,
sem_flg: 0,
})])
}
}
fn new(key: Key, nsems: usize, flags: c_int) -> io::Result<Semaphore> {
let nsems = c_int::try_from(nsems).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"trying to allocate too many semaphores",
)
})?;
let id = unsafe { libc::semget(key.k, nsems + 1, flags) };
if id == -1 {
Err(io::Error::last_os_error())
} else {
let sem = ManuallyDrop::new(Semaphore { id, nsems });
sem.adjust_refcount(1)?;
Ok(ManuallyDrop::into_inner(sem))
}
}
pub fn create(
key: Key,
nsems: usize,
exclusive: Exclusive,
mode: Mode,
) -> io::Result<Semaphore> {
let flags = (mode.bits() & 0b111_111_111) as i32
| libc::IPC_CREAT
| match exclusive {
Exclusive::Yes => libc::IPC_EXCL,
Exclusive::No => 0,
};
Semaphore::new(key, nsems, flags)
}
pub fn open(key: Key, nsems: usize) -> io::Result<Semaphore> {
Semaphore::new(key, nsems, 0)
}
pub fn try_clone(&self) -> io::Result<Semaphore> {
self.adjust_refcount(1)?;
Ok(Semaphore {
id: self.id,
nsems: self.nsems,
})
}
pub fn op(&self, ops: &[SemOp]) -> io::Result<()> {
if ops.iter().any(|o| o.0.sem_num as c_int >= self.nsems) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"tried to update a non-existing semaphore",
));
}
unsafe { self.op_unchecked(ops) }
}
unsafe fn op_unchecked(&self, ops: &[SemOp]) -> io::Result<()> {
let res = libc::semop(
self.id,
ops.as_ptr() as *mut SemOp as *mut libc::sembuf,
ops.len(),
);
if res == -1 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
pub fn get_val(&self, sem: usize) -> io::Result<c_int> {
if sem >= self.nsems as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"tried to get the value of a non-existing semaphore",
));
}
unsafe { self.get_val_unchecked(sem as c_int) }
}
unsafe fn get_val_unchecked(&self, sem: c_int) -> io::Result<c_int> {
let res = libc::semctl(self.id, sem, GETVAL);
if res == -1 {
Err(io::Error::last_os_error())
} else {
Ok(res)
}
}
pub fn get_all(&self) -> io::Result<Vec<c_ushort>> {
let mut vec = vec![0; 1 + self.nsems as usize];
let res = unsafe { libc::semctl(self.id, 0, GETALL, vec.as_mut_slice()) };
if res == -1 {
Err(io::Error::last_os_error())
} else {
vec.pop();
Ok(vec)
}
}
pub fn set_val(&self, sem: usize, val: c_int) -> io::Result<()> {
if sem > self.nsems as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"tried to set the value of a non-existing semaphore",
));
}
unsafe { self.set_val_unchecked(sem as c_int, val) }
}
pub unsafe fn set_val_unchecked(&self, sem: c_int, val: c_int) -> io::Result<()> {
let res = libc::semctl(self.id, sem, SETVAL, val);
if res != 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
pub fn at(&self, idx: c_ushort) -> Sem {
assert!(
(idx as c_int) < self.nsems,
"trying to get a non-existing semaphore"
);
Sem(idx)
}
}
impl Drop for Semaphore {
fn drop(&mut self) {
let _ = self.adjust_refcount(-1);
unsafe {
if let Ok(0) = self.get_val_unchecked(self.nsems) {
let _ = libc::semctl(self.id, 0, libc::IPC_RMID);
}
}
}
}
pub struct Sem(c_ushort);
impl Sem {
pub fn op(&self, v: c_short) -> SemOp {
SemOp(libc::sembuf {
sem_num: self.0,
sem_op: v,
sem_flg: 0,
})
}
pub fn add(&self, v: c_short) -> SemOp {
assert!(v > 0, "trying to add a negative value to a semaphore");
self.op(v)
}
pub fn remove(&self, v: c_short) -> SemOp {
assert!(v > 0, "trying to remove a negative value from a semaphore");
self.op(-v)
}
pub fn wait_zero(&self) -> SemOp {
self.op(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
os::unix::{
io::{AsRawFd, IntoRawFd},
net,
},
process, ptr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use caring::Shared;
use sendfd::{RecvWithFd, SendWithFd};
const RUNS: usize = 1000;
macro_rules! loop_one {
($s:ident, $v:ident) => {
for _ in 0..RUNS {
$s.op(&[$s.at(0).remove(1)]).unwrap();
assert_eq!($v.swap(false, Ordering::Relaxed), true);
$s.op(&[$s.at(1).add(1)]).unwrap();
}
};
}
macro_rules! loop_two {
($s:ident, $v:ident) => {
$s.op(&[$s.at(0).add(1)]).unwrap();
for _ in 0..RUNS {
$s.op(&[$s.at(1).remove(1)]).unwrap();
assert_eq!($v.swap(true, Ordering::Relaxed), false);
$s.op(&[$s.at(0).add(1)]).unwrap();
}
};
}
#[test]
fn across_threads_private() {
let s = Arc::new(
Semaphore::create(
Key::private(),
2,
Exclusive::Yes,
Mode::from_bits(0o600).unwrap(),
)
.unwrap(),
);
let v = Arc::new(AtomicBool::new(true));
{
let s = s.clone();
let v = v.clone();
std::thread::spawn(move || loop_one!(s, v));
}
loop_two!(s, v);
}
#[test]
fn across_threads_named() {
let f = tempfile::NamedTempFile::new().unwrap();
let s = Semaphore::create(
Key::new(f.as_ref(), NonZeroU8::new(b'0').expect("non-zero")).expect("key"),
2,
Exclusive::Yes,
Mode::from_bits(0o600).unwrap(),
)
.unwrap();
let v = Arc::new(AtomicBool::new(true));
{
let v = v.clone();
let s = Semaphore::open(
Key::new(f.as_ref(), NonZeroU8::new(b'0').unwrap()).unwrap(),
2,
)
.unwrap();
std::thread::spawn(move || loop_one!(s, v));
}
loop_two!(s, v);
}
#[test]
fn across_threads_cloned() {
let f = tempfile::NamedTempFile::new().unwrap();
let s = Semaphore::create(
Key::new(f.as_ref(), NonZeroU8::new(b'0').expect("non-zero")).expect("key"),
2,
Exclusive::Yes,
Mode::from_bits(0o600).unwrap(),
)
.unwrap();
s.try_clone().unwrap();
let v = Arc::new(AtomicBool::new(true));
{
let v = v.clone();
let s = s.try_clone().unwrap();
std::thread::spawn(move || loop_one!(s, v));
}
loop_two!(s, v);
}
#[test]
fn across_processes_named() {
let f = tempfile::NamedTempFile::new().unwrap();
let k = Key::new(f.as_ref(), NonZeroU8::new(b'0').unwrap()).unwrap();
let v_fd = Shared::new(AtomicBool::new(true)).unwrap().into_raw_fd();
let v_fd2 = nix::unistd::dup(v_fd).unwrap();
let child = || {
let s =
Semaphore::create(k, 2, Exclusive::No, Mode::from_bits(0o600).unwrap()).unwrap();
let v: Shared<AtomicBool> = unsafe { Shared::from_raw_fd(v_fd) }.unwrap();
loop_one!(s, v);
};
let parent = || {
let s =
Semaphore::create(k, 2, Exclusive::No, Mode::from_bits(0o600).unwrap()).unwrap();
let v: Shared<AtomicBool> = unsafe { Shared::from_raw_fd(v_fd2) }.unwrap();
loop_two!(s, v);
};
unsafe {
let pid = libc::fork();
assert!(pid != -1);
if pid == 0 {
child();
process::exit(0);
} else {
parent();
libc::waitpid(pid, ptr::null_mut(), 0);
}
}
}
#[test]
fn across_processes_sendfd_shmem() {
let (l, r) = net::UnixDatagram::pair().unwrap();
let parent = || {
let s_fd = Shared::new(0u8).unwrap().into_raw_fd();
let s = Semaphore::create(
Key::new_fd(s_fd, NonZeroU8::new(b'0').unwrap()).unwrap(),
2,
Exclusive::Yes,
Mode::from_bits(0o600).unwrap(),
)
.unwrap();
let v = Shared::new(AtomicBool::new(true)).unwrap();
l.send_with_fd(b"", &[s_fd, v.as_raw_fd()]).unwrap();
loop_two!(s, v);
};
let child = || {
let mut recv_bytes = [0; 128];
let mut recv_fds = [0, 0];
r.recv_with_fd(&mut recv_bytes, &mut recv_fds).unwrap();
let [s_fd, v_fd] = recv_fds;
let s = Semaphore::open(Key::new_fd(s_fd, NonZeroU8::new(b'0').unwrap()).unwrap(), 2)
.unwrap();
let v: Shared<AtomicBool> = unsafe { Shared::from_raw_fd(v_fd) }.unwrap();
loop_one!(s, v);
};
unsafe {
let pid = libc::fork();
assert!(pid != -1);
if pid == 0 {
child();
process::exit(0);
} else {
parent();
libc::waitpid(pid, ptr::null_mut(), 0);
}
}
}
#[test]
fn with_different_names() {
let f = tempfile::NamedTempFile::new().expect("creating temp file");
let s = Semaphore::create(
Key::new(f.as_ref(), NonZeroU8::new(b'0').expect("non-zero")).expect("key 1"),
2,
Exclusive::Yes,
Mode::from_bits(0o600).expect("mode from bits"),
)
.expect("creating first semaphore");
let v = Arc::new(AtomicBool::new(true));
{
let v = v.clone();
let s = Semaphore::open(
Key::new_fd(f.as_raw_fd(), NonZeroU8::new(b'0').expect("non-zero 2"))
.expect("key 2"),
2,
)
.expect("creating second semaphore");
std::thread::spawn(move || loop_one!(s, v));
}
loop_two!(s, v);
}
#[test]
fn same_name_different_files() {
let fd = tempfile::tempfile().unwrap().into_raw_fd();
let fd2 = tempfile::tempfile().unwrap().into_raw_fd();
let child = || {
println!("child's fd: {}", fd);
let _s = Semaphore::create(
Key::new_fd(fd, NonZeroU8::new(b'0').unwrap()).unwrap(),
2,
Exclusive::Yes,
Mode::from_bits(0o600).unwrap(),
)
.unwrap();
std::thread::sleep(Duration::from_secs(2));
};
let parent = || {
std::thread::sleep(Duration::from_secs(1));
assert_eq!(nix::unistd::dup2(fd2, fd).unwrap(), fd);
println!("parent's fd: {}", fd);
let _s = Semaphore::create(
Key::new_fd(fd, NonZeroU8::new(b'0').unwrap()).unwrap(),
2,
Exclusive::Yes,
Mode::from_bits(0o600).unwrap(),
)
.unwrap();
};
unsafe {
let pid = libc::fork();
assert!(pid != -1);
if pid == 0 {
child();
process::exit(0);
} else {
parent();
libc::waitpid(pid, ptr::null_mut(), 0);
}
}
}
}