1use std::sync::{Arc, Condvar, Mutex};
4
5#[derive(Clone)]
10pub struct Semaphore {
11 state: Arc<(Mutex<usize>, Condvar)>,
12}
13
14pub struct SemaphoreGuard {
19 state: Arc<(Mutex<usize>, Condvar)>,
20}
21
22impl Semaphore {
23 pub fn new(permits: usize) -> Self {
25 Self {
26 state: Arc::new((Mutex::new(permits), Condvar::new())),
27 }
28 }
29
30 pub fn acquire(&self) -> SemaphoreGuard {
34 let (lock, cvar) = &*self.state;
35 let mut available = lock.lock().unwrap();
36
37 while *available == 0 {
39 available = cvar.wait(available).unwrap();
40 }
41
42 *available -= 1;
44
45 SemaphoreGuard {
46 state: Arc::clone(&self.state),
47 }
48 }
49}
50
51impl Drop for SemaphoreGuard {
52 fn drop(&mut self) {
53 let (lock, cvar) = &*self.state;
54 let mut available = lock.lock().unwrap();
55 *available += 1;
56 cvar.notify_one();
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use std::sync::atomic::{AtomicUsize, Ordering};
64 use std::thread;
65 use std::time::Duration;
66
67 #[test]
68 fn test_semaphore_limits_concurrency() {
69 let sem = Semaphore::new(2);
70 let counter = Arc::new(AtomicUsize::new(0));
71 let max_concurrent = Arc::new(AtomicUsize::new(0));
72
73 let mut handles = vec![];
74
75 for _ in 0..10 {
76 let sem = sem.clone();
77 let counter = Arc::clone(&counter);
78 let max_concurrent = Arc::clone(&max_concurrent);
79
80 let handle = thread::spawn(move || {
81 let _guard = sem.acquire();
82
83 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
85
86 max_concurrent.fetch_max(current, Ordering::SeqCst);
88
89 thread::sleep(Duration::from_millis(10));
91
92 counter.fetch_sub(1, Ordering::SeqCst);
94 });
95
96 handles.push(handle);
97 }
98
99 for handle in handles {
100 handle.join().unwrap();
101 }
102
103 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
105 }
106}