use super::events::Timestamped;
use log::{debug, trace};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use tokio::sync::{Mutex, Notify};
#[derive(Clone)]
struct PriorityQueueEvent<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
event: Arc<T>,
}
impl<T> PriorityQueueEvent<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
fn new(event: Arc<T>) -> Self {
Self { event }
}
}
impl<T> PartialEq for PriorityQueueEvent<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
fn eq(&self, other: &Self) -> bool {
self.event.timestamp() == other.event.timestamp()
}
}
impl<T> Eq for PriorityQueueEvent<T> where T: Timestamped + Clone + Send + Sync + 'static {}
impl<T> PartialOrd for PriorityQueueEvent<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for PriorityQueueEvent<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
fn cmp(&self, other: &Self) -> Ordering {
other.event.timestamp().cmp(&self.event.timestamp())
}
}
#[derive(Debug)]
pub struct PriorityQueueMetrics {
pub total_enqueued: AtomicU64,
pub total_dequeued: AtomicU64,
pub current_depth: AtomicUsize,
pub max_depth_seen: AtomicUsize,
pub drops_due_to_capacity: AtomicU64,
pub blocked_enqueue_count: AtomicU64,
}
impl PriorityQueueMetrics {
pub fn snapshot(&self) -> MetricsSnapshot {
MetricsSnapshot {
total_enqueued: self.total_enqueued.load(AtomicOrdering::Relaxed),
total_dequeued: self.total_dequeued.load(AtomicOrdering::Relaxed),
current_depth: self.current_depth.load(AtomicOrdering::Relaxed),
max_depth_seen: self.max_depth_seen.load(AtomicOrdering::Relaxed),
drops_due_to_capacity: self.drops_due_to_capacity.load(AtomicOrdering::Relaxed),
blocked_enqueue_count: self.blocked_enqueue_count.load(AtomicOrdering::Relaxed),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MetricsSnapshot {
pub total_enqueued: u64,
pub total_dequeued: u64,
pub current_depth: usize,
pub max_depth_seen: usize,
pub drops_due_to_capacity: u64,
pub blocked_enqueue_count: u64,
}
impl Default for PriorityQueueMetrics {
fn default() -> Self {
Self {
total_enqueued: AtomicU64::new(0),
total_dequeued: AtomicU64::new(0),
current_depth: AtomicUsize::new(0),
max_depth_seen: AtomicUsize::new(0),
drops_due_to_capacity: AtomicU64::new(0),
blocked_enqueue_count: AtomicU64::new(0),
}
}
}
pub struct PriorityQueue<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
heap: Arc<Mutex<BinaryHeap<PriorityQueueEvent<T>>>>,
notify: Arc<Notify>,
max_capacity: usize,
metrics: Arc<PriorityQueueMetrics>,
}
impl<T> PriorityQueue<T>
where
T: Timestamped + Clone + Send + Sync + Debug + 'static,
{
pub fn new(max_capacity: usize) -> Self {
Self {
heap: Arc::new(Mutex::new(BinaryHeap::new())),
notify: Arc::new(Notify::new()),
max_capacity,
metrics: Arc::new(PriorityQueueMetrics::default()),
}
}
pub async fn enqueue(&self, event: Arc<T>) -> bool {
let mut heap = self.heap.lock().await;
if heap.len() >= self.max_capacity {
let previous = self
.metrics
.drops_due_to_capacity
.fetch_add(1, AtomicOrdering::Relaxed);
let total_drops = previous + 1;
if total_drops == 1 || total_drops % 100 == 0 {
debug!(
"Priority queue at capacity ({}); {} events dropped so far",
self.max_capacity, total_drops
);
} else {
trace!(
"Priority queue drop (capacity {}): {:?}",
self.max_capacity,
event
);
}
drop(heap);
return false;
}
heap.push(PriorityQueueEvent::new(event));
self.metrics
.total_enqueued
.fetch_add(1, AtomicOrdering::Relaxed);
let current_depth = heap.len();
self.metrics
.current_depth
.store(current_depth, AtomicOrdering::Relaxed);
let mut max_seen = self.metrics.max_depth_seen.load(AtomicOrdering::Relaxed);
while current_depth > max_seen {
match self.metrics.max_depth_seen.compare_exchange_weak(
max_seen,
current_depth,
AtomicOrdering::Relaxed,
AtomicOrdering::Relaxed,
) {
Ok(_) => break,
Err(x) => max_seen = x,
}
}
drop(heap);
self.notify.notify_one();
true
}
pub async fn enqueue_wait(&self, event: Arc<T>) {
loop {
let notified = self.notify.notified();
tokio::pin!(notified);
let mut heap = self.heap.lock().await;
if heap.len() < self.max_capacity {
heap.push(PriorityQueueEvent::new(event));
self.metrics
.total_enqueued
.fetch_add(1, AtomicOrdering::Relaxed);
let current_depth = heap.len();
self.metrics
.current_depth
.store(current_depth, AtomicOrdering::Relaxed);
let mut max_seen = self.metrics.max_depth_seen.load(AtomicOrdering::Relaxed);
while current_depth > max_seen {
match self.metrics.max_depth_seen.compare_exchange_weak(
max_seen,
current_depth,
AtomicOrdering::Relaxed,
AtomicOrdering::Relaxed,
) {
Ok(_) => break,
Err(x) => max_seen = x,
}
}
drop(heap);
self.notify.notify_one();
return;
}
let blocked_count = self
.metrics
.blocked_enqueue_count
.fetch_add(1, AtomicOrdering::Relaxed)
+ 1;
if blocked_count == 1 || blocked_count % 100 == 0 {
debug!(
"Priority queue enqueue blocked (capacity {}); blocked {} times so far",
self.max_capacity, blocked_count
);
}
notified.as_mut().enable();
drop(heap);
notified.await;
}
}
pub async fn try_dequeue(&self) -> Option<Arc<T>> {
let mut heap = self.heap.lock().await;
let event = heap.pop().map(|pq_event| pq_event.event);
if event.is_some() {
self.metrics
.total_dequeued
.fetch_add(1, AtomicOrdering::Relaxed);
self.metrics
.current_depth
.store(heap.len(), AtomicOrdering::Relaxed);
drop(heap);
self.notify.notify_one();
}
event
}
pub async fn dequeue(&self) -> Arc<T> {
loop {
let notified = self.notify.notified();
tokio::pin!(notified);
let mut heap = self.heap.lock().await;
if let Some(pq_event) = heap.pop() {
let event = pq_event.event;
self.metrics
.total_dequeued
.fetch_add(1, AtomicOrdering::Relaxed);
self.metrics
.current_depth
.store(heap.len(), AtomicOrdering::Relaxed);
drop(heap);
self.notify.notify_one();
return event;
}
notified.as_mut().enable();
drop(heap);
notified.await;
}
}
pub async fn depth(&self) -> usize {
let heap = self.heap.lock().await;
heap.len()
}
pub async fn is_empty(&self) -> bool {
let heap = self.heap.lock().await;
heap.is_empty()
}
pub async fn metrics(&self) -> MetricsSnapshot {
self.metrics.snapshot()
}
pub async fn reset_metrics(&self) {
self.metrics
.total_enqueued
.store(0, AtomicOrdering::Relaxed);
self.metrics
.total_dequeued
.store(0, AtomicOrdering::Relaxed);
self.metrics.current_depth.store(0, AtomicOrdering::Relaxed);
self.metrics
.max_depth_seen
.store(0, AtomicOrdering::Relaxed);
self.metrics
.drops_due_to_capacity
.store(0, AtomicOrdering::Relaxed);
self.metrics
.blocked_enqueue_count
.store(0, AtomicOrdering::Relaxed);
}
pub async fn drain(&self) -> Vec<Arc<T>> {
let mut heap = self.heap.lock().await;
let events: Vec<Arc<T>> = heap.drain().map(|pq_event| pq_event.event).collect();
self.metrics.current_depth.store(0, AtomicOrdering::Relaxed);
debug!("Drained {} events from priority queue", events.len());
events
}
}
impl<T> Clone for PriorityQueue<T>
where
T: Timestamped + Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
heap: Arc::clone(&self.heap),
notify: Arc::clone(&self.notify),
max_capacity: self.max_capacity,
metrics: Arc::clone(&self.metrics),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[derive(Debug, Clone)]
struct TestEvent {
id: String,
timestamp: chrono::DateTime<Utc>,
}
impl Timestamped for TestEvent {
fn timestamp(&self) -> chrono::DateTime<Utc> {
self.timestamp
}
}
fn create_test_event(id: &str, timestamp: chrono::DateTime<Utc>) -> Arc<TestEvent> {
Arc::new(TestEvent {
id: id.to_string(),
timestamp,
})
}
#[tokio::test]
async fn test_priority_queue_ordering() {
let pq = PriorityQueue::new(100);
let now = Utc::now();
let event1 = create_test_event("event1", now);
let event2 = create_test_event("event2", now - chrono::Duration::seconds(5));
let event3 = create_test_event("event3", now + chrono::Duration::seconds(5));
pq.enqueue(event1).await;
pq.enqueue(event3).await;
pq.enqueue(event2).await;
let dequeued1 = pq.try_dequeue().await.unwrap();
assert_eq!(dequeued1.id, "event2");
let dequeued2 = pq.try_dequeue().await.unwrap();
assert_eq!(dequeued2.id, "event1");
let dequeued3 = pq.try_dequeue().await.unwrap();
assert_eq!(dequeued3.id, "event3"); }
#[tokio::test]
async fn test_priority_queue_capacity() {
let pq = PriorityQueue::new(2);
let now = Utc::now();
let event1 = create_test_event("event1", now);
let event2 = create_test_event("event2", now);
let event3 = create_test_event("event3", now);
assert!(pq.enqueue(event1).await);
assert!(pq.enqueue(event2).await);
assert!(!pq.enqueue(event3).await);
let metrics = pq.metrics().await;
assert_eq!(metrics.drops_due_to_capacity, 1);
assert_eq!(metrics.total_enqueued, 2);
}
#[tokio::test]
async fn test_priority_queue_metrics() {
let pq = PriorityQueue::new(100);
let now = Utc::now();
pq.enqueue(create_test_event("event1", now)).await;
pq.enqueue(create_test_event("event2", now)).await;
let metrics = pq.metrics().await;
assert_eq!(metrics.total_enqueued, 2);
assert_eq!(metrics.current_depth, 2);
assert_eq!(metrics.max_depth_seen, 2);
pq.try_dequeue().await;
let metrics = pq.metrics().await;
assert_eq!(metrics.total_dequeued, 1);
assert_eq!(metrics.current_depth, 1);
}
#[tokio::test]
async fn test_blocking_dequeue() {
let pq = PriorityQueue::new(100);
let pq_clone = pq.clone();
tokio::spawn(async move {
tokio::task::yield_now().await;
let event = create_test_event("event1", Utc::now());
pq_clone.enqueue(event).await;
});
let event = pq.dequeue().await;
assert_eq!(event.id, "event1");
}
#[tokio::test]
async fn test_drain() {
let pq = PriorityQueue::new(100);
let now = Utc::now();
pq.enqueue(create_test_event("event1", now)).await;
pq.enqueue(create_test_event("event2", now)).await;
pq.enqueue(create_test_event("event3", now)).await;
let drained = pq.drain().await;
assert_eq!(drained.len(), 3);
assert!(pq.is_empty().await);
}
#[tokio::test]
async fn test_enqueue_wait_blocks_when_full() {
let pq = PriorityQueue::new(2);
let now = Utc::now();
pq.enqueue_wait(create_test_event("event1", now)).await;
pq.enqueue_wait(create_test_event("event2", now)).await;
assert_eq!(pq.depth().await, 2);
let pq_clone = pq.clone();
let event3 = create_test_event("event3", now);
let enqueue_task = tokio::spawn(async move {
pq_clone.enqueue_wait(event3).await;
"enqueued"
});
tokio::task::yield_now().await;
assert!(!enqueue_task.is_finished());
pq.try_dequeue().await;
let result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), enqueue_task).await;
assert!(result.is_ok(), "enqueue_wait should have unblocked");
assert_eq!(pq.depth().await, 2); }
#[tokio::test]
async fn test_enqueue_wait_unblocks_on_dequeue() {
let pq = PriorityQueue::new(1);
let now = Utc::now();
pq.enqueue_wait(create_test_event("event1", now)).await;
let pq_clone = pq.clone();
let event2 = create_test_event("event2", now);
let enqueue_task = tokio::spawn(async move {
pq_clone.enqueue_wait(event2).await;
});
tokio::task::yield_now().await;
let dequeued = pq.try_dequeue().await;
assert!(dequeued.is_some());
let result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), enqueue_task).await;
assert!(result.is_ok(), "enqueue_wait should have been notified");
assert_eq!(pq.depth().await, 1);
}
#[tokio::test]
async fn test_enqueue_wait_multiple_waiters() {
let pq = PriorityQueue::new(1);
let now = Utc::now();
pq.enqueue_wait(create_test_event("event1", now)).await;
let mut tasks = vec![];
for i in 2..=4 {
let pq_clone = pq.clone();
let event = create_test_event(&format!("event{i}"), now);
let task = tokio::spawn(async move {
pq_clone.enqueue_wait(event).await;
i
});
tasks.push(task);
}
tokio::task::yield_now().await;
for expected_id in 2..=4 {
pq.try_dequeue().await;
tokio::task::yield_now().await;
let completed_count = tasks.iter().filter(|t| t.is_finished()).count();
assert_eq!(
completed_count,
expected_id - 1,
"Expected {} tasks to complete",
expected_id - 1
);
}
for task in tasks {
tokio::time::timeout(tokio::time::Duration::from_millis(100), task)
.await
.expect("Task should complete")
.expect("Task should not panic");
}
}
#[tokio::test]
async fn test_enqueue_wait_metrics() {
let pq = PriorityQueue::new(2);
let now = Utc::now();
pq.enqueue_wait(create_test_event("event1", now)).await;
pq.enqueue_wait(create_test_event("event2", now)).await;
pq.reset_metrics().await;
let pq_clone = pq.clone();
let event3 = create_test_event("event3", now);
let enqueue_task = tokio::spawn(async move {
pq_clone.enqueue_wait(event3).await;
});
tokio::task::yield_now().await;
let metrics = pq.metrics().await;
assert!(
metrics.blocked_enqueue_count >= 1,
"Should have blocked at least once, got {}",
metrics.blocked_enqueue_count
);
pq.try_dequeue().await;
tokio::time::timeout(tokio::time::Duration::from_millis(100), enqueue_task)
.await
.expect("Task should complete")
.expect("Task should not panic");
let final_metrics = pq.metrics().await;
assert_eq!(final_metrics.total_enqueued, 1);
assert!(final_metrics.blocked_enqueue_count >= 1);
}
#[tokio::test]
async fn test_enqueue_wait_vs_enqueue_behavior() {
let pq = PriorityQueue::new(2);
let now = Utc::now();
pq.enqueue(create_test_event("event1", now)).await;
pq.enqueue(create_test_event("event2", now)).await;
let result = pq.enqueue(create_test_event("event3", now)).await;
assert!(
!result,
"Non-blocking enqueue should return false when full"
);
let metrics = pq.metrics().await;
assert_eq!(metrics.drops_due_to_capacity, 1);
let pq_clone = pq.clone();
let event4 = create_test_event("event4", now);
let enqueue_task = tokio::spawn(async move {
pq_clone.enqueue_wait(event4).await;
});
pq.try_dequeue().await;
tokio::time::timeout(tokio::time::Duration::from_millis(100), enqueue_task)
.await
.expect("enqueue_wait should complete")
.expect("Task should not panic");
assert_eq!(pq.depth().await, 2);
}
}