1use core::hash::Hash;
22use core::time::Duration;
23use std::collections::HashMap;
24use std::sync::atomic::{AtomicBool, Ordering};
25use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
26
27use clock_lib::{Clock, Monotonic, SystemClock};
28use tokio::sync::Notify;
29
30use crate::decision::Decision;
31use crate::error::ThrottleError;
32use crate::limiter::Limiter;
33
34#[non_exhaustive]
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum Overflow {
40 #[default]
42 Reject,
43 DropOldest,
45 DropLowestPriority,
48}
49
50struct Waiter<K> {
52 seq: u64,
54 priority: u32,
56 deadline_ms: Option<u64>,
58 key: K,
60 evicted: Arc<AtomicBool>,
63}
64
65struct State<K> {
67 waiters: HashMap<u64, Waiter<K>>,
68 service_seq: u64,
70 next_seq: u64,
72 last_served: HashMap<K, u64>,
74}
75
76impl<K: Eq + Hash + Clone> State<K> {
77 fn new() -> Self {
78 Self {
79 waiters: HashMap::new(),
80 service_seq: 0,
81 next_seq: 0,
82 last_served: HashMap::new(),
83 }
84 }
85
86 fn prune_expired(&mut self, now_ms: u64) {
89 self.waiters
90 .retain(|_, w| w.deadline_ms.is_none_or(|d| now_ms < d));
91 }
92
93 fn winner(&self, now_ms: u64) -> Option<u64> {
96 self.waiters
97 .iter()
98 .filter(|(_, w)| w.deadline_ms.is_none_or(|d| now_ms < d))
99 .min_by(|(_, a), (_, b)| {
100 b.priority
101 .cmp(&a.priority) .then_with(|| self.recency(&a.key).cmp(&self.recency(&b.key)))
103 .then_with(|| a.seq.cmp(&b.seq))
104 })
105 .map(|(&id, _)| id)
106 }
107
108 fn recency(&self, key: &K) -> u64 {
110 self.last_served.get(key).copied().unwrap_or(0)
111 }
112
113 fn serve(&mut self, id: u64) {
115 if let Some(w) = self.waiters.remove(&id) {
116 self.service_seq += 1;
117 let _ = self.last_served.insert(w.key, self.service_seq);
118 }
119 }
120
121 fn insert(
123 &mut self,
124 priority: u32,
125 deadline_ms: Option<u64>,
126 key: K,
127 ) -> (u64, Arc<AtomicBool>) {
128 let id = self.next_seq;
129 self.next_seq += 1;
130 let evicted = Arc::new(AtomicBool::new(false));
131 let _ = self.waiters.insert(
132 id,
133 Waiter {
134 seq: id,
135 priority,
136 deadline_ms,
137 key,
138 evicted: Arc::clone(&evicted),
139 },
140 );
141 (id, evicted)
142 }
143
144 fn oldest(&self) -> Option<u64> {
146 self.waiters
147 .iter()
148 .min_by_key(|(_, w)| w.seq)
149 .map(|(&id, _)| id)
150 }
151
152 fn weakest(&self) -> Option<(u64, u32)> {
155 self.waiters
156 .iter()
157 .min_by(|(_, a), (_, b)| a.priority.cmp(&b.priority).then_with(|| b.seq.cmp(&a.seq)))
158 .map(|(&id, w)| (id, w.priority))
159 }
160}
161
162pub struct Queue<L, K = (), C = SystemClock>
189where
190 K: Eq + Hash + Clone + Send + Sync,
191 C: Clock,
192{
193 inner: L,
194 state: Mutex<State<K>>,
195 notify: Notify,
196 capacity: usize,
197 overflow: Overflow,
198 clock: C,
199 epoch: Monotonic,
200}
201
202impl Queue<core::convert::Infallible, ()> {
205 #[must_use]
207 pub fn builder() -> QueueBuilder {
208 QueueBuilder::new()
209 }
210}
211
212impl<L, K, C> Queue<L, K, C>
213where
214 L: Limiter,
215 K: Eq + Hash + Clone + Send + Sync,
216 C: Clock + Clone,
217{
218 fn new(inner: L, capacity: usize, overflow: Overflow, clock: C) -> Self {
219 let epoch = clock.now();
220 Self {
221 inner,
222 state: Mutex::new(State::new()),
223 notify: Notify::new(),
224 capacity: capacity.max(1),
225 overflow,
226 clock,
227 epoch,
228 }
229 }
230
231 #[must_use]
234 pub fn with_clock<C2>(self, clock: C2) -> Queue<L, K, C2>
235 where
236 C2: Clock + Clone,
237 {
238 Queue::new(self.inner, self.capacity, self.overflow, clock)
239 }
240
241 #[must_use]
243 pub fn len(&self) -> usize {
244 self.lock().waiters.len()
245 }
246
247 #[must_use]
249 pub fn is_empty(&self) -> bool {
250 self.lock().waiters.is_empty()
251 }
252
253 #[must_use]
255 pub fn capacity(&self) -> usize {
256 self.capacity
257 }
258
259 pub fn inner(&self) -> &L {
261 &self.inner
262 }
263
264 #[inline]
265 fn lock(&self) -> MutexGuard<'_, State<K>> {
266 self.state.lock().unwrap_or_else(PoisonError::into_inner)
267 }
268
269 #[inline]
270 fn now_ms(&self) -> u64 {
271 let elapsed = self.clock.now().saturating_duration_since(self.epoch);
272 u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
273 }
274
275 fn register(
281 &self,
282 now_ms: u64,
283 priority: u32,
284 deadline_ms: Option<u64>,
285 key: &K,
286 ) -> Result<(u64, Arc<AtomicBool>), ThrottleError> {
287 let mut did_evict = false;
288 let outcome = {
289 let mut state = self.lock();
290 state.prune_expired(now_ms);
291
292 if state.waiters.len() < self.capacity {
293 Ok(state.insert(priority, deadline_ms, key.clone()))
294 } else {
295 match self.overflow {
296 Overflow::Reject => Err(ThrottleError::QueueFull),
297 Overflow::DropOldest => match state.oldest() {
298 Some(victim) => {
299 evict(&mut state, victim);
300 did_evict = true;
301 Ok(state.insert(priority, deadline_ms, key.clone()))
302 }
303 None => Err(ThrottleError::QueueFull),
304 },
305 Overflow::DropLowestPriority => match state.weakest() {
306 Some((victim, weakest)) if priority > weakest => {
308 evict(&mut state, victim);
309 did_evict = true;
310 Ok(state.insert(priority, deadline_ms, key.clone()))
311 }
312 _ => Err(ThrottleError::QueueFull),
313 },
314 }
315 }
316 };
317
318 if did_evict {
319 crate::obs::queue_overflow(match self.overflow {
320 Overflow::Reject => "reject",
321 Overflow::DropOldest => "drop_oldest",
322 Overflow::DropLowestPriority => "drop_lowest_priority",
323 });
324 } else if outcome.is_err() {
325 crate::obs::queue_overflow("reject");
326 }
327 if did_evict || outcome.is_ok() {
328 self.notify.notify_waiters();
329 crate::obs::queue_depth(self.len());
330 }
331 outcome
332 }
333
334 pub async fn acquire(
349 &self,
350 key: K,
351 priority: u32,
352 deadline: Option<Duration>,
353 ) -> Result<(), ThrottleError> {
354 let start_ms = self.now_ms();
355 let deadline_ms = deadline
356 .map(|d| start_ms.saturating_add(u64::try_from(d.as_millis()).unwrap_or(u64::MAX)));
357
358 let timer = crate::obs::Timer::start();
359 let (id, evicted) = self.register(start_ms, priority, deadline_ms, &key)?;
360 let _guard = LeaveGuard { queue: self, id };
362
363 loop {
364 let notified = self.notify.notified();
367 tokio::pin!(notified);
368 let _ = notified.as_mut().enable();
371
372 if evicted.load(Ordering::Acquire) {
373 return Err(ThrottleError::QueueFull);
374 }
375
376 let now_ms = self.now_ms();
377 if deadline_ms.is_some_and(|d| now_ms >= d) {
378 crate::obs::deadline_exceeded();
379 return Err(ThrottleError::DeadlineExceeded);
380 }
381
382 let wait = {
383 let mut state = self.lock();
384 if state.winner(now_ms) == Some(id) {
385 match self.inner.acquire_cost(1) {
386 Decision::Acquired => {
387 state.serve(id);
388 drop(state);
389 self.notify.notify_waiters();
390 crate::obs::acquired("queue");
391 crate::obs::wait("queue", &timer);
392 crate::obs::trace_acquire("queue", 1, true, &timer);
393 return Ok(());
394 }
395 Decision::Impossible => {
396 return Err(ThrottleError::CostExceedsCapacity {
397 cost: 1,
398 capacity: self.inner.capacity(),
399 });
400 }
401 Decision::Retry { after } => after,
403 }
404 } else {
405 Duration::from_secs(3600)
407 }
408 };
409
410 let sleep_for = cap_to_deadline(wait, now_ms, deadline_ms);
411 tokio::select! {
412 () = notified.as_mut() => {}
413 () = tokio::time::sleep(sleep_for) => {}
414 }
415 }
416 }
417}
418
419fn cap_to_deadline(wait: Duration, now_ms: u64, deadline_ms: Option<u64>) -> Duration {
421 match deadline_ms {
422 Some(d) => wait.min(Duration::from_millis(d.saturating_sub(now_ms))),
423 None => wait,
424 }
425}
426
427fn evict<K: Eq + Hash + Clone>(state: &mut State<K>, id: u64) {
429 if let Some(w) = state.waiters.remove(&id) {
430 w.evicted.store(true, Ordering::Release);
431 }
432}
433
434struct LeaveGuard<'a, L, K, C>
436where
437 L: Limiter,
438 K: Eq + Hash + Clone + Send + Sync,
439 C: Clock + Clone,
440{
441 queue: &'a Queue<L, K, C>,
442 id: u64,
443}
444
445impl<L, K, C> Drop for LeaveGuard<'_, L, K, C>
446where
447 L: Limiter,
448 K: Eq + Hash + Clone + Send + Sync,
449 C: Clock + Clone,
450{
451 fn drop(&mut self) {
452 let depth = {
453 let mut state = self.queue.lock();
454 let _ = state.waiters.remove(&self.id);
455 state.waiters.len()
456 };
457 self.queue.notify.notify_waiters();
459 crate::obs::queue_depth(depth);
460 }
461}
462
463#[derive(Debug, Clone, Copy)]
465pub struct QueueBuilder {
466 capacity: usize,
467 overflow: Overflow,
468}
469
470impl Default for QueueBuilder {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476impl QueueBuilder {
477 #[must_use]
479 pub fn new() -> Self {
480 Self {
481 capacity: 1024,
482 overflow: Overflow::Reject,
483 }
484 }
485
486 #[must_use]
488 pub fn capacity(mut self, capacity: usize) -> Self {
489 self.capacity = capacity.max(1);
490 self
491 }
492
493 #[must_use]
495 pub fn overflow(mut self, overflow: Overflow) -> Self {
496 self.overflow = overflow;
497 self
498 }
499
500 #[must_use]
502 pub fn build<L, K>(self, limiter: L) -> Queue<L, K, SystemClock>
503 where
504 L: Limiter,
505 K: Eq + Hash + Clone + Send + Sync,
506 {
507 Queue::new(limiter, self.capacity, self.overflow, SystemClock::new())
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 #![allow(clippy::unwrap_used)]
514
515 use super::{Overflow, Queue};
516 use crate::throttle::Throttle;
517 use core::time::Duration;
518 use std::sync::Arc;
519
520 fn assert_send_sync<T: Send + Sync>() {}
521
522 #[test]
523 fn test_queue_is_send_sync() {
524 assert_send_sync::<Queue<Throttle, &'static str>>();
525 }
526
527 #[tokio::test]
528 async fn test_immediate_acquire_when_token_is_free() {
529 let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(10));
530 assert!(queue.acquire((), 0, None).await.is_ok());
531 assert!(queue.is_empty());
532 }
533
534 #[tokio::test]
535 async fn test_cost_exceeds_capacity_is_reported() {
536 let queue: Queue<Throttle, ()> = Queue::builder().build(Throttle::per_second(0));
537 let err = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
538 assert!(matches!(
539 err,
540 Err(crate::ThrottleError::CostExceedsCapacity { .. })
541 ));
542 }
543
544 #[tokio::test]
545 async fn test_deadline_exceeded_when_no_token_arrives() {
546 let queue: Queue<Throttle, ()> =
549 Queue::builder().build(Throttle::per_duration(1, Duration::from_secs(3600)));
550 assert!(queue.acquire((), 0, None).await.is_ok()); let err = queue.acquire((), 0, Some(Duration::from_millis(30))).await;
553 assert!(matches!(err, Err(crate::ThrottleError::DeadlineExceeded)));
554 assert!(queue.is_empty(), "the expired waiter is removed");
555 }
556
557 #[tokio::test]
558 async fn test_reject_overflow_when_full() {
559 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
562 Queue::builder()
563 .capacity(1)
564 .overflow(Overflow::Reject)
565 .build(Throttle::per_duration(1, Duration::from_secs(3600))),
566 );
567 assert!(queue.acquire((), 0, None).await.is_ok()); let q = Arc::clone(&queue);
570 let parked = tokio::spawn(async move { q.acquire((), 0, None).await });
571 while queue.is_empty() {
572 tokio::task::yield_now().await;
573 }
574 let rejected = queue.acquire((), 0, Some(Duration::from_secs(1))).await;
575 assert!(matches!(rejected, Err(crate::ThrottleError::QueueFull)));
576 parked.abort();
577 }
578
579 #[tokio::test]
580 async fn test_drop_oldest_overflow_evicts_the_first_waiter() {
581 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
582 Queue::builder()
583 .capacity(1)
584 .overflow(Overflow::DropOldest)
585 .build(Throttle::per_duration(1, Duration::from_secs(3600))),
586 );
587 assert!(queue.acquire((), 0, None).await.is_ok()); let q = Arc::clone(&queue);
591 let first = tokio::spawn(async move { q.acquire((), 0, None).await });
592 while queue.is_empty() {
593 tokio::task::yield_now().await;
594 }
595 let q = Arc::clone(&queue);
597 let second = tokio::spawn(async move { q.acquire((), 0, None).await });
598 let first_result = first.await.unwrap();
599 assert!(matches!(first_result, Err(crate::ThrottleError::QueueFull)));
600 second.abort();
601 }
602
603 #[tokio::test]
604 async fn test_priority_is_served_high_first() {
605 use std::sync::atomic::{AtomicU32, Ordering};
606
607 let queue: Arc<Queue<Throttle, ()>> = Arc::new(
611 Queue::builder()
612 .capacity(10)
613 .build(Throttle::per_duration(1, Duration::from_millis(50))),
614 );
615 assert!(queue.acquire((), 0, None).await.is_ok()); let order = Arc::new(std::sync::Mutex::new(Vec::new()));
618 let started = Arc::new(AtomicU32::new(0));
619
620 let mut handles = Vec::new();
621 for priority in [1u32, 5, 3] {
622 let q = Arc::clone(&queue);
623 let order = Arc::clone(&order);
624 let started = Arc::clone(&started);
625 handles.push(tokio::spawn(async move {
626 let _ = started.fetch_add(1, Ordering::Relaxed);
627 q.acquire((), priority, None).await.unwrap();
628 order.lock().unwrap().push(priority);
629 }));
630 }
631 while queue.len() < 3 {
633 tokio::task::yield_now().await;
634 }
635 for h in handles {
636 h.await.unwrap();
637 }
638
639 assert_eq!(*order.lock().unwrap(), vec![5, 3, 1]);
640 }
641
642 #[test]
643 fn test_fair_winner_rotates_across_keys_at_equal_priority() {
644 use super::{State, Waiter};
645 use std::sync::atomic::AtomicBool;
646
647 fn enqueue(state: &mut State<&'static str>, id: u64, priority: u32, key: &'static str) {
648 let _ = state.waiters.insert(
649 id,
650 Waiter {
651 seq: id,
652 priority,
653 deadline_ms: None,
654 key,
655 evicted: Arc::new(AtomicBool::new(false)),
656 },
657 );
658 }
659
660 let mut state = State::<&'static str>::new();
661 enqueue(&mut state, 0, 0, "a");
663 enqueue(&mut state, 1, 0, "a");
664 enqueue(&mut state, 2, 0, "b");
665
666 assert_eq!(state.winner(0), Some(0));
668 state.serve(0);
669 assert_eq!(state.winner(0), Some(2));
672 state.serve(2);
673 assert_eq!(state.winner(0), Some(1));
675 }
676
677 #[test]
678 fn test_priority_beats_fairness_in_winner_selection() {
679 use super::{State, Waiter};
680 use std::sync::atomic::AtomicBool;
681
682 let mut state = State::<&'static str>::new();
683 let _ = state.waiters.insert(
684 0,
685 Waiter {
686 seq: 0,
687 priority: 1,
688 deadline_ms: None,
689 key: "a",
690 evicted: Arc::new(AtomicBool::new(false)),
691 },
692 );
693 let _ = state.waiters.insert(
694 1,
695 Waiter {
696 seq: 1,
697 priority: 9,
698 deadline_ms: None,
699 key: "b",
700 evicted: Arc::new(AtomicBool::new(false)),
701 },
702 );
703 assert_eq!(state.winner(0), Some(1));
705 }
706
707 #[test]
708 fn test_winner_skips_expired_waiters() {
709 use super::{State, Waiter};
710 use std::sync::atomic::AtomicBool;
711
712 let mut state = State::<&'static str>::new();
713 let _ = state.waiters.insert(
714 0,
715 Waiter {
716 seq: 0,
717 priority: 9,
718 deadline_ms: Some(100),
719 key: "a",
720 evicted: Arc::new(AtomicBool::new(false)),
721 },
722 );
723 let _ = state.waiters.insert(
724 1,
725 Waiter {
726 seq: 1,
727 priority: 1,
728 deadline_ms: None,
729 key: "b",
730 evicted: Arc::new(AtomicBool::new(false)),
731 },
732 );
733 assert_eq!(state.winner(200), Some(1));
735 }
736}