1use core::time::Duration;
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard};
8
9use ahash::RandomState;
10use clock_lib::{Clock, Monotonic, SystemClock};
11
12use crate::decision::Decision;
13#[cfg(feature = "runtime")]
14use crate::error::ThrottleError;
15use crate::eviction::Eviction;
16use crate::limiter::Limiter;
17use crate::throttle::Throttle;
18
19const DEFAULT_SHARDS: usize = 16;
22
23struct Entry<C: Clock> {
30 throttle: Throttle<C>,
31 last_seen: AtomicU64,
32}
33
34struct Shard<K, C: Clock> {
37 map: RwLock<HashMap<K, Entry<C>, RandomState>>,
38 seq: AtomicU64,
41}
42
43impl<K, C: Clock> Shard<K, C> {
44 fn new() -> Self {
45 Self {
46 map: RwLock::new(HashMap::default()),
47 seq: AtomicU64::new(0),
48 }
49 }
50}
51
52pub struct PerKey<K, C = SystemClock>
83where
84 C: Clock,
85{
86 shards: Box<[Shard<K, C>]>,
87 shard_mask: u64,
90 hasher: RandomState,
91 eviction: Eviction,
92 amount: u32,
93 period: Duration,
94 clock: C,
95 epoch: Monotonic,
96}
97
98impl<K> PerKey<K, SystemClock>
99where
100 K: Eq + Hash + Clone + Send + Sync + 'static,
101{
102 #[must_use]
114 pub fn per_second(rate: u32) -> Self {
115 Self::build(
116 rate,
117 Duration::from_secs(1),
118 SystemClock::new(),
119 DEFAULT_SHARDS,
120 Eviction::default(),
121 )
122 }
123
124 #[must_use]
137 pub fn per_duration(amount: u32, period: Duration) -> Self {
138 Self::build(
139 amount,
140 period,
141 SystemClock::new(),
142 DEFAULT_SHARDS,
143 Eviction::default(),
144 )
145 }
146}
147
148impl<K, C> PerKey<K, C>
149where
150 K: Eq + Hash + Clone + Send + Sync + 'static,
151 C: Clock + Clone,
152{
153 fn build(amount: u32, period: Duration, clock: C, shards: usize, eviction: Eviction) -> Self {
154 let shard_count = shards.max(1).next_power_of_two();
155 let shards = (0..shard_count)
156 .map(|_| Shard::new())
157 .collect::<Vec<_>>()
158 .into_boxed_slice();
159 let epoch = clock.now();
160 Self {
161 shards,
162 shard_mask: shard_count as u64 - 1,
163 hasher: RandomState::new(),
164 eviction,
165 amount,
166 period,
167 clock,
168 epoch,
169 }
170 }
171
172 #[must_use]
193 pub fn with_clock<C2>(self, clock: C2) -> PerKey<K, C2>
194 where
195 C2: Clock + Clone,
196 {
197 PerKey::build(
198 self.amount,
199 self.period,
200 clock,
201 self.shards.len(),
202 self.eviction,
203 )
204 }
205
206 #[must_use]
219 pub fn with_eviction(mut self, eviction: Eviction) -> Self {
220 self.eviction = eviction;
221 self
222 }
223
224 #[must_use]
238 pub fn with_shards(self, shards: usize) -> Self {
239 PerKey::build(self.amount, self.period, self.clock, shards, self.eviction)
240 }
241
242 #[inline]
244 #[must_use]
245 pub fn capacity(&self) -> u32 {
246 self.amount
247 }
248
249 #[inline]
251 #[must_use]
252 pub fn shard_count(&self) -> usize {
253 self.shards.len()
254 }
255
256 #[must_use]
261 pub fn len(&self) -> usize {
262 self.shards
263 .iter()
264 .map(|shard| read_guard(&shard.map).len())
265 .sum()
266 }
267
268 #[must_use]
270 pub fn is_empty(&self) -> bool {
271 self.shards
272 .iter()
273 .all(|shard| read_guard(&shard.map).is_empty())
274 }
275
276 #[inline]
289 #[must_use]
290 pub fn try_acquire(&self, key: &K) -> bool {
291 self.try_acquire_with_cost(key, 1)
292 }
293
294 #[inline]
296 #[must_use]
297 pub fn try_acquire_with_cost(&self, key: &K, cost: u32) -> bool {
298 self.decide(key, cost).is_acquired()
299 }
300
301 #[inline]
304 #[must_use]
305 pub fn peek(&self, key: &K, cost: u32) -> Decision {
306 let shard = self.shard_for(key);
307 let guard = read_guard(&shard.map);
308 match guard.get(key) {
309 Some(entry) => entry.throttle.peek(cost),
310 None if cost > self.amount => Decision::Impossible,
312 None => Decision::Acquired,
313 }
314 }
315
316 #[must_use]
319 pub fn available(&self, key: &K) -> u32 {
320 let shard = self.shard_for(key);
321 let guard = read_guard(&shard.map);
322 guard
323 .get(key)
324 .map_or(self.amount, |entry| entry.throttle.available())
325 }
326
327 #[inline]
329 fn make_throttle(&self) -> Throttle<C> {
330 Throttle::per_duration(self.amount, self.period).with_clock(self.clock.clone())
331 }
332
333 #[inline]
335 fn now_ms(&self) -> u64 {
336 let elapsed = self.clock.now().saturating_duration_since(self.epoch);
337 u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
338 }
339
340 #[inline]
343 fn stamp(&self, shard: &Shard<K, C>, now_ms: u64) -> u64 {
344 if self.eviction.idle_ttl().is_some() {
345 now_ms
346 } else {
347 shard.seq.fetch_add(1, Ordering::Relaxed)
348 }
349 }
350
351 #[inline]
352 fn shard_for(&self, key: &K) -> &Shard<K, C> {
353 let index = (self.hasher.hash_one(key) & self.shard_mask) as usize;
354 &self.shards[index]
355 }
356
357 fn decide(&self, key: &K, cost: u32) -> Decision {
360 let now_ms = self.now_ms();
361 let shard = self.shard_for(key);
362
363 {
367 let guard = read_guard(&shard.map);
368 if let Some(entry) = guard.get(key) {
369 entry
370 .last_seen
371 .store(self.stamp(shard, now_ms), Ordering::Relaxed);
372 return entry.throttle.acquire_cost(cost);
373 }
374 }
375
376 let mut guard = write_guard(&shard.map);
379 if let Some(entry) = guard.get(key) {
380 entry
381 .last_seen
382 .store(self.stamp(shard, now_ms), Ordering::Relaxed);
383 return entry.throttle.acquire_cost(cost);
384 }
385
386 let stamp = self.stamp(shard, now_ms);
387 self.evict_for_insert(&mut guard, now_ms);
388 let throttle = self.make_throttle();
389 let outcome = throttle.acquire_cost(cost);
390 let _ = guard.insert(
391 key.clone(),
392 Entry {
393 throttle,
394 last_seen: AtomicU64::new(stamp),
395 },
396 );
397 outcome
398 }
399
400 fn evict_for_insert(&self, map: &mut HashMap<K, Entry<C>, RandomState>, now_ms: u64) {
404 if let Some(ttl) = self.eviction.idle_ttl() {
405 let ttl_ms = u64::try_from(ttl.as_millis()).unwrap_or(u64::MAX);
406 map.retain(|_, entry| {
407 now_ms.saturating_sub(entry.last_seen.load(Ordering::Relaxed)) < ttl_ms
408 });
409 }
410
411 if let Some(max) = self.eviction.max_keys() {
412 let per_shard_cap = max.div_ceil(self.shards.len()).max(1);
413 while map.len() >= per_shard_cap {
414 let victim = map
415 .iter()
416 .min_by_key(|(_, entry)| entry.last_seen.load(Ordering::Relaxed))
417 .map(|(key, _)| key.clone());
418 match victim {
419 Some(key) => {
420 let _ = map.remove(&key);
421 }
422 None => break,
423 }
424 }
425 }
426 }
427}
428
429#[cfg(feature = "runtime")]
430#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
431impl<K, C> PerKey<K, C>
432where
433 K: Eq + Hash + Clone + Send + Sync + 'static,
434 C: Clock + Clone,
435{
436 pub async fn acquire(&self, key: &K) -> Result<(), ThrottleError> {
455 self.acquire_with_cost(key, 1).await
456 }
457
458 pub async fn acquire_with_cost(&self, key: &K, cost: u32) -> Result<(), ThrottleError> {
465 loop {
466 match self.decide(key, cost) {
467 Decision::Acquired => return Ok(()),
468 Decision::Impossible => {
469 return Err(ThrottleError::CostExceedsCapacity {
470 cost,
471 capacity: self.amount,
472 });
473 }
474 Decision::Retry { after } => crate::rt::sleep(after).await,
475 }
476 }
477 }
478}
479
480impl<K, C> crate::limiter::KeyedLimiter<K> for PerKey<K, C>
481where
482 K: Eq + Hash + Clone + Send + Sync + 'static,
483 C: Clock + Clone + 'static,
484{
485 #[inline]
486 fn peek(&self, key: &K, cost: u32) -> Decision {
487 PerKey::peek(self, key, cost)
488 }
489
490 #[inline]
491 fn try_acquire_with_cost(&self, key: &K, cost: u32) -> bool {
492 PerKey::try_acquire_with_cost(self, key, cost)
493 }
494
495 #[inline]
496 fn capacity(&self) -> u32 {
497 PerKey::capacity(self)
498 }
499}
500
501fn read_guard<T>(lock: &RwLock<T>) -> RwLockReadGuard<'_, T> {
504 lock.read().unwrap_or_else(PoisonError::into_inner)
505}
506
507fn write_guard<T>(lock: &RwLock<T>) -> RwLockWriteGuard<'_, T> {
509 lock.write().unwrap_or_else(PoisonError::into_inner)
510}
511
512#[cfg(test)]
513mod tests {
514 #![allow(clippy::unwrap_used)]
515
516 use super::PerKey;
517 use crate::eviction::Eviction;
518 use clock_lib::ManualClock;
519 use core::time::Duration;
520 use std::sync::Arc;
521
522 fn assert_send_sync<T: Send + Sync>() {}
523
524 #[test]
525 fn test_perkey_is_send_sync() {
526 assert_send_sync::<PerKey<String>>();
527 assert_send_sync::<PerKey<u64>>();
528 }
529
530 #[test]
531 fn test_keys_are_independent() {
532 let limiter: PerKey<&str> = PerKey::per_second(1);
533 assert!(limiter.try_acquire(&"a"));
534 assert!(!limiter.try_acquire(&"a")); assert!(limiter.try_acquire(&"b")); }
537
538 #[test]
539 fn test_first_acquire_creates_exactly_one_key() {
540 let limiter: PerKey<&str> = PerKey::per_second(10);
541 assert_eq!(limiter.len(), 0);
542 assert!(limiter.try_acquire(&"a"));
543 assert_eq!(limiter.len(), 1);
544 assert!(limiter.try_acquire(&"a"));
545 assert_eq!(limiter.len(), 1);
546 }
547
548 #[test]
549 fn test_shard_count_rounds_up_to_power_of_two() {
550 assert_eq!(PerKey::<u64>::per_second(1).with_shards(5).shard_count(), 8);
551 assert_eq!(
552 PerKey::<u64>::per_second(1).with_shards(16).shard_count(),
553 16
554 );
555 assert_eq!(PerKey::<u64>::per_second(1).with_shards(0).shard_count(), 1);
556 }
557
558 #[test]
559 fn test_peek_does_not_create_state() {
560 let limiter: PerKey<&str> = PerKey::per_second(5);
561 assert!(limiter.peek(&"ghost", 1).is_acquired());
562 assert_eq!(limiter.len(), 0, "peek must not insert a key");
563 }
564
565 #[test]
566 fn test_available_reports_full_capacity_for_unseen_key() {
567 let limiter: PerKey<&str> = PerKey::per_second(7);
568 assert_eq!(limiter.available(&"unseen"), 7);
569 assert!(limiter.try_acquire_with_cost(&"seen", 3));
570 assert_eq!(limiter.available(&"seen"), 4);
571 }
572
573 #[test]
574 fn test_refill_under_manual_clock() {
575 let clock = Arc::new(ManualClock::new());
576 let limiter = PerKey::<&str>::per_second(2).with_clock(clock.clone());
577
578 assert!(limiter.try_acquire(&"k"));
579 assert!(limiter.try_acquire(&"k"));
580 assert!(!limiter.try_acquire(&"k"));
581
582 clock.advance(Duration::from_secs(1));
583 assert!(limiter.try_acquire(&"k"));
584 }
585
586 #[test]
587 fn test_capacity_bounds_total_keys_under_unique_flood() {
588 let shards = 8;
589 let cap = 100usize;
590 let limiter: PerKey<u64> = PerKey::per_second(10)
591 .with_shards(shards)
592 .with_eviction(Eviction::capacity(cap));
593
594 for k in 0..10_000u64 {
595 let _ = limiter.try_acquire(&k);
596 }
597
598 let per_shard_cap = cap.div_ceil(shards).max(1);
599 let bound = per_shard_cap * shards;
600 assert!(
601 limiter.len() <= bound,
602 "flood grew to {} keys, bound {bound}",
603 limiter.len()
604 );
605 }
606
607 #[test]
608 fn test_ttl_reclaims_idle_keys_on_later_insert() {
609 let clock = Arc::new(ManualClock::new());
610 let limiter = PerKey::<&str>::per_second(10)
611 .with_clock(clock.clone())
612 .with_eviction(Eviction::idle(Duration::from_millis(1000)).with_capacity(1))
613 .with_shards(1);
614
615 assert!(limiter.try_acquire(&"idle"));
616 assert_eq!(limiter.len(), 1);
617
618 clock.advance(Duration::from_millis(2000));
619 assert!(limiter.try_acquire(&"fresh"));
621 assert_eq!(limiter.len(), 1, "the idle key should have been reclaimed");
622 }
623
624 #[test]
625 fn test_recently_seen_key_survives_eviction_pressure() {
626 let limiter: PerKey<String> = PerKey::per_second(1_000)
627 .with_shards(1)
628 .with_eviction(Eviction::capacity(4));
629
630 for round in 0..50u64 {
631 assert!(limiter.try_acquire(&"hot".to_string()));
632 let _ = limiter.try_acquire(&round.to_string());
633 }
634 assert!(limiter.try_acquire(&"hot".to_string()));
636 }
637
638 #[cfg(feature = "runtime")]
639 #[tokio::test]
640 async fn test_acquire_errors_when_cost_exceeds_capacity() {
641 use crate::error::ThrottleError;
642
643 let limiter: PerKey<&str> = PerKey::per_second(5);
644 let err = limiter.acquire_with_cost(&"k", 9).await.unwrap_err();
645 assert_eq!(
646 err,
647 ThrottleError::CostExceedsCapacity {
648 cost: 9,
649 capacity: 5,
650 }
651 );
652 }
653
654 #[cfg(feature = "runtime")]
655 #[tokio::test]
656 async fn test_acquire_waits_then_succeeds() {
657 let limiter: PerKey<&str> = PerKey::per_second(1000);
658 for _ in 0..1000 {
659 assert!(limiter.try_acquire(&"k"));
660 }
661 assert!(!limiter.try_acquire(&"k"));
662 assert!(limiter.acquire(&"k").await.is_ok());
663 }
664}