use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
pub enum Priority {
VIP = 4,
High = 3,
Normal = 2,
Low = 1,
}
impl Priority {
pub fn weight(&self) -> u32 {
match self {
Priority::VIP => 8,
Priority::High => 4,
Priority::Normal => 2,
Priority::Low => 1,
}
}
pub fn from_u8(value: u8) -> Option<Self> {
match value {
1 => Some(Priority::Low),
2 => Some(Priority::Normal),
3 => Some(Priority::High),
4 => Some(Priority::VIP),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestMetadata {
pub request_id: String,
pub user_id: String,
pub priority: Priority,
pub created_at: u64, pub deadline_secs: Option<u64>,
pub estimated_tokens: u32,
pub model_id: String,
pub tags: Vec<String>,
pub retry_count: u32,
pub dependencies: Vec<String>,
}
impl RequestMetadata {
pub fn new(request_id: String, user_id: String, priority: Priority, model_id: String) -> Self {
Self {
request_id,
user_id,
priority,
created_at: Self::current_timestamp(),
deadline_secs: None,
estimated_tokens: 256, model_id,
tags: Vec::new(),
retry_count: 0,
dependencies: Vec::new(),
}
}
pub fn with_deadline(mut self, deadline_secs: u64) -> Self {
self.deadline_secs = Some(deadline_secs);
self
}
pub fn with_estimated_tokens(mut self, tokens: u32) -> Self {
self.estimated_tokens = tokens;
self
}
pub fn with_tag(mut self, tag: String) -> Self {
self.tags.push(tag);
self
}
pub fn with_dependency(mut self, dep_id: String) -> Self {
self.dependencies.push(dep_id);
self
}
pub fn effective_priority(&self) -> i32 {
let mut priority_value = self.priority as i32;
let age_ms = Self::current_timestamp().saturating_sub(self.created_at);
let age_secs = age_ms / 1000;
priority_value += (age_secs / 10) as i32;
if let Some(deadline_secs) = self.deadline_secs {
let elapsed_secs = age_secs;
let remaining_secs = deadline_secs.saturating_sub(elapsed_secs);
if remaining_secs < 10 {
priority_value = (Priority::VIP as i32) + 10;
} else if remaining_secs < 30 {
priority_value = (Priority::VIP as i32) + 5;
}
}
priority_value
}
fn current_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub fn age_ms(&self) -> u64 {
Self::current_timestamp().saturating_sub(self.created_at)
}
}
#[derive(Debug, Clone)]
struct QueuedRequest {
metadata: RequestMetadata,
sequence: u64,
}
impl PartialEq for QueuedRequest {
fn eq(&self, other: &Self) -> bool {
self.metadata.request_id == other.metadata.request_id
}
}
impl Eq for QueuedRequest {}
impl PartialOrd for QueuedRequest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for QueuedRequest {
fn cmp(&self, other: &Self) -> Ordering {
let self_priority = self.metadata.effective_priority();
let other_priority = other.metadata.effective_priority();
match self_priority.cmp(&other_priority) {
Ordering::Equal => {
other.sequence.cmp(&self.sequence)
}
ordering => ordering,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStats {
pub queued_count: usize,
pub total_weight: u32,
pub estimated_wait_ms: u64,
}
#[derive(Debug)]
pub struct PriorityQueue {
heap: BinaryHeap<QueuedRequest>,
sequence_counter: u64,
}
impl PriorityQueue {
pub fn new() -> Self {
Self {
heap: BinaryHeap::new(),
sequence_counter: 0,
}
}
pub fn push(&mut self, metadata: RequestMetadata) {
let queued = QueuedRequest {
metadata,
sequence: self.sequence_counter,
};
self.sequence_counter = self.sequence_counter.wrapping_add(1);
self.heap.push(queued);
}
pub fn pop(&mut self) -> Option<RequestMetadata> {
self.heap.pop().map(|q| q.metadata)
}
pub fn peek(&self) -> Option<&RequestMetadata> {
self.heap.peek().map(|q| &q.metadata)
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn stats(&self) -> QueueStats {
let queued_count = self.heap.len();
let total_weight: u32 = self.heap.iter().map(|q| q.metadata.priority.weight()).sum();
let estimated_tokens: u32 = self.heap.iter().map(|q| q.metadata.estimated_tokens).sum();
let estimated_wait_ms = ((estimated_tokens as f64 / 50.0) * 1000.0) as u64;
QueueStats {
queued_count,
total_weight,
estimated_wait_ms,
}
}
pub fn remove_by_id(&mut self, request_id: &str) -> Option<RequestMetadata> {
let mut removed = None;
let temp: Vec<QueuedRequest> = self
.heap
.drain()
.filter(|q| {
if q.metadata.request_id == request_id {
removed = Some(q.metadata.clone());
false
} else {
true
}
})
.collect();
temp.into_iter().for_each(|q| {
self.heap.push(q);
});
removed
}
pub fn iter(&self) -> impl Iterator<Item = &RequestMetadata> {
self.heap.iter().map(|q| &q.metadata)
}
pub fn drain(&mut self) -> impl Iterator<Item = RequestMetadata> + '_ {
self.heap.drain().map(|q| q.metadata)
}
}
impl Default for PriorityQueue {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_priority_ordering() {
let mut queue = PriorityQueue::new();
queue.push(RequestMetadata::new(
"req1".to_string(),
"user1".to_string(),
Priority::Normal,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req2".to_string(),
"user2".to_string(),
Priority::VIP,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req3".to_string(),
"user3".to_string(),
Priority::Low,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req4".to_string(),
"user4".to_string(),
Priority::High,
"model1".to_string(),
));
assert_eq!(queue.pop().unwrap().request_id, "req2"); assert_eq!(queue.pop().unwrap().request_id, "req4"); assert_eq!(queue.pop().unwrap().request_id, "req1"); assert_eq!(queue.pop().unwrap().request_id, "req3"); assert!(queue.is_empty());
}
#[test]
fn test_fifo_ordering_same_priority() {
let mut queue = PriorityQueue::new();
queue.push(RequestMetadata::new(
"req1".to_string(),
"user1".to_string(),
Priority::Normal,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req2".to_string(),
"user2".to_string(),
Priority::Normal,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req3".to_string(),
"user3".to_string(),
Priority::Normal,
"model1".to_string(),
));
assert_eq!(queue.pop().unwrap().request_id, "req1");
assert_eq!(queue.pop().unwrap().request_id, "req2");
assert_eq!(queue.pop().unwrap().request_id, "req3");
}
#[test]
fn test_deadline_handling() {
let mut queue = PriorityQueue::new();
queue.push(RequestMetadata::new(
"req1".to_string(),
"user1".to_string(),
Priority::Low,
"model1".to_string(),
));
let mut urgent_req = RequestMetadata::new(
"req2".to_string(),
"user2".to_string(),
Priority::Normal,
"model1".to_string(),
);
urgent_req.deadline_secs = Some(5);
queue.push(urgent_req);
let first = queue.pop().unwrap();
assert_eq!(first.request_id, "req2");
}
#[test]
fn test_queue_stats() {
let mut queue = PriorityQueue::new();
queue.push(RequestMetadata::new(
"req1".to_string(),
"user1".to_string(),
Priority::VIP,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req2".to_string(),
"user2".to_string(),
Priority::High,
"model1".to_string(),
));
let stats = queue.stats();
assert_eq!(stats.queued_count, 2);
assert_eq!(stats.total_weight, 8 + 4); }
#[test]
fn test_remove_by_id() {
let mut queue = PriorityQueue::new();
queue.push(RequestMetadata::new(
"req1".to_string(),
"user1".to_string(),
Priority::Normal,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req2".to_string(),
"user2".to_string(),
Priority::Normal,
"model1".to_string(),
));
queue.push(RequestMetadata::new(
"req3".to_string(),
"user3".to_string(),
Priority::Normal,
"model1".to_string(),
));
let removed = queue.remove_by_id("req2");
assert!(removed.is_some());
assert_eq!(removed.unwrap().request_id, "req2");
assert_eq!(queue.len(), 2);
assert_eq!(queue.pop().unwrap().request_id, "req1");
assert_eq!(queue.pop().unwrap().request_id, "req3");
}
#[test]
fn test_empty_queue() {
let queue = PriorityQueue::new();
assert!(queue.is_empty());
assert_eq!(queue.len(), 0);
assert!(queue.peek().is_none());
}
#[test]
fn test_priority_weight() {
assert_eq!(Priority::VIP.weight(), 8);
assert_eq!(Priority::High.weight(), 4);
assert_eq!(Priority::Normal.weight(), 2);
assert_eq!(Priority::Low.weight(), 1);
}
#[test]
fn test_request_builder() {
let req = RequestMetadata::new(
"req1".to_string(),
"user1".to_string(),
Priority::Normal,
"model1".to_string(),
)
.with_deadline(60)
.with_estimated_tokens(512)
.with_tag("batch".to_string())
.with_dependency("dep1".to_string());
assert_eq!(req.deadline_secs, Some(60));
assert_eq!(req.estimated_tokens, 512);
assert_eq!(req.tags, vec!["batch".to_string()]);
assert_eq!(req.dependencies, vec!["dep1".to_string()]);
}
}