commonware_utils/
priority_set.rs1use std::{
2 cmp::Ordering,
3 collections::{BTreeSet, HashMap, HashSet},
4 hash::Hash,
5};
6
7#[derive(Eq, PartialEq)]
9struct Entry<I: Ord + Hash + Clone, P: Ord + Copy> {
10 item: I,
11 priority: P,
12}
13
14impl<I: Ord + Hash + Clone, P: Ord + Copy> Ord for Entry<I, P> {
15 fn cmp(&self, other: &Self) -> Ordering {
16 match self.priority.cmp(&other.priority) {
17 Ordering::Equal => self.item.cmp(&other.item),
18 other => other,
19 }
20 }
21}
22
23impl<I: Ord + Hash + Clone, V: Ord + Copy> PartialOrd for Entry<I, V> {
24 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
25 Some(self.cmp(other))
26 }
27}
28
29pub struct PrioritySet<I: Ord + Hash + Clone, P: Ord + Copy> {
32 entries: BTreeSet<Entry<I, P>>,
33 keys: HashMap<I, P>,
34}
35
36impl<I: Ord + Hash + Clone, P: Ord + Copy> PrioritySet<I, P> {
37 #[allow(clippy::new_without_default)]
39 pub fn new() -> Self {
40 Self {
41 entries: BTreeSet::new(),
42 keys: HashMap::new(),
43 }
44 }
45
46 pub fn put(&mut self, item: I, priority: P) {
48 let entry = if let Some(old_priority) = self.keys.remove(&item) {
50 let mut old_entry = Entry {
52 item: item.clone(),
53 priority: old_priority,
54 };
55 self.entries.remove(&old_entry);
56
57 old_entry.priority = priority;
59 old_entry
60 } else {
61 Entry { item, priority }
62 };
63
64 self.keys.insert(entry.item.clone(), entry.priority);
66 self.entries.insert(entry);
67 }
68
69 pub fn get(&self, item: &I) -> Option<P> {
71 self.keys.get(item).cloned()
72 }
73
74 pub fn remove(&mut self, item: &I) -> bool {
78 let Some(entry) = self.keys.remove(item).map(|priority| Entry {
79 item: item.clone(),
80 priority,
81 }) else {
82 return false;
83 };
84 assert!(self.entries.remove(&entry));
85 true
86 }
87
88 pub fn reconcile(&mut self, keep: &[I], default: P) {
91 let mut retained: HashSet<_> = keep.iter().collect();
93 let to_remove = self
94 .keys
95 .keys()
96 .filter(|item| !retained.remove(*item))
97 .cloned()
98 .collect::<Vec<_>>();
99 for item in to_remove {
100 let priority = self.keys.remove(&item).unwrap();
101 let entry = Entry { item, priority };
102 self.entries.remove(&entry);
103 }
104
105 for item in retained {
107 self.put(item.clone(), default);
108 }
109 }
110
111 pub fn retain(&mut self, predicate: impl Fn(&I) -> bool) {
113 self.entries.retain(|entry| predicate(&entry.item));
114 self.keys.retain(|key, _| predicate(key));
115 }
116
117 pub fn contains(&self, item: &I) -> bool {
119 self.keys.contains_key(item)
120 }
121
122 pub fn peek(&self) -> Option<(&I, &P)> {
124 self.entries
125 .iter()
126 .next()
127 .map(|entry| (&entry.item, &entry.priority))
128 }
129
130 pub fn pop(&mut self) -> Option<(I, P)> {
132 self.entries.pop_first().map(|entry| {
133 self.keys.remove(&entry.item);
134 (entry.item, entry.priority)
135 })
136 }
137
138 pub fn clear(&mut self) {
140 self.entries.clear();
141 self.keys.clear();
142 }
143
144 pub fn iter(&self) -> impl Iterator<Item = (&I, &P)> {
146 self.entries
147 .iter()
148 .map(|entry| (&entry.item, &entry.priority))
149 }
150
151 pub fn len(&self) -> usize {
153 self.entries.len()
154 }
155
156 pub fn is_empty(&self) -> bool {
158 self.entries.is_empty()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use std::time::Duration;
166
167 #[test]
168 fn test_put_remove_and_iter() {
169 let mut pq = PrioritySet::new();
171
172 let key1 = "key1";
174 let key2 = "key2";
175 pq.put(key1, Duration::from_secs(10));
176 pq.put(key2, Duration::from_secs(5));
177
178 let entries: Vec<_> = pq.iter().collect();
180 assert_eq!(entries.len(), 2);
181 assert_eq!(*entries[0].0, key2);
182 assert_eq!(*entries[1].0, key1);
183
184 pq.remove(&key1);
186
187 let entries: Vec<_> = pq.iter().collect();
189 assert_eq!(entries.len(), 1);
190 assert_eq!(*entries[0].0, key2);
191
192 pq.remove(&key1);
194
195 let entries: Vec<_> = pq.iter().collect();
197 assert_eq!(entries.len(), 1);
198 assert_eq!(*entries[0].0, key2);
199 }
200
201 #[test]
202 fn test_update() {
203 let mut pq = PrioritySet::new();
205
206 let key = "key";
208 pq.put(key, Duration::from_secs(10));
209 assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(10));
210
211 pq.put(key, Duration::from_secs(5));
213 assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(5));
214
215 let entries: Vec<_> = pq.iter().collect();
217 assert_eq!(entries.len(), 1);
218 assert_eq!(*entries[0].1, Duration::from_secs(5));
219 }
220
221 #[test]
222 fn test_reconcile() {
223 let mut pq = PrioritySet::new();
225
226 let key1 = "key1";
228 let key2 = "key2";
229 pq.put(key1, Duration::from_secs(10));
230 pq.put(key2, Duration::from_secs(5));
231
232 let key3 = "key3";
234 pq.reconcile(&[key1, key3], Duration::from_secs(2));
235
236 let entries: Vec<_> = pq.iter().collect();
238 assert_eq!(entries.len(), 2);
239 assert!(entries
240 .iter()
241 .any(|e| *e.0 == key1 && *e.1 == Duration::from_secs(10)));
242 assert!(entries
243 .iter()
244 .any(|e| *e.0 == key3 && *e.1 == Duration::from_secs(2)));
245 }
246
247 #[test]
248 fn test_retain() {
249 let mut pq = PrioritySet::new();
251
252 pq.put("key1", Duration::from_secs(10));
254 pq.put("key2", Duration::from_secs(5));
255 pq.put("item3", Duration::from_secs(15));
256
257 pq.retain(|key| key.starts_with("key"));
259
260 assert_eq!(pq.len(), 2);
262 assert!(pq.contains(&"key1"));
263 assert!(pq.contains(&"key2"));
264 assert!(!pq.contains(&"item3"));
265
266 let entries: Vec<_> = pq.iter().collect();
268 assert_eq!(entries.len(), 2);
269 assert_eq!(*entries[0].0, "key2");
270 assert_eq!(*entries[1].0, "key1");
271 }
272
273 #[test]
274 fn test_clear() {
275 let mut pq = PrioritySet::new();
277
278 pq.put("key1", Duration::from_secs(10));
280 pq.put("key2", Duration::from_secs(5));
281
282 pq.clear();
284
285 assert_eq!(pq.len(), 0);
287 assert!(pq.is_empty());
288 assert!(pq.iter().next().is_none());
289 assert!(pq.peek().is_none());
290 }
291}