1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use atomic_wait::{wait, wake_one};
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};

/// a primitive to limit access
pub struct Semaphore<T> {
    usage: AtomicU32,
    capacity: u32,
    resource: UnsafeCell<T>,
}

pub struct SemaphoreGuard<'a, T> {
    semaphore: &'a Semaphore<T>,
}

impl<T> Semaphore<T> {
    pub fn new(resource: T, capacity: u32) -> Self {
        Self {
            usage: AtomicU32::new(0),
            capacity,
            resource: UnsafeCell::new(resource),
        }
    }

    pub fn acquire(&self) -> SemaphoreGuard<T> {
        let mut count = self.usage.load(Relaxed);

        loop {
            match count < self.capacity {
                true => {
                    match self
                        .usage
                        .compare_exchange(count, count + 1, Acquire, Relaxed)
                    {
                        Ok(_) => return SemaphoreGuard { semaphore: self },
                        Err(e) => count = e,
                    }
                }
                false => {
                    wait(&self.usage, count);
                    count = self.usage.load(Relaxed);
                }
            }
        }
    }
}

unsafe impl<T> Sync for Semaphore<T> where T: Send + Sync {}

impl<T> SemaphoreGuard<'_, T> {
    fn usage(&self) -> u32 {
        self.semaphore.usage.load(SeqCst)
    }
}

impl<T> Deref for SemaphoreGuard<'_, T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        unsafe { &*self.semaphore.resource.get() }
    }
}

impl<T> Drop for SemaphoreGuard<'_, T> {
    fn drop(&mut self) {
        if self.semaphore.usage.fetch_sub(1, Release) == self.semaphore.capacity {
            wake_one(&self.semaphore.usage);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use std::thread::scope;

    // this test passes locally but it's causing github actions to deadlock :(
    // hence we will ignore it when running `cargo test`
    #[ignore]
    #[test]
    fn test_semaphore_capacity() {
        let capacity = 10;
        let thread_count = 25;
        let semaphore = Arc::new(Semaphore::new((), capacity));
        let barrier = Arc::new(AtomicU32::new(0));
        let (sender, receiver) = std::sync::mpsc::channel::<u32>();

        scope(|s| {
            for _ in 0..thread_count {
                let sender = sender.clone();
                let barrier = barrier.clone();
                let semaphore = semaphore.clone();

                s.spawn(move || {
                    barrier.fetch_add(1, Acquire);
                    let guard = semaphore.acquire();
                    // Wait until the barrier is removed
                    while barrier.load(Relaxed) != u32::MAX {}
                    // Check how many threads are using the barrier
                    let val = guard.usage();
                    sender.send(val).unwrap();
                });
            }

            s.spawn(move || {
                // Wait for threads to start
                while barrier.load(Relaxed) != thread_count {}
                // Release all threads
                barrier.store(u32::MAX, Release);
            });
        });

        let mut results = Vec::new();
        for _ in 0..thread_count {
            results.push(receiver.recv().unwrap());
        }

        // At least one of the threads should see a usage value equal to
        // the capacity of the semaphore
        let max_value = results.iter().max().unwrap();
        assert_eq!(max_value, &capacity);
    }
}