atomr_infer_remote_core/
queue.rs1use std::collections::BinaryHeap;
6
7use atomr_infer_core::batch::ExecuteBatch;
8use atomr_infer_core::tokens::TokenChunk;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum Priority {
13 Background,
14 Normal,
15 High,
16 Critical,
17}
18
19#[derive(Debug)]
20pub struct PriorityRequest {
21 pub priority: Priority,
22 pub arrival_seq: u64,
25 pub batch: ExecuteBatch,
26 pub output: tokio::sync::mpsc::Sender<Result<TokenChunk, atomr_infer_core::error::InferenceError>>,
29}
30
31impl PartialEq for PriorityRequest {
35 fn eq(&self, o: &Self) -> bool {
36 self.priority == o.priority && self.arrival_seq == o.arrival_seq
37 }
38}
39impl Eq for PriorityRequest {}
40impl PartialOrd for PriorityRequest {
41 fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(o))
43 }
44}
45impl Ord for PriorityRequest {
46 fn cmp(&self, o: &Self) -> std::cmp::Ordering {
47 self.priority
48 .cmp(&o.priority)
49 .then_with(|| o.arrival_seq.cmp(&self.arrival_seq))
50 }
51}
52
53pub struct RequestQueue {
54 inner: BinaryHeap<PriorityRequest>,
55 capacity: usize,
56 next_seq: u64,
57}
58
59impl RequestQueue {
60 pub fn new(capacity: usize) -> Self {
61 Self {
62 inner: BinaryHeap::new(),
63 capacity,
64 next_seq: 0,
65 }
66 }
67
68 pub fn len(&self) -> usize {
69 self.inner.len()
70 }
71 pub fn is_empty(&self) -> bool {
72 self.inner.is_empty()
73 }
74 pub fn is_full(&self) -> bool {
75 self.inner.len() >= self.capacity
76 }
77
78 #[allow(clippy::result_large_err)] pub fn push(&mut self, mut req: PriorityRequest) -> Result<(), PriorityRequest> {
84 if self.is_full() {
85 return Err(req);
86 }
87 req.arrival_seq = self.next_seq;
88 self.next_seq += 1;
89 self.inner.push(req);
90 Ok(())
91 }
92
93 pub fn pop(&mut self) -> Option<PriorityRequest> {
94 self.inner.pop()
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 fn req(priority: Priority) -> PriorityRequest {
103 let (tx, _rx) = tokio::sync::mpsc::channel(1);
104 PriorityRequest {
105 priority,
106 arrival_seq: 0,
107 batch: ExecuteBatch {
108 request_id: "r".into(),
109 model: "m".into(),
110 messages: vec![],
111 sampling: Default::default(),
112 stream: false,
113 estimated_tokens: 1,
114 },
115 output: tx,
116 }
117 }
118
119 #[test]
120 fn priority_first_then_fifo() {
121 let mut q = RequestQueue::new(8);
122 q.push(req(Priority::Normal)).unwrap();
123 q.push(req(Priority::Normal)).unwrap();
124 q.push(req(Priority::High)).unwrap();
125 let first = q.pop().unwrap();
126 assert_eq!(first.priority, Priority::High);
127 let second = q.pop().unwrap();
128 assert_eq!(second.priority, Priority::Normal);
129 assert_eq!(second.arrival_seq, 0);
130 }
131
132 #[test]
133 fn full_queue_rejects() {
134 let mut q = RequestQueue::new(1);
135 assert!(q.push(req(Priority::Normal)).is_ok());
136 assert!(q.push(req(Priority::Normal)).is_err());
137 }
138}