1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use std::sync::Arc;

use crate::locking::{Condvar, Mutex};

/// A lightweight real-time safe semaphore
pub struct Semaphore {
    inner: Arc<SemaphoreInner>,
}

impl Semaphore {
    /// Creates a new semaphore with the given capacity
    pub fn new(capacity: usize) -> Self {
        Self {
            inner: SemaphoreInner {
                permissions: <_>::default(),
                capacity,
                cv: Condvar::new(),
            }
            .into(),
        }
    }
    /// Tries to acquire permission, returns None if failed
    pub fn try_acquire(&self) -> Option<SemaphoreGuard> {
        let mut count = self.inner.permissions.lock();
        if *count == self.inner.capacity {
            return None;
        }
        *count += 1;
        Some(SemaphoreGuard {
            inner: self.inner.clone(),
        })
    }
    /// Acquires permission, blocks until it is available
    pub fn acquire(&self) -> SemaphoreGuard {
        let mut count = self.inner.permissions.lock();
        while *count == self.inner.capacity {
            self.inner.cv.wait(&mut count);
        }
        *count += 1;
        SemaphoreGuard {
            inner: self.inner.clone(),
        }
    }
    /// Returns the capacity of the semaphore
    pub fn capacity(&self) -> usize {
        self.inner.capacity
    }
    /// Returns the number of available permissions
    pub fn available(&self) -> usize {
        self.inner.capacity - *self.inner.permissions.lock()
    }
    /// Returns the number of used permissions
    pub fn used(&self) -> usize {
        *self.inner.permissions.lock()
    }
    /// For tests only
    #[allow(dead_code)]
    fn is_poisoned(&self) -> bool {
        *self.inner.permissions.lock() > self.inner.capacity
    }
}

struct SemaphoreInner {
    permissions: Mutex<usize>,
    capacity: usize,
    cv: Condvar,
}

impl SemaphoreInner {
    fn release(&self) {
        let mut count = self.permissions.lock();
        *count -= 1;
        self.cv.notify_one();
    }
}

#[allow(clippy::module_name_repetitions)]
/// A guard that releases the permission when dropped
pub struct SemaphoreGuard {
    inner: Arc<SemaphoreInner>,
}

impl Drop for SemaphoreGuard {
    fn drop(&mut self) {
        self.inner.release();
    }
}

#[cfg(test)]
mod test {
    use std::time::Instant;

    use super::*;

    #[test]
    fn test_semaphore() {
        let sem = Semaphore::new(2);
        assert_eq!(sem.capacity(), 2);
        assert_eq!(sem.available(), 2);
        assert_eq!(sem.used(), 0);
        let _g1 = sem.acquire();
        assert_eq!(sem.available(), 1);
        assert_eq!(sem.used(), 1);
        let _g2 = sem.acquire();
        assert_eq!(sem.available(), 0);
        assert_eq!(sem.used(), 2);
        let g3 = sem.try_acquire();
        assert!(g3.is_none());
        drop(_g1);
        assert_eq!(sem.available(), 1);
        assert_eq!(sem.used(), 1);
        let _g4 = sem.acquire();
        assert_eq!(sem.available(), 0);
        assert_eq!(sem.used(), 2);
    }
    #[test]
    fn test_semaphore_multithread() {
        let start = Instant::now();
        let sem = Semaphore::new(10);
        let mut tasks = Vec::new();
        for _ in 0..100 {
            let perm = sem.acquire();
            tasks.push(std::thread::spawn(move || {
                let _perm = perm;
                std::thread::sleep(std::time::Duration::from_millis(1));
            }));
        }
        'outer: loop {
            for task in &tasks {
                std::hint::spin_loop();
                assert!(!sem.is_poisoned(), "Semaphore is poisoned");
                if !task.is_finished() {
                    continue 'outer;
                }
            }
            break 'outer;
        }
        assert!(start.elapsed().as_millis() > 10);
    }
}