use std::collections::BinaryHeap;
use inference_core::batch::ExecuteBatch;
use inference_core::tokens::TokenChunk;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Priority {
Background,
Normal,
High,
Critical,
}
#[derive(Debug)]
pub struct PriorityRequest {
pub priority: Priority,
pub arrival_seq: u64,
pub batch: ExecuteBatch,
pub output: tokio::sync::mpsc::Sender<Result<TokenChunk, inference_core::error::InferenceError>>,
}
impl PartialEq for PriorityRequest {
fn eq(&self, o: &Self) -> bool {
self.priority == o.priority && self.arrival_seq == o.arrival_seq
}
}
impl Eq for PriorityRequest {}
impl PartialOrd for PriorityRequest {
fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(o))
}
}
impl Ord for PriorityRequest {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
self.priority
.cmp(&o.priority)
.then_with(|| o.arrival_seq.cmp(&self.arrival_seq))
}
}
pub struct RequestQueue {
inner: BinaryHeap<PriorityRequest>,
capacity: usize,
next_seq: u64,
}
impl RequestQueue {
pub fn new(capacity: usize) -> Self {
Self {
inner: BinaryHeap::new(),
capacity,
next_seq: 0,
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn is_full(&self) -> bool {
self.inner.len() >= self.capacity
}
#[allow(clippy::result_large_err)] pub fn push(&mut self, mut req: PriorityRequest) -> Result<(), PriorityRequest> {
if self.is_full() {
return Err(req);
}
req.arrival_seq = self.next_seq;
self.next_seq += 1;
self.inner.push(req);
Ok(())
}
pub fn pop(&mut self) -> Option<PriorityRequest> {
self.inner.pop()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn req(priority: Priority) -> PriorityRequest {
let (tx, _rx) = tokio::sync::mpsc::channel(1);
PriorityRequest {
priority,
arrival_seq: 0,
batch: ExecuteBatch {
request_id: "r".into(),
model: "m".into(),
messages: vec![],
sampling: Default::default(),
stream: false,
estimated_tokens: 1,
},
output: tx,
}
}
#[test]
fn priority_first_then_fifo() {
let mut q = RequestQueue::new(8);
q.push(req(Priority::Normal)).unwrap();
q.push(req(Priority::Normal)).unwrap();
q.push(req(Priority::High)).unwrap();
let first = q.pop().unwrap();
assert_eq!(first.priority, Priority::High);
let second = q.pop().unwrap();
assert_eq!(second.priority, Priority::Normal);
assert_eq!(second.arrival_seq, 0);
}
#[test]
fn full_queue_rejects() {
let mut q = RequestQueue::new(1);
assert!(q.push(req(Priority::Normal)).is_ok());
assert!(q.push(req(Priority::Normal)).is_err());
}
}