1use crate::{LlmProvider, LlmRequest, LlmResponse, Result};
35use async_trait::async_trait;
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38use std::sync::Arc;
39use tokio::sync::{mpsc, oneshot, Mutex, Semaphore};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
43pub enum RequestPriority {
44 Low = 0,
46 #[default]
48 Normal = 1,
49 High = 2,
51}
52
53#[derive(Debug, Clone)]
55pub struct PriorityQueueConfig {
56 pub max_queue_size: usize,
58 pub max_workers: usize,
60}
61
62impl Default for PriorityQueueConfig {
63 fn default() -> Self {
64 Self {
65 max_queue_size: 1000,
66 max_workers: 10,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Default)]
73pub struct PriorityQueueStats {
74 pub queue_length: usize,
76 pub high_priority_count: usize,
78 pub normal_priority_count: usize,
80 pub low_priority_count: usize,
82 pub total_processed: usize,
84 pub total_rejected: usize,
86 pub active_workers: usize,
88}
89
90struct PriorityRequest {
91 priority: RequestPriority,
92 sequence: u64,
93 request: LlmRequest,
94 response_tx: oneshot::Sender<Result<LlmResponse>>,
95}
96
97impl PartialEq for PriorityRequest {
98 fn eq(&self, other: &Self) -> bool {
99 self.priority == other.priority && self.sequence == other.sequence
100 }
101}
102
103impl Eq for PriorityRequest {}
104
105impl PartialOrd for PriorityRequest {
106 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107 Some(self.cmp(other))
108 }
109}
110
111impl Ord for PriorityRequest {
112 fn cmp(&self, other: &Self) -> Ordering {
113 match self.priority.cmp(&other.priority) {
115 Ordering::Equal => other.sequence.cmp(&self.sequence), other => other,
117 }
118 }
119}
120
121struct QueueState {
122 heap: BinaryHeap<PriorityRequest>,
123 sequence: u64,
124 stats: PriorityQueueStats,
125 max_queue_size: usize,
126}
127
128impl QueueState {
129 fn new(max_queue_size: usize) -> Self {
130 Self {
131 heap: BinaryHeap::new(),
132 sequence: 0,
133 stats: PriorityQueueStats::default(),
134 max_queue_size,
135 }
136 }
137
138 fn enqueue(
139 &mut self,
140 priority: RequestPriority,
141 request: LlmRequest,
142 response_tx: oneshot::Sender<Result<LlmResponse>>,
143 ) -> bool {
144 if self.heap.len() >= self.max_queue_size {
145 self.stats.total_rejected += 1;
147 return false;
148 }
149
150 let priority_req = PriorityRequest {
151 priority,
152 sequence: self.sequence,
153 request,
154 response_tx,
155 };
156
157 self.sequence += 1;
158 self.heap.push(priority_req);
159 self.update_priority_counts();
160 true
161 }
162
163 fn dequeue(&mut self) -> Option<PriorityRequest> {
164 let req = self.heap.pop();
165 if req.is_some() {
166 self.update_priority_counts();
167 }
168 req
169 }
170
171 fn update_priority_counts(&mut self) {
172 self.stats.queue_length = self.heap.len();
173 self.stats.high_priority_count = self
174 .heap
175 .iter()
176 .filter(|r| r.priority == RequestPriority::High)
177 .count();
178 self.stats.normal_priority_count = self
179 .heap
180 .iter()
181 .filter(|r| r.priority == RequestPriority::Normal)
182 .count();
183 self.stats.low_priority_count = self
184 .heap
185 .iter()
186 .filter(|r| r.priority == RequestPriority::Low)
187 .count();
188 }
189}
190
191struct QueueWorker<P> {
192 provider: Arc<P>,
193 queue_state: Arc<Mutex<QueueState>>,
194 semaphore: Arc<Semaphore>,
195 rx: Arc<Mutex<mpsc::UnboundedReceiver<()>>>,
196}
197
198impl<P: LlmProvider + 'static> QueueWorker<P> {
199 async fn run(self) {
200 loop {
201 {
203 let mut rx = self.rx.lock().await;
204 if rx.recv().await.is_none() {
205 break; }
207 }
208
209 let permit = self.semaphore.clone().acquire_owned().await.unwrap();
211
212 let priority_req = {
214 let mut state = self.queue_state.lock().await;
215 state.dequeue()
216 };
217
218 if let Some(priority_req) = priority_req {
219 let provider = Arc::clone(&self.provider);
220 let queue_state = Arc::clone(&self.queue_state);
221
222 tokio::spawn(async move {
223 let result = provider.complete(priority_req.request).await;
224
225 {
227 let mut state = queue_state.lock().await;
228 state.stats.total_processed += 1;
229 }
230
231 let _ = priority_req.response_tx.send(result);
232 drop(permit);
233 });
234 } else {
235 drop(permit);
236 }
237 }
238 }
239}
240
241pub struct PriorityQueueProvider<P> {
243 tx: mpsc::UnboundedSender<(
244 RequestPriority,
245 LlmRequest,
246 oneshot::Sender<Result<LlmResponse>>,
247 )>,
248 queue_state: Arc<Mutex<QueueState>>,
249 _phantom: std::marker::PhantomData<P>,
250}
251
252impl<P: LlmProvider + 'static> PriorityQueueProvider<P> {
253 pub fn new(provider: P, config: PriorityQueueConfig) -> Self {
255 let (tx, mut rx) = mpsc::unbounded_channel();
256 let (notify_tx, notify_rx) = mpsc::unbounded_channel();
257 let queue_state = Arc::new(Mutex::new(QueueState::new(config.max_queue_size)));
258 let semaphore = Arc::new(Semaphore::new(config.max_workers));
259
260 let queue_state_clone = Arc::clone(&queue_state);
262 let notify_tx_clone = notify_tx.clone();
263 tokio::spawn(async move {
264 while let Some((priority, request, response_tx)) = rx.recv().await {
265 let mut state = queue_state_clone.lock().await;
266 if state.enqueue(priority, request, response_tx) {
267 let _ = notify_tx_clone.send(()); }
269 }
270 });
271
272 let worker = QueueWorker {
274 provider: Arc::new(provider),
275 queue_state: Arc::clone(&queue_state),
276 semaphore,
277 rx: Arc::new(Mutex::new(notify_rx)),
278 };
279 tokio::spawn(worker.run());
280
281 Self {
282 tx,
283 queue_state,
284 _phantom: std::marker::PhantomData,
285 }
286 }
287
288 pub async fn complete_with_priority(
290 &self,
291 request: LlmRequest,
292 priority: RequestPriority,
293 ) -> Result<LlmResponse> {
294 let (response_tx, response_rx) = oneshot::channel();
295
296 self.tx
297 .send((priority, request, response_tx))
298 .map_err(|_| crate::LlmError::Other("Queue handler has stopped".to_string()))?;
299
300 response_rx
301 .await
302 .map_err(|_| crate::LlmError::Other("Response channel closed".to_string()))?
303 }
304
305 pub async fn stats(&self) -> PriorityQueueStats {
307 self.queue_state.lock().await.stats.clone()
308 }
309}
310
311#[async_trait]
312impl<P: LlmProvider + 'static> LlmProvider for PriorityQueueProvider<P> {
313 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
314 self.complete_with_priority(request, RequestPriority::Normal)
316 .await
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::{LlmResponse, Usage};
324
325 struct MockProvider {
326 delay_ms: u64,
327 }
328
329 #[async_trait]
330 impl LlmProvider for MockProvider {
331 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
332 if self.delay_ms > 0 {
333 tokio::time::sleep(tokio::time::Duration::from_millis(self.delay_ms)).await;
334 }
335 Ok(LlmResponse {
336 content: format!("Response to: {}", request.prompt),
337 model: "mock-model".to_string(),
338 usage: Some(Usage {
339 prompt_tokens: 10,
340 completion_tokens: 20,
341 total_tokens: 30,
342 }),
343 tool_calls: vec![],
344 })
345 }
346 }
347
348 #[tokio::test]
349 async fn test_priority_ordering() {
350 assert!(RequestPriority::High > RequestPriority::Normal);
351 assert!(RequestPriority::Normal > RequestPriority::Low);
352 }
353
354 #[tokio::test]
355 async fn test_priority_queue_config_default() {
356 let config = PriorityQueueConfig::default();
357 assert_eq!(config.max_queue_size, 1000);
358 assert_eq!(config.max_workers, 10);
359 }
360
361 #[tokio::test]
362 async fn test_priority_queue_single_request() {
363 let provider = MockProvider { delay_ms: 10 };
364 let config = PriorityQueueConfig {
365 max_queue_size: 100,
366 max_workers: 5,
367 };
368 let queue_provider = PriorityQueueProvider::new(provider, config);
369
370 let request = LlmRequest {
371 prompt: "Test".to_string(),
372 system_prompt: None,
373 temperature: None,
374 max_tokens: None,
375 tools: vec![],
376 images: vec![],
377 };
378
379 let response = queue_provider
380 .complete_with_priority(request, RequestPriority::High)
381 .await
382 .unwrap();
383 assert_eq!(response.content, "Response to: Test");
384
385 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
386
387 let stats = queue_provider.stats().await;
388 assert_eq!(stats.total_processed, 1);
389 }
390
391 #[tokio::test]
392 async fn test_priority_queue_priority_ordering() {
393 let provider = MockProvider { delay_ms: 50 };
394 let config = PriorityQueueConfig {
395 max_queue_size: 100,
396 max_workers: 1, };
398 let queue_provider = Arc::new(PriorityQueueProvider::new(provider, config));
399
400 let mut handles = vec![];
401
402 for (i, priority) in [
404 (0, RequestPriority::Low),
405 (1, RequestPriority::High),
406 (2, RequestPriority::Normal),
407 ]
408 .iter()
409 {
410 let qp = Arc::clone(&queue_provider);
411 let i = *i;
412 let priority = *priority;
413 let handle = tokio::spawn(async move {
414 let request = LlmRequest {
415 prompt: format!("Request {}", i),
416 system_prompt: None,
417 temperature: None,
418 max_tokens: None,
419 tools: vec![],
420 images: vec![],
421 };
422 qp.complete_with_priority(request, priority).await
423 });
424 handles.push(handle);
425 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
426 }
427
428 for handle in handles {
429 let _ = handle.await.unwrap();
430 }
431
432 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
433
434 let stats = queue_provider.stats().await;
435 assert_eq!(stats.total_processed, 3);
436 }
437
438 #[tokio::test]
439 async fn test_priority_queue_stats() {
440 let provider = MockProvider { delay_ms: 100 };
441 let config = PriorityQueueConfig {
442 max_queue_size: 100,
443 max_workers: 1,
444 };
445 let queue_provider = Arc::new(PriorityQueueProvider::new(provider, config));
446
447 let mut handles = vec![];
449 for i in 0..5 {
450 let qp = Arc::clone(&queue_provider);
451 let handle = tokio::spawn(async move {
452 let request = LlmRequest {
453 prompt: format!("Request {}", i),
454 system_prompt: None,
455 temperature: None,
456 max_tokens: None,
457 tools: vec![],
458 images: vec![],
459 };
460 qp.complete(request).await
461 });
462 handles.push(handle);
463 }
464
465 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
467 let stats = queue_provider.stats().await;
468 assert!(stats.queue_length > 0 || stats.total_processed > 0);
469
470 for handle in handles {
472 let _ = handle.await.unwrap();
473 }
474
475 tokio::time::sleep(tokio::time::Duration::from_millis(600)).await;
476
477 let stats = queue_provider.stats().await;
478 assert_eq!(stats.total_processed, 5);
479 assert_eq!(stats.queue_length, 0);
480 }
481
482 #[tokio::test]
483 async fn test_priority_queue_default_priority() {
484 let provider = MockProvider { delay_ms: 10 };
485 let config = PriorityQueueConfig::default();
486 let queue_provider = PriorityQueueProvider::new(provider, config);
487
488 let request = LlmRequest {
489 prompt: "Default priority".to_string(),
490 system_prompt: None,
491 temperature: None,
492 max_tokens: None,
493 tools: vec![],
494 images: vec![],
495 };
496
497 let response = queue_provider.complete(request).await.unwrap();
499 assert_eq!(response.content, "Response to: Default priority");
500 }
501}