commonware_utils/
priority_set.rsuse std::{
cmp::Ordering,
collections::{BTreeSet, HashMap, HashSet},
hash::Hash,
};
#[derive(Eq, PartialEq)]
struct Entry<I: Ord + Hash + Clone, P: Ord + Copy> {
item: I,
priority: P,
}
impl<I: Ord + Hash + Clone, P: Ord + Copy> Ord for Entry<I, P> {
fn cmp(&self, other: &Self) -> Ordering {
match self.priority.cmp(&other.priority) {
Ordering::Equal => self.item.cmp(&other.item),
other => other,
}
}
}
impl<I: Ord + Hash + Clone, V: Ord + Copy> PartialOrd for Entry<I, V> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct PrioritySet<I: Ord + Hash + Clone, P: Ord + Copy> {
entries: BTreeSet<Entry<I, P>>,
keys: HashMap<I, P>,
}
impl<I: Ord + Hash + Clone, P: Ord + Copy> PrioritySet<I, P> {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
entries: BTreeSet::new(),
keys: HashMap::new(),
}
}
pub fn put(&mut self, item: I, priority: P) {
let entry = if let Some(old_priority) = self.keys.remove(&item) {
let mut old_entry = Entry {
item: item.clone(),
priority: old_priority,
};
self.entries.remove(&old_entry);
old_entry.priority = priority;
old_entry
} else {
Entry { item, priority }
};
self.keys.insert(entry.item.clone(), entry.priority);
self.entries.insert(entry);
}
pub fn get(&self, item: &I) -> Option<P> {
self.keys.get(item).cloned()
}
pub fn remove(&mut self, item: &I) {
let Some(entry) = self.keys.remove(item).map(|priority| Entry {
item: item.clone(),
priority,
}) else {
return;
};
self.entries.remove(&entry);
}
pub fn reconcile(&mut self, keep: &[I], default: P) {
let mut retained: HashSet<_> = keep.iter().collect();
let to_remove = self
.keys
.keys()
.filter(|item| !retained.remove(*item))
.cloned()
.collect::<Vec<_>>();
for item in to_remove {
let priority = self.keys.remove(&item).unwrap();
let entry = Entry { item, priority };
self.entries.remove(&entry);
}
for item in retained {
self.put(item.clone(), default);
}
}
pub fn iter(&self) -> impl Iterator<Item = (&I, &P)> {
self.entries
.iter()
.map(|entry| (&entry.item, &entry.priority))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_put_remove_and_iter() {
let mut pq = PrioritySet::new();
let key1 = "key1";
let key2 = "key2";
pq.put(key1, Duration::from_secs(10));
pq.put(key2, Duration::from_secs(5));
let entries: Vec<_> = pq.iter().collect();
assert_eq!(entries.len(), 2);
assert_eq!(*entries[0].0, key2);
assert_eq!(*entries[1].0, key1);
pq.remove(&key1);
let entries: Vec<_> = pq.iter().collect();
assert_eq!(entries.len(), 1);
assert_eq!(*entries[0].0, key2);
pq.remove(&key1);
let entries: Vec<_> = pq.iter().collect();
assert_eq!(entries.len(), 1);
assert_eq!(*entries[0].0, key2);
}
#[test]
fn test_update() {
let mut pq = PrioritySet::new();
let key = "key";
pq.put(key, Duration::from_secs(10));
assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(10));
pq.put(key, Duration::from_secs(5));
assert_eq!(pq.get(&key).unwrap(), Duration::from_secs(5));
let entries: Vec<_> = pq.iter().collect();
assert_eq!(entries.len(), 1);
assert_eq!(*entries[0].1, Duration::from_secs(5));
}
#[test]
fn test_reconcile() {
let mut pq = PrioritySet::new();
let key1 = "key1";
let key2 = "key2";
pq.put(key1, Duration::from_secs(10));
pq.put(key2, Duration::from_secs(5));
let key3 = "key3";
pq.reconcile(&[key1, key3], Duration::from_secs(2));
let entries: Vec<_> = pq.iter().collect();
assert_eq!(entries.len(), 2);
assert!(entries
.iter()
.any(|e| *e.0 == key1 && *e.1 == Duration::from_secs(10)));
assert!(entries
.iter()
.any(|e| *e.0 == key3 && *e.1 == Duration::from_secs(2)));
}
}