Skip to main content

commonware_utils/
concurrency.rs

1//! Utilities for managing concurrency.
2
3use crate::sync::Mutex;
4use core::{
5    hash::Hash,
6    num::NonZeroU32,
7    sync::atomic::{AtomicU32, Ordering},
8};
9use std::{collections::HashSet, sync::Arc};
10
11/// Limit the concurrency of some operation without blocking.
12pub struct Limiter {
13    max: u32,
14    current: Arc<AtomicU32>,
15}
16
17impl Limiter {
18    /// Create a limiter that allows up to `max` concurrent reservations.
19    pub fn new(max: NonZeroU32) -> Self {
20        Self {
21            max: max.get(),
22            current: Arc::new(AtomicU32::new(0)),
23        }
24    }
25
26    /// Attempt to reserve a slot. Returns `None` when the limiter is saturated.
27    pub fn try_acquire(&self) -> Option<Reservation> {
28        self.current
29            .fetch_update(Ordering::AcqRel, Ordering::Relaxed, |current| {
30                (current < self.max).then_some(current + 1)
31            })
32            .map(|_| Reservation {
33                current: self.current.clone(),
34            })
35            .ok()
36    }
37}
38
39/// A reservation for a slot in the [Limiter].
40pub struct Reservation {
41    current: Arc<AtomicU32>,
42}
43
44impl Drop for Reservation {
45    fn drop(&mut self) {
46        self.current.fetch_sub(1, Ordering::AcqRel);
47    }
48}
49
50/// Limit the concurrency of some keyed operation without blocking.
51pub struct KeyedLimiter<K: Eq + Hash + Clone> {
52    max: u32,
53    current: Arc<Mutex<HashSet<K>>>,
54}
55
56impl<K: Eq + Hash + Clone> KeyedLimiter<K> {
57    /// Create a limiter that allows up to `max` concurrent reservations.
58    pub fn new(max: NonZeroU32) -> Self {
59        Self {
60            max: max.get(),
61            current: Arc::new(Mutex::new(HashSet::new())),
62        }
63    }
64
65    /// Attempt to reserve a slot for a given key. Returns `None` when the limiter is saturated or
66    /// the key is already reserved.
67    pub fn try_acquire(&self, key: K) -> Option<KeyedReservation<K>> {
68        let mut current = self.current.lock();
69        if current.len() >= self.max as usize {
70            return None;
71        }
72        if !current.insert(key.clone()) {
73            return None;
74        }
75        drop(current);
76
77        Some(KeyedReservation {
78            key,
79            current: self.current.clone(),
80        })
81    }
82}
83
84/// A reservation for a slot in the [KeyedLimiter].
85pub struct KeyedReservation<K: Eq + Hash + Clone> {
86    key: K,
87    current: Arc<Mutex<HashSet<K>>>,
88}
89
90impl<K: Eq + Hash + Clone> Drop for KeyedReservation<K> {
91    fn drop(&mut self) {
92        self.current.lock().remove(&self.key);
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::NZU32;
100
101    #[test]
102    fn allows_reservations_up_to_max() {
103        let limiter = Limiter::new(NZU32!(2));
104
105        let first = limiter
106            .try_acquire()
107            .expect("first reservation should succeed");
108        let second = limiter
109            .try_acquire()
110            .expect("second reservation should succeed");
111
112        assert!(limiter.try_acquire().is_none());
113
114        drop(second);
115        let third = limiter
116            .try_acquire()
117            .expect("reservation after drop should succeed");
118
119        drop(third);
120        drop(first);
121    }
122
123    #[test]
124    fn allows_reservations_up_to_max_for_key() {
125        let limiter = KeyedLimiter::new(NZU32!(2));
126
127        let first = limiter
128            .try_acquire(0)
129            .expect("first reservation should succeed");
130        let second = limiter
131            .try_acquire(1)
132            .expect("second reservation should succeed");
133        assert!(limiter.try_acquire(2).is_none());
134
135        drop(second);
136        let third = limiter
137            .try_acquire(2)
138            .expect("third reservation should succeed");
139
140        drop(third);
141        drop(first);
142    }
143
144    #[test]
145    fn blocks_conflicting_reservations_for_key() {
146        let limiter = KeyedLimiter::new(NZU32!(2));
147
148        let first = limiter
149            .try_acquire(0)
150            .expect("first reservation should succeed");
151        assert!(limiter.try_acquire(0).is_none());
152
153        drop(first);
154        let second = limiter
155            .try_acquire(0)
156            .expect("second reservation should succeed");
157
158        drop(second);
159    }
160}