use std::{
io,
sync::atomic::{AtomicU32, Ordering},
};
use libc::gettid;
use log::trace;
pub struct PerThreadMutex {
futex_word: AtomicU32,
thread_id: AtomicU32,
acquisitions: AtomicU32,
}
impl Default for PerThreadMutex {
fn default() -> Self {
PerThreadMutex {
futex_word: AtomicU32::new(0),
thread_id: AtomicU32::new(0),
acquisitions: AtomicU32::new(0),
}
}
}
impl PerThreadMutex {
pub fn acquire(&self) -> PerThreadMutexGuard<'_> {
loop {
if self
.futex_word
.compare_exchange_weak(0, 1, Ordering::AcqRel, Ordering::Acquire)
== Ok(0)
{
let thread_id = unsafe { libc::gettid() } as u32;
assert_eq!(self.acquisitions.fetch_add(1, Ordering::AcqRel), 0);
assert_eq!(
self.thread_id.compare_exchange(
0,
thread_id,
Ordering::AcqRel,
Ordering::Acquire
),
Ok(0)
);
trace!("[{}] Acquired initial lock", thread_id);
return PerThreadMutexGuard(self, thread_id);
} else {
let thread_id = unsafe { gettid() } as u32;
if self.thread_id.load(Ordering::Acquire) == thread_id {
let count = self.acquisitions.fetch_add(1, Ordering::AcqRel);
if count == u32::MAX {
panic!("Acquisition counter overflowed");
}
trace!("[{}] Acquired lock number {}", thread_id, count + 1);
return PerThreadMutexGuard(self, thread_id);
} else {
trace!("[{}] Thread is waiting", unsafe { libc::gettid() });
match unsafe {
libc::syscall(
libc::SYS_futex,
self.futex_word.as_ptr(),
libc::FUTEX_WAIT,
1,
0,
0,
0,
)
} {
0 => (),
_ => match io::Error::last_os_error().raw_os_error() {
Some(libc::EINTR | libc::EAGAIN) => (),
Some(libc::EACCES) => {
unreachable!("Local variable is always readable")
}
Some(i) => unreachable!(
"Only EAGAIN, EACCES, and EINTR are returned by FUTEX_WAIT; got {}",
i
),
None => unreachable!(),
},
}
}
}
}
}
}
pub struct PerThreadMutexGuard<'a>(&'a PerThreadMutex, u32);
impl Drop for PerThreadMutexGuard<'_> {
fn drop(&mut self) {
let acquisitions = self.0.acquisitions.fetch_sub(1, Ordering::AcqRel);
assert!(acquisitions > 0);
if acquisitions == 1 {
assert_eq!(
self.0
.thread_id
.compare_exchange(self.1, 0, Ordering::AcqRel, Ordering::Acquire),
Ok(self.1)
);
assert_eq!(
self.0
.futex_word
.compare_exchange(1, 0, Ordering::AcqRel, Ordering::Acquire),
Ok(1)
);
trace!("[{}] Unlocking mutex", self.1);
let i = unsafe {
libc::syscall(
libc::SYS_futex,
self.0.futex_word.as_ptr(),
libc::FUTEX_WAKE as i64,
libc::INT_MAX as i64,
0,
0,
0,
)
};
trace!("[{}] Number of waiters woken: {}", self.1, i);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{sync::Arc, thread::spawn};
use env_logger::init;
#[test]
fn test_lock() {
init();
let mutex = Arc::new(PerThreadMutex::default());
let mutex_clone = Arc::clone(&mutex);
let handle1 = spawn(move || {
let _guard1 = mutex_clone.acquire();
let _guard2 = mutex_clone.acquire();
let _guard3 = mutex_clone.acquire();
});
let mutex_clone = Arc::clone(&mutex);
let handle2 = spawn(move || {
let _guard1 = mutex_clone.acquire();
let _guard2 = mutex_clone.acquire();
let _guard3 = mutex_clone.acquire();
let _guard4 = mutex_clone.acquire();
});
let mutex_clone = Arc::clone(&mutex);
let handle3 = spawn(move || {
let _guard1 = mutex_clone.acquire();
let _guard2 = mutex_clone.acquire();
});
let mutex_clone = Arc::clone(&mutex);
let handle4 = spawn(move || {
let _guard1 = mutex_clone.acquire();
let _guard2 = mutex_clone.acquire();
let _guard3 = mutex_clone.acquire();
let _guard4 = mutex_clone.acquire();
let _guard5 = mutex_clone.acquire();
});
let mutex_clone = Arc::clone(&mutex);
let handle5 = spawn(move || {
let _guard1 = mutex_clone.acquire();
});
for handle in [handle1, handle2, handle3, handle4, handle5] {
handle.join().unwrap();
}
}
}