use core::hash::Hash;
use core::time::Duration;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
use clock_lib::{Clock, Monotonic, SystemClock};
use tokio::sync::Notify;
use crate::decision::Decision;
use crate::error::ThrottleError;
use crate::limiter::Limiter;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Overflow {
#[default]
Reject,
DropOldest,
DropLowestPriority,
}
struct Waiter<K> {
seq: u64,
priority: u32,
deadline_ms: Option<u64>,
key: K,
evicted: Arc<AtomicBool>,
}
struct State<K> {
waiters: HashMap<u64, Waiter<K>>,
service_seq: u64,
next_seq: u64,
last_served: HashMap<K, u64>,
}
impl<K: Eq + Hash + Clone> State<K> {
fn new() -> Self {
Self {
waiters: HashMap::new(),
service_seq: 0,
next_seq: 0,
last_served: HashMap::new(),
}
}
fn prune_expired(&mut self, now_ms: u64) {
self.waiters
.retain(|_, w| w.deadline_ms.is_none_or(|d| now_ms < d));
}
fn winner(&self, now_ms: u64) -> Option<u64> {
self.waiters
.iter()
.filter(|(_, w)| w.deadline_ms.is_none_or(|d| now_ms < d))
.min_by(|(_, a), (_, b)| {
b.priority
.cmp(&a.priority) .then_with(|| self.recency(&a.key).cmp(&self.recency(&b.key)))
.then_with(|| a.seq.cmp(&b.seq))
})
.map(|(&id, _)| id)
}
fn recency(&self, key: &K) -> u64 {
self.last_served.get(key).copied().unwrap_or(0)
}
fn serve(&mut self, id: u64) {
if let Some(w) = self.waiters.remove(&id) {
self.service_seq += 1;
let _ = self.last_served.insert(w.key, self.service_seq);
}
}
fn insert(
&mut self,
priority: u32,
deadline_ms: Option<u64>,
key: K,
) -> (u64, Arc<AtomicBool>) {
let id = self.next_seq;
self.next_seq += 1;
let evicted = Arc::new(AtomicBool::new(false));
let _ = self.waiters.insert(
id,
Waiter {
seq: id,
priority,
deadline_ms,
key,
evicted: Arc::clone(&evicted),
},
);
(id, evicted)
}
fn oldest(&self) -> Option<u64> {
self.waiters
.iter()
.min_by_key(|(_, w)| w.seq)
.map(|(&id, _)| id)
}
fn weakest(&self) -> Option<(u64, u32)> {
self.waiters
.iter()
.min_by(|(_, a), (_, b)| a.priority.cmp(&b.priority).then_with(|| b.seq.cmp(&a.seq)))
.map(|(&id, w)| (id, w.priority))
}
}
pub struct Queue<L, K = (), C = SystemClock>
where
K: Eq + Hash + Clone + Send + Sync,
C: Clock,
{
inner: L,
state: Mutex<State<K>>,
notify: Notify,
capacity: usize,
overflow: Overflow,
clock: C,
epoch: Monotonic,
}
impl Queue<core::convert::Infallible, ()> {
#[must_use]
pub fn builder() -> QueueBuilder {
QueueBuilder::new()
}
}
impl<L, K, C> Queue<L, K, C>
where
L: Limiter,
K: Eq + Hash + Clone + Send + Sync,
C: Clock + Clone,
{
fn new(inner: L, capacity: usize, overflow: Overflow, clock: C) -> Self {
let epoch = clock.now();
Self {
inner,
state: Mutex::new(State::new()),
notify: Notify::new(),
capacity: capacity.max(1),
overflow,
clock,
epoch,
}
}
#[must_use]
pub fn with_clock<C2>(self, clock: C2) -> Queue<L, K, C2>
where
C2: Clock + Clone,
{
Queue::new(self.inner, self.capacity, self.overflow, clock)
}
#[must_use]
pub fn len(&self) -> usize {
self.lock().waiters.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.lock().waiters.is_empty()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn inner(&self) -> &L {
&self.inner
}
#[inline]
fn lock(&self) -> MutexGuard<'_, State<K>> {
self.state.lock().unwrap_or_else(PoisonError::into_inner)
}
#[inline]
fn now_ms(&self) -> u64 {
let elapsed = self.clock.now().saturating_duration_since(self.epoch);
u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
}
fn register(
&self,
now_ms: u64,
priority: u32,
deadline_ms: Option<u64>,
key: &K,
) -> Result<(u64, Arc<AtomicBool>), ThrottleError> {
let mut did_evict = false;
let outcome = {
let mut state = self.lock();
state.prune_expired(now_ms);
if state.waiters.len() < self.capacity {
Ok(state.insert(priority, deadline_ms, key.clone()))
} else {
match self.overflow {
Overflow::Reject => Err(ThrottleError::QueueFull),
Overflow::DropOldest => match state.oldest() {
Some(victim) => {
evict(&mut state, victim);
did_evict = true;
Ok(state.insert(priority, deadline_ms, key.clone()))
}
None => Err(ThrottleError::QueueFull),
},
Overflow::DropLowestPriority => match state.weakest() {
Some((victim, weakest)) if priority > weakest => {
evict(&mut state, victim);
did_evict = true;
Ok(state.insert(priority, deadline_ms, key.clone()))
}
_ => Err(ThrottleError::QueueFull),
},
}
}
};
if did_evict || outcome.is_ok() {
self.notify.notify_waiters();
}
outcome
}
pub async fn acquire(
&self,
key: K,
priority: u32,
deadline: Option<Duration>,
) -> Result<(), ThrottleError> {
let start_ms = self.now_ms();
let deadline_ms = deadline
.map(|d| start_ms.saturating_add(u64::try_from(d.as_millis()).unwrap_or(u64::MAX)));
let (id, evicted) = self.register(start_ms, priority, deadline_ms, &key)?;
let _guard = LeaveGuard { queue: self, id };
loop {
let notified = self.notify.notified();
tokio::pin!(notified);
let _ = notified.as_mut().enable();
if evicted.load(Ordering::Acquire) {
return Err(ThrottleError::QueueFull);
}
let now_ms = self.now_ms();
if deadline_ms.is_some_and(|d| now_ms >= d) {
return Err(ThrottleError::DeadlineExceeded);
}
let wait = {
let mut state = self.lock();
if state.winner(now_ms) == Some(id) {
match self.inner.acquire_cost(1) {
Decision::Acquired => {
state.serve(id);
drop(state);
self.notify.notify_waiters();
return Ok(());
}
Decision::Impossible => {
return Err(ThrottleError::CostExceedsCapacity {
cost: 1,
capacity: self.inner.capacity(),
});
}
Decision::Retry { after } => after,
}
} else {
Duration::from_secs(3600)
}
};
let sleep_for = cap_to_deadline(wait, now_ms, deadline_ms);
tokio::select! {
() = notified.as_mut() => {}
() = tokio::time::sleep(sleep_for) => {}
}
}
}
}
fn cap_to_deadline(wait: Duration, now_ms: u64, deadline_ms: Option<u64>) -> Duration {
match deadline_ms {
Some(d) => wait.min(Duration::from_millis(d.saturating_sub(now_ms))),
None => wait,
}
}
fn evict<K: Eq + Hash + Clone>(state: &mut State<K>, id: u64) {
if let Some(w) = state.waiters.remove(&id) {
w.evicted.store(true, Ordering::Release);
}
}
struct LeaveGuard<'a, L, K, C>
where
L: Limiter,
K: Eq + Hash + Clone + Send + Sync,
C: Clock + Clone,
{
queue: &'a Queue<L, K, C>,
id: u64,
}
impl<L, K, C> Drop for LeaveGuard<'_, L, K, C>
where
L: Limiter,
K: Eq + Hash + Clone + Send + Sync,
C: Clock + Clone,
{
fn drop(&mut self) {
{
let mut state = self.queue.lock();
let _ = state.waiters.remove(&self.id);
}
self.queue.notify.notify_waiters();
}
}
#[derive(Debug, Clone, Copy)]
pub struct QueueBuilder {
capacity: usize,
overflow: Overflow,
}
impl Default for QueueBuilder {
fn default() -> Self {
Self::new()
}
}
impl QueueBuilder {
#[must_use]
pub fn new() -> Self {
Self {
capacity: 1024,
overflow: Overflow::Reject,
}
}
#[must_use]
pub fn capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity.max(1);
self
}
#[must_use]
pub fn overflow(mut self, overflow: Overflow) -> Self {
self.overflow = overflow;
self
}
#[must_use]
pub fn build<L, K>(self, limiter: L) -> Queue<L, K, SystemClock>
where
L: Limiter,
K: Eq + Hash + Clone + Send + Sync,
{
Queue::new(limiter, self.capacity, self.overflow, SystemClock::new())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::{Overflow, Queue};
use crate::throttle::Throttle;
use core::time::Duration;
use std::sync::Arc;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn test_queue_is_send_sync() {
assert_send_sync::<Queue<Throttle, &'static str>>();
}
#[tokio::test]
async fn test_immediate_acquire_when_token_is_free() {
let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(10));
assert!(queue.acquire((), 0, None).await.is_ok());
assert!(queue.is_empty());
}
#[tokio::test]
async fn test_cost_exceeds_capacity_is_reported() {
let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(0));
let err = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
assert!(matches!(
err,
Err(crate::ThrottleError::CostExceedsCapacity { .. })
));
}
#[tokio::test]
async fn test_deadline_exceeded_when_no_token_arrives() {
let queue: Queue<Throttle, ()> =
Queue::builder().build(Throttle::per_duration(1, Duration::from_secs(3600)));
assert!(queue.acquire((), 0, None).await.is_ok());
let err = queue.acquire((), 0, Some(Duration::from_millis(30))).await;
assert!(matches!(err, Err(crate::ThrottleError::DeadlineExceeded)));
assert!(queue.is_empty(), "the expired waiter is removed");
}
#[tokio::test]
async fn test_reject_overflow_when_full() {
let queue: Arc<Queue<Throttle, ()>> = Arc::new(
Queue::builder()
.capacity(1)
.overflow(Overflow::Reject)
.build(Throttle::per_duration(1, Duration::from_secs(3600))),
);
assert!(queue.acquire((), 0, None).await.is_ok());
let q = Arc::clone(&queue);
let parked = tokio::spawn(async move { q.acquire((), 0, None).await });
while queue.is_empty() {
tokio::task::yield_now().await;
}
let rejected = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
assert!(matches!(rejected, Err(crate::ThrottleError::QueueFull)));
parked.abort();
}
#[tokio::test]
async fn test_drop_oldest_overflow_evicts_the_first_waiter() {
let queue: Arc<Queue<Throttle, ()>> = Arc::new(
Queue::builder()
.capacity(1)
.overflow(Overflow::DropOldest)
.build(Throttle::per_duration(1, Duration::from_secs(3600))),
);
assert!(queue.acquire((), 0, None).await.is_ok());
let q = Arc::clone(&queue);
let first = tokio::spawn(async move { q.acquire((), 0, None).await });
while queue.is_empty() {
tokio::task::yield_now().await;
}
let q = Arc::clone(&queue);
let second = tokio::spawn(async move { q.acquire((), 0, None).await });
let first_result = first.await.unwrap();
assert!(matches!(first_result, Err(crate::ThrottleError::QueueFull)));
second.abort();
}
#[tokio::test]
async fn test_priority_is_served_high_first() {
use std::sync::atomic::{AtomicU32, Ordering};
let queue: Arc<Queue<Throttle, ()>> = Arc::new(
Queue::builder()
.capacity(10)
.build(Throttle::per_duration(1, Duration::from_millis(50))),
);
assert!(queue.acquire((), 0, None).await.is_ok());
let order = Arc::new(std::sync::Mutex::new(Vec::new()));
let started = Arc::new(AtomicU32::new(0));
let mut handles = Vec::new();
for priority in [1u32, 5, 3] {
let q = Arc::clone(&queue);
let order = Arc::clone(&order);
let started = Arc::clone(&started);
handles.push(tokio::spawn(async move {
let _ = started.fetch_add(1, Ordering::Relaxed);
q.acquire((), priority, None).await.unwrap();
order.lock().unwrap().push(priority);
}));
}
while queue.len() < 3 {
tokio::task::yield_now().await;
}
for h in handles {
h.await.unwrap();
}
assert_eq!(*order.lock().unwrap(), vec![5, 3, 1]);
}
#[test]
fn test_fair_winner_rotates_across_keys_at_equal_priority() {
use super::{State, Waiter};
use std::sync::atomic::AtomicBool;
fn enqueue(state: &mut State<&'static str>, id: u64, priority: u32, key: &'static str) {
let _ = state.waiters.insert(
id,
Waiter {
seq: id,
priority,
deadline_ms: None,
key,
evicted: Arc::new(AtomicBool::new(false)),
},
);
}
let mut state = State::<&'static str>::new();
enqueue(&mut state, 0, 0, "a");
enqueue(&mut state, 1, 0, "a");
enqueue(&mut state, 2, 0, "b");
assert_eq!(state.winner(0), Some(0));
state.serve(0);
assert_eq!(state.winner(0), Some(2));
state.serve(2);
assert_eq!(state.winner(0), Some(1));
}
#[test]
fn test_priority_beats_fairness_in_winner_selection() {
use super::{State, Waiter};
use std::sync::atomic::AtomicBool;
let mut state = State::<&'static str>::new();
let _ = state.waiters.insert(
0,
Waiter {
seq: 0,
priority: 1,
deadline_ms: None,
key: "a",
evicted: Arc::new(AtomicBool::new(false)),
},
);
let _ = state.waiters.insert(
1,
Waiter {
seq: 1,
priority: 9,
deadline_ms: None,
key: "b",
evicted: Arc::new(AtomicBool::new(false)),
},
);
assert_eq!(state.winner(0), Some(1));
}
#[test]
fn test_winner_skips_expired_waiters() {
use super::{State, Waiter};
use std::sync::atomic::AtomicBool;
let mut state = State::<&'static str>::new();
let _ = state.waiters.insert(
0,
Waiter {
seq: 0,
priority: 9,
deadline_ms: Some(100),
key: "a",
evicted: Arc::new(AtomicBool::new(false)),
},
);
let _ = state.waiters.insert(
1,
Waiter {
seq: 1,
priority: 1,
deadline_ms: None,
key: "b",
evicted: Arc::new(AtomicBool::new(false)),
},
);
assert_eq!(state.winner(200), Some(1));
}
}