1use crate::protocol::DownloadId;
8use parking_lot::Mutex;
9use std::collections::{BinaryHeap, HashMap};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};
13
14pub use crate::protocol::DownloadPriority;
16
17#[derive(Debug, Clone, Eq, PartialEq)]
19struct QueueEntry {
20 id: DownloadId,
21 priority: DownloadPriority,
22 sequence: u64,
24}
25
26impl Ord for QueueEntry {
27 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
28 match self.priority.cmp(&other.priority) {
30 std::cmp::Ordering::Equal => other.sequence.cmp(&self.sequence), other => other,
32 }
33 }
34}
35
36impl PartialOrd for QueueEntry {
37 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
38 Some(self.cmp(other))
39 }
40}
41
42pub struct PriorityPermit {
45 _permit: OwnedSemaphorePermit,
46 id: DownloadId,
47 queue: Arc<PriorityQueue>,
48}
49
50impl Drop for PriorityPermit {
51 fn drop(&mut self) {
52 self.queue.inner.lock().active.remove(&self.id);
54 self.queue.notify.notify_waiters();
56 }
57}
58
59struct PriorityQueueInner {
61 waiting: BinaryHeap<QueueEntry>,
63 active: HashMap<DownloadId, DownloadPriority>,
65 waiting_priorities: HashMap<DownloadId, DownloadPriority>,
67}
68
69pub struct PriorityQueue {
75 semaphore: Arc<Semaphore>,
77 inner: Mutex<PriorityQueueInner>,
79 sequence: AtomicU64,
81 notify: Notify,
83}
84
85impl PriorityQueue {
86 pub fn new(max_concurrent: usize) -> Arc<Self> {
88 Arc::new(Self {
89 semaphore: Arc::new(Semaphore::new(max_concurrent)),
90 inner: Mutex::new(PriorityQueueInner {
91 waiting: BinaryHeap::new(),
92 active: HashMap::new(),
93 waiting_priorities: HashMap::new(),
94 }),
95 sequence: AtomicU64::new(0),
96 notify: Notify::new(),
97 })
98 }
99
100 pub async fn acquire(
116 self: &Arc<Self>,
117 id: DownloadId,
118 priority: DownloadPriority,
119 ) -> PriorityPermit {
120 let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
122 {
123 let mut inner = self.inner.lock();
124 inner.waiting.push(QueueEntry {
125 id,
126 priority,
127 sequence,
128 });
129 inner.waiting_priorities.insert(id, priority);
130 }
131
132 loop {
133 {
135 let inner = self.inner.lock();
136 if let Some(next) = inner.waiting.peek() {
137 if next.id == id {
138 drop(inner); if let Ok(permit) = self.semaphore.clone().try_acquire_owned() {
143 let mut inner = self.inner.lock();
145 inner.waiting.pop();
146 inner.waiting_priorities.remove(&id);
147 inner.active.insert(id, priority);
148
149 return PriorityPermit {
150 _permit: permit,
151 id,
152 queue: Arc::clone(self),
153 };
154 }
155 }
156 }
157 }
158
159 self.notify.notified().await;
161 }
162 }
163
164 pub fn try_acquire(
186 self: &Arc<Self>,
187 id: DownloadId,
188 priority: DownloadPriority,
189 ) -> Option<PriorityPermit> {
190 let mut inner = self.inner.lock();
191
192 if let Some(next) = inner.waiting.peek() {
194 if next.priority > priority {
195 return None; }
197 }
198
199 match self.semaphore.clone().try_acquire_owned() {
201 Ok(permit) => {
202 inner.active.insert(id, priority);
203 Some(PriorityPermit {
204 _permit: permit,
205 id,
206 queue: Arc::clone(self),
207 })
208 }
209 Err(_) => None,
210 }
211 }
212
213 pub fn set_priority(&self, id: DownloadId, new_priority: DownloadPriority) -> bool {
218 let mut inner = self.inner.lock();
219
220 if inner.waiting_priorities.contains_key(&id) {
222 let entries: Vec<_> = inner.waiting.drain().collect();
224 for entry in entries {
225 if entry.id == id {
226 inner.waiting.push(QueueEntry {
227 id: entry.id,
228 priority: new_priority,
229 sequence: entry.sequence,
230 });
231 } else {
232 inner.waiting.push(entry);
233 }
234 }
235 inner.waiting_priorities.insert(id, new_priority);
236 drop(inner);
237
238 self.notify.notify_waiters();
240 return true;
241 }
242
243 if let Some(priority) = inner.active.get_mut(&id) {
245 *priority = new_priority;
246 return true;
247 }
248
249 false
250 }
251
252 pub fn remove(&self, id: DownloadId) {
256 let mut inner = self.inner.lock();
257 inner.waiting_priorities.remove(&id);
258 let entries: Vec<_> = inner.waiting.drain().filter(|e| e.id != id).collect();
260 for entry in entries {
261 inner.waiting.push(entry);
262 }
263 }
264
265 pub fn get_priority(&self, id: DownloadId) -> Option<DownloadPriority> {
267 let inner = self.inner.lock();
268 inner
269 .waiting_priorities
270 .get(&id)
271 .or_else(|| inner.active.get(&id))
272 .copied()
273 }
274
275 pub fn active_count(&self) -> usize {
277 self.inner.lock().active.len()
278 }
279
280 pub fn waiting_count(&self) -> usize {
282 self.inner.lock().waiting.len()
283 }
284
285 pub fn queue_position(&self, id: DownloadId) -> Option<usize> {
287 let inner = self.inner.lock();
288 if !inner.waiting_priorities.contains_key(&id) {
289 return None;
290 }
291 let mut sorted: Vec<_> = inner.waiting.iter().cloned().collect();
293 sorted.sort_by(|a, b| b.cmp(a)); sorted.iter().position(|e| e.id == id).map(|p| p + 1)
295 }
296
297 pub fn stats(&self) -> PriorityQueueStats {
299 let inner = self.inner.lock();
300 let mut by_priority = HashMap::new();
301 for priority in inner.waiting_priorities.values() {
302 *by_priority.entry(*priority).or_insert(0) += 1;
303 }
304 PriorityQueueStats {
305 active: inner.active.len(),
306 waiting: inner.waiting.len(),
307 waiting_by_priority: by_priority,
308 }
309 }
310}
311
312#[derive(Debug, Clone)]
314pub struct PriorityQueueStats {
315 pub active: usize,
317 pub waiting: usize,
319 pub waiting_by_priority: HashMap<DownloadPriority, usize>,
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_priority_ordering() {
329 assert!(DownloadPriority::Critical > DownloadPriority::High);
330 assert!(DownloadPriority::High > DownloadPriority::Normal);
331 assert!(DownloadPriority::Normal > DownloadPriority::Low);
332 }
333
334 #[test]
335 fn test_priority_from_str() {
336 assert_eq!(
337 "low".parse::<DownloadPriority>().unwrap(),
338 DownloadPriority::Low
339 );
340 assert_eq!(
341 "normal".parse::<DownloadPriority>().unwrap(),
342 DownloadPriority::Normal
343 );
344 assert_eq!(
345 "high".parse::<DownloadPriority>().unwrap(),
346 DownloadPriority::High
347 );
348 assert_eq!(
349 "critical".parse::<DownloadPriority>().unwrap(),
350 DownloadPriority::Critical
351 );
352 }
353
354 #[test]
355 fn test_queue_entry_ordering() {
356 let entry1 = QueueEntry {
357 id: DownloadId::new(),
358 priority: DownloadPriority::Normal,
359 sequence: 1,
360 };
361 let entry2 = QueueEntry {
362 id: DownloadId::new(),
363 priority: DownloadPriority::High,
364 sequence: 2,
365 };
366 let entry3 = QueueEntry {
367 id: DownloadId::new(),
368 priority: DownloadPriority::Normal,
369 sequence: 0,
370 };
371
372 assert!(entry2 > entry1);
374
375 assert!(entry3 > entry1);
377 }
378
379 #[tokio::test]
380 async fn test_priority_queue_basic() {
381 let queue = PriorityQueue::new(2);
382 let id1 = DownloadId::new();
383 let id2 = DownloadId::new();
384
385 let permit1 = queue.clone().acquire(id1, DownloadPriority::Normal).await;
387 let permit2 = queue.clone().acquire(id2, DownloadPriority::Normal).await;
388
389 assert_eq!(queue.active_count(), 2);
390
391 drop(permit1);
393 drop(permit2);
394
395 assert_eq!(queue.active_count(), 0);
396 }
397
398 #[tokio::test]
399 async fn test_priority_queue_priority_ordering() {
400 let queue = PriorityQueue::new(1);
401 let id_low = DownloadId::new();
402 let id_high = DownloadId::new();
403
404 let permit1 = queue
406 .clone()
407 .acquire(DownloadId::new(), DownloadPriority::Normal)
408 .await;
409
410 let queue_clone = queue.clone();
412 let low_handle =
413 tokio::spawn(async move { queue_clone.acquire(id_low, DownloadPriority::Low).await });
414
415 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
417
418 let queue_clone = queue.clone();
420 let high_handle =
421 tokio::spawn(async move { queue_clone.acquire(id_high, DownloadPriority::High).await });
422
423 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
425
426 assert_eq!(queue.waiting_count(), 2);
427
428 drop(permit1);
430
431 let high_permit = tokio::time::timeout(std::time::Duration::from_millis(100), high_handle)
433 .await
434 .expect("timeout")
435 .expect("join error");
436
437 assert_eq!(queue.active_count(), 1);
438 assert_eq!(queue.waiting_count(), 1);
439
440 drop(high_permit);
442
443 let _low_permit = tokio::time::timeout(std::time::Duration::from_millis(100), low_handle)
445 .await
446 .expect("timeout")
447 .expect("join error");
448
449 assert_eq!(queue.active_count(), 1);
450 assert_eq!(queue.waiting_count(), 0);
451 }
452
453 #[test]
454 fn test_set_priority() {
455 let queue = PriorityQueue::new(1);
456 let id = DownloadId::new();
457
458 {
460 let mut inner = queue.inner.lock();
461 inner.waiting.push(QueueEntry {
462 id,
463 priority: DownloadPriority::Low,
464 sequence: 0,
465 });
466 inner.waiting_priorities.insert(id, DownloadPriority::Low);
467 }
468
469 assert_eq!(queue.get_priority(id), Some(DownloadPriority::Low));
470
471 assert!(queue.set_priority(id, DownloadPriority::High));
473
474 assert_eq!(queue.get_priority(id), Some(DownloadPriority::High));
475 }
476
477 #[test]
478 fn test_remove() {
479 let queue = PriorityQueue::new(1);
480 let id = DownloadId::new();
481
482 {
484 let mut inner = queue.inner.lock();
485 inner.waiting.push(QueueEntry {
486 id,
487 priority: DownloadPriority::Normal,
488 sequence: 0,
489 });
490 inner
491 .waiting_priorities
492 .insert(id, DownloadPriority::Normal);
493 }
494
495 assert_eq!(queue.waiting_count(), 1);
496
497 queue.remove(id);
499
500 assert_eq!(queue.waiting_count(), 0);
501 assert_eq!(queue.get_priority(id), None);
502 }
503
504 #[test]
505 fn test_stats() {
506 let queue = PriorityQueue::new(2);
507
508 {
510 let mut inner = queue.inner.lock();
511 for i in 0..3 {
512 let id = DownloadId::new();
513 let priority = match i % 3 {
514 0 => DownloadPriority::Low,
515 1 => DownloadPriority::Normal,
516 _ => DownloadPriority::High,
517 };
518 inner.waiting.push(QueueEntry {
519 id,
520 priority,
521 sequence: i,
522 });
523 inner.waiting_priorities.insert(id, priority);
524 }
525 }
526
527 let stats = queue.stats();
528 assert_eq!(stats.waiting, 3);
529 assert_eq!(stats.active, 0);
530 }
531}