1use core::hash::Hash;
22use core::time::Duration;
23use std::collections::HashMap;
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
26
27use clock_lib::{Clock, Monotonic, SystemClock};
28use event_listener::Event;
29
30use crate::decision::Decision;
31use crate::error::ThrottleError;
32use crate::limiter::Limiter;
33
34#[non_exhaustive]
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum Overflow {
40 #[default]
42 Reject,
43 DropOldest,
45 DropLowestPriority,
48}
49
50struct Waiter<K> {
52 seq: u64,
54 priority: u32,
56 deadline_ms: Option<u64>,
58 key: K,
60 evicted: Arc<AtomicBool>,
63}
64
65struct State<K> {
67 waiters: HashMap<u64, Waiter<K>>,
68 service_seq: u64,
70 next_seq: u64,
72 last_served: HashMap<K, u64>,
74}
75
76impl<K: Eq + Hash + Clone> State<K> {
77 fn new() -> Self {
78 Self {
79 waiters: HashMap::new(),
80 service_seq: 0,
81 next_seq: 0,
82 last_served: HashMap::new(),
83 }
84 }
85
86 fn prune_expired(&mut self, now_ms: u64) {
89 self.waiters
90 .retain(|_, w| w.deadline_ms.is_none_or(|d| now_ms < d));
91 }
92
93 fn winner(&self, now_ms: u64) -> Option<u64> {
96 self.waiters
97 .iter()
98 .filter(|(_, w)| w.deadline_ms.is_none_or(|d| now_ms < d))
99 .min_by(|(_, a), (_, b)| {
100 b.priority
101 .cmp(&a.priority) .then_with(|| self.recency(&a.key).cmp(&self.recency(&b.key)))
103 .then_with(|| a.seq.cmp(&b.seq))
104 })
105 .map(|(&id, _)| id)
106 }
107
108 fn recency(&self, key: &K) -> u64 {
110 self.last_served.get(key).copied().unwrap_or(0)
111 }
112
113 fn serve(&mut self, id: u64) {
115 if let Some(w) = self.waiters.remove(&id) {
116 self.service_seq += 1;
117 let _ = self.last_served.insert(w.key, self.service_seq);
118 }
119 }
120
121 fn insert(
123 &mut self,
124 priority: u32,
125 deadline_ms: Option<u64>,
126 key: K,
127 ) -> (u64, Arc<AtomicBool>) {
128 let id = self.next_seq;
129 self.next_seq += 1;
130 let evicted = Arc::new(AtomicBool::new(false));
131 let _ = self.waiters.insert(
132 id,
133 Waiter {
134 seq: id,
135 priority,
136 deadline_ms,
137 key,
138 evicted: Arc::clone(&evicted),
139 },
140 );
141 (id, evicted)
142 }
143
144 fn oldest(&self) -> Option<u64> {
146 self.waiters
147 .iter()
148 .min_by_key(|(_, w)| w.seq)
149 .map(|(&id, _)| id)
150 }
151
152 fn weakest(&self) -> Option<(u64, u32)> {
155 self.waiters
156 .iter()
157 .min_by(|(_, a), (_, b)| a.priority.cmp(&b.priority).then_with(|| b.seq.cmp(&a.seq)))
158 .map(|(&id, w)| (id, w.priority))
159 }
160}
161
162pub struct Queue<L, K = (), C = SystemClock>
189where
190 K: Eq + Hash + Clone + Send + Sync,
191 C: Clock,
192{
193 inner: L,
194 state: Mutex<State<K>>,
195 notify: Event,
196 capacity: usize,
197 overflow: Overflow,
198 clock: C,
199 epoch: Monotonic,
200}
201
202impl Queue<core::convert::Infallible, ()> {
205 #[must_use]
207 pub fn builder() -> QueueBuilder {
208 QueueBuilder::new()
209 }
210}
211
212impl<L, K, C> Queue<L, K, C>
213where
214 L: Limiter,
215 K: Eq + Hash + Clone + Send + Sync,
216 C: Clock + Clone,
217{
218 fn new(inner: L, capacity: usize, overflow: Overflow, clock: C) -> Self {
219 let epoch = clock.now();
220 Self {
221 inner,
222 state: Mutex::new(State::new()),
223 notify: Event::new(),
224 capacity: capacity.max(1),
225 overflow,
226 clock,
227 epoch,
228 }
229 }
230
231 #[must_use]
234 pub fn with_clock<C2>(self, clock: C2) -> Queue<L, K, C2>
235 where
236 C2: Clock + Clone,
237 {
238 Queue::new(self.inner, self.capacity, self.overflow, clock)
239 }
240
241 #[must_use]
243 pub fn len(&self) -> usize {
244 self.lock().waiters.len()
245 }
246
247 #[must_use]
249 pub fn is_empty(&self) -> bool {
250 self.lock().waiters.is_empty()
251 }
252
253 #[must_use]
255 pub fn capacity(&self) -> usize {
256 self.capacity
257 }
258
259 pub fn inner(&self) -> &L {
261 &self.inner
262 }
263
264 #[inline]
265 fn lock(&self) -> MutexGuard<'_, State<K>> {
266 self.state.lock().unwrap_or_else(PoisonError::into_inner)
267 }
268
269 #[inline]
270 fn now_ms(&self) -> u64 {
271 let elapsed = self.clock.now().saturating_duration_since(self.epoch);
272 u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
273 }
274
275 fn register(
281 &self,
282 now_ms: u64,
283 priority: u32,
284 deadline_ms: Option<u64>,
285 key: &K,
286 ) -> Result<(u64, Arc<AtomicBool>), ThrottleError> {
287 let mut did_evict = false;
288 let outcome = {
289 let mut state = self.lock();
290 state.prune_expired(now_ms);
291
292 if state.waiters.len() < self.capacity {
293 Ok(state.insert(priority, deadline_ms, key.clone()))
294 } else {
295 match self.overflow {
296 Overflow::Reject => Err(ThrottleError::QueueFull),
297 Overflow::DropOldest => match state.oldest() {
298 Some(victim) => {
299 evict(&mut state, victim);
300 did_evict = true;
301 Ok(state.insert(priority, deadline_ms, key.clone()))
302 }
303 None => Err(ThrottleError::QueueFull),
304 },
305 Overflow::DropLowestPriority => match state.weakest() {
306 Some((victim, weakest)) if priority > weakest => {
308 evict(&mut state, victim);
309 did_evict = true;
310 Ok(state.insert(priority, deadline_ms, key.clone()))
311 }
312 _ => Err(ThrottleError::QueueFull),
313 },
314 }
315 }
316 };
317
318 if did_evict {
319 crate::obs::queue_overflow(match self.overflow {
320 Overflow::Reject => "reject",
321 Overflow::DropOldest => "drop_oldest",
322 Overflow::DropLowestPriority => "drop_lowest_priority",
323 });
324 } else if outcome.is_err() {
325 crate::obs::queue_overflow("reject");
326 }
327 if did_evict || outcome.is_ok() {
328 let _ = self.notify.notify(usize::MAX);
329 crate::obs::queue_depth(self.len());
330 }
331 outcome
332 }
333
334 pub async fn acquire(
349 &self,
350 key: K,
351 priority: u32,
352 deadline: Option<Duration>,
353 ) -> Result<(), ThrottleError> {
354 let start_ms = self.now_ms();
355 let deadline_ms = deadline
356 .map(|d| start_ms.saturating_add(u64::try_from(d.as_millis()).unwrap_or(u64::MAX)));
357
358 let timer = crate::obs::Timer::start();
359 let (id, evicted) = self.register(start_ms, priority, deadline_ms, &key)?;
360 let _guard = LeaveGuard { queue: self, id };
362
363 loop {
364 let listener = self.notify.listen();
368
369 if evicted.load(Ordering::Acquire) {
370 return Err(ThrottleError::QueueFull);
371 }
372
373 let now_ms = self.now_ms();
374 if deadline_ms.is_some_and(|d| now_ms >= d) {
375 crate::obs::deadline_exceeded();
376 return Err(ThrottleError::DeadlineExceeded);
377 }
378
379 let wait = {
380 let mut state = self.lock();
381 if state.winner(now_ms) == Some(id) {
382 match self.inner.acquire_cost(1) {
383 Decision::Acquired => {
384 state.serve(id);
385 drop(state);
386 let _ = self.notify.notify(usize::MAX);
387 crate::obs::acquired("queue");
388 crate::obs::wait("queue", &timer);
389 crate::obs::trace_acquire("queue", 1, true, &timer);
390 return Ok(());
391 }
392 Decision::Impossible => {
393 return Err(ThrottleError::CostExceedsCapacity {
394 cost: 1,
395 capacity: self.inner.capacity(),
396 });
397 }
398 Decision::Retry { after } => after,
400 }
401 } else {
402 Duration::from_secs(3600)
404 }
405 };
406
407 let sleep_for = cap_to_deadline(wait, now_ms, deadline_ms);
408 futures_lite::future::or(listener, crate::rt::sleep(sleep_for)).await;
410 }
411 }
412}
413
414fn cap_to_deadline(wait: Duration, now_ms: u64, deadline_ms: Option<u64>) -> Duration {
416 match deadline_ms {
417 Some(d) => wait.min(Duration::from_millis(d.saturating_sub(now_ms))),
418 None => wait,
419 }
420}
421
422fn evict<K: Eq + Hash + Clone>(state: &mut State<K>, id: u64) {
424 if let Some(w) = state.waiters.remove(&id) {
425 w.evicted.store(true, Ordering::Release);
426 }
427}
428
429struct LeaveGuard<'a, L, K, C>
431where
432 L: Limiter,
433 K: Eq + Hash + Clone + Send + Sync,
434 C: Clock + Clone,
435{
436 queue: &'a Queue<L, K, C>,
437 id: u64,
438}
439
440impl<L, K, C> Drop for LeaveGuard<'_, L, K, C>
441where
442 L: Limiter,
443 K: Eq + Hash + Clone + Send + Sync,
444 C: Clock + Clone,
445{
446 fn drop(&mut self) {
447 let depth = {
448 let mut state = self.queue.lock();
449 let _ = state.waiters.remove(&self.id);
450 state.waiters.len()
451 };
452 let _ = self.queue.notify.notify(usize::MAX);
454 crate::obs::queue_depth(depth);
455 }
456}
457
458#[derive(Debug, Clone, Copy)]
460pub struct QueueBuilder {
461 capacity: usize,
462 overflow: Overflow,
463}
464
465impl Default for QueueBuilder {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471impl QueueBuilder {
472 #[must_use]
474 pub fn new() -> Self {
475 Self {
476 capacity: 1024,
477 overflow: Overflow::Reject,
478 }
479 }
480
481 #[must_use]
483 pub fn capacity(mut self, capacity: usize) -> Self {
484 self.capacity = capacity.max(1);
485 self
486 }
487
488 #[must_use]
490 pub fn overflow(mut self, overflow: Overflow) -> Self {
491 self.overflow = overflow;
492 self
493 }
494
495 #[must_use]
497 pub fn build<L, K>(self, limiter: L) -> Queue<L, K, SystemClock>
498 where
499 L: Limiter,
500 K: Eq + Hash + Clone + Send + Sync,
501 {
502 Queue::new(limiter, self.capacity, self.overflow, SystemClock::new())
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 #![allow(clippy::unwrap_used)]
509
510 use super::{Overflow, Queue};
511 use crate::throttle::Throttle;
512 use core::time::Duration;
513 use std::sync::Arc;
514
515 fn assert_send_sync<T: Send + Sync>() {}
516
517 #[test]
518 fn test_queue_is_send_sync() {
519 assert_send_sync::<Queue<Throttle, &'static str>>();
520 }
521
522 #[tokio::test]
523 async fn test_immediate_acquire_when_token_is_free() {
524 let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(10));
525 assert!(queue.acquire((), 0, None).await.is_ok());
526 assert!(queue.is_empty());
527 }
528
529 #[tokio::test]
530 async fn test_cost_exceeds_capacity_is_reported() {
531 let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(0));
532 let err = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
533 assert!(matches!(
534 err,
535 Err(crate::ThrottleError::CostExceedsCapacity { .. })
536 ));
537 }
538
539 #[tokio::test]
540 async fn test_deadline_exceeded_when_no_token_arrives() {
541 let queue: Queue<Throttle, ()> =
544 Queue::builder().build(Throttle::per_duration(1, Duration::from_secs(3600)));
545 assert!(queue.acquire((), 0, None).await.is_ok()); let err = queue.acquire((), 0, Some(Duration::from_millis(30))).await;
548 assert!(matches!(err, Err(crate::ThrottleError::DeadlineExceeded)));
549 assert!(queue.is_empty(), "the expired waiter is removed");
550 }
551
552 #[tokio::test]
553 async fn test_reject_overflow_when_full() {
554 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
557 Queue::builder()
558 .capacity(1)
559 .overflow(Overflow::Reject)
560 .build(Throttle::per_duration(1, Duration::from_secs(3600))),
561 );
562 assert!(queue.acquire((), 0, None).await.is_ok()); let q = Arc::clone(&queue);
565 let parked = tokio::spawn(async move { q.acquire((), 0, None).await });
566 while queue.is_empty() {
567 tokio::task::yield_now().await;
568 }
569 let rejected = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
570 assert!(matches!(rejected, Err(crate::ThrottleError::QueueFull)));
571 parked.abort();
572 }
573
574 #[tokio::test]
575 async fn test_drop_oldest_overflow_evicts_the_first_waiter() {
576 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
577 Queue::builder()
578 .capacity(1)
579 .overflow(Overflow::DropOldest)
580 .build(Throttle::per_duration(1, Duration::from_secs(3600))),
581 );
582 assert!(queue.acquire((), 0, None).await.is_ok()); let q = Arc::clone(&queue);
586 let first = tokio::spawn(async move { q.acquire((), 0, None).await });
587 while queue.is_empty() {
588 tokio::task::yield_now().await;
589 }
590 let q = Arc::clone(&queue);
592 let second = tokio::spawn(async move { q.acquire((), 0, None).await });
593 let first_result = first.await.unwrap();
594 assert!(matches!(first_result, Err(crate::ThrottleError::QueueFull)));
595 second.abort();
596 }
597
598 #[tokio::test]
599 async fn test_priority_is_served_high_first() {
600 use std::sync::atomic::{AtomicU32, Ordering};
601
602 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
606 Queue::builder()
607 .capacity(10)
608 .build(Throttle::per_duration(1, Duration::from_millis(50))),
609 );
610 assert!(queue.acquire((), 0, None).await.is_ok()); let order = Arc::new(std::sync::Mutex::new(Vec::new()));
613 let started = Arc::new(AtomicU32::new(0));
614
615 let mut handles = Vec::new();
616 for priority in [1u32, 5, 3] {
617 let q = Arc::clone(&queue);
618 let order = Arc::clone(&order);
619 let started = Arc::clone(&started);
620 handles.push(tokio::spawn(async move {
621 let _ = started.fetch_add(1, Ordering::Relaxed);
622 q.acquire((), priority, None).await.unwrap();
623 order.lock().unwrap().push(priority);
624 }));
625 }
626 while queue.len() < 3 {
628 tokio::task::yield_now().await;
629 }
630 for h in handles {
631 h.await.unwrap();
632 }
633
634 assert_eq!(*order.lock().unwrap(), vec![5, 3, 1]);
635 }
636
637 #[test]
638 fn test_fair_winner_rotates_across_keys_at_equal_priority() {
639 use super::{State, Waiter};
640 use std::sync::atomic::AtomicBool;
641
642 fn enqueue(state: &mut State<&'static str>, id: u64, priority: u32, key: &'static str) {
643 let _ = state.waiters.insert(
644 id,
645 Waiter {
646 seq: id,
647 priority,
648 deadline_ms: None,
649 key,
650 evicted: Arc::new(AtomicBool::new(false)),
651 },
652 );
653 }
654
655 let mut state = State::<&'static str>::new();
656 enqueue(&mut state, 0, 0, "a");
658 enqueue(&mut state, 1, 0, "a");
659 enqueue(&mut state, 2, 0, "b");
660
661 assert_eq!(state.winner(0), Some(0));
663 state.serve(0);
664 assert_eq!(state.winner(0), Some(2));
667 state.serve(2);
668 assert_eq!(state.winner(0), Some(1));
670 }
671
672 #[test]
673 fn test_priority_beats_fairness_in_winner_selection() {
674 use super::{State, Waiter};
675 use std::sync::atomic::AtomicBool;
676
677 let mut state = State::<&'static str>::new();
678 let _ = state.waiters.insert(
679 0,
680 Waiter {
681 seq: 0,
682 priority: 1,
683 deadline_ms: None,
684 key: "a",
685 evicted: Arc::new(AtomicBool::new(false)),
686 },
687 );
688 let _ = state.waiters.insert(
689 1,
690 Waiter {
691 seq: 1,
692 priority: 9,
693 deadline_ms: None,
694 key: "b",
695 evicted: Arc::new(AtomicBool::new(false)),
696 },
697 );
698 assert_eq!(state.winner(0), Some(1));
700 }
701
702 #[test]
703 fn test_winner_skips_expired_waiters() {
704 use super::{State, Waiter};
705 use std::sync::atomic::AtomicBool;
706
707 let mut state = State::<&'static str>::new();
708 let _ = state.waiters.insert(
709 0,
710 Waiter {
711 seq: 0,
712 priority: 9,
713 deadline_ms: Some(100),
714 key: "a",
715 evicted: Arc::new(AtomicBool::new(false)),
716 },
717 );
718 let _ = state.waiters.insert(
719 1,
720 Waiter {
721 seq: 1,
722 priority: 1,
723 deadline_ms: None,
724 key: "b",
725 evicted: Arc::new(AtomicBool::new(false)),
726 },
727 );
728 assert_eq!(state.winner(200), Some(1));
730 }
731}