commonware_utils/
priority_set.rs

1use std::{
2    cmp::Ordering,
3    collections::{BTreeSet, HashMap, HashSet},
4    hash::Hash,
5};
6
7/// An entry in the `PrioritySet`.
8#[derive(Eq, PartialEq)]
9struct Entry<I: Ord + Hash + Clone, P: Ord + Copy> {
10    item: I,
11    priority: P,
12}
13
14impl<I: Ord + Hash + Clone, P: Ord + Copy> Ord for Entry<I, P> {
15    fn cmp(&self, other: &Self) -> Ordering {
16        match self.priority.cmp(&other.priority) {
17            Ordering::Equal => self.item.cmp(&other.item),
18            other => other,
19        }
20    }
21}
22
23impl<I: Ord + Hash + Clone, V: Ord + Copy> PartialOrd for Entry<I, V> {
24    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
25        Some(self.cmp(other))
26    }
27}
28
29/// A set that offers efficient iteration over
30/// its elements in priority-ascending order.
31pub struct PrioritySet<I: Ord + Hash + Clone, P: Ord + Copy> {
32    entries: BTreeSet<Entry<I, P>>,
33    keys: HashMap<I, P>,
34}
35
36impl<I: Ord + Hash + Clone, P: Ord + Copy> PrioritySet<I, P> {
37    /// Create a new `PrioritySet`.
38    #[allow(clippy::new_without_default)]
39    pub fn new() -> Self {
40        Self {
41            entries: BTreeSet::new(),
42            keys: HashMap::new(),
43        }
44    }
45
46    /// Insert an item with a priority, overwriting the previous priority if it exists.
47    pub fn put(&mut self, item: I, priority: P) {
48        // Remove old entry, if it exists
49        let entry = if let Some(old_priority) = self.keys.remove(&item) {
50            // Remove the item from the old priority's set
51            let mut old_entry = Entry {
52                item: item.clone(),
53                priority: old_priority,
54            };
55            self.entries.remove(&old_entry);
56
57            // We reuse the entry to avoid another item clone
58            old_entry.priority = priority;
59            old_entry
60        } else {
61            Entry { item, priority }
62        };
63
64        // Insert the entry
65        self.keys.insert(entry.item.clone(), entry.priority);
66        self.entries.insert(entry);
67    }
68
69    /// Get the current priority of an item.
70    pub fn get(&self, item: &I) -> Option<P> {
71        self.keys.get(item).cloned()
72    }
73
74    /// Remove an item from the set.
75    pub fn remove(&mut self, item: &I) {
76        let Some(entry) = self.keys.remove(item).map(|priority| Entry {
77            item: item.clone(),
78            priority,
79        }) else {
80            return;
81        };
82        self.entries.remove(&entry);
83    }
84
85    /// Remove all previously inserted items not included in `keep`
86    /// and add any items not yet seen with a priority of `initial`.
87    pub fn reconcile(&mut self, keep: &[I], default: P) {
88        // Remove items not in keep
89        let mut retained: HashSet<_> = keep.iter().collect();
90        let to_remove = self
91            .keys
92            .keys()
93            .filter(|item| !retained.remove(*item))
94            .cloned()
95            .collect::<Vec<_>>();
96        for item in to_remove {
97            let priority = self.keys.remove(&item).unwrap();
98            let entry = Entry { item, priority };
99            self.entries.remove(&entry);
100        }
101
102        // Add any items not yet removed with the initial priority
103        for item in retained {
104            self.put(item.clone(), default);
105        }
106    }
107
108    /// Returns an iterator over all items in the set in priority-ascending order.
109    pub fn iter(&self) -> impl Iterator<Item = (&I, &P)> {
110        self.entries
111            .iter()
112            .map(|entry| (&entry.item, &entry.priority))
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use std::time::Duration;
120
121    #[test]
122    fn test_put_remove_and_iter() {
123        // Create a new PrioritySet
124        let mut pq = PrioritySet::new();
125
126        // Add items with different priorities
127        let key1 = "key1";
128        let key2 = "key2";
129        pq.put(key1, Duration::from_secs(10));
130        pq.put(key2, Duration::from_secs(5));
131
132        // Verify iteration order
133        let entries: Vec<_> = pq.iter().collect();
134        assert_eq!(entries.len(), 2);
135        assert_eq!(*entries[0].0, key2);
136        assert_eq!(*entries[1].0, key1);
137
138        // Remove existing item
139        pq.remove(&key1);
140
141        // Verify new iteration order
142        let entries: Vec<_> = pq.iter().collect();
143        assert_eq!(entries.len(), 1);
144        assert_eq!(*entries[0].0, key2);
145
146        // Remove non-existing item
147        pq.remove(&key1);
148
149        // Verify iteration order is still the same
150        let entries: Vec<_> = pq.iter().collect();
151        assert_eq!(entries.len(), 1);
152        assert_eq!(*entries[0].0, key2);
153    }
154
155    #[test]
156    fn test_update() {
157        // Create a new PrioritySet
158        let mut pq = PrioritySet::new();
159
160        // Add an item with a priority and verify it can be retrieved
161        let key = "key";
162        pq.put(key, Duration::from_secs(10));
163        assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(10));
164
165        // Update the priority and verify it has changed
166        pq.put(key, Duration::from_secs(5));
167        assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(5));
168
169        // Verify updated priority is in the iteration
170        let entries: Vec<_> = pq.iter().collect();
171        assert_eq!(entries.len(), 1);
172        assert_eq!(*entries[0].1, Duration::from_secs(5));
173    }
174
175    #[test]
176    fn test_reconcile() {
177        // Create a new PrioritySet
178        let mut pq = PrioritySet::new();
179
180        // Add 2 items with different priorities
181        let key1 = "key1";
182        let key2 = "key2";
183        pq.put(key1, Duration::from_secs(10));
184        pq.put(key2, Duration::from_secs(5));
185
186        // Introduce a new item and remove an existing one
187        let key3 = "key3";
188        pq.reconcile(&[key1, key3], Duration::from_secs(2));
189
190        // Verify iteration over only the kept items
191        let entries: Vec<_> = pq.iter().collect();
192        assert_eq!(entries.len(), 2);
193        assert!(entries
194            .iter()
195            .any(|e| *e.0 == key1 && *e.1 == Duration::from_secs(10)));
196        assert!(entries
197            .iter()
198            .any(|e| *e.0 == key3 && *e.1 == Duration::from_secs(2)));
199    }
200}