use crate::protocol::DownloadId;
use parking_lot::Mutex;
use std::collections::{BinaryHeap, HashMap};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};
pub use crate::protocol::DownloadPriority;
#[derive(Debug, Clone, Eq, PartialEq)]
struct QueueEntry {
id: DownloadId,
priority: DownloadPriority,
sequence: u64,
}
impl Ord for QueueEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.priority.cmp(&other.priority) {
std::cmp::Ordering::Equal => other.sequence.cmp(&self.sequence), other => other,
}
}
}
impl PartialOrd for QueueEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
pub struct PriorityPermit {
_permit: OwnedSemaphorePermit,
id: DownloadId,
queue: Arc<PriorityQueue>,
}
impl Drop for PriorityPermit {
fn drop(&mut self) {
self.queue.inner.lock().active.remove(&self.id);
self.queue.notify.notify_waiters();
}
}
struct PriorityQueueInner {
waiting: BinaryHeap<QueueEntry>,
active: HashMap<DownloadId, DownloadPriority>,
waiting_priorities: HashMap<DownloadId, DownloadPriority>,
}
pub struct PriorityQueue {
semaphore: Arc<Semaphore>,
max_concurrent: AtomicUsize,
inner: Mutex<PriorityQueueInner>,
sequence: AtomicU64,
notify: Notify,
}
impl PriorityQueue {
pub fn new(max_concurrent: usize) -> Arc<Self> {
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
max_concurrent: AtomicUsize::new(max_concurrent),
inner: Mutex::new(PriorityQueueInner {
waiting: BinaryHeap::new(),
active: HashMap::new(),
waiting_priorities: HashMap::new(),
}),
sequence: AtomicU64::new(0),
notify: Notify::new(),
})
}
pub async fn acquire(
self: &Arc<Self>,
id: DownloadId,
priority: DownloadPriority,
) -> PriorityPermit {
let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
{
let mut inner = self.inner.lock();
inner.waiting.push(QueueEntry {
id,
priority,
sequence,
});
inner.waiting_priorities.insert(id, priority);
}
loop {
{
let inner = self.inner.lock();
if let Some(next) = inner.waiting.peek() {
if next.id == id
&& inner.active.len() < self.max_concurrent.load(Ordering::Relaxed)
{
drop(inner);
if let Ok(permit) = self.semaphore.clone().try_acquire_owned() {
let mut inner = self.inner.lock();
inner.waiting.pop();
inner.waiting_priorities.remove(&id);
inner.active.insert(id, priority);
return PriorityPermit {
_permit: permit,
id,
queue: Arc::clone(self),
};
}
}
}
}
self.notify.notified().await;
}
}
pub fn try_acquire(
self: &Arc<Self>,
id: DownloadId,
priority: DownloadPriority,
) -> Option<PriorityPermit> {
let mut inner = self.inner.lock();
if inner.active.len() >= self.max_concurrent.load(Ordering::Relaxed) {
return None;
}
if let Some(next) = inner.waiting.peek() {
if next.priority > priority {
return None; }
}
match self.semaphore.clone().try_acquire_owned() {
Ok(permit) => {
inner.active.insert(id, priority);
Some(PriorityPermit {
_permit: permit,
id,
queue: Arc::clone(self),
})
}
Err(_) => None,
}
}
pub fn set_priority(&self, id: DownloadId, new_priority: DownloadPriority) -> bool {
let mut inner = self.inner.lock();
if inner.waiting_priorities.contains_key(&id) {
let entries: Vec<_> = inner.waiting.drain().collect();
for entry in entries {
if entry.id == id {
inner.waiting.push(QueueEntry {
id: entry.id,
priority: new_priority,
sequence: entry.sequence,
});
} else {
inner.waiting.push(entry);
}
}
inner.waiting_priorities.insert(id, new_priority);
drop(inner);
self.notify.notify_waiters();
return true;
}
if let Some(priority) = inner.active.get_mut(&id) {
*priority = new_priority;
return true;
}
false
}
pub fn remove(&self, id: DownloadId) {
let mut inner = self.inner.lock();
inner.waiting_priorities.remove(&id);
let entries: Vec<_> = inner.waiting.drain().filter(|e| e.id != id).collect();
for entry in entries {
inner.waiting.push(entry);
}
}
pub fn get_priority(&self, id: DownloadId) -> Option<DownloadPriority> {
let inner = self.inner.lock();
inner
.waiting_priorities
.get(&id)
.or_else(|| inner.active.get(&id))
.copied()
}
pub fn set_max_concurrent(&self, max_concurrent: usize) {
let previous = self.max_concurrent.swap(max_concurrent, Ordering::Relaxed);
if max_concurrent > previous {
self.semaphore.add_permits(max_concurrent - previous);
}
self.notify.notify_waiters();
}
pub fn active_count(&self) -> usize {
self.inner.lock().active.len()
}
pub fn waiting_count(&self) -> usize {
self.inner.lock().waiting.len()
}
pub fn queue_position(&self, id: DownloadId) -> Option<usize> {
let inner = self.inner.lock();
if !inner.waiting_priorities.contains_key(&id) {
return None;
}
let mut sorted: Vec<_> = inner.waiting.iter().cloned().collect();
sorted.sort_by(|a, b| b.cmp(a)); sorted.iter().position(|e| e.id == id).map(|p| p + 1)
}
pub fn stats(&self) -> PriorityQueueStats {
let inner = self.inner.lock();
let mut by_priority = HashMap::new();
for priority in inner.waiting_priorities.values() {
*by_priority.entry(*priority).or_insert(0) += 1;
}
PriorityQueueStats {
active: inner.active.len(),
waiting: inner.waiting.len(),
waiting_by_priority: by_priority,
}
}
}
#[derive(Debug, Clone)]
pub struct PriorityQueueStats {
pub active: usize,
pub waiting: usize,
pub waiting_by_priority: HashMap<DownloadPriority, usize>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_priority_ordering() {
assert!(DownloadPriority::Critical > DownloadPriority::High);
assert!(DownloadPriority::High > DownloadPriority::Normal);
assert!(DownloadPriority::Normal > DownloadPriority::Low);
}
#[test]
fn test_priority_from_str() {
assert_eq!(
"low".parse::<DownloadPriority>().unwrap(),
DownloadPriority::Low
);
assert_eq!(
"normal".parse::<DownloadPriority>().unwrap(),
DownloadPriority::Normal
);
assert_eq!(
"high".parse::<DownloadPriority>().unwrap(),
DownloadPriority::High
);
assert_eq!(
"critical".parse::<DownloadPriority>().unwrap(),
DownloadPriority::Critical
);
}
#[test]
fn test_queue_entry_ordering() {
let entry1 = QueueEntry {
id: DownloadId::new(),
priority: DownloadPriority::Normal,
sequence: 1,
};
let entry2 = QueueEntry {
id: DownloadId::new(),
priority: DownloadPriority::High,
sequence: 2,
};
let entry3 = QueueEntry {
id: DownloadId::new(),
priority: DownloadPriority::Normal,
sequence: 0,
};
assert!(entry2 > entry1);
assert!(entry3 > entry1);
}
#[tokio::test]
async fn test_priority_queue_basic() {
let queue = PriorityQueue::new(2);
let id1 = DownloadId::new();
let id2 = DownloadId::new();
let permit1 = queue.clone().acquire(id1, DownloadPriority::Normal).await;
let permit2 = queue.clone().acquire(id2, DownloadPriority::Normal).await;
assert_eq!(queue.active_count(), 2);
drop(permit1);
drop(permit2);
assert_eq!(queue.active_count(), 0);
}
#[tokio::test]
async fn test_priority_queue_priority_ordering() {
let queue = PriorityQueue::new(1);
let id_low = DownloadId::new();
let id_high = DownloadId::new();
let permit1 = queue
.clone()
.acquire(DownloadId::new(), DownloadPriority::Normal)
.await;
let queue_clone = queue.clone();
let low_handle =
tokio::spawn(async move { queue_clone.acquire(id_low, DownloadPriority::Low).await });
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let queue_clone = queue.clone();
let high_handle =
tokio::spawn(async move { queue_clone.acquire(id_high, DownloadPriority::High).await });
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert_eq!(queue.waiting_count(), 2);
drop(permit1);
let high_permit = tokio::time::timeout(std::time::Duration::from_millis(100), high_handle)
.await
.expect("timeout")
.expect("join error");
assert_eq!(queue.active_count(), 1);
assert_eq!(queue.waiting_count(), 1);
drop(high_permit);
let _low_permit = tokio::time::timeout(std::time::Duration::from_millis(100), low_handle)
.await
.expect("timeout")
.expect("join error");
assert_eq!(queue.active_count(), 1);
assert_eq!(queue.waiting_count(), 0);
}
#[test]
fn test_set_priority() {
let queue = PriorityQueue::new(1);
let id = DownloadId::new();
{
let mut inner = queue.inner.lock();
inner.waiting.push(QueueEntry {
id,
priority: DownloadPriority::Low,
sequence: 0,
});
inner.waiting_priorities.insert(id, DownloadPriority::Low);
}
assert_eq!(queue.get_priority(id), Some(DownloadPriority::Low));
assert!(queue.set_priority(id, DownloadPriority::High));
assert_eq!(queue.get_priority(id), Some(DownloadPriority::High));
}
#[test]
fn test_remove() {
let queue = PriorityQueue::new(1);
let id = DownloadId::new();
{
let mut inner = queue.inner.lock();
inner.waiting.push(QueueEntry {
id,
priority: DownloadPriority::Normal,
sequence: 0,
});
inner
.waiting_priorities
.insert(id, DownloadPriority::Normal);
}
assert_eq!(queue.waiting_count(), 1);
queue.remove(id);
assert_eq!(queue.waiting_count(), 0);
assert_eq!(queue.get_priority(id), None);
}
#[test]
fn test_stats() {
let queue = PriorityQueue::new(2);
{
let mut inner = queue.inner.lock();
for i in 0..3 {
let id = DownloadId::new();
let priority = match i % 3 {
0 => DownloadPriority::Low,
1 => DownloadPriority::Normal,
_ => DownloadPriority::High,
};
inner.waiting.push(QueueEntry {
id,
priority,
sequence: i,
});
inner.waiting_priorities.insert(id, priority);
}
}
let stats = queue.stats();
assert_eq!(stats.waiting, 3);
assert_eq!(stats.active, 0);
}
}