1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use atomic_wait::{wait, wake_one};
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
pub struct Semaphore<T> {
usage: AtomicU32,
capacity: u32,
resource: UnsafeCell<T>,
}
pub struct SemaphoreGuard<'a, T> {
semaphore: &'a Semaphore<T>,
}
impl<T> Semaphore<T> {
pub fn new(resource: T, capacity: u32) -> Self {
Self {
usage: AtomicU32::new(0),
capacity,
resource: UnsafeCell::new(resource),
}
}
pub fn acquire(&self) -> SemaphoreGuard<T> {
let mut count = self.usage.load(Relaxed);
loop {
match count < self.capacity {
true => {
match self
.usage
.compare_exchange(count, count + 1, Acquire, Relaxed)
{
Ok(_) => return SemaphoreGuard { semaphore: self },
Err(e) => count = e,
}
}
false => {
wait(&self.usage, count);
count = self.usage.load(Relaxed);
}
}
}
}
}
unsafe impl<T> Sync for Semaphore<T> where T: Send + Sync {}
impl<T> SemaphoreGuard<'_, T> {
fn usage(&self) -> u32 {
self.semaphore.usage.load(SeqCst)
}
}
impl<T> Deref for SemaphoreGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.semaphore.resource.get() }
}
}
impl<T> Drop for SemaphoreGuard<'_, T> {
fn drop(&mut self) {
if self.semaphore.usage.fetch_sub(1, Release) == self.semaphore.capacity {
wake_one(&self.semaphore.usage);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread::scope;
#[ignore]
#[test]
fn test_semaphore_capacity() {
let capacity = 10;
let thread_count = 25;
let semaphore = Arc::new(Semaphore::new((), capacity));
let barrier = Arc::new(AtomicU32::new(0));
let (sender, receiver) = std::sync::mpsc::channel::<u32>();
scope(|s| {
for _ in 0..thread_count {
let sender = sender.clone();
let barrier = barrier.clone();
let semaphore = semaphore.clone();
s.spawn(move || {
barrier.fetch_add(1, Acquire);
let guard = semaphore.acquire();
while barrier.load(Relaxed) != u32::MAX {}
let val = guard.usage();
sender.send(val).unwrap();
});
}
s.spawn(move || {
while barrier.load(Relaxed) != thread_count {}
barrier.store(u32::MAX, Release);
});
});
let mut results = Vec::new();
for _ in 0..thread_count {
results.push(receiver.recv().unwrap());
}
let max_value = results.iter().max().unwrap();
assert_eq!(max_value, &capacity);
}
}