use std::{
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap},
sync::{Arc, Mutex},
};
use tokio::sync::watch;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Priority {
pub track: u8,
pub group: u64,
}
impl Priority {
pub fn new(track: u8, group: u64) -> Self {
Self { track, group }
}
}
impl Ord for Priority {
fn cmp(&self, other: &Self) -> Ordering {
other.track.cmp(&self.track).then(other.group.cmp(&self.group))
}
}
impl PartialOrd for Priority {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
struct PriorityItem {
id: usize,
priority: Priority,
}
impl PartialEq for PriorityItem {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority
}
}
impl Eq for PriorityItem {}
impl PartialOrd for PriorityItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityItem {
fn cmp(&self, other: &Self) -> Ordering {
self.priority.cmp(&other.priority)
}
}
#[derive(Clone, Default)]
pub struct PriorityQueue {
state: Arc<Mutex<PriorityState>>,
}
impl PriorityQueue {
pub fn insert(&self, priority: Priority) -> PriorityHandle {
self.state.lock().unwrap().insert(priority, self.clone())
}
}
const MAX_VEC_SIZE: usize = 255;
enum Location {
Vec(usize), Overflow, }
#[derive(Default)]
struct PriorityState {
vec: Vec<PriorityItem>,
overflow: BinaryHeap<Reverse<PriorityItem>>,
indexes: HashMap<usize, (Location, watch::Sender<u8>)>,
next_id: usize,
}
impl PriorityState {
pub fn insert(&mut self, priority: Priority, myself: PriorityQueue) -> PriorityHandle {
let id = self.next_id;
self.next_id += 1;
let (tx, rx) = watch::channel(u8::MAX);
self.indexes.insert(id, (Location::Overflow, tx));
self.place(PriorityItem { id, priority });
PriorityHandle {
id,
priority,
rx,
queue: myself,
}
}
fn update_indices_from(&mut self, start: usize) {
for (idx, item) in self.vec.iter().enumerate().skip(start) {
Self::update_location(&mut self.indexes, item.id, Location::Vec(idx));
}
}
fn update_location(indexes: &mut HashMap<usize, (Location, watch::Sender<u8>)>, id: usize, location: Location) {
let (loc, tx) = indexes.get_mut(&id).expect("item not in indexes");
*loc = location;
let new_priority = match loc {
Location::Vec(idx) => (*idx).try_into().unwrap_or(u8::MAX),
Location::Overflow => u8::MAX,
};
tx.send_if_modified(|p| {
if *p != new_priority {
*p = new_priority;
true
} else {
false
}
});
}
fn place(&mut self, item: PriorityItem) {
let id = item.id;
if self.vec.len() < MAX_VEC_SIZE {
if let Some(Reverse(top)) = self.overflow.peek()
&& *top < item
{
let Reverse(promoted) = self.overflow.pop().unwrap();
self.overflow.push(Reverse(item));
Self::update_location(&mut self.indexes, id, Location::Overflow);
let insert_pos = self.vec.binary_search(&promoted).unwrap_or_else(|pos| pos);
let promoted_id = promoted.id;
self.vec.insert(insert_pos, promoted);
Self::update_location(&mut self.indexes, promoted_id, Location::Vec(insert_pos));
self.update_indices_from(insert_pos + 1);
return;
}
let insert_pos = self.vec.binary_search(&item).unwrap_or_else(|pos| pos);
self.vec.insert(insert_pos, item);
Self::update_location(&mut self.indexes, id, Location::Vec(insert_pos));
self.update_indices_from(insert_pos + 1);
return;
}
let lowest_in_vec = self.vec.last().unwrap();
if item > *lowest_in_vec {
self.overflow.push(Reverse(item));
Self::update_location(&mut self.indexes, id, Location::Overflow);
return;
}
let removed = self.vec.pop().unwrap();
Self::update_location(&mut self.indexes, removed.id, Location::Overflow);
self.overflow.push(Reverse(removed));
let insert_pos = self.vec.binary_search(&item).unwrap_or_else(|pos| pos);
self.vec.insert(insert_pos, item);
Self::update_location(&mut self.indexes, id, Location::Vec(insert_pos));
self.update_indices_from(insert_pos + 1);
}
fn extract(&mut self, id: usize) -> PriorityItem {
let (location, _) = self.indexes.get(&id).expect("item not in indexes");
match location {
Location::Vec(idx) => {
let idx = *idx;
let item = self.vec.remove(idx);
self.update_indices_from(idx);
item
}
Location::Overflow => {
let mut found = None;
let drained: Vec<_> = self.overflow.drain().collect();
for Reverse(entry) in drained {
if entry.id == id && found.is_none() {
found = Some(entry);
} else {
self.overflow.push(Reverse(entry));
}
}
found.expect("item not found in overflow heap")
}
}
}
fn set_priority(&mut self, id: usize, new_priority: Priority) {
let mut item = self.extract(id);
item.priority = new_priority;
self.place(item);
}
fn remove(&mut self, id: usize) {
let was_in_vec = matches!(self.indexes.get(&id), Some((Location::Vec(_), _)));
self.extract(id);
self.indexes.remove(&id);
if was_in_vec && let Some(Reverse(overflow_item)) = self.overflow.pop() {
let overflow_id = overflow_item.id;
self.vec.push(overflow_item);
Self::update_location(&mut self.indexes, overflow_id, Location::Vec(self.vec.len() - 1));
}
}
}
pub struct PriorityHandle {
id: usize,
priority: Priority,
rx: watch::Receiver<u8>,
queue: PriorityQueue,
}
impl Drop for PriorityHandle {
fn drop(&mut self) {
self.queue.state.lock().unwrap().remove(self.id);
}
}
impl PriorityHandle {
pub fn current(&mut self) -> u8 {
*self.rx.borrow_and_update()
}
pub async fn next(&mut self) -> u8 {
let _ = self.rx.changed().await;
*self.rx.borrow_and_update()
}
pub fn set_track(&mut self, new_track: u8) {
if self.priority.track == new_track {
return;
}
self.priority.track = new_track;
self.queue.state.lock().unwrap().set_priority(self.id, self.priority);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_item() {
let queue = PriorityQueue::default();
let mut handle = queue.insert(Priority::new(100, 5));
assert_eq!(handle.current(), 0); }
#[test]
fn test_track_priority_ordering() {
let queue = PriorityQueue::default();
let mut low = queue.insert(Priority::new(50, 0));
let mut high = queue.insert(Priority::new(255, 0));
let mut mid = queue.insert(Priority::new(100, 0));
assert_eq!(high.current(), 0); assert_eq!(mid.current(), 1); assert_eq!(low.current(), 2); }
#[test]
fn test_group_priority_on_same_track() {
let queue = PriorityQueue::default();
let mut group10 = queue.insert(Priority::new(100, 10));
let mut group5 = queue.insert(Priority::new(100, 5));
let mut group1 = queue.insert(Priority::new(100, 1));
assert_eq!(group10.current(), 0);
assert_eq!(group5.current(), 1);
assert_eq!(group1.current(), 2);
}
#[test]
fn test_track_priority_overrides_group() {
let queue = PriorityQueue::default();
let mut low_track_high_group = queue.insert(Priority::new(50, 1000));
let mut high_track_low_group = queue.insert(Priority::new(255, 1));
assert_eq!(high_track_low_group.current(), 0);
assert_eq!(low_track_high_group.current(), 1);
}
#[test]
fn test_removal_on_drop() {
let queue = PriorityQueue::default();
let mut first = queue.insert(Priority::new(255, 0));
let mut second = queue.insert(Priority::new(100, 0));
let mut third = queue.insert(Priority::new(50, 0));
assert_eq!(first.current(), 0);
assert_eq!(second.current(), 1);
assert_eq!(third.current(), 2);
drop(second);
assert_eq!(first.current(), 0);
assert_eq!(third.current(), 1);
}
#[test]
fn test_removal_of_highest_priority() {
let queue = PriorityQueue::default();
let mut first = queue.insert(Priority::new(255, 0));
let mut second = queue.insert(Priority::new(100, 0));
assert_eq!(first.current(), 0);
assert_eq!(second.current(), 1);
drop(first);
assert_eq!(second.current(), 0);
}
#[test]
fn test_removal_of_lowest_priority() {
let queue = PriorityQueue::default();
let mut first = queue.insert(Priority::new(255, 0));
let mut second = queue.insert(Priority::new(100, 0));
assert_eq!(first.current(), 0);
assert_eq!(second.current(), 1);
drop(second);
assert_eq!(first.current(), 0);
}
#[test]
fn test_many_items_with_same_priority() {
let queue = PriorityQueue::default();
let mut handles: Vec<_> = (0..10).rev().map(|i| queue.insert(Priority::new(100, i))).collect();
assert_eq!(handles[0].current(), 0);
for handle in handles.iter_mut() {
assert!(handle.current() < 10);
}
}
#[test]
fn test_max_priority_value_overflow() {
let queue = PriorityQueue::default();
let mut handles: Vec<_> = (0..300).rev().map(|i| queue.insert(Priority::new(100, i))).collect();
assert_eq!(handles[0].current(), 0);
let mut low_priority_count = 0;
for handle in handles.iter_mut() {
if handle.current() == u8::MAX {
low_priority_count += 1;
}
}
assert!(low_priority_count > 0, "Should have some items beyond u8::MAX index");
assert_eq!(low_priority_count, 45, "Exactly 45 items should overflow (300-255)");
}
#[test]
fn test_complex_ordering() {
let queue = PriorityQueue::default();
let mut high_track_high_group = queue.insert(Priority::new(255, 10));
let mut high_track_low_group = queue.insert(Priority::new(255, 1));
let mut mid_track_high_group = queue.insert(Priority::new(100, 5));
let mut mid_track_low_group = queue.insert(Priority::new(100, 1));
let mut low_track_high_group = queue.insert(Priority::new(50, 100));
assert_eq!(high_track_high_group.current(), 0); assert_eq!(high_track_low_group.current(), 1); assert_eq!(mid_track_high_group.current(), 2); assert_eq!(mid_track_low_group.current(), 3); assert_eq!(low_track_high_group.current(), 4); }
#[tokio::test]
async fn test_watch_notification_on_overflow_promotion() {
let queue = PriorityQueue::default();
let mut fillers: Vec<_> = (0..255)
.rev()
.map(|i| queue.insert(Priority::new(100, i + 100)))
.collect();
let mut overflow_item = queue.insert(Priority::new(100, 50));
assert_eq!(overflow_item.current(), u8::MAX);
let task = tokio::spawn(async move { overflow_item.next().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
fillers.remove(0);
let result = task.await.unwrap();
assert!(result < u8::MAX, "Should be promoted from overflow");
}
#[test]
fn test_interleaved_insertions_and_removals() {
let queue = PriorityQueue::default();
let mut h1 = queue.insert(Priority::new(200, 0));
let h2 = queue.insert(Priority::new(150, 0));
let mut h3 = queue.insert(Priority::new(100, 0));
assert_eq!(h1.current(), 0);
drop(h2);
assert_eq!(h1.current(), 0);
assert!(h3.current() < 2);
let mut h4 = queue.insert(Priority::new(250, 0));
assert_eq!(h4.current(), 0);
assert_eq!(h1.current(), 1);
drop(h4);
assert_eq!(h1.current(), 0);
}
#[test]
fn test_same_track_and_group() {
let queue = PriorityQueue::default();
let mut h1 = queue.insert(Priority::new(100, 5));
let mut h2 = queue.insert(Priority::new(100, 5));
let mut h3 = queue.insert(Priority::new(100, 5));
let indices = [h1.current(), h2.current(), h3.current()];
assert_eq!(indices.len(), 3);
assert!(indices.contains(&0));
assert!(indices.contains(&1));
assert!(indices.contains(&2));
}
#[test]
fn test_removal_updates_siblings() {
let queue = PriorityQueue::default();
let mut root = queue.insert(Priority::new(255, 0));
let left = queue.insert(Priority::new(100, 0));
let mut right = queue.insert(Priority::new(100, 0));
assert_eq!(root.current(), 0);
drop(left);
assert_eq!(root.current(), 0);
assert_eq!(right.current(), 1);
}
#[test]
fn test_heap_property_maintained() {
let queue = PriorityQueue::default();
let mut handles = vec![
queue.insert(Priority::new(100, 5)),
queue.insert(Priority::new(200, 3)),
queue.insert(Priority::new(50, 10)),
queue.insert(Priority::new(200, 8)),
queue.insert(Priority::new(100, 1)),
];
assert_eq!(handles[3].current(), 0);
drop(handles.remove(3));
assert_eq!(handles[1].current(), 0);
}
#[tokio::test]
async fn test_notification_on_demotion_to_overflow() {
let queue = PriorityQueue::default();
let _fillers: Vec<_> = (0..254).map(|i| queue.insert(Priority::new(100, i + 100))).collect();
let mut at_edge = queue.insert(Priority::new(100, 50));
assert_eq!(at_edge.current(), 254);
let task = tokio::spawn(async move { at_edge.next().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let _high = queue.insert(Priority::new(255, 1000));
let new_priority = task.await.unwrap();
assert_eq!(new_priority, u8::MAX, "Should be demoted to overflow");
}
#[test]
fn test_empty_after_all_removed() {
let queue = PriorityQueue::default();
let h1 = queue.insert(Priority::new(100, 0));
let h2 = queue.insert(Priority::new(200, 0));
let h3 = queue.insert(Priority::new(50, 0));
drop(h1);
drop(h2);
drop(h3);
let mut h4 = queue.insert(Priority::new(100, 0));
assert_eq!(h4.current(), 0);
}
#[test]
fn test_set_track_reorders() {
let queue = PriorityQueue::default();
let mut s1_g1 = queue.insert(Priority::new(255, 1));
let mut s1_g2 = queue.insert(Priority::new(255, 2));
let mut s2_g1 = queue.insert(Priority::new(55, 1));
let mut s2_g2 = queue.insert(Priority::new(55, 2));
assert_eq!(s1_g2.current(), 0); assert_eq!(s1_g1.current(), 1);
assert_eq!(s2_g2.current(), 2); assert_eq!(s2_g1.current(), 3);
s1_g1.set_track(55);
s1_g2.set_track(55);
s2_g1.set_track(255);
s2_g2.set_track(255);
assert_eq!(s2_g2.current(), 0); assert_eq!(s2_g1.current(), 1);
assert_eq!(s1_g2.current(), 2); assert_eq!(s1_g1.current(), 3);
}
#[tokio::test]
async fn test_set_track_notifies_other_handles() {
let queue = PriorityQueue::default();
let mut h_high = queue.insert(Priority::new(255, 1));
let mut h_low = queue.insert(Priority::new(50, 1));
assert_eq!(h_low.current(), 1);
let task = tokio::spawn(async move { h_low.next().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
h_high.set_track(10);
let new_priority = task.await.unwrap();
assert_eq!(new_priority, 0, "h_low should be promoted to the top");
}
#[test]
fn test_set_track_self() {
let queue = PriorityQueue::default();
let mut h_high = queue.insert(Priority::new(255, 1));
let mut h_mid = queue.insert(Priority::new(100, 1));
let mut h_low = queue.insert(Priority::new(50, 1));
assert_eq!(h_high.current(), 0);
assert_eq!(h_mid.current(), 1);
assert_eq!(h_low.current(), 2);
h_high.set_track(10);
assert_eq!(h_mid.current(), 0);
assert_eq!(h_low.current(), 1);
assert_eq!(h_high.current(), 2);
}
#[test]
fn test_set_track_swaps_demoted_vec_item_with_overflow() {
let queue = PriorityQueue::default();
let mut fillers: Vec<_> = (1..=255u64).map(|g| queue.insert(Priority::new(100, g))).collect();
let mut top = queue.insert(Priority::new(200, 0));
assert_eq!(top.current(), 0);
assert_eq!(fillers[0].current(), u8::MAX, "f1 was kicked into overflow");
top.set_track(0);
assert!(fillers[0].current() < u8::MAX, "f1 should be promoted back into vec");
assert_eq!(top.current(), u8::MAX, "demoted top should land in overflow");
}
#[test]
fn test_set_track_lowered_within_vec_no_overflow_disruption() {
let queue = PriorityQueue::default();
let mut a = queue.insert(Priority::new(200, 0));
let mut b = queue.insert(Priority::new(100, 0));
let mut c = queue.insert(Priority::new(50, 0));
assert_eq!(a.current(), 0);
assert_eq!(b.current(), 1);
assert_eq!(c.current(), 2);
a.set_track(75);
assert_eq!(b.current(), 0);
assert_eq!(a.current(), 1);
assert_eq!(c.current(), 2);
}
#[test]
fn test_remove_promotes_highest_priority_overflow_item() {
let queue = PriorityQueue::default();
let fillers: Vec<_> = (100..355u64).map(|g| queue.insert(Priority::new(200, g))).collect();
let mut low = queue.insert(Priority::new(100, 1));
let mut mid = queue.insert(Priority::new(100, 2));
let mut high = queue.insert(Priority::new(100, 3));
assert_eq!(low.current(), u8::MAX);
assert_eq!(mid.current(), u8::MAX);
assert_eq!(high.current(), u8::MAX);
drop(fillers);
assert_eq!(
high.current(),
0,
"highest-priority overflow item should land at index 0"
);
assert_eq!(mid.current(), 1);
assert_eq!(low.current(), 2);
}
#[tokio::test]
async fn test_set_track_notifies_swapped_overflow_item() {
tokio::time::pause();
let queue = PriorityQueue::default();
let mut fillers: Vec<_> = (1..=255u64).map(|g| queue.insert(Priority::new(100, g))).collect();
let mut top = queue.insert(Priority::new(200, 0));
assert_eq!(top.current(), 0);
let mut f1 = fillers.remove(0);
assert_eq!(f1.current(), u8::MAX);
let task = tokio::spawn(async move { f1.next().await });
tokio::task::yield_now().await;
top.set_track(0);
let promoted = task.await.unwrap();
assert!(promoted < u8::MAX, "f1 should be notified of promotion");
}
}