commonware_utils/
concurrency.rs1use crate::sync::Mutex;
4use core::{
5 hash::Hash,
6 num::NonZeroU32,
7 sync::atomic::{AtomicU32, Ordering},
8};
9use std::{collections::HashSet, sync::Arc};
10
11pub struct Limiter {
13 max: u32,
14 current: Arc<AtomicU32>,
15}
16
17impl Limiter {
18 pub fn new(max: NonZeroU32) -> Self {
20 Self {
21 max: max.get(),
22 current: Arc::new(AtomicU32::new(0)),
23 }
24 }
25
26 pub fn try_acquire(&self) -> Option<Reservation> {
28 self.current
29 .fetch_update(Ordering::AcqRel, Ordering::Relaxed, |current| {
30 (current < self.max).then_some(current + 1)
31 })
32 .map(|_| Reservation {
33 current: self.current.clone(),
34 })
35 .ok()
36 }
37}
38
39pub struct Reservation {
41 current: Arc<AtomicU32>,
42}
43
44impl Drop for Reservation {
45 fn drop(&mut self) {
46 self.current.fetch_sub(1, Ordering::AcqRel);
47 }
48}
49
50pub struct KeyedLimiter<K: Eq + Hash + Clone> {
52 max: u32,
53 current: Arc<Mutex<HashSet<K>>>,
54}
55
56impl<K: Eq + Hash + Clone> KeyedLimiter<K> {
57 pub fn new(max: NonZeroU32) -> Self {
59 Self {
60 max: max.get(),
61 current: Arc::new(Mutex::new(HashSet::new())),
62 }
63 }
64
65 pub fn try_acquire(&self, key: K) -> Option<KeyedReservation<K>> {
68 let mut current = self.current.lock();
69 if current.len() >= self.max as usize {
70 return None;
71 }
72 if !current.insert(key.clone()) {
73 return None;
74 }
75 drop(current);
76
77 Some(KeyedReservation {
78 key,
79 current: self.current.clone(),
80 })
81 }
82}
83
84pub struct KeyedReservation<K: Eq + Hash + Clone> {
86 key: K,
87 current: Arc<Mutex<HashSet<K>>>,
88}
89
90impl<K: Eq + Hash + Clone> Drop for KeyedReservation<K> {
91 fn drop(&mut self) {
92 self.current.lock().remove(&self.key);
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::NZU32;
100
101 #[test]
102 fn allows_reservations_up_to_max() {
103 let limiter = Limiter::new(NZU32!(2));
104
105 let first = limiter
106 .try_acquire()
107 .expect("first reservation should succeed");
108 let second = limiter
109 .try_acquire()
110 .expect("second reservation should succeed");
111
112 assert!(limiter.try_acquire().is_none());
113
114 drop(second);
115 let third = limiter
116 .try_acquire()
117 .expect("reservation after drop should succeed");
118
119 drop(third);
120 drop(first);
121 }
122
123 #[test]
124 fn allows_reservations_up_to_max_for_key() {
125 let limiter = KeyedLimiter::new(NZU32!(2));
126
127 let first = limiter
128 .try_acquire(0)
129 .expect("first reservation should succeed");
130 let second = limiter
131 .try_acquire(1)
132 .expect("second reservation should succeed");
133 assert!(limiter.try_acquire(2).is_none());
134
135 drop(second);
136 let third = limiter
137 .try_acquire(2)
138 .expect("third reservation should succeed");
139
140 drop(third);
141 drop(first);
142 }
143
144 #[test]
145 fn blocks_conflicting_reservations_for_key() {
146 let limiter = KeyedLimiter::new(NZU32!(2));
147
148 let first = limiter
149 .try_acquire(0)
150 .expect("first reservation should succeed");
151 assert!(limiter.try_acquire(0).is_none());
152
153 drop(first);
154 let second = limiter
155 .try_acquire(0)
156 .expect("second reservation should succeed");
157
158 drop(second);
159 }
160}