commonware_utils/
priority_set.rs

1use std::{
2    cmp::Ordering,
3    collections::{BTreeSet, HashMap, HashSet},
4    hash::Hash,
5};
6
7/// An entry in the `PrioritySet`.
8#[derive(Eq, PartialEq)]
9struct Entry<I: Ord + Hash + Clone, P: Ord + Copy> {
10    item: I,
11    priority: P,
12}
13
14impl<I: Ord + Hash + Clone, P: Ord + Copy> Ord for Entry<I, P> {
15    fn cmp(&self, other: &Self) -> Ordering {
16        match self.priority.cmp(&other.priority) {
17            Ordering::Equal => self.item.cmp(&other.item),
18            other => other,
19        }
20    }
21}
22
23impl<I: Ord + Hash + Clone, V: Ord + Copy> PartialOrd for Entry<I, V> {
24    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
25        Some(self.cmp(other))
26    }
27}
28
29/// A set that offers efficient iteration over
30/// its elements in priority-ascending order.
31pub struct PrioritySet<I: Ord + Hash + Clone, P: Ord + Copy> {
32    entries: BTreeSet<Entry<I, P>>,
33    keys: HashMap<I, P>,
34}
35
36impl<I: Ord + Hash + Clone, P: Ord + Copy> PrioritySet<I, P> {
37    /// Create a new `PrioritySet`.
38    #[allow(clippy::new_without_default)]
39    pub fn new() -> Self {
40        Self {
41            entries: BTreeSet::new(),
42            keys: HashMap::new(),
43        }
44    }
45
46    /// Insert an item with a priority, overwriting the previous priority if it exists.
47    pub fn put(&mut self, item: I, priority: P) {
48        // Remove old entry, if it exists
49        let entry = if let Some(old_priority) = self.keys.remove(&item) {
50            // Remove the item from the old priority's set
51            let mut old_entry = Entry {
52                item: item.clone(),
53                priority: old_priority,
54            };
55            self.entries.remove(&old_entry);
56
57            // We reuse the entry to avoid another item clone
58            old_entry.priority = priority;
59            old_entry
60        } else {
61            Entry { item, priority }
62        };
63
64        // Insert the entry
65        self.keys.insert(entry.item.clone(), entry.priority);
66        self.entries.insert(entry);
67    }
68
69    /// Get the current priority of an item.
70    pub fn get(&self, item: &I) -> Option<P> {
71        self.keys.get(item).cloned()
72    }
73
74    /// Remove an item from the set.
75    ///
76    /// Returns `true` if the item was present.
77    pub fn remove(&mut self, item: &I) -> bool {
78        let Some(entry) = self.keys.remove(item).map(|priority| Entry {
79            item: item.clone(),
80            priority,
81        }) else {
82            return false;
83        };
84        assert!(self.entries.remove(&entry));
85        true
86    }
87
88    /// Remove all previously inserted items not included in `keep`
89    /// and add any items not yet seen with a priority of `initial`.
90    pub fn reconcile(&mut self, keep: &[I], default: P) {
91        // Remove items not in keep
92        let mut retained: HashSet<_> = keep.iter().collect();
93        let to_remove = self
94            .keys
95            .keys()
96            .filter(|item| !retained.remove(*item))
97            .cloned()
98            .collect::<Vec<_>>();
99        for item in to_remove {
100            let priority = self.keys.remove(&item).unwrap();
101            let entry = Entry { item, priority };
102            self.entries.remove(&entry);
103        }
104
105        // Add any items not yet removed with the initial priority
106        for item in retained {
107            self.put(item.clone(), default);
108        }
109    }
110
111    /// Retains only the items where the key satisfies the predicate.
112    pub fn retain(&mut self, predicate: impl Fn(&I) -> bool) {
113        self.entries.retain(|entry| predicate(&entry.item));
114        self.keys.retain(|key, _| predicate(key));
115    }
116
117    /// Returns `true` if the set contains the item.
118    pub fn contains(&self, item: &I) -> bool {
119        self.keys.contains_key(item)
120    }
121
122    /// Returns the item with the highest priority.
123    pub fn peek(&self) -> Option<(&I, &P)> {
124        self.entries
125            .iter()
126            .next()
127            .map(|entry| (&entry.item, &entry.priority))
128    }
129
130    /// Removes and returns the item with the highest priority.
131    pub fn pop(&mut self) -> Option<(I, P)> {
132        self.entries.pop_first().map(|entry| {
133            self.keys.remove(&entry.item);
134            (entry.item, entry.priority)
135        })
136    }
137
138    /// Remove all items from the set.
139    pub fn clear(&mut self) {
140        self.entries.clear();
141        self.keys.clear();
142    }
143
144    /// Returns an iterator over all items in the set in priority-ascending order.
145    pub fn iter(&self) -> impl Iterator<Item = (&I, &P)> {
146        self.entries
147            .iter()
148            .map(|entry| (&entry.item, &entry.priority))
149    }
150
151    /// Returns the number of items in the set.
152    pub fn len(&self) -> usize {
153        self.entries.len()
154    }
155
156    /// Returns `true` if the set is empty.
157    pub fn is_empty(&self) -> bool {
158        self.entries.is_empty()
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use std::time::Duration;
166
167    #[test]
168    fn test_put_remove_and_iter() {
169        // Create a new PrioritySet
170        let mut pq = PrioritySet::new();
171
172        // Add items with different priorities
173        let key1 = "key1";
174        let key2 = "key2";
175        pq.put(key1, Duration::from_secs(10));
176        pq.put(key2, Duration::from_secs(5));
177
178        // Verify iteration order
179        let entries: Vec<_> = pq.iter().collect();
180        assert_eq!(entries.len(), 2);
181        assert_eq!(*entries[0].0, key2);
182        assert_eq!(*entries[1].0, key1);
183
184        // Remove existing item
185        pq.remove(&key1);
186
187        // Verify new iteration order
188        let entries: Vec<_> = pq.iter().collect();
189        assert_eq!(entries.len(), 1);
190        assert_eq!(*entries[0].0, key2);
191
192        // Remove non-existing item
193        pq.remove(&key1);
194
195        // Verify iteration order is still the same
196        let entries: Vec<_> = pq.iter().collect();
197        assert_eq!(entries.len(), 1);
198        assert_eq!(*entries[0].0, key2);
199    }
200
201    #[test]
202    fn test_update() {
203        // Create a new PrioritySet
204        let mut pq = PrioritySet::new();
205
206        // Add an item with a priority and verify it can be retrieved
207        let key = "key";
208        pq.put(key, Duration::from_secs(10));
209        assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(10));
210
211        // Update the priority and verify it has changed
212        pq.put(key, Duration::from_secs(5));
213        assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(5));
214
215        // Verify updated priority is in the iteration
216        let entries: Vec<_> = pq.iter().collect();
217        assert_eq!(entries.len(), 1);
218        assert_eq!(*entries[0].1, Duration::from_secs(5));
219    }
220
221    #[test]
222    fn test_reconcile() {
223        // Create a new PrioritySet
224        let mut pq = PrioritySet::new();
225
226        // Add 2 items with different priorities
227        let key1 = "key1";
228        let key2 = "key2";
229        pq.put(key1, Duration::from_secs(10));
230        pq.put(key2, Duration::from_secs(5));
231
232        // Introduce a new item and remove an existing one
233        let key3 = "key3";
234        pq.reconcile(&[key1, key3], Duration::from_secs(2));
235
236        // Verify iteration over only the kept items
237        let entries: Vec<_> = pq.iter().collect();
238        assert_eq!(entries.len(), 2);
239        assert!(entries
240            .iter()
241            .any(|e| *e.0 == key1 && *e.1 == Duration::from_secs(10)));
242        assert!(entries
243            .iter()
244            .any(|e| *e.0 == key3 && *e.1 == Duration::from_secs(2)));
245    }
246
247    #[test]
248    fn test_retain() {
249        // Create a new PrioritySet
250        let mut pq = PrioritySet::new();
251
252        // Add items with different priorities
253        pq.put("key1", Duration::from_secs(10));
254        pq.put("key2", Duration::from_secs(5));
255        pq.put("item3", Duration::from_secs(15));
256
257        // Retain only items that start with "key"
258        pq.retain(|key| key.starts_with("key"));
259
260        // Verify that only "key1" and "key2" are present
261        assert_eq!(pq.len(), 2);
262        assert!(pq.contains(&"key1"));
263        assert!(pq.contains(&"key2"));
264        assert!(!pq.contains(&"item3"));
265
266        // Verify iteration order
267        let entries: Vec<_> = pq.iter().collect();
268        assert_eq!(entries.len(), 2);
269        assert_eq!(*entries[0].0, "key2");
270        assert_eq!(*entries[1].0, "key1");
271    }
272
273    #[test]
274    fn test_clear() {
275        // Create a new PrioritySet
276        let mut pq = PrioritySet::new();
277
278        // Add some items
279        pq.put("key1", Duration::from_secs(10));
280        pq.put("key2", Duration::from_secs(5));
281
282        // Clear the set
283        pq.clear();
284
285        // Verify the set is empty
286        assert_eq!(pq.len(), 0);
287        assert!(pq.is_empty());
288        assert!(pq.iter().next().is_none());
289        assert!(pq.peek().is_none());
290    }
291}