use std::sync::{Arc, Condvar, Mutex as StdMutex, MutexGuard as StdMutexGuard, RwLock as StdRwLock};
use std::sync::atomic::{AtomicI64, Ordering};
#[derive(Clone)]
pub struct Mutex<T: ?Sized> {
inner: Arc<StdMutex<T>>,
}
impl<T> Mutex<T> {
pub fn new(v: T) -> Self
where
T: Sized,
{
Mutex { inner: Arc::new(StdMutex::new(v)) }
}
pub fn Lock(&self) -> StdMutexGuard<'_, T> {
self.inner.lock().unwrap()
}
pub fn TryLock(&self) -> Option<StdMutexGuard<'_, T>> {
self.inner.try_lock().ok()
}
}
impl Mutex<()> {
pub fn empty() -> Self {
Mutex::new(())
}
}
#[derive(Clone)]
pub struct RWMutex<T: ?Sized> {
inner: Arc<StdRwLock<T>>,
}
impl<T> RWMutex<T> {
pub fn new(v: T) -> Self
where
T: Sized,
{
RWMutex { inner: Arc::new(StdRwLock::new(v)) }
}
pub fn Lock(&self) -> std::sync::RwLockWriteGuard<'_, T> {
self.inner.write().unwrap()
}
pub fn RLock(&self) -> std::sync::RwLockReadGuard<'_, T> {
self.inner.read().unwrap()
}
}
#[derive(Clone)]
pub struct WaitGroup {
inner: Arc<WaitGroupInner>,
}
struct WaitGroupInner {
count: AtomicI64,
mu: StdMutex<()>,
cv: Condvar,
}
impl WaitGroup {
pub fn new() -> Self {
WaitGroup {
inner: Arc::new(WaitGroupInner {
count: AtomicI64::new(0),
mu: StdMutex::new(()),
cv: Condvar::new(),
}),
}
}
pub fn Add(&self, delta: crate::types::int64) {
self.inner.count.fetch_add(delta, Ordering::SeqCst);
}
pub fn Done(&self) {
let prev = self.inner.count.fetch_sub(1, Ordering::SeqCst);
if prev == 1 {
let _g = self.inner.mu.lock().unwrap();
self.inner.cv.notify_all();
}
}
pub fn Wait(&self) {
let mut g = self.inner.mu.lock().unwrap();
while self.inner.count.load(Ordering::SeqCst) > 0 {
g = self.inner.cv.wait(g).unwrap();
}
}
}
impl Default for WaitGroup {
fn default() -> Self { Self::new() }
}
#[derive(Clone)]
pub struct Once {
inner: Arc<std::sync::Once>,
}
impl Once {
pub fn new() -> Self {
Once { inner: Arc::new(std::sync::Once::new()) }
}
pub fn Do<F: FnOnce()>(&self, f: F) {
self.inner.call_once(f);
}
}
impl Default for Once {
fn default() -> Self { Self::new() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mutex_guards_data() {
let mu = Mutex::new(0i64);
{
let mut g = mu.Lock();
*g += 42;
}
assert_eq!(*mu.Lock(), 42);
}
#[test]
fn mutex_try_lock() {
let mu = Mutex::new(1i64);
let _g = mu.Lock();
assert!(mu.TryLock().is_none());
}
#[test]
fn rwmutex_many_readers() {
let rw = RWMutex::new(5i64);
let r1 = rw.RLock();
let r2 = rw.RLock();
assert_eq!(*r1 + *r2, 10);
}
#[test]
fn waitgroup_blocks_until_done() {
let wg = WaitGroup::new();
wg.Add(3);
for _ in 0..3 {
let w = wg.clone();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
w.Done();
});
}
wg.Wait();
}
#[test]
fn once_runs_body_once() {
let once = Once::new();
let counter = Arc::new(AtomicI64::new(0));
for _ in 0..5 {
let o = once.clone();
let c = counter.clone();
o.Do(move || {
c.fetch_add(1, Ordering::SeqCst);
});
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}