use parking_lot::{Condvar, Mutex};
use std::time::Duration;
#[derive(Debug, Default)]
struct LockState {
readers: i32,
writers: i32,
waiting_writers: i32,
write_preferred: bool,
}
pub struct ReadWriteLock {
state: Mutex<LockState>,
read_cond: Condvar,
write_cond: Condvar,
}
impl ReadWriteLock {
pub fn new() -> Self {
Self {
state: Mutex::new(LockState::default()),
read_cond: Condvar::new(),
write_cond: Condvar::new(),
}
}
pub fn read<F, T>(&self, f: F) -> T
where
F: FnOnce() -> T,
{
let mut state = self.state.lock();
while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
self.read_cond.wait(&mut state);
}
state.readers += 1;
drop(state);
let result = f();
let mut state = self.state.lock();
state.readers -= 1;
if state.readers == 0 {
self.write_cond.notify_all();
self.read_cond.notify_all();
}
result
}
pub fn write<F, T>(&self, f: F) -> T
where
F: FnOnce() -> T,
{
let mut state = self.state.lock();
state.waiting_writers += 1;
state.write_preferred = true;
while state.readers > 0 || state.writers > 0 {
self.write_cond.wait(&mut state);
}
state.waiting_writers -= 1;
state.writers += 1;
drop(state);
let result = f();
let mut state = self.state.lock();
state.writers -= 1;
if state.waiting_writers == 0 {
state.write_preferred = false;
}
self.write_cond.notify_all();
self.read_cond.notify_all();
result
}
pub fn try_read_timeout<F, T>(&self, f: F, timeout: Duration) -> Option<T>
where
F: FnOnce() -> T,
{
let mut state = self.state.lock();
let deadline = std::time::Instant::now() + timeout;
while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
if self.read_cond.wait_until(&mut state, deadline).timed_out() {
return None;
}
}
state.readers += 1;
drop(state);
let result = f();
let mut state = self.state.lock();
state.readers -= 1;
if state.readers == 0 {
self.write_cond.notify_all();
self.read_cond.notify_all();
}
Some(result)
}
pub fn try_write_timeout<F, T>(&self, f: F, timeout: Duration) -> Option<T>
where
F: FnOnce() -> T,
{
let mut state = self.state.lock();
state.waiting_writers += 1;
state.write_preferred = true;
let deadline = std::time::Instant::now() + timeout;
while state.readers > 0 || state.writers > 0 {
if self.write_cond.wait_until(&mut state, deadline).timed_out() {
state.waiting_writers -= 1;
if state.waiting_writers == 0 {
state.write_preferred = false;
}
return None;
}
}
state.waiting_writers -= 1;
state.writers += 1;
drop(state);
let result = f();
let mut state = self.state.lock();
state.writers -= 1;
if state.waiting_writers == 0 {
state.write_preferred = false;
}
self.write_cond.notify_all();
self.read_cond.notify_all();
Some(result)
}
pub fn state(&self) -> LockStateInfo {
let state = self.state.lock();
LockStateInfo {
readers: state.readers,
writers: state.writers,
waiting_writers: state.waiting_writers,
write_preferred: state.write_preferred,
}
}
}
impl Default for ReadWriteLock {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct LockStateInfo {
pub readers: i32,
pub writers: i32,
pub waiting_writers: i32,
pub write_preferred: bool,
}
#[allow(dead_code)]
pub struct ReadGuard<'a> {
lock: &'a ReadWriteLock,
}
#[allow(dead_code)]
impl<'a> ReadGuard<'a> {
pub fn new(lock: &'a ReadWriteLock) -> Self {
let mut state = lock.state.lock();
while state.writers > 0 || (state.write_preferred && state.waiting_writers > 0) {
lock.read_cond.wait(&mut state);
}
state.readers += 1;
Self { lock }
}
}
impl<'a> Drop for ReadGuard<'a> {
fn drop(&mut self) {
let mut state = self.lock.state.lock();
state.readers -= 1;
if state.readers == 0 {
self.lock.write_cond.notify_all();
self.lock.read_cond.notify_all();
}
}
}
#[allow(dead_code)]
pub struct WriteGuard<'a> {
lock: &'a ReadWriteLock,
}
#[allow(dead_code)]
impl<'a> WriteGuard<'a> {
pub fn new(lock: &'a ReadWriteLock) -> Self {
let mut state = lock.state.lock();
state.waiting_writers += 1;
state.write_preferred = true;
while state.readers > 0 || state.writers > 0 {
lock.write_cond.wait(&mut state);
}
state.waiting_writers -= 1;
state.writers += 1;
Self { lock }
}
}
impl<'a> Drop for WriteGuard<'a> {
fn drop(&mut self) {
let mut state = self.lock.state.lock();
state.writers -= 1;
if state.waiting_writers == 0 {
state.write_preferred = false;
}
self.lock.write_cond.notify_all();
self.lock.read_cond.notify_all();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::thread;
#[test]
fn test_read_write_lock_basic() {
let lock = ReadWriteLock::new();
let result = lock.read(|| 42);
assert_eq!(result, 42);
let result = lock.write(|| 100);
assert_eq!(result, 100);
}
#[test]
fn test_concurrent_reads() {
let lock = Arc::new(ReadWriteLock::new());
let counter = Arc::new(AtomicU32::new(0));
let mut handles = vec![];
for _ in 0..10 {
let lock = Arc::clone(&lock);
let counter = Arc::clone(&counter);
handles.push(thread::spawn(move || {
lock.read(|| {
counter.fetch_add(1, Ordering::SeqCst);
thread::sleep(Duration::from_millis(10));
});
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 10);
}
#[test]
fn test_state_info() {
let lock = ReadWriteLock::new();
let state = lock.state();
assert_eq!(state.readers, 0);
assert_eq!(state.writers, 0);
}
}