oxify_connect_llm/
priority_queue.rs

1//! Priority queue for LLM request management
2//!
3//! This module provides a priority-based request queue system for managing LLM requests
4//! with different urgency levels. It supports fair processing, backpressure handling,
5//! and comprehensive queue statistics.
6//!
7//! # Example
8//!
9//! ```rust
10//! use oxify_connect_llm::{PriorityQueueProvider, PriorityQueueConfig, RequestPriority, OpenAIProvider, LlmProvider, LlmRequest};
11//!
12//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
13//! let provider = OpenAIProvider::new("test-key".to_string(), "gpt-4".to_string());
14//! let config = PriorityQueueConfig {
15//!     max_queue_size: 100,
16//!     max_workers: 5,
17//! };
18//! let queue_provider = PriorityQueueProvider::new(provider, config);
19//!
20//! // Submit high-priority request
21//! let request = LlmRequest {
22//!     prompt: "Urgent query".to_string(),
23//!     system_prompt: None,
24//!     temperature: None,
25//!     max_tokens: None,
26//!     tools: vec![],
27//!     images: vec![],
28//! };
29//! // let response = queue_provider.complete_with_priority(request, RequestPriority::High).await?;
30//! # Ok(())
31//! # }
32//! ```
33
34use crate::{LlmProvider, LlmRequest, LlmResponse, Result};
35use async_trait::async_trait;
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38use std::sync::Arc;
39use tokio::sync::{mpsc, oneshot, Mutex, Semaphore};
40
41/// Request priority levels
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
43pub enum RequestPriority {
44    /// Low priority - background tasks
45    Low = 0,
46    /// Normal priority - standard requests
47    #[default]
48    Normal = 1,
49    /// High priority - user-facing real-time requests
50    High = 2,
51}
52
53/// Configuration for priority queue
54#[derive(Debug, Clone)]
55pub struct PriorityQueueConfig {
56    /// Maximum total queue size across all priorities
57    pub max_queue_size: usize,
58    /// Maximum number of concurrent workers
59    pub max_workers: usize,
60}
61
62impl Default for PriorityQueueConfig {
63    fn default() -> Self {
64        Self {
65            max_queue_size: 1000,
66            max_workers: 10,
67        }
68    }
69}
70
71/// Statistics about priority queue operations
72#[derive(Debug, Clone, Default)]
73pub struct PriorityQueueStats {
74    /// Number of requests currently in queue
75    pub queue_length: usize,
76    /// Number of high-priority requests in queue
77    pub high_priority_count: usize,
78    /// Number of normal-priority requests in queue
79    pub normal_priority_count: usize,
80    /// Number of low-priority requests in queue
81    pub low_priority_count: usize,
82    /// Total requests processed
83    pub total_processed: usize,
84    /// Total requests rejected (queue full)
85    pub total_rejected: usize,
86    /// Number of active workers
87    pub active_workers: usize,
88}
89
90struct PriorityRequest {
91    priority: RequestPriority,
92    sequence: u64,
93    request: LlmRequest,
94    response_tx: oneshot::Sender<Result<LlmResponse>>,
95}
96
97impl PartialEq for PriorityRequest {
98    fn eq(&self, other: &Self) -> bool {
99        self.priority == other.priority && self.sequence == other.sequence
100    }
101}
102
103impl Eq for PriorityRequest {}
104
105impl PartialOrd for PriorityRequest {
106    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
107        Some(self.cmp(other))
108    }
109}
110
111impl Ord for PriorityRequest {
112    fn cmp(&self, other: &Self) -> Ordering {
113        // Higher priority first, then FIFO within same priority (lower sequence first)
114        match self.priority.cmp(&other.priority) {
115            Ordering::Equal => other.sequence.cmp(&self.sequence), // Reverse for FIFO
116            other => other,
117        }
118    }
119}
120
121struct QueueState {
122    heap: BinaryHeap<PriorityRequest>,
123    sequence: u64,
124    stats: PriorityQueueStats,
125    max_queue_size: usize,
126}
127
128impl QueueState {
129    fn new(max_queue_size: usize) -> Self {
130        Self {
131            heap: BinaryHeap::new(),
132            sequence: 0,
133            stats: PriorityQueueStats::default(),
134            max_queue_size,
135        }
136    }
137
138    fn enqueue(
139        &mut self,
140        priority: RequestPriority,
141        request: LlmRequest,
142        response_tx: oneshot::Sender<Result<LlmResponse>>,
143    ) -> bool {
144        if self.heap.len() >= self.max_queue_size {
145            // Reject if queue is full
146            self.stats.total_rejected += 1;
147            return false;
148        }
149
150        let priority_req = PriorityRequest {
151            priority,
152            sequence: self.sequence,
153            request,
154            response_tx,
155        };
156
157        self.sequence += 1;
158        self.heap.push(priority_req);
159        self.update_priority_counts();
160        true
161    }
162
163    fn dequeue(&mut self) -> Option<PriorityRequest> {
164        let req = self.heap.pop();
165        if req.is_some() {
166            self.update_priority_counts();
167        }
168        req
169    }
170
171    fn update_priority_counts(&mut self) {
172        self.stats.queue_length = self.heap.len();
173        self.stats.high_priority_count = self
174            .heap
175            .iter()
176            .filter(|r| r.priority == RequestPriority::High)
177            .count();
178        self.stats.normal_priority_count = self
179            .heap
180            .iter()
181            .filter(|r| r.priority == RequestPriority::Normal)
182            .count();
183        self.stats.low_priority_count = self
184            .heap
185            .iter()
186            .filter(|r| r.priority == RequestPriority::Low)
187            .count();
188    }
189}
190
191struct QueueWorker<P> {
192    provider: Arc<P>,
193    queue_state: Arc<Mutex<QueueState>>,
194    semaphore: Arc<Semaphore>,
195    rx: Arc<Mutex<mpsc::UnboundedReceiver<()>>>,
196}
197
198impl<P: LlmProvider + 'static> QueueWorker<P> {
199    async fn run(self) {
200        loop {
201            // Wait for notification
202            {
203                let mut rx = self.rx.lock().await;
204                if rx.recv().await.is_none() {
205                    break; // Channel closed
206                }
207            }
208
209            // Acquire worker permit
210            let permit = self.semaphore.clone().acquire_owned().await.unwrap();
211
212            // Dequeue request
213            let priority_req = {
214                let mut state = self.queue_state.lock().await;
215                state.dequeue()
216            };
217
218            if let Some(priority_req) = priority_req {
219                let provider = Arc::clone(&self.provider);
220                let queue_state = Arc::clone(&self.queue_state);
221
222                tokio::spawn(async move {
223                    let result = provider.complete(priority_req.request).await;
224
225                    // Update stats
226                    {
227                        let mut state = queue_state.lock().await;
228                        state.stats.total_processed += 1;
229                    }
230
231                    let _ = priority_req.response_tx.send(result);
232                    drop(permit);
233                });
234            } else {
235                drop(permit);
236            }
237        }
238    }
239}
240
241/// Priority queue provider that wraps any LLM provider
242pub struct PriorityQueueProvider<P> {
243    tx: mpsc::UnboundedSender<(
244        RequestPriority,
245        LlmRequest,
246        oneshot::Sender<Result<LlmResponse>>,
247    )>,
248    queue_state: Arc<Mutex<QueueState>>,
249    _phantom: std::marker::PhantomData<P>,
250}
251
252impl<P: LlmProvider + 'static> PriorityQueueProvider<P> {
253    /// Create a new priority queue provider
254    pub fn new(provider: P, config: PriorityQueueConfig) -> Self {
255        let (tx, mut rx) = mpsc::unbounded_channel();
256        let (notify_tx, notify_rx) = mpsc::unbounded_channel();
257        let queue_state = Arc::new(Mutex::new(QueueState::new(config.max_queue_size)));
258        let semaphore = Arc::new(Semaphore::new(config.max_workers));
259
260        // Spawn enqueue handler
261        let queue_state_clone = Arc::clone(&queue_state);
262        let notify_tx_clone = notify_tx.clone();
263        tokio::spawn(async move {
264            while let Some((priority, request, response_tx)) = rx.recv().await {
265                let mut state = queue_state_clone.lock().await;
266                if state.enqueue(priority, request, response_tx) {
267                    let _ = notify_tx_clone.send(()); // Notify worker
268                }
269            }
270        });
271
272        // Spawn worker
273        let worker = QueueWorker {
274            provider: Arc::new(provider),
275            queue_state: Arc::clone(&queue_state),
276            semaphore,
277            rx: Arc::new(Mutex::new(notify_rx)),
278        };
279        tokio::spawn(worker.run());
280
281        Self {
282            tx,
283            queue_state,
284            _phantom: std::marker::PhantomData,
285        }
286    }
287
288    /// Submit a request with specified priority
289    pub async fn complete_with_priority(
290        &self,
291        request: LlmRequest,
292        priority: RequestPriority,
293    ) -> Result<LlmResponse> {
294        let (response_tx, response_rx) = oneshot::channel();
295
296        self.tx
297            .send((priority, request, response_tx))
298            .map_err(|_| crate::LlmError::Other("Queue handler has stopped".to_string()))?;
299
300        response_rx
301            .await
302            .map_err(|_| crate::LlmError::Other("Response channel closed".to_string()))?
303    }
304
305    /// Get current queue statistics
306    pub async fn stats(&self) -> PriorityQueueStats {
307        self.queue_state.lock().await.stats.clone()
308    }
309}
310
311#[async_trait]
312impl<P: LlmProvider + 'static> LlmProvider for PriorityQueueProvider<P> {
313    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
314        // Default to Normal priority
315        self.complete_with_priority(request, RequestPriority::Normal)
316            .await
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::{LlmResponse, Usage};
324
325    struct MockProvider {
326        delay_ms: u64,
327    }
328
329    #[async_trait]
330    impl LlmProvider for MockProvider {
331        async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
332            if self.delay_ms > 0 {
333                tokio::time::sleep(tokio::time::Duration::from_millis(self.delay_ms)).await;
334            }
335            Ok(LlmResponse {
336                content: format!("Response to: {}", request.prompt),
337                model: "mock-model".to_string(),
338                usage: Some(Usage {
339                    prompt_tokens: 10,
340                    completion_tokens: 20,
341                    total_tokens: 30,
342                }),
343                tool_calls: vec![],
344            })
345        }
346    }
347
348    #[tokio::test]
349    async fn test_priority_ordering() {
350        assert!(RequestPriority::High > RequestPriority::Normal);
351        assert!(RequestPriority::Normal > RequestPriority::Low);
352    }
353
354    #[tokio::test]
355    async fn test_priority_queue_config_default() {
356        let config = PriorityQueueConfig::default();
357        assert_eq!(config.max_queue_size, 1000);
358        assert_eq!(config.max_workers, 10);
359    }
360
361    #[tokio::test]
362    async fn test_priority_queue_single_request() {
363        let provider = MockProvider { delay_ms: 10 };
364        let config = PriorityQueueConfig {
365            max_queue_size: 100,
366            max_workers: 5,
367        };
368        let queue_provider = PriorityQueueProvider::new(provider, config);
369
370        let request = LlmRequest {
371            prompt: "Test".to_string(),
372            system_prompt: None,
373            temperature: None,
374            max_tokens: None,
375            tools: vec![],
376            images: vec![],
377        };
378
379        let response = queue_provider
380            .complete_with_priority(request, RequestPriority::High)
381            .await
382            .unwrap();
383        assert_eq!(response.content, "Response to: Test");
384
385        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
386
387        let stats = queue_provider.stats().await;
388        assert_eq!(stats.total_processed, 1);
389    }
390
391    #[tokio::test]
392    async fn test_priority_queue_priority_ordering() {
393        let provider = MockProvider { delay_ms: 50 };
394        let config = PriorityQueueConfig {
395            max_queue_size: 100,
396            max_workers: 1, // Single worker to ensure ordering
397        };
398        let queue_provider = Arc::new(PriorityQueueProvider::new(provider, config));
399
400        let mut handles = vec![];
401
402        // Submit low priority first, then high priority
403        for (i, priority) in [
404            (0, RequestPriority::Low),
405            (1, RequestPriority::High),
406            (2, RequestPriority::Normal),
407        ]
408        .iter()
409        {
410            let qp = Arc::clone(&queue_provider);
411            let i = *i;
412            let priority = *priority;
413            let handle = tokio::spawn(async move {
414                let request = LlmRequest {
415                    prompt: format!("Request {}", i),
416                    system_prompt: None,
417                    temperature: None,
418                    max_tokens: None,
419                    tools: vec![],
420                    images: vec![],
421                };
422                qp.complete_with_priority(request, priority).await
423            });
424            handles.push(handle);
425            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
426        }
427
428        for handle in handles {
429            let _ = handle.await.unwrap();
430        }
431
432        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
433
434        let stats = queue_provider.stats().await;
435        assert_eq!(stats.total_processed, 3);
436    }
437
438    #[tokio::test]
439    async fn test_priority_queue_stats() {
440        let provider = MockProvider { delay_ms: 100 };
441        let config = PriorityQueueConfig {
442            max_queue_size: 100,
443            max_workers: 1,
444        };
445        let queue_provider = Arc::new(PriorityQueueProvider::new(provider, config));
446
447        // Submit multiple requests quickly
448        let mut handles = vec![];
449        for i in 0..5 {
450            let qp = Arc::clone(&queue_provider);
451            let handle = tokio::spawn(async move {
452                let request = LlmRequest {
453                    prompt: format!("Request {}", i),
454                    system_prompt: None,
455                    temperature: None,
456                    max_tokens: None,
457                    tools: vec![],
458                    images: vec![],
459                };
460                qp.complete(request).await
461            });
462            handles.push(handle);
463        }
464
465        // Check queue has items
466        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
467        let stats = queue_provider.stats().await;
468        assert!(stats.queue_length > 0 || stats.total_processed > 0);
469
470        // Wait for completion
471        for handle in handles {
472            let _ = handle.await.unwrap();
473        }
474
475        tokio::time::sleep(tokio::time::Duration::from_millis(600)).await;
476
477        let stats = queue_provider.stats().await;
478        assert_eq!(stats.total_processed, 5);
479        assert_eq!(stats.queue_length, 0);
480    }
481
482    #[tokio::test]
483    async fn test_priority_queue_default_priority() {
484        let provider = MockProvider { delay_ms: 10 };
485        let config = PriorityQueueConfig::default();
486        let queue_provider = PriorityQueueProvider::new(provider, config);
487
488        let request = LlmRequest {
489            prompt: "Default priority".to_string(),
490            system_prompt: None,
491            temperature: None,
492            max_tokens: None,
493            tools: vec![],
494            images: vec![],
495        };
496
497        // Using LlmProvider trait (default priority)
498        let response = queue_provider.complete(request).await.unwrap();
499        assert_eq!(response.content, "Response to: Default priority");
500    }
501}