use std::{
cmp::Ordering,
collections::{BinaryHeap, HashMap},
sync::{Arc, Mutex},
};
use tokio::sync::watch;
#[derive(Debug, Clone)]
struct PriorityItem {
id: usize,
track: u8,
group: u64,
}
impl PartialEq for PriorityItem {
fn eq(&self, other: &Self) -> bool {
self.track == other.track && self.group == other.group
}
}
impl Eq for PriorityItem {}
impl PartialOrd for PriorityItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityItem {
fn cmp(&self, other: &Self) -> Ordering {
other.track.cmp(&self.track).then(other.group.cmp(&self.group))
}
}
#[derive(Clone, Default)]
pub struct PriorityQueue {
state: Arc<Mutex<PriorityState>>,
}
impl PriorityQueue {
pub fn insert(&self, track: u8, group: u64) -> PriorityHandle {
self.state.lock().unwrap().insert(track, group, self.clone())
}
}
const MAX_VEC_SIZE: usize = 255;
enum Location {
Vec(usize), Overflow, }
#[derive(Default)]
struct PriorityState {
vec: Vec<PriorityItem>,
overflow: BinaryHeap<PriorityItem>,
indexes: HashMap<usize, (Location, watch::Sender<u8>)>,
next_id: usize,
}
impl PriorityState {
pub fn insert(&mut self, track: u8, group: u64, myself: PriorityQueue) -> PriorityHandle {
let id = self.next_id;
self.next_id += 1;
let item = PriorityItem { track, group, id };
if self.vec.len() < MAX_VEC_SIZE {
let insert_pos = self.vec.binary_search(&item).unwrap_or_else(|pos| pos);
let initial_priority = insert_pos.try_into().unwrap_or(u8::MAX);
let (tx, rx) = watch::channel(initial_priority);
self.vec.insert(insert_pos, item);
self.indexes.insert(id, (Location::Vec(insert_pos), tx));
self.update_indices_from(insert_pos + 1);
return PriorityHandle { id, rx, queue: myself };
}
let lowest_in_vec = self.vec.last().unwrap();
if item > *lowest_in_vec {
let (tx, rx) = watch::channel(u8::MAX);
self.overflow.push(item);
self.indexes.insert(id, (Location::Overflow, tx));
return PriorityHandle { id, rx, queue: myself };
}
let removed = self.vec.pop().unwrap();
Self::update_location(&mut self.indexes, removed.id, Location::Overflow);
self.overflow.push(removed);
let insert_pos = self.vec.binary_search(&item).unwrap_or_else(|pos| pos);
let initial_priority = insert_pos.try_into().expect("only 255 items allowed");
let (tx, rx) = watch::channel(initial_priority);
self.vec.insert(insert_pos, item);
self.indexes.insert(id, (Location::Vec(insert_pos), tx));
self.update_indices_from(insert_pos + 1);
PriorityHandle { id, rx, queue: myself }
}
fn update_indices_from(&mut self, start: usize) {
for (idx, item) in self.vec.iter().enumerate().skip(start) {
Self::update_location(&mut self.indexes, item.id, Location::Vec(idx));
}
}
fn update_location(indexes: &mut HashMap<usize, (Location, watch::Sender<u8>)>, id: usize, location: Location) {
let (loc, tx) = indexes.get_mut(&id).expect("item not in indexes");
*loc = location;
let new_priority = match loc {
Location::Vec(idx) => (*idx).try_into().unwrap_or(u8::MAX),
Location::Overflow => u8::MAX,
};
let _ = tx.send_if_modified(|p| {
if *p != new_priority {
*p = new_priority;
true
} else {
false
}
});
}
fn remove(&mut self, id: usize) {
let (location, _) = self.indexes.remove(&id).expect("item not in indexes");
if let Location::Vec(pos) = location {
self.vec.remove(pos);
if let Some(overflow_item) = self.overflow.pop() {
let overflow_id = overflow_item.id;
self.vec.push(overflow_item);
Self::update_location(&mut self.indexes, overflow_id, Location::Vec(self.vec.len() - 1));
}
self.update_indices_from(pos);
} else {
let original_len = self.overflow.len();
self.overflow = self.overflow.drain().filter(|item| item.id != id).collect();
assert_eq!(self.overflow.len(), original_len - 1, "item not found in overflow heap");
}
}
}
pub struct PriorityHandle {
id: usize,
rx: watch::Receiver<u8>,
queue: PriorityQueue,
}
impl Drop for PriorityHandle {
fn drop(&mut self) {
self.queue.state.lock().unwrap().remove(self.id);
}
}
impl PriorityHandle {
pub fn current(&mut self) -> u8 {
*self.rx.borrow_and_update()
}
pub async fn next(&mut self) -> u8 {
let _ = self.rx.changed().await;
*self.rx.borrow_and_update()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_item() {
let queue = PriorityQueue::default();
let mut handle = queue.insert(100, 5);
assert_eq!(handle.current(), 0); }
#[test]
fn test_track_priority_ordering() {
let queue = PriorityQueue::default();
let mut low = queue.insert(50, 0);
let mut high = queue.insert(255, 0);
let mut mid = queue.insert(100, 0);
assert_eq!(high.current(), 0); assert_eq!(mid.current(), 1); assert_eq!(low.current(), 2); }
#[test]
fn test_group_priority_on_same_track() {
let queue = PriorityQueue::default();
let mut group10 = queue.insert(100, 10);
let mut group5 = queue.insert(100, 5);
let mut group1 = queue.insert(100, 1);
assert_eq!(group10.current(), 0);
assert_eq!(group5.current(), 1);
assert_eq!(group1.current(), 2);
}
#[test]
fn test_track_priority_overrides_group() {
let queue = PriorityQueue::default();
let mut low_track_high_group = queue.insert(50, 1000);
let mut high_track_low_group = queue.insert(255, 1);
assert_eq!(high_track_low_group.current(), 0);
assert_eq!(low_track_high_group.current(), 1);
}
#[test]
fn test_removal_on_drop() {
let queue = PriorityQueue::default();
let mut first = queue.insert(255, 0);
let mut second = queue.insert(100, 0);
let mut third = queue.insert(50, 0);
assert_eq!(first.current(), 0);
assert_eq!(second.current(), 1);
assert_eq!(third.current(), 2);
drop(second);
assert_eq!(first.current(), 0);
assert_eq!(third.current(), 1);
}
#[test]
fn test_removal_of_highest_priority() {
let queue = PriorityQueue::default();
let mut first = queue.insert(255, 0);
let mut second = queue.insert(100, 0);
assert_eq!(first.current(), 0);
assert_eq!(second.current(), 1);
drop(first);
assert_eq!(second.current(), 0);
}
#[test]
fn test_removal_of_lowest_priority() {
let queue = PriorityQueue::default();
let mut first = queue.insert(255, 0);
let mut second = queue.insert(100, 0);
assert_eq!(first.current(), 0);
assert_eq!(second.current(), 1);
drop(second);
assert_eq!(first.current(), 0);
}
#[test]
fn test_many_items_with_same_priority() {
let queue = PriorityQueue::default();
let mut handles: Vec<_> = (0..10).rev().map(|i| queue.insert(100, i)).collect();
assert_eq!(handles[0].current(), 0);
for handle in handles.iter_mut() {
assert!(handle.current() < 10);
}
}
#[test]
fn test_max_priority_value_overflow() {
let queue = PriorityQueue::default();
let mut handles: Vec<_> = (0..300).rev().map(|i| queue.insert(100, i)).collect();
assert_eq!(handles[0].current(), 0);
let mut low_priority_count = 0;
for handle in handles.iter_mut() {
if handle.current() == u8::MAX {
low_priority_count += 1;
}
}
assert!(low_priority_count > 0, "Should have some items beyond u8::MAX index");
assert_eq!(low_priority_count, 45, "Exactly 45 items should overflow (300-255)");
}
#[test]
fn test_complex_ordering() {
let queue = PriorityQueue::default();
let mut high_track_high_group = queue.insert(255, 10);
let mut high_track_low_group = queue.insert(255, 1);
let mut mid_track_high_group = queue.insert(100, 5);
let mut mid_track_low_group = queue.insert(100, 1);
let mut low_track_high_group = queue.insert(50, 100);
assert_eq!(high_track_high_group.current(), 0); assert_eq!(high_track_low_group.current(), 1); assert_eq!(mid_track_high_group.current(), 2); assert_eq!(mid_track_low_group.current(), 3); assert_eq!(low_track_high_group.current(), 4); }
#[tokio::test]
async fn test_watch_notification_on_overflow_promotion() {
let queue = PriorityQueue::default();
let mut fillers: Vec<_> = (0..255).rev().map(|i| queue.insert(100, i + 100)).collect();
let mut overflow_item = queue.insert(100, 50);
assert_eq!(overflow_item.current(), u8::MAX);
let task = tokio::spawn(async move { overflow_item.next().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
fillers.remove(0);
let result = task.await.unwrap();
assert!(result < u8::MAX, "Should be promoted from overflow");
}
#[test]
fn test_interleaved_insertions_and_removals() {
let queue = PriorityQueue::default();
let mut h1 = queue.insert(200, 0);
let h2 = queue.insert(150, 0);
let mut h3 = queue.insert(100, 0);
assert_eq!(h1.current(), 0);
drop(h2);
assert_eq!(h1.current(), 0);
assert!(h3.current() < 2);
let mut h4 = queue.insert(250, 0);
assert_eq!(h4.current(), 0);
assert_eq!(h1.current(), 1);
drop(h4);
assert_eq!(h1.current(), 0);
}
#[test]
fn test_same_track_and_group() {
let queue = PriorityQueue::default();
let mut h1 = queue.insert(100, 5);
let mut h2 = queue.insert(100, 5);
let mut h3 = queue.insert(100, 5);
let indices = [h1.current(), h2.current(), h3.current()];
assert_eq!(indices.len(), 3);
assert!(indices.contains(&0));
assert!(indices.contains(&1));
assert!(indices.contains(&2));
}
#[test]
fn test_removal_updates_siblings() {
let queue = PriorityQueue::default();
let mut root = queue.insert(255, 0);
let left = queue.insert(100, 0);
let mut right = queue.insert(100, 0);
assert_eq!(root.current(), 0);
drop(left);
assert_eq!(root.current(), 0);
assert_eq!(right.current(), 1);
}
#[test]
fn test_heap_property_maintained() {
let queue = PriorityQueue::default();
let mut handles = vec![
queue.insert(100, 5),
queue.insert(200, 3),
queue.insert(50, 10),
queue.insert(200, 8),
queue.insert(100, 1),
];
assert_eq!(handles[3].current(), 0);
drop(handles.remove(3));
assert_eq!(handles[1].current(), 0);
}
#[tokio::test]
async fn test_notification_on_demotion_to_overflow() {
let queue = PriorityQueue::default();
let _fillers: Vec<_> = (0..254).map(|i| queue.insert(100, i + 100)).collect();
let mut at_edge = queue.insert(100, 50);
assert_eq!(at_edge.current(), 254);
let task = tokio::spawn(async move { at_edge.next().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let _high = queue.insert(255, 1000);
let new_priority = task.await.unwrap();
assert_eq!(new_priority, u8::MAX, "Should be demoted to overflow");
}
#[test]
fn test_empty_after_all_removed() {
let queue = PriorityQueue::default();
let h1 = queue.insert(100, 0);
let h2 = queue.insert(200, 0);
let h3 = queue.insert(50, 0);
drop(h1);
drop(h2);
drop(h3);
let mut h4 = queue.insert(100, 0);
assert_eq!(h4.current(), 0);
}
}