pub(crate) struct PiMutex<T> {
inner: std::cell::UnsafeCell<T>,
mutex: std::cell::UnsafeCell<libc::pthread_mutex_t>,
}
unsafe impl<T: Send> Send for PiMutex<T> {}
unsafe impl<T: Send> Sync for PiMutex<T> {}
impl<T> PiMutex<T> {
pub(crate) fn new(value: T) -> Self {
unsafe {
let mut attr: libc::pthread_mutexattr_t = std::mem::zeroed();
let rc = libc::pthread_mutexattr_init(&mut attr);
assert_eq!(rc, 0, "pthread_mutexattr_init failed: {rc}");
let rc = libc::pthread_mutexattr_setprotocol(&mut attr, libc::PTHREAD_PRIO_INHERIT);
if rc == libc::ENOTSUP {
tracing::warn!(
"PTHREAD_PRIO_INHERIT unsupported (errno {}); PiMutex degrading to non-PI protocol",
rc
);
} else {
assert_eq!(
rc, 0,
"pthread_mutexattr_setprotocol(PTHREAD_PRIO_INHERIT) failed: {rc}"
);
}
let mut mutex: libc::pthread_mutex_t = std::mem::zeroed();
let rc = libc::pthread_mutex_init(&mut mutex, &attr);
libc::pthread_mutexattr_destroy(&mut attr);
assert_eq!(rc, 0, "pthread_mutex_init failed: {rc}");
PiMutex {
inner: std::cell::UnsafeCell::new(value),
mutex: std::cell::UnsafeCell::new(mutex),
}
}
}
pub(crate) fn lock(&self) -> PiMutexGuard<'_, T> {
unsafe {
let rc = libc::pthread_mutex_lock(self.mutex.get());
assert_eq!(rc, 0, "pthread_mutex_lock failed: {rc}");
}
PiMutexGuard { mutex: self }
}
}
impl<T> Drop for PiMutex<T> {
fn drop(&mut self) {
unsafe {
libc::pthread_mutex_destroy(self.mutex.get());
}
}
}
pub(crate) struct PiMutexGuard<'a, T> {
mutex: &'a PiMutex<T>,
}
impl<T> std::ops::Deref for PiMutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.mutex.inner.get() }
}
}
impl<T> std::ops::DerefMut for PiMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.mutex.inner.get() }
}
}
impl<T> Drop for PiMutexGuard<'_, T> {
fn drop(&mut self) {
unsafe {
let rc = libc::pthread_mutex_unlock(self.mutex.mutex.get());
if rc != 0 {
eprintln!("pthread_mutex_unlock failed: {rc}");
libc::abort();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn pi_mutex_lock_unlock() {
let m = PiMutex::new(42u32);
{
let mut guard = m.lock();
assert_eq!(*guard, 42);
*guard = 99;
}
assert_eq!(*m.lock(), 99);
}
#[test]
fn pi_mutex_cross_thread() {
let m = Arc::new(PiMutex::new(0u32));
let m2 = m.clone();
let handle = std::thread::spawn(move || {
*m2.lock() += 1;
});
handle.join().unwrap();
assert_eq!(*m.lock(), 1);
}
#[test]
fn pi_mutex_concurrent_increment() {
let m = Arc::new(PiMutex::new(0u64));
let threads: Vec<_> = (0..8)
.map(|_| {
let m = m.clone();
std::thread::spawn(move || {
for _ in 0..1000 {
*m.lock() += 1;
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(*m.lock(), 8000);
}
#[test]
fn pi_mutex_contention_10_threads_increments_correctly() {
let m = Arc::new(PiMutex::new(0u64));
let threads: Vec<_> = (0..10)
.map(|_| {
let m = m.clone();
std::thread::spawn(move || {
for _ in 0..1000 {
*m.lock() += 1;
}
})
})
.collect();
for t in threads {
t.join().expect("worker thread panicked");
}
assert_eq!(*m.lock(), 10_000);
}
#[test]
fn pi_mutex_protocol_is_inherit() {
unsafe {
let mut attr: libc::pthread_mutexattr_t = std::mem::zeroed();
assert_eq!(libc::pthread_mutexattr_init(&mut attr), 0);
assert_eq!(
libc::pthread_mutexattr_setprotocol(&mut attr, libc::PTHREAD_PRIO_INHERIT),
0,
);
let mut protocol: libc::c_int = 0;
assert_eq!(libc::pthread_mutexattr_getprotocol(&attr, &mut protocol), 0);
assert_eq!(protocol, libc::PTHREAD_PRIO_INHERIT);
libc::pthread_mutexattr_destroy(&mut attr);
}
}
}