use std::collections::VecDeque;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Priority {
Critical = 0,
High = 1,
Normal = 2,
Low = 3,
}
impl Priority {
pub const COUNT: usize = 4;
#[inline]
pub fn from_u8(v: u8) -> Self {
match v {
0 => Self::Critical,
1 => Self::High,
2 => Self::Normal,
_ => Self::Low,
}
}
#[inline]
pub fn suggested_channel(&self) -> u8 {
match self {
Self::Critical => 0, Self::High => 1, Self::Normal => 0, Self::Low => 2, }
}
}
impl Default for Priority {
fn default() -> Self {
Self::Normal
}
}
#[derive(Debug)]
pub struct QueuedPacket {
pub peer_id: u16,
pub channel: u8,
pub data: Vec<u8>,
pub priority: Priority,
}
pub struct PriorityQueue {
queues: [VecDeque<QueuedPacket>; Priority::COUNT],
total_size: usize,
max_size: usize,
}
impl PriorityQueue {
pub fn new() -> Self {
Self::with_capacity(1024)
}
pub fn with_capacity(max_size: usize) -> Self {
Self {
queues: [
VecDeque::with_capacity(64), VecDeque::with_capacity(128), VecDeque::with_capacity(64), VecDeque::with_capacity(32), ],
total_size: 0,
max_size,
}
}
#[inline]
pub fn push(&mut self, priority: Priority, peer_id: u16, channel: u8, data: Vec<u8>) -> bool {
let packet_size = data.len();
if priority != Priority::Critical && self.total_size + packet_size > self.max_size {
return false;
}
self.total_size += packet_size;
self.queues[priority as usize].push_back(QueuedPacket {
peer_id,
channel,
data,
priority,
});
true
}
#[inline]
pub fn pop(&mut self) -> Option<QueuedPacket> {
for queue in &mut self.queues {
if let Some(packet) = queue.pop_front() {
self.total_size = self.total_size.saturating_sub(packet.data.len());
return Some(packet);
}
}
None
}
#[inline]
pub fn peek(&self) -> Option<&QueuedPacket> {
for queue in &self.queues {
if let Some(packet) = queue.front() {
return Some(packet);
}
}
None
}
#[inline]
pub fn is_empty(&self) -> bool {
self.queues.iter().all(|q| q.is_empty())
}
#[inline]
pub fn len(&self) -> usize {
self.queues.iter().map(|q| q.len()).sum()
}
#[inline]
pub fn total_size(&self) -> usize {
self.total_size
}
#[inline]
pub fn count_at(&self, priority: Priority) -> usize {
self.queues[priority as usize].len()
}
pub fn drop_below(&mut self, priority: Priority) {
for p in (priority as usize + 1)..Priority::COUNT {
for packet in self.queues[p].drain(..) {
self.total_size = self.total_size.saturating_sub(packet.data.len());
}
}
}
pub fn clear(&mut self) {
for queue in &mut self.queues {
queue.clear();
}
self.total_size = 0;
}
pub fn drain_budget(&mut self, mut budget: usize) -> Vec<QueuedPacket> {
let mut result = Vec::new();
for queue in &mut self.queues {
while let Some(packet) = queue.front() {
if packet.data.len() > budget && !result.is_empty() {
break;
}
let packet = queue.pop_front().unwrap();
budget = budget.saturating_sub(packet.data.len());
self.total_size = self.total_size.saturating_sub(packet.data.len());
result.push(packet);
if budget == 0 {
return result;
}
}
}
result
}
}
impl Default for PriorityQueue {
fn default() -> Self {
Self::new()
}
}
pub struct WeightedQueue {
inner: PriorityQueue,
weights: [usize; Priority::COUNT],
counters: [usize; Priority::COUNT],
}
impl WeightedQueue {
pub fn new() -> Self {
Self {
inner: PriorityQueue::new(),
weights: [8, 4, 2, 1], counters: [0; Priority::COUNT],
}
}
#[inline]
pub fn push(&mut self, priority: Priority, peer_id: u16, channel: u8, data: Vec<u8>) -> bool {
self.inner.push(priority, peer_id, channel, data)
}
pub fn pop_weighted(&mut self) -> Option<QueuedPacket> {
for (i, queue) in self.inner.queues.iter_mut().enumerate() {
if !queue.is_empty() && self.counters[i] < self.weights[i] {
self.counters[i] += 1;
if let Some(packet) = queue.pop_front() {
self.inner.total_size = self.inner.total_size.saturating_sub(packet.data.len());
return Some(packet);
}
}
}
self.counters = [0; Priority::COUNT];
for (i, queue) in self.inner.queues.iter_mut().enumerate() {
if !queue.is_empty() {
self.counters[i] += 1;
if let Some(packet) = queue.pop_front() {
self.inner.total_size = self.inner.total_size.saturating_sub(packet.data.len());
return Some(packet);
}
}
}
None
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl Default for WeightedQueue {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_priority_ordering() {
let mut queue = PriorityQueue::new();
queue.push(Priority::Low, 1, 0, vec![1]);
queue.push(Priority::Critical, 1, 0, vec![2]);
queue.push(Priority::Normal, 1, 0, vec![3]);
queue.push(Priority::High, 1, 0, vec![4]);
assert_eq!(queue.pop().unwrap().data, vec![2]); assert_eq!(queue.pop().unwrap().data, vec![4]); assert_eq!(queue.pop().unwrap().data, vec![3]); assert_eq!(queue.pop().unwrap().data, vec![1]); assert!(queue.pop().is_none());
}
#[test]
fn test_same_priority_fifo() {
let mut queue = PriorityQueue::new();
queue.push(Priority::Normal, 1, 0, vec![1]);
queue.push(Priority::Normal, 1, 0, vec![2]);
queue.push(Priority::Normal, 1, 0, vec![3]);
assert_eq!(queue.pop().unwrap().data, vec![1]);
assert_eq!(queue.pop().unwrap().data, vec![2]);
assert_eq!(queue.pop().unwrap().data, vec![3]);
}
#[test]
fn test_critical_bypass_capacity() {
let mut queue = PriorityQueue::with_capacity(10);
assert!(queue.push(Priority::Normal, 1, 0, vec![0; 10]));
assert!(!queue.push(Priority::Normal, 1, 0, vec![0; 5]));
assert!(queue.push(Priority::Critical, 1, 0, vec![0; 100]));
}
#[test]
fn test_drop_below() {
let mut queue = PriorityQueue::new();
queue.push(Priority::Critical, 1, 0, vec![1]);
queue.push(Priority::High, 1, 0, vec![2]);
queue.push(Priority::Normal, 1, 0, vec![3]);
queue.push(Priority::Low, 1, 0, vec![4]);
queue.drop_below(Priority::High);
assert_eq!(queue.len(), 2); assert_eq!(queue.pop().unwrap().priority, Priority::Critical);
assert_eq!(queue.pop().unwrap().priority, Priority::High);
}
#[test]
fn test_drain_budget() {
let mut queue = PriorityQueue::new();
queue.push(Priority::Critical, 1, 0, vec![0; 10]);
queue.push(Priority::High, 1, 0, vec![0; 15]); queue.push(Priority::Normal, 1, 0, vec![0; 30]);
let packets = queue.drain_budget(25);
assert_eq!(packets.len(), 2);
assert_eq!(packets[0].priority, Priority::Critical);
assert_eq!(packets[1].priority, Priority::High);
}
#[test]
fn test_weighted_queue() {
let mut queue = WeightedQueue::new();
for _ in 0..4 {
queue.push(Priority::Critical, 1, 0, vec![1]);
queue.push(Priority::High, 1, 0, vec![2]);
queue.push(Priority::Normal, 1, 0, vec![3]);
queue.push(Priority::Low, 1, 0, vec![4]);
}
let mut counts = [0usize; 4];
for _ in 0..16 {
if let Some(p) = queue.pop_weighted() {
counts[p.priority as usize] += 1;
}
}
assert!(counts[0] >= counts[1]);
assert!(counts[1] >= counts[2]);
assert!(counts[2] >= counts[3]);
}
}