use crate::{Error, RefCounter, Result};
use std::ffi::CString;
use std::sync::atomic::{AtomicBool, Ordering};
pub struct LinuxRefCounter {
sem_id: i32,
has_reference: AtomicBool,
has_lock: AtomicBool,
#[allow(dead_code)]
name: String,
}
impl LinuxRefCounter {
pub fn new(name: &str) -> Result<Self> {
let key = Self::name_to_key(name)?;
let sem_id = unsafe {
libc::semget(
key,
2,
libc::IPC_CREAT | libc::IPC_EXCL | 0o666,
)
};
let sem_id = if sem_id < 0 {
let errno = std::io::Error::last_os_error();
if errno.raw_os_error() == Some(libc::EEXIST) {
let id = unsafe { libc::semget(key, 2, 0o666) };
if id < 0 {
return Err(Error::RefCounterInit(format!(
"Failed to open semaphore: {}",
std::io::Error::last_os_error()
)));
}
id
} else {
return Err(Error::RefCounterInit(format!(
"Failed to create semaphore: {}",
errno
)));
}
} else {
let arg = libc::semun { val: 1 };
unsafe {
if libc::semctl(sem_id, 1, libc::SETVAL, arg) < 0 {
return Err(Error::RefCounterInit(format!(
"Failed to initialize lock semaphore: {}",
std::io::Error::last_os_error()
)));
}
}
sem_id
};
Ok(Self {
sem_id,
has_reference: AtomicBool::new(false),
has_lock: AtomicBool::new(false),
name: name.to_string(),
})
}
fn name_to_key(name: &str) -> Result<libc::key_t> {
let path = CString::new("/tmp").unwrap();
let proj_id = name.bytes().fold(1u8, |acc, b| acc.wrapping_add(b)) as i32;
let key = unsafe { libc::ftok(path.as_ptr(), proj_id) };
if key < 0 {
return Err(Error::RefCounterInit(format!(
"ftok failed: {}",
std::io::Error::last_os_error()
)));
}
Ok(key)
}
fn sem_op(&self, sem_num: u16, op: i16, flags: i16) -> Result<()> {
let mut sop = libc::sembuf {
sem_num,
sem_op: op,
sem_flg: flags,
};
let result = unsafe { libc::semop(self.sem_id, &mut sop, 1) };
if result < 0 {
return Err(Error::Platform(format!(
"semop failed: {}",
std::io::Error::last_os_error()
)));
}
Ok(())
}
fn sem_val(&self, sem_num: i32) -> Result<i32> {
let val = unsafe { libc::semctl(self.sem_id, sem_num, libc::GETVAL) };
if val < 0 {
return Err(Error::Count(format!(
"semctl GETVAL failed: {}",
std::io::Error::last_os_error()
)));
}
Ok(val)
}
}
impl RefCounter for LinuxRefCounter {
fn acquire(&self) -> Result<u32> {
if self.has_reference.load(Ordering::SeqCst) {
return self.count();
}
self.sem_op(0, 1, libc::SEM_UNDO as i16)?;
self.has_reference.store(true, Ordering::SeqCst);
self.count()
}
fn release(&self) -> Result<u32> {
if !self.has_reference.load(Ordering::SeqCst) {
return self.count();
}
self.sem_op(0, -1, libc::SEM_UNDO as i16)?;
self.has_reference.store(false, Ordering::SeqCst);
self.count()
}
fn count(&self) -> Result<u32> {
let val = self.sem_val(0)?;
Ok(val as u32)
}
fn try_lock(&self) -> Result<bool> {
if self.has_lock.load(Ordering::SeqCst) {
return Ok(true);
}
let result = self.sem_op(1, -1, (libc::IPC_NOWAIT | libc::SEM_UNDO) as i16);
match result {
Ok(()) => {
self.has_lock.store(true, Ordering::SeqCst);
Ok(true)
}
Err(_) => Ok(false), }
}
fn unlock(&self) -> Result<()> {
if !self.has_lock.load(Ordering::SeqCst) {
return Ok(());
}
self.sem_op(1, 1, libc::SEM_UNDO as i16)?;
self.has_lock.store(false, Ordering::SeqCst);
Ok(())
}
}
impl Drop for LinuxRefCounter {
fn drop(&mut self) {
if self.has_lock.load(Ordering::SeqCst) {
let _ = self.unlock();
}
}
}
#[repr(C)]
#[allow(dead_code)]
union libc_semun {
val: i32,
buf: *mut libc::semid_ds,
array: *mut u16,
}
mod libc {
pub use ::libc::*;
#[repr(C)]
pub union semun {
pub val: i32,
pub buf: *mut semid_ds,
pub array: *mut u16,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linux_ref_counter_basic() {
let counter = LinuxRefCounter::new("test-procref-basic").unwrap();
assert_eq!(counter.count().unwrap(), 0);
let count = counter.acquire().unwrap();
assert_eq!(count, 1);
let count = counter.acquire().unwrap();
assert_eq!(count, 1);
let count = counter.release().unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_linux_lock() {
let counter = LinuxRefCounter::new("test-procref-lock").unwrap();
assert!(counter.try_lock().unwrap());
assert!(counter.try_lock().unwrap());
counter.unlock().unwrap();
}
}