chio_guards/external/
token_bucket.rs1use std::sync::Arc;
11use std::sync::Mutex;
12
13use tokio::time::Instant;
14
15use super::cache::{Clock, TokioClock};
16
17pub struct TokenBucket {
23 inner: Mutex<BucketInner>,
24 rate_per_second: f64,
25 burst: f64,
26 clock: Arc<dyn Clock>,
27}
28
29#[derive(Debug)]
30struct BucketInner {
31 tokens: f64,
32 last_refill: Instant,
33}
34
35impl TokenBucket {
36 pub fn new(rate_per_second: f64, burst: u32) -> Self {
41 Self::with_clock(rate_per_second, burst, Arc::new(TokioClock))
42 }
43
44 pub fn with_clock(rate_per_second: f64, burst: u32, clock: Arc<dyn Clock>) -> Self {
46 let rate = rate_per_second.max(0.0);
47 let burst_f = f64::from(burst.max(1));
48 let now = clock.now();
49 Self {
50 inner: Mutex::new(BucketInner {
51 tokens: burst_f,
52 last_refill: now,
53 }),
54 rate_per_second: rate,
55 burst: burst_f,
56 clock,
57 }
58 }
59
60 pub fn rate_per_second(&self) -> f64 {
62 self.rate_per_second
63 }
64
65 pub fn burst(&self) -> u32 {
67 self.burst as u32
68 }
69
70 pub fn try_acquire(&self) -> bool {
73 self.try_acquire_n(1.0)
74 }
75
76 pub fn try_acquire_n(&self, n: f64) -> bool {
78 if n <= 0.0 {
79 return true;
80 }
81 let now = self.clock.now();
82 let Ok(mut inner) = self.inner.lock() else {
83 return false;
84 };
85 self.refill(&mut inner, now);
86 if inner.tokens + f64::EPSILON >= n {
87 inner.tokens -= n;
88 if inner.tokens < 0.0 {
89 inner.tokens = 0.0;
90 }
91 true
92 } else {
93 false
94 }
95 }
96
97 pub fn available(&self) -> f64 {
99 let now = self.clock.now();
100 let Ok(mut inner) = self.inner.lock() else {
101 return 0.0;
102 };
103 self.refill(&mut inner, now);
104 inner.tokens
105 }
106
107 fn refill(&self, inner: &mut BucketInner, now: Instant) {
108 if self.rate_per_second == 0.0 {
109 inner.last_refill = now;
110 return;
111 }
112 let elapsed = now
113 .saturating_duration_since(inner.last_refill)
114 .as_secs_f64();
115 if elapsed <= 0.0 {
116 return;
117 }
118 let added = elapsed * self.rate_per_second;
119 inner.tokens = (inner.tokens + added).min(self.burst);
120 inner.last_refill = now;
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use std::time::Duration;
128
129 #[tokio::test(flavor = "current_thread", start_paused = true)]
130 async fn starts_full_with_burst_capacity() {
131 let bucket = TokenBucket::new(1.0, 3);
132 assert!(bucket.try_acquire());
133 assert!(bucket.try_acquire());
134 assert!(bucket.try_acquire());
135 assert!(!bucket.try_acquire());
136 }
137
138 #[tokio::test(flavor = "current_thread", start_paused = true)]
139 async fn refills_at_configured_rate() {
140 let bucket = TokenBucket::new(2.0, 2);
141 assert!(bucket.try_acquire());
142 assert!(bucket.try_acquire());
143 assert!(!bucket.try_acquire());
144 tokio::time::advance(Duration::from_millis(1100)).await;
145 assert!(bucket.try_acquire());
147 assert!(bucket.try_acquire());
148 assert!(!bucket.try_acquire());
149 }
150
151 #[tokio::test(flavor = "current_thread", start_paused = true)]
152 async fn zero_rate_only_uses_burst() {
153 let bucket = TokenBucket::new(0.0, 1);
154 assert!(bucket.try_acquire());
155 tokio::time::advance(Duration::from_secs(60)).await;
156 assert!(!bucket.try_acquire());
157 }
158}