commonware_utils/
concurrency.rs1use core::{
4 hash::Hash,
5 num::NonZeroU32,
6 sync::atomic::{AtomicU32, Ordering},
7};
8use std::{
9 collections::HashSet,
10 sync::{Arc, Mutex},
11};
12
13pub struct Limiter {
15 max: u32,
16 current: Arc<AtomicU32>,
17}
18
19impl Limiter {
20 pub fn new(max: NonZeroU32) -> Self {
22 Self {
23 max: max.get(),
24 current: Arc::new(AtomicU32::new(0)),
25 }
26 }
27
28 pub fn try_acquire(&self) -> Option<Reservation> {
30 self.current
31 .fetch_update(Ordering::AcqRel, Ordering::Relaxed, |current| {
32 (current < self.max).then_some(current + 1)
33 })
34 .map(|_| Reservation {
35 current: self.current.clone(),
36 })
37 .ok()
38 }
39}
40
41pub struct Reservation {
43 current: Arc<AtomicU32>,
44}
45
46impl Drop for Reservation {
47 fn drop(&mut self) {
48 self.current.fetch_sub(1, Ordering::AcqRel);
49 }
50}
51
52pub struct KeyedLimiter<K: Eq + Hash + Clone> {
54 max: u32,
55 current: Arc<Mutex<HashSet<K>>>,
56}
57
58impl<K: Eq + Hash + Clone> KeyedLimiter<K> {
59 pub fn new(max: NonZeroU32) -> Self {
61 Self {
62 max: max.get(),
63 current: Arc::new(Mutex::new(HashSet::new())),
64 }
65 }
66
67 pub fn try_acquire(&self, key: K) -> Option<KeyedReservation<K>> {
70 let mut current = self.current.lock().unwrap();
71 if current.len() >= self.max as usize {
72 return None;
73 }
74 if !current.insert(key.clone()) {
75 return None;
76 }
77 drop(current);
78
79 Some(KeyedReservation {
80 key,
81 current: self.current.clone(),
82 })
83 }
84}
85
86pub struct KeyedReservation<K: Eq + Hash + Clone> {
88 key: K,
89 current: Arc<Mutex<HashSet<K>>>,
90}
91
92impl<K: Eq + Hash + Clone> Drop for KeyedReservation<K> {
93 fn drop(&mut self) {
94 self.current.lock().unwrap().remove(&self.key);
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::NZU32;
102
103 #[test]
104 fn allows_reservations_up_to_max() {
105 let limiter = Limiter::new(NZU32!(2));
106
107 let first = limiter
108 .try_acquire()
109 .expect("first reservation should succeed");
110 let second = limiter
111 .try_acquire()
112 .expect("second reservation should succeed");
113
114 assert!(limiter.try_acquire().is_none());
115
116 drop(second);
117 let third = limiter
118 .try_acquire()
119 .expect("reservation after drop should succeed");
120
121 drop(third);
122 drop(first);
123 }
124
125 #[test]
126 fn allows_reservations_up_to_max_for_key() {
127 let limiter = KeyedLimiter::new(NZU32!(2));
128
129 let first = limiter
130 .try_acquire(0)
131 .expect("first reservation should succeed");
132 let second = limiter
133 .try_acquire(1)
134 .expect("second reservation should succeed");
135 assert!(limiter.try_acquire(2).is_none());
136
137 drop(second);
138 let third = limiter
139 .try_acquire(2)
140 .expect("third reservation should succeed");
141
142 drop(third);
143 drop(first);
144 }
145
146 #[test]
147 fn blocks_conflicting_reservations_for_key() {
148 let limiter = KeyedLimiter::new(NZU32!(2));
149
150 let first = limiter
151 .try_acquire(0)
152 .expect("first reservation should succeed");
153 assert!(limiter.try_acquire(0).is_none());
154
155 drop(first);
156 let second = limiter
157 .try_acquire(0)
158 .expect("second reservation should succeed");
159
160 drop(second);
161 }
162}