1use std::collections::VecDeque;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::{Arc, Condvar, Mutex};
16use std::time::{Duration, Instant};
17
18use crate::sampling::SamplingParams;
19
20#[derive(Debug, Clone, serde::Serialize)]
26pub struct QueueStats {
27 pub len: usize,
29 pub capacity: usize,
31 pub utilization: f32,
33 pub total_enqueued: u64,
35 pub total_dequeued: u64,
37 pub total_dropped: u64,
39 pub drop_rate: f32,
41}
42
43pub struct BoundedQueue<T> {
53 queue: Mutex<VecDeque<(T, Instant)>>,
55 not_empty: Condvar,
57 not_full: Condvar,
59 capacity: usize,
61 pub total_enqueued: AtomicU64,
63 pub total_dequeued: AtomicU64,
65 pub total_dropped: AtomicU64,
67}
68
69impl<T: Send> BoundedQueue<T> {
70 pub fn new(capacity: usize) -> Self {
72 assert!(capacity > 0, "queue capacity must be at least 1");
73 Self {
74 queue: Mutex::new(VecDeque::with_capacity(capacity)),
75 not_empty: Condvar::new(),
76 not_full: Condvar::new(),
77 capacity,
78 total_enqueued: AtomicU64::new(0),
79 total_dequeued: AtomicU64::new(0),
80 total_dropped: AtomicU64::new(0),
81 }
82 }
83
84 pub fn try_push(&self, item: T) -> bool {
89 let mut guard = self
90 .queue
91 .lock()
92 .expect("queue mutex should not be poisoned");
93
94 if guard.len() >= self.capacity {
95 self.total_dropped.fetch_add(1, Ordering::Relaxed);
96 return false;
97 }
98
99 guard.push_back((item, Instant::now()));
100 self.total_enqueued.fetch_add(1, Ordering::Relaxed);
101 self.not_empty.notify_one();
102 true
103 }
104
105 pub fn push_timeout(&self, item: T, timeout: Duration) -> bool {
110 let deadline = Instant::now() + timeout;
111
112 let mut guard = self
113 .queue
114 .lock()
115 .expect("queue mutex should not be poisoned");
116
117 loop {
118 if guard.len() < self.capacity {
119 guard.push_back((item, Instant::now()));
120 self.total_enqueued.fetch_add(1, Ordering::Relaxed);
121 self.not_empty.notify_one();
122 return true;
123 }
124
125 let remaining = match deadline.checked_duration_since(Instant::now()) {
126 Some(d) => d,
127 None => {
128 self.total_dropped.fetch_add(1, Ordering::Relaxed);
129 return false;
130 }
131 };
132
133 let (new_guard, timed_out) = self
134 .not_full
135 .wait_timeout(guard, remaining)
136 .expect("queue condvar should not be poisoned");
137 guard = new_guard;
138
139 if timed_out.timed_out() {
140 self.total_dropped.fetch_add(1, Ordering::Relaxed);
141 return false;
142 }
143 }
144 }
145
146 pub fn pop(&self) -> Option<T> {
150 let mut guard = self
151 .queue
152 .lock()
153 .expect("queue mutex should not be poisoned");
154
155 guard.pop_front().map(|(item, _enqueued_at)| {
156 self.total_dequeued.fetch_add(1, Ordering::Relaxed);
157 self.not_full.notify_one();
158 item
159 })
160 }
161
162 pub fn pop_timeout(&self, timeout: Duration) -> Option<T> {
167 let deadline = Instant::now() + timeout;
168
169 let mut guard = self
170 .queue
171 .lock()
172 .expect("queue mutex should not be poisoned");
173
174 loop {
175 if let Some((item, _)) = guard.pop_front() {
176 self.total_dequeued.fetch_add(1, Ordering::Relaxed);
177 self.not_full.notify_one();
178 return Some(item);
179 }
180
181 let remaining = deadline.checked_duration_since(Instant::now())?;
182
183 let (new_guard, timed_out) = self
184 .not_empty
185 .wait_timeout(guard, remaining)
186 .expect("queue condvar should not be poisoned");
187 guard = new_guard;
188
189 if timed_out.timed_out() && guard.is_empty() {
190 return None;
191 }
192 }
193 }
194
195 pub fn len(&self) -> usize {
197 self.queue
198 .lock()
199 .expect("queue mutex should not be poisoned")
200 .len()
201 }
202
203 pub fn is_empty(&self) -> bool {
205 self.len() == 0
206 }
207
208 pub fn is_full(&self) -> bool {
210 self.len() >= self.capacity
211 }
212
213 pub fn capacity(&self) -> usize {
215 self.capacity
216 }
217
218 pub fn utilization(&self) -> f32 {
220 self.len() as f32 / self.capacity as f32
221 }
222
223 pub fn stats(&self) -> QueueStats {
225 let len = self.len();
226 let enqueued = self.total_enqueued.load(Ordering::Relaxed);
227 let dropped = self.total_dropped.load(Ordering::Relaxed);
228 let attempted = enqueued + dropped;
229 let drop_rate = if attempted == 0 {
230 0.0
231 } else {
232 dropped as f32 / attempted as f32
233 };
234
235 QueueStats {
236 len,
237 capacity: self.capacity,
238 utilization: len as f32 / self.capacity as f32,
239 total_enqueued: enqueued,
240 total_dequeued: self.total_dequeued.load(Ordering::Relaxed),
241 total_dropped: dropped,
242 drop_rate,
243 }
244 }
245
246 pub fn drain(&self) -> Vec<T> {
248 let mut guard = self
249 .queue
250 .lock()
251 .expect("queue mutex should not be poisoned");
252
253 let count = guard.len();
254 let items: Vec<T> = guard.drain(..).map(|(item, _)| item).collect();
255 self.total_dequeued
256 .fetch_add(count as u64, Ordering::Relaxed);
257 self.not_full.notify_all();
258 items
259 }
260}
261
262pub struct InferenceWorkItem {
268 pub id: u64,
270 pub prompt_tokens: Vec<u32>,
272 pub max_tokens: usize,
274 pub params: SamplingParams,
276 pub created_at: Instant,
278 pub result_tx: std::sync::mpsc::SyncSender<Vec<u32>>,
280}
281
282impl InferenceWorkItem {
283 pub fn wait_time(&self) -> Duration {
285 self.created_at.elapsed()
286 }
287
288 pub fn is_expired(&self, ttl: Duration) -> bool {
290 self.wait_time() > ttl
291 }
292}
293
294pub struct InferenceQueue {
304 queue: Arc<BoundedQueue<InferenceWorkItem>>,
305 next_id: AtomicU64,
306}
307
308impl InferenceQueue {
309 pub fn new(capacity: usize) -> Self {
311 Self {
312 queue: Arc::new(BoundedQueue::new(capacity)),
313 next_id: AtomicU64::new(1),
314 }
315 }
316
317 pub fn submit(
322 &self,
323 prompt_tokens: Vec<u32>,
324 max_tokens: usize,
325 params: SamplingParams,
326 ) -> Option<std::sync::mpsc::Receiver<Vec<u32>>> {
327 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
328 let (tx, rx) = std::sync::mpsc::sync_channel(1);
329
330 let item = InferenceWorkItem {
331 id,
332 prompt_tokens,
333 max_tokens,
334 params,
335 created_at: Instant::now(),
336 result_tx: tx,
337 };
338
339 if self.queue.try_push(item) {
340 Some(rx)
341 } else {
342 None
343 }
344 }
345
346 pub fn queue_depth(&self) -> usize {
348 self.queue.len()
349 }
350
351 pub fn is_full(&self) -> bool {
353 self.queue.is_full()
354 }
355
356 pub fn stats(&self) -> QueueStats {
358 self.queue.stats()
359 }
360}
361
362#[cfg(test)]
367mod tests {
368 use super::*;
369 use std::sync::atomic::Ordering;
370
371 #[test]
374 fn test_bounded_queue_try_push() {
375 let q: BoundedQueue<u32> = BoundedQueue::new(4);
376 assert!(q.try_push(1));
377 assert!(q.try_push(2));
378 assert_eq!(q.len(), 2);
379 assert_eq!(q.total_enqueued.load(Ordering::Relaxed), 2);
380 }
381
382 #[test]
383 fn test_bounded_queue_try_push_full_returns_false() {
384 let q: BoundedQueue<u32> = BoundedQueue::new(2);
385 assert!(q.try_push(10));
386 assert!(q.try_push(20));
387 assert!(!q.try_push(30));
389 assert_eq!(q.total_dropped.load(Ordering::Relaxed), 1);
390 assert_eq!(q.len(), 2);
391 }
392
393 #[test]
394 fn test_bounded_queue_pop_empty_returns_none() {
395 let q: BoundedQueue<u32> = BoundedQueue::new(4);
396 assert_eq!(q.pop(), None);
397 }
398
399 #[test]
400 fn test_bounded_queue_fifo_order() {
401 let q: BoundedQueue<u32> = BoundedQueue::new(8);
402 for i in 0..5u32 {
403 assert!(q.try_push(i));
404 }
405 for expected in 0..5u32 {
406 assert_eq!(q.pop(), Some(expected));
407 }
408 assert_eq!(q.pop(), None);
409 }
410
411 #[test]
412 fn test_bounded_queue_stats() {
413 let q: BoundedQueue<u32> = BoundedQueue::new(4);
414 q.try_push(1);
415 q.try_push(2);
416 q.pop();
417
418 let stats = q.stats();
419 assert_eq!(stats.capacity, 4);
420 assert_eq!(stats.len, 1);
421 assert_eq!(stats.total_enqueued, 2);
422 assert_eq!(stats.total_dequeued, 1);
423 assert_eq!(stats.total_dropped, 0);
424 assert!((stats.utilization - 0.25).abs() < f32::EPSILON);
425 }
426
427 #[test]
428 fn test_bounded_queue_drain() {
429 let q: BoundedQueue<u32> = BoundedQueue::new(8);
430 for i in 0..4u32 {
431 q.try_push(i);
432 }
433 let items = q.drain();
434 assert_eq!(items, vec![0, 1, 2, 3]);
435 assert_eq!(q.len(), 0);
436 assert_eq!(q.total_dequeued.load(Ordering::Relaxed), 4);
437 }
438
439 #[test]
442 fn test_inference_queue_submit_and_depth() {
443 let iq = InferenceQueue::new(8);
444 let _rx1 = iq
445 .submit(vec![1, 2, 3], 16, SamplingParams::default())
446 .expect("submit should succeed on an empty queue");
447 let _rx2 = iq
448 .submit(vec![4, 5, 6], 16, SamplingParams::default())
449 .expect("second submit should succeed");
450
451 assert_eq!(iq.queue_depth(), 2);
452 assert!(!iq.is_full());
453 }
454
455 #[test]
456 fn test_inference_queue_full_returns_none() {
457 let iq = InferenceQueue::new(2);
458
459 let _rx1 = iq
460 .submit(vec![1], 8, SamplingParams::default())
461 .expect("first submit");
462 let _rx2 = iq
463 .submit(vec![2], 8, SamplingParams::default())
464 .expect("second submit");
465
466 assert!(iq.is_full());
468 let result = iq.submit(vec![3], 8, SamplingParams::default());
469 assert!(result.is_none(), "submit to a full queue must return None");
470 }
471}