use std::collections::BinaryHeap;
use std::sync::{Arc, Condvar, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PopResult {
Empty,
Completed,
Normal,
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct QueueEntry<T> {
priority: usize,
cost: usize,
data: T,
}
impl<T> PartialOrd for QueueEntry<T>
where
T: Eq,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for QueueEntry<T>
where
T: Eq,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(self.priority, self.cost).cmp(&(other.priority, other.cost))
}
}
struct QueueState<T> {
queue: BinaryHeap<QueueEntry<T>>,
n_producers: usize,
current_cost: usize,
max_cost: usize,
}
pub struct BoundedPriorityQueue<T> {
state: Arc<(Mutex<QueueState<T>>, Condvar, Condvar)>,
}
impl<T> BoundedPriorityQueue<T>
where
T: Clone + Eq,
{
pub fn new(n_producers: usize, max_cost: usize) -> Self {
BoundedPriorityQueue {
state: Arc::new((
Mutex::new(QueueState {
queue: BinaryHeap::new(),
n_producers,
current_cost: 0,
max_cost,
}),
Condvar::new(), Condvar::new(), )),
}
}
pub fn emplace(&self, data: T, priority: usize, cost: usize) {
let (mutex, cv_empty, cv_full) = &*self.state;
let mut state = mutex.lock().unwrap();
while state.current_cost >= state.max_cost {
state = cv_full.wait(state).unwrap();
}
let was_empty = state.queue.is_empty();
state.queue.push(QueueEntry {
priority,
cost,
data,
});
state.current_cost += cost;
if was_empty {
cv_empty.notify_all();
}
}
pub fn emplace_many_no_cost(&self, data: T, priority: usize, n_items: usize) {
let (mutex, cv_empty, _) = &*self.state;
let mut state = mutex.lock().unwrap();
for _ in 0..n_items {
state.queue.push(QueueEntry {
priority,
cost: 0,
data: data.clone(),
});
}
cv_empty.notify_all();
}
pub fn pop_large(&self) -> (PopResult, Option<T>) {
let (mutex, cv_empty, cv_full) = &*self.state;
let mut state = mutex.lock().unwrap();
while state.queue.is_empty() && state.n_producers > 0 {
state = cv_empty.wait(state).unwrap();
}
if state.queue.is_empty() {
return if state.n_producers > 0 {
(PopResult::Empty, None)
} else {
(PopResult::Completed, None)
};
}
let entry = state.queue.pop().unwrap();
state.current_cost -= entry.cost;
if state.queue.is_empty() {
cv_empty.notify_all();
}
cv_full.notify_all();
(PopResult::Normal, Some(entry.data))
}
pub fn mark_completed(&self) {
let (mutex, cv_empty, _) = &*self.state;
let mut state = mutex.lock().unwrap();
state.n_producers -= 1;
if state.n_producers == 0 {
cv_empty.notify_all();
}
}
pub fn is_empty(&self) -> bool {
let (mutex, _, _) = &*self.state;
let state = mutex.lock().unwrap();
state.queue.is_empty()
}
pub fn is_completed(&self) -> bool {
let (mutex, _, _) = &*self.state;
let state = mutex.lock().unwrap();
state.queue.is_empty() && state.n_producers == 0
}
pub fn get_size(&self) -> (usize, usize) {
let (mutex, _, _) = &*self.state;
let state = mutex.lock().unwrap();
(state.queue.len(), state.current_cost)
}
}
impl<T> Clone for BoundedPriorityQueue<T> {
fn clone(&self) -> Self {
BoundedPriorityQueue {
state: Arc::clone(&self.state),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_basic_operations() {
let queue: BoundedPriorityQueue<String> = BoundedPriorityQueue::new(1, 1000);
queue.emplace("task1".to_string(), 100, 50);
queue.emplace("task2".to_string(), 200, 50);
queue.emplace("task3".to_string(), 150, 50);
let (result, data) = queue.pop_large();
assert_eq!(result, PopResult::Normal);
assert_eq!(data, Some("task2".to_string()));
let (result, data) = queue.pop_large();
assert_eq!(result, PopResult::Normal);
assert_eq!(data, Some("task3".to_string()));
let (result, data) = queue.pop_large();
assert_eq!(result, PopResult::Normal);
assert_eq!(data, Some("task1".to_string()));
}
#[test]
fn test_completion_signaling() {
let queue: BoundedPriorityQueue<String> = BoundedPriorityQueue::new(1, 1000);
queue.mark_completed();
let (result, data) = queue.pop_large();
assert_eq!(result, PopResult::Completed);
assert_eq!(data, None);
}
#[test]
fn test_emplace_many_no_cost() {
let queue: BoundedPriorityQueue<String> = BoundedPriorityQueue::new(1, 1000);
queue.emplace_many_no_cost("sync".to_string(), 500, 3);
for _ in 0..3 {
let (result, data) = queue.pop_large();
assert_eq!(result, PopResult::Normal);
assert_eq!(data, Some("sync".to_string()));
}
}
#[test]
fn test_multi_threaded() {
let queue: BoundedPriorityQueue<String> = BoundedPriorityQueue::new(2, 1000);
let q1 = queue.clone();
let q2 = queue.clone();
let p1 = thread::spawn(move || {
for i in 0..5 {
q1.emplace(format!("p1-{}", i), 100 + i, 10);
thread::sleep(Duration::from_millis(1));
}
q1.mark_completed();
});
let p2 = thread::spawn(move || {
for i in 0..5 {
q2.emplace(format!("p2-{}", i), 200 + i, 10);
thread::sleep(Duration::from_millis(1));
}
q2.mark_completed();
});
let mut count = 0;
loop {
match queue.pop_large() {
(PopResult::Normal, Some(_)) => {
count += 1;
}
(PopResult::Empty, None) => {
thread::sleep(Duration::from_millis(1));
continue;
}
(PopResult::Completed, None) => {
break;
}
_ => panic!("Unexpected queue state"),
}
}
p1.join().unwrap();
p2.join().unwrap();
assert_eq!(count, 10);
}
#[test]
fn test_capacity_limiting() {
let queue: BoundedPriorityQueue<String> = BoundedPriorityQueue::new(1, 100);
let q = queue.clone();
queue.emplace("task1".to_string(), 100, 50);
queue.emplace("task2".to_string(), 100, 50);
let producer = thread::spawn(move || {
q.emplace("task3".to_string(), 100, 50);
});
thread::sleep(Duration::from_millis(10));
assert_eq!(queue.get_size(), (2, 100));
queue.pop_large();
producer.join().unwrap();
assert_eq!(queue.get_size(), (2, 100));
}
}