use std::sync::{Condvar, Mutex};
pub struct Semaphore {
_mutex: Mutex<u32>,
_cv: Condvar,
}
impl Semaphore {
pub const fn new(initial_value: u32) -> Self {
Self {
_mutex: Mutex::new(initial_value),
_cv: Condvar::new(),
}
}
pub fn wait(&self) {
loop {
let mut guard = self._mutex.lock().unwrap();
if *guard > 0 {
*guard -= 1;
{
return;
}
}
while *guard == 0 {
guard = self._cv.wait(guard).unwrap();
}
}
}
pub fn release(&self) {
let mut guard = self._mutex.lock().unwrap();
*guard += 1;
self._cv.notify_all();
}
pub fn get_current_value(&self) -> u32 {
*self._mutex.lock().unwrap()
}
}
#[cfg(test)]
mod test {
use std::{thread, time::Duration};
use rand::Rng;
use super::Semaphore;
#[test]
fn wait_and_release() {
let s = Semaphore::new(1);
s.wait();
s.release();
}
#[test]
fn release_while_wait() {
let sem = Semaphore::new(1);
sem.wait();
thread::scope(|s| {
let waiter = s.spawn(|| {
sem.wait();
});
thread::sleep(Duration::from_millis(100));
assert!(!waiter.is_finished());
sem.release();
thread::sleep(Duration::from_millis(100));
assert!(waiter.is_finished());
});
assert_eq!(0, sem.get_current_value());
sem.release();
assert_eq!(1, sem.get_current_value());
}
fn stress(initial_count: u32) {
let sem = Semaphore::new(initial_count);
thread::scope(|scope| {
for _ in 0..initial_count * 4 {
scope.spawn(|| {
let mut rng = rand::thread_rng();
(0..10000)
.map(|_| (rng.gen::<f64>() * 20.0) as u64 + 1)
.map(|f| Duration::from_micros(f))
.collect::<Vec<_>>()
.into_iter()
.for_each(|d| {
sem.wait();
thread::sleep(d);
sem.release();
});
});
}
});
assert_eq!(initial_count, sem.get_current_value());
}
#[test]
fn stress1() {
stress(1);
}
#[test]
fn stress2() {
stress(2);
}
#[test]
fn stress4() {
stress(4);
}
#[test]
fn stress8() {
stress(8);
}
}