use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
use crate::sampling::SamplingParams;
#[derive(Debug, Clone, serde::Serialize)]
pub struct QueueStats {
pub len: usize,
pub capacity: usize,
pub utilization: f32,
pub total_enqueued: u64,
pub total_dequeued: u64,
pub total_dropped: u64,
pub drop_rate: f32,
}
pub struct BoundedQueue<T> {
queue: Mutex<VecDeque<(T, Instant)>>,
not_empty: Condvar,
not_full: Condvar,
capacity: usize,
pub total_enqueued: AtomicU64,
pub total_dequeued: AtomicU64,
pub total_dropped: AtomicU64,
}
impl<T: Send> BoundedQueue<T> {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "queue capacity must be at least 1");
Self {
queue: Mutex::new(VecDeque::with_capacity(capacity)),
not_empty: Condvar::new(),
not_full: Condvar::new(),
capacity,
total_enqueued: AtomicU64::new(0),
total_dequeued: AtomicU64::new(0),
total_dropped: AtomicU64::new(0),
}
}
pub fn try_push(&self, item: T) -> bool {
let mut guard = self
.queue
.lock()
.expect("queue mutex should not be poisoned");
if guard.len() >= self.capacity {
self.total_dropped.fetch_add(1, Ordering::Relaxed);
return false;
}
guard.push_back((item, Instant::now()));
self.total_enqueued.fetch_add(1, Ordering::Relaxed);
self.not_empty.notify_one();
true
}
pub fn push_timeout(&self, item: T, timeout: Duration) -> bool {
let deadline = Instant::now() + timeout;
let mut guard = self
.queue
.lock()
.expect("queue mutex should not be poisoned");
loop {
if guard.len() < self.capacity {
guard.push_back((item, Instant::now()));
self.total_enqueued.fetch_add(1, Ordering::Relaxed);
self.not_empty.notify_one();
return true;
}
let remaining = match deadline.checked_duration_since(Instant::now()) {
Some(d) => d,
None => {
self.total_dropped.fetch_add(1, Ordering::Relaxed);
return false;
}
};
let (new_guard, timed_out) = self
.not_full
.wait_timeout(guard, remaining)
.expect("queue condvar should not be poisoned");
guard = new_guard;
if timed_out.timed_out() {
self.total_dropped.fetch_add(1, Ordering::Relaxed);
return false;
}
}
}
pub fn pop(&self) -> Option<T> {
let mut guard = self
.queue
.lock()
.expect("queue mutex should not be poisoned");
guard.pop_front().map(|(item, _enqueued_at)| {
self.total_dequeued.fetch_add(1, Ordering::Relaxed);
self.not_full.notify_one();
item
})
}
pub fn pop_timeout(&self, timeout: Duration) -> Option<T> {
let deadline = Instant::now() + timeout;
let mut guard = self
.queue
.lock()
.expect("queue mutex should not be poisoned");
loop {
if let Some((item, _)) = guard.pop_front() {
self.total_dequeued.fetch_add(1, Ordering::Relaxed);
self.not_full.notify_one();
return Some(item);
}
let remaining = deadline.checked_duration_since(Instant::now())?;
let (new_guard, timed_out) = self
.not_empty
.wait_timeout(guard, remaining)
.expect("queue condvar should not be poisoned");
guard = new_guard;
if timed_out.timed_out() && guard.is_empty() {
return None;
}
}
}
pub fn len(&self) -> usize {
self.queue
.lock()
.expect("queue mutex should not be poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_full(&self) -> bool {
self.len() >= self.capacity
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn utilization(&self) -> f32 {
self.len() as f32 / self.capacity as f32
}
pub fn stats(&self) -> QueueStats {
let len = self.len();
let enqueued = self.total_enqueued.load(Ordering::Relaxed);
let dropped = self.total_dropped.load(Ordering::Relaxed);
let attempted = enqueued + dropped;
let drop_rate = if attempted == 0 {
0.0
} else {
dropped as f32 / attempted as f32
};
QueueStats {
len,
capacity: self.capacity,
utilization: len as f32 / self.capacity as f32,
total_enqueued: enqueued,
total_dequeued: self.total_dequeued.load(Ordering::Relaxed),
total_dropped: dropped,
drop_rate,
}
}
pub fn drain(&self) -> Vec<T> {
let mut guard = self
.queue
.lock()
.expect("queue mutex should not be poisoned");
let count = guard.len();
let items: Vec<T> = guard.drain(..).map(|(item, _)| item).collect();
self.total_dequeued
.fetch_add(count as u64, Ordering::Relaxed);
self.not_full.notify_all();
items
}
}
pub struct InferenceWorkItem {
pub id: u64,
pub prompt_tokens: Vec<u32>,
pub max_tokens: usize,
pub params: SamplingParams,
pub created_at: Instant,
pub result_tx: std::sync::mpsc::SyncSender<Vec<u32>>,
}
impl InferenceWorkItem {
pub fn wait_time(&self) -> Duration {
self.created_at.elapsed()
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.wait_time() > ttl
}
}
pub struct InferenceQueue {
queue: Arc<BoundedQueue<InferenceWorkItem>>,
next_id: AtomicU64,
}
impl InferenceQueue {
pub fn new(capacity: usize) -> Self {
Self {
queue: Arc::new(BoundedQueue::new(capacity)),
next_id: AtomicU64::new(1),
}
}
pub fn submit(
&self,
prompt_tokens: Vec<u32>,
max_tokens: usize,
params: SamplingParams,
) -> Option<std::sync::mpsc::Receiver<Vec<u32>>> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = std::sync::mpsc::sync_channel(1);
let item = InferenceWorkItem {
id,
prompt_tokens,
max_tokens,
params,
created_at: Instant::now(),
result_tx: tx,
};
if self.queue.try_push(item) {
Some(rx)
} else {
None
}
}
pub fn queue_depth(&self) -> usize {
self.queue.len()
}
pub fn is_full(&self) -> bool {
self.queue.is_full()
}
pub fn stats(&self) -> QueueStats {
self.queue.stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::Ordering;
#[test]
fn test_bounded_queue_try_push() {
let q: BoundedQueue<u32> = BoundedQueue::new(4);
assert!(q.try_push(1));
assert!(q.try_push(2));
assert_eq!(q.len(), 2);
assert_eq!(q.total_enqueued.load(Ordering::Relaxed), 2);
}
#[test]
fn test_bounded_queue_try_push_full_returns_false() {
let q: BoundedQueue<u32> = BoundedQueue::new(2);
assert!(q.try_push(10));
assert!(q.try_push(20));
assert!(!q.try_push(30));
assert_eq!(q.total_dropped.load(Ordering::Relaxed), 1);
assert_eq!(q.len(), 2);
}
#[test]
fn test_bounded_queue_pop_empty_returns_none() {
let q: BoundedQueue<u32> = BoundedQueue::new(4);
assert_eq!(q.pop(), None);
}
#[test]
fn test_bounded_queue_fifo_order() {
let q: BoundedQueue<u32> = BoundedQueue::new(8);
for i in 0..5u32 {
assert!(q.try_push(i));
}
for expected in 0..5u32 {
assert_eq!(q.pop(), Some(expected));
}
assert_eq!(q.pop(), None);
}
#[test]
fn test_bounded_queue_stats() {
let q: BoundedQueue<u32> = BoundedQueue::new(4);
q.try_push(1);
q.try_push(2);
q.pop();
let stats = q.stats();
assert_eq!(stats.capacity, 4);
assert_eq!(stats.len, 1);
assert_eq!(stats.total_enqueued, 2);
assert_eq!(stats.total_dequeued, 1);
assert_eq!(stats.total_dropped, 0);
assert!((stats.utilization - 0.25).abs() < f32::EPSILON);
}
#[test]
fn test_bounded_queue_drain() {
let q: BoundedQueue<u32> = BoundedQueue::new(8);
for i in 0..4u32 {
q.try_push(i);
}
let items = q.drain();
assert_eq!(items, vec![0, 1, 2, 3]);
assert_eq!(q.len(), 0);
assert_eq!(q.total_dequeued.load(Ordering::Relaxed), 4);
}
#[test]
fn test_inference_queue_submit_and_depth() {
let iq = InferenceQueue::new(8);
let _rx1 = iq
.submit(vec![1, 2, 3], 16, SamplingParams::default())
.expect("submit should succeed on an empty queue");
let _rx2 = iq
.submit(vec![4, 5, 6], 16, SamplingParams::default())
.expect("second submit should succeed");
assert_eq!(iq.queue_depth(), 2);
assert!(!iq.is_full());
}
#[test]
fn test_inference_queue_full_returns_none() {
let iq = InferenceQueue::new(2);
let _rx1 = iq
.submit(vec![1], 8, SamplingParams::default())
.expect("first submit");
let _rx2 = iq
.submit(vec![2], 8, SamplingParams::default())
.expect("second submit");
assert!(iq.is_full());
let result = iq.submit(vec![3], 8, SamplingParams::default());
assert!(result.is_none(), "submit to a full queue must return None");
}
}