commonware_utils/
priority_set.rs1use std::{
2 cmp::Ordering,
3 collections::{BTreeSet, HashMap, HashSet},
4 hash::Hash,
5};
6
7#[derive(Eq, PartialEq)]
9struct Entry<I: Ord + Hash + Clone, P: Ord + Copy> {
10 item: I,
11 priority: P,
12}
13
14impl<I: Ord + Hash + Clone, P: Ord + Copy> Ord for Entry<I, P> {
15 fn cmp(&self, other: &Self) -> Ordering {
16 match self.priority.cmp(&other.priority) {
17 Ordering::Equal => self.item.cmp(&other.item),
18 other => other,
19 }
20 }
21}
22
23impl<I: Ord + Hash + Clone, V: Ord + Copy> PartialOrd for Entry<I, V> {
24 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
25 Some(self.cmp(other))
26 }
27}
28
29pub struct PrioritySet<I: Ord + Hash + Clone, P: Ord + Copy> {
32 entries: BTreeSet<Entry<I, P>>,
33 keys: HashMap<I, P>,
34}
35
36impl<I: Ord + Hash + Clone, P: Ord + Copy> PrioritySet<I, P> {
37 #[allow(clippy::new_without_default)]
39 pub fn new() -> Self {
40 Self {
41 entries: BTreeSet::new(),
42 keys: HashMap::new(),
43 }
44 }
45
46 pub fn put(&mut self, item: I, priority: P) {
48 let entry = if let Some(old_priority) = self.keys.remove(&item) {
50 let mut old_entry = Entry {
52 item: item.clone(),
53 priority: old_priority,
54 };
55 self.entries.remove(&old_entry);
56
57 old_entry.priority = priority;
59 old_entry
60 } else {
61 Entry { item, priority }
62 };
63
64 self.keys.insert(entry.item.clone(), entry.priority);
66 self.entries.insert(entry);
67 }
68
69 pub fn get(&self, item: &I) -> Option<P> {
71 self.keys.get(item).cloned()
72 }
73
74 pub fn remove(&mut self, item: &I) {
76 let Some(entry) = self.keys.remove(item).map(|priority| Entry {
77 item: item.clone(),
78 priority,
79 }) else {
80 return;
81 };
82 self.entries.remove(&entry);
83 }
84
85 pub fn reconcile(&mut self, keep: &[I], default: P) {
88 let mut retained: HashSet<_> = keep.iter().collect();
90 let to_remove = self
91 .keys
92 .keys()
93 .filter(|item| !retained.remove(*item))
94 .cloned()
95 .collect::<Vec<_>>();
96 for item in to_remove {
97 let priority = self.keys.remove(&item).unwrap();
98 let entry = Entry { item, priority };
99 self.entries.remove(&entry);
100 }
101
102 for item in retained {
104 self.put(item.clone(), default);
105 }
106 }
107
108 pub fn iter(&self) -> impl Iterator<Item = (&I, &P)> {
110 self.entries
111 .iter()
112 .map(|entry| (&entry.item, &entry.priority))
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use std::time::Duration;
120
121 #[test]
122 fn test_put_remove_and_iter() {
123 let mut pq = PrioritySet::new();
125
126 let key1 = "key1";
128 let key2 = "key2";
129 pq.put(key1, Duration::from_secs(10));
130 pq.put(key2, Duration::from_secs(5));
131
132 let entries: Vec<_> = pq.iter().collect();
134 assert_eq!(entries.len(), 2);
135 assert_eq!(*entries[0].0, key2);
136 assert_eq!(*entries[1].0, key1);
137
138 pq.remove(&key1);
140
141 let entries: Vec<_> = pq.iter().collect();
143 assert_eq!(entries.len(), 1);
144 assert_eq!(*entries[0].0, key2);
145
146 pq.remove(&key1);
148
149 let entries: Vec<_> = pq.iter().collect();
151 assert_eq!(entries.len(), 1);
152 assert_eq!(*entries[0].0, key2);
153 }
154
155 #[test]
156 fn test_update() {
157 let mut pq = PrioritySet::new();
159
160 let key = "key";
162 pq.put(key, Duration::from_secs(10));
163 assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(10));
164
165 pq.put(key, Duration::from_secs(5));
167 assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(5));
168
169 let entries: Vec<_> = pq.iter().collect();
171 assert_eq!(entries.len(), 1);
172 assert_eq!(*entries[0].1, Duration::from_secs(5));
173 }
174
175 #[test]
176 fn test_reconcile() {
177 let mut pq = PrioritySet::new();
179
180 let key1 = "key1";
182 let key2 = "key2";
183 pq.put(key1, Duration::from_secs(10));
184 pq.put(key2, Duration::from_secs(5));
185
186 let key3 = "key3";
188 pq.reconcile(&[key1, key3], Duration::from_secs(2));
189
190 let entries: Vec<_> = pq.iter().collect();
192 assert_eq!(entries.len(), 2);
193 assert!(entries
194 .iter()
195 .any(|e| *e.0 == key1 && *e.1 == Duration::from_secs(10)));
196 assert!(entries
197 .iter()
198 .any(|e| *e.0 == key3 && *e.1 == Duration::from_secs(2)));
199 }
200}