1use std::sync::Arc;
2use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
3use std::time::{Duration, Instant};
4use tokio::time::sleep;
5
6pub struct TokenBucket {
14 tokens: AtomicU64,
15 capacity: u64,
16 refill_rate: AtomicU64, last_refill: Arc<AtomicInstant>,
18 adaptive_config: Arc<AdaptiveConfig>,
20 metrics: Arc<RateMetrics>,
21 congestion_state: AtomicU8, }
23
24#[derive(Debug, Clone)]
26pub struct AdaptiveConfig {
27 pub min_rate: u64,
29 pub max_rate: u64,
31 pub target_utilization: f64,
33 pub congestion_threshold: f64,
35 pub recovery_factor: f64,
37 pub growth_factor: f64,
39 pub measurement_window_ms: u64,
41}
42
43impl Default for AdaptiveConfig {
44 fn default() -> Self {
45 Self {
46 min_rate: 1024, max_rate: 100 * 1024 * 1024, target_utilization: 0.8,
49 congestion_threshold: 0.95,
50 recovery_factor: 0.5,
51 growth_factor: 1.1,
52 measurement_window_ms: 1000, }
54 }
55}
56
57#[derive(Debug)]
59pub struct RateMetrics {
60 pub total_bytes: AtomicU64,
62 pub total_wait_time_us: AtomicU64,
64 pub acquisition_count: AtomicU64,
66 pub wait_count: AtomicU64,
68 pub last_measurement: AtomicU64, pub window_bytes: AtomicU64,
72}
73
74impl Default for RateMetrics {
75 fn default() -> Self {
76 let now = std::time::SystemTime::now()
77 .duration_since(std::time::UNIX_EPOCH)
78 .unwrap()
79 .as_millis() as u64;
80 Self {
81 total_bytes: AtomicU64::new(0),
82 total_wait_time_us: AtomicU64::new(0),
83 acquisition_count: AtomicU64::new(0),
84 wait_count: AtomicU64::new(0),
85 last_measurement: AtomicU64::new(now),
86 window_bytes: AtomicU64::new(0),
87 }
88 }
89}
90
91impl RateMetrics {
92 pub fn record_acquisition(&self, bytes: u64, wait_time_us: u64) {
93 self.total_bytes.fetch_add(bytes, Ordering::Relaxed);
94 self.total_wait_time_us
95 .fetch_add(wait_time_us, Ordering::Relaxed);
96 self.acquisition_count.fetch_add(1, Ordering::Relaxed);
97 if wait_time_us > 0 {
98 self.wait_count.fetch_add(1, Ordering::Relaxed);
99 }
100 self.window_bytes.fetch_add(bytes, Ordering::Relaxed);
101 }
102
103 pub fn get_throughput(&self) -> f64 {
104 let total_bytes = self.total_bytes.load(Ordering::Relaxed);
105 let total_wait_us = self.total_wait_time_us.load(Ordering::Relaxed);
106 let count = self.acquisition_count.load(Ordering::Relaxed);
107
108 if count == 0 {
109 return 0.0;
110 }
111
112 let total_time_s = total_wait_us as f64 / 1_000_000.0;
114 if total_time_s > 0.0 {
115 total_bytes as f64 / total_time_s
116 } else {
117 0.0
118 }
119 }
120
121 pub fn get_utilization(&self, current_rate: u64) -> f64 {
122 let window_bytes = self.window_bytes.load(Ordering::Relaxed);
123 let window_duration_s = 1.0; let expected_bytes = current_rate as f64 * window_duration_s;
125
126 if expected_bytes > 0.0 {
127 window_bytes as f64 / expected_bytes
128 } else {
129 0.0
130 }
131 }
132
133 pub fn reset_window(&self) {
134 self.window_bytes.store(0, Ordering::Relaxed);
135 let now = std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .unwrap()
138 .as_millis() as u64;
139 self.last_measurement.store(now, Ordering::Relaxed);
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq)]
144enum CongestionState {
145 Normal = 0,
146 Congestion = 1,
147 Recovery = 2,
148}
149
150#[derive(Debug)]
152struct AtomicInstant {
153 instant: std::sync::Mutex<Instant>,
154}
155
156impl AtomicInstant {
157 fn new(instant: Instant) -> Self {
158 Self {
159 instant: std::sync::Mutex::new(instant),
160 }
161 }
162
163 fn get(&self) -> Instant {
164 *self.instant.lock().unwrap()
165 }
166
167 fn set(&self, instant: Instant) {
168 *self.instant.lock().unwrap() = instant;
169 }
170}
171
172impl TokenBucket {
173 pub fn new(capacity: u64, refill_rate: u64) -> Self {
180 let now = Instant::now();
181 Self {
182 tokens: AtomicU64::new(capacity),
183 capacity,
184 refill_rate: AtomicU64::new(refill_rate),
185 last_refill: Arc::new(AtomicInstant::new(now)),
186 adaptive_config: Arc::new(AdaptiveConfig::default()),
187 metrics: Arc::new(RateMetrics::default()),
188 congestion_state: AtomicU8::new(CongestionState::Normal as u8),
189 }
190 }
191
192 pub fn new_adaptive(capacity: u64, refill_rate: u64, config: AdaptiveConfig) -> Self {
200 let now = Instant::now();
201 Self {
202 tokens: AtomicU64::new(capacity),
203 capacity,
204 refill_rate: AtomicU64::new(refill_rate),
205 last_refill: Arc::new(AtomicInstant::new(now)),
206 adaptive_config: Arc::new(config),
207 metrics: Arc::new(RateMetrics::default()),
208 congestion_state: AtomicU8::new(CongestionState::Normal as u8),
209 }
210 }
211
212 pub async fn acquire(&self, bytes: usize) {
220 let tokens_needed = bytes as u64;
221 let start_time = Instant::now();
222
223 loop {
224 self.refill();
226
227 let current_tokens = self.tokens.load(Ordering::Relaxed);
229 if current_tokens >= tokens_needed {
230 if self
232 .tokens
233 .compare_exchange_weak(
234 current_tokens,
235 current_tokens - tokens_needed,
236 Ordering::Relaxed,
237 Ordering::Relaxed,
238 )
239 .is_ok()
240 {
241 let wait_time_us = start_time.elapsed().as_micros() as u64;
242 self.metrics.record_acquisition(tokens_needed, wait_time_us);
243
244 return;
245 }
246 continue;
248 }
249
250 let deficit = tokens_needed - current_tokens;
252 let current_rate = self.refill_rate.load(Ordering::Relaxed);
253
254 let wait_time = Duration::from_secs_f64(deficit as f64 / current_rate as f64);
256
257 sleep(wait_time).await;
259 }
260 }
261
262 pub fn try_acquire(&self, bytes: usize) -> bool {
270 let tokens_needed = bytes as u64;
271
272 self.refill();
274
275 let current_tokens = self.tokens.load(Ordering::Relaxed);
277 if current_tokens >= tokens_needed {
278 if self
280 .tokens
281 .compare_exchange_weak(
282 current_tokens,
283 current_tokens - tokens_needed,
284 Ordering::Relaxed,
285 Ordering::Relaxed,
286 )
287 .is_ok()
288 {
289 return true;
290 }
291 }
292
293 false
294 }
295
296 fn refill(&self) {
298 let now = Instant::now();
299 let last_refill = self.last_refill.get();
300 let elapsed = now.duration_since(last_refill);
301
302 if elapsed.as_secs_f64() > 0.0 {
303 let current_rate = self.refill_rate.load(Ordering::Relaxed);
304 let tokens_to_add = (current_rate as f64 * elapsed.as_secs_f64()) as u64;
305 let current_tokens = self.tokens.load(Ordering::Relaxed);
306 let new_tokens = (current_tokens + tokens_to_add).min(self.capacity);
307
308 self.tokens.store(new_tokens, Ordering::Relaxed);
309 self.last_refill.set(now);
310 }
311 }
312
313 pub fn available_tokens(&self) -> u64 {
315 self.refill();
316 self.tokens.load(Ordering::Relaxed)
317 }
318
319 pub fn current_rate(&self) -> u64 {
321 self.refill_rate.load(Ordering::Relaxed)
322 }
323
324 pub fn check_and_adjust_rate(&self) {
326 let config = &self.adaptive_config;
327 let current_rate = self.refill_rate.load(Ordering::Relaxed);
328 let utilization = self.metrics.get_utilization(current_rate);
329
330 let now = std::time::SystemTime::now()
332 .duration_since(std::time::UNIX_EPOCH)
333 .unwrap()
334 .as_millis() as u64;
335 let last_measurement = self.metrics.last_measurement.load(Ordering::Relaxed);
336
337 if now - last_measurement >= config.measurement_window_ms {
338 self.adjust_rate_based_on_conditions(utilization);
339 self.metrics.reset_window();
340 }
341 }
342
343 fn adjust_rate_based_on_conditions(&self, utilization: f64) {
345 let config = &self.adaptive_config;
346 let current_rate = self.refill_rate.load(Ordering::Relaxed);
347 let current_state = self.congestion_state.load(Ordering::Relaxed);
348
349 match current_state {
350 0 => {
351 if utilization > config.congestion_threshold {
353 self.congestion_state
355 .store(CongestionState::Congestion as u8, Ordering::Relaxed);
356 let new_rate = (current_rate as f64 * config.recovery_factor)
357 .max(config.min_rate as f64) as u64;
358 self.refill_rate.store(new_rate, Ordering::Relaxed);
359 } else if utilization < config.target_utilization {
360 let new_rate = (current_rate as f64 * config.growth_factor)
362 .min(config.max_rate as f64) as u64;
363 self.refill_rate.store(new_rate, Ordering::Relaxed);
364 }
365 }
366 1 => {
367 if utilization < config.target_utilization {
369 self.congestion_state
371 .store(CongestionState::Recovery as u8, Ordering::Relaxed);
372 } else {
373 let new_rate = (current_rate as f64 * config.recovery_factor)
375 .max(config.min_rate as f64) as u64;
376 self.refill_rate.store(new_rate, Ordering::Relaxed);
377 }
378 }
379 2 => {
380 if utilization < config.congestion_threshold {
382 self.congestion_state
384 .store(CongestionState::Normal as u8, Ordering::Relaxed);
385 let new_rate = (current_rate as f64 * config.growth_factor)
386 .min(config.max_rate as f64) as u64;
387 self.refill_rate.store(new_rate, Ordering::Relaxed);
388 } else {
389 self.congestion_state
391 .store(CongestionState::Congestion as u8, Ordering::Relaxed);
392 }
393 }
394 _ => {}
395 }
396 }
397
398 pub fn get_metrics(&self) -> &RateMetrics {
400 &self.metrics
401 }
402
403 pub fn set_rate(&self, new_rate: u64) {
405 let config = &self.adaptive_config;
406 let clamped_rate = new_rate.clamp(config.min_rate, config.max_rate);
407 self.refill_rate.store(clamped_rate, Ordering::Relaxed);
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use tokio::time::{Duration, sleep};
415
416 #[tokio::test]
417 async fn test_token_bucket_basic() {
418 let bucket = TokenBucket::new(100, 50);
420
421 bucket.acquire(50).await;
423 assert!(bucket.available_tokens() <= 50);
424
425 bucket.acquire(50).await;
427 assert_eq!(bucket.available_tokens(), 0);
428
429 tokio::time::sleep(Duration::from_millis(10)).await;
431
432 let start = Instant::now();
434 bucket.acquire(25).await;
435 let elapsed = start.elapsed();
436
437 assert!(elapsed >= Duration::from_millis(450));
439 assert!(elapsed <= Duration::from_millis(550));
440 }
441
442 #[tokio::test]
443 async fn test_token_bucket_refill() {
444 let bucket = TokenBucket::new(100, 100);
445
446 bucket.acquire(100).await;
448 assert_eq!(bucket.available_tokens(), 0);
449
450 sleep(Duration::from_millis(100)).await;
452
453 let available = bucket.available_tokens();
455 assert!(available > 5); assert!(available <= 15); }
458
459 #[tokio::test]
460 async fn test_token_bucket_concurrent() {
461 let bucket = Arc::new(TokenBucket::new(1000, 100));
462 let mut handles = vec![];
463
464 for _ in 0..10 {
466 let bucket_clone = Arc::clone(&bucket);
467 let handle = tokio::spawn(async move {
468 bucket_clone.acquire(100).await;
469 });
470 handles.push(handle);
471 }
472
473 for handle in handles {
475 handle.await.unwrap();
476 }
477
478 assert_eq!(bucket.available_tokens(), 0);
480 }
481
482 #[tokio::test]
483 async fn test_adaptive_rate_limiting() {
484 let config = AdaptiveConfig {
485 min_rate: 10,
486 max_rate: 1000,
487 target_utilization: 0.8,
488 congestion_threshold: 0.9,
489 recovery_factor: 0.5,
490 growth_factor: 1.2,
491 measurement_window_ms: 100,
492 };
493
494 let bucket = TokenBucket::new_adaptive(100, 100, config);
495
496 assert_eq!(bucket.current_rate(), 100);
498
499 bucket.set_rate(500);
501 assert_eq!(bucket.current_rate(), 500);
502
503 bucket.set_rate(2000);
505 assert_eq!(bucket.current_rate(), 1000);
506
507 bucket.set_rate(5);
509 assert_eq!(bucket.current_rate(), 10);
510 }
511
512 #[tokio::test]
513 async fn test_congestion_detection() {
514 let config = AdaptiveConfig {
515 min_rate: 10,
516 max_rate: 1000,
517 target_utilization: 0.5,
518 congestion_threshold: 0.8,
519 recovery_factor: 0.5,
520 growth_factor: 1.2,
521 measurement_window_ms: 50,
522 };
523
524 let bucket = TokenBucket::new_adaptive(1000, 100, config);
525
526 for _ in 0..20 {
528 bucket.acquire(10).await;
529 }
530
531 sleep(Duration::from_millis(60)).await;
533
534 bucket.check_and_adjust_rate();
536
537 bucket.acquire(10).await;
539
540 assert!(bucket.current_rate() < 100);
542 }
543
544 #[tokio::test]
545 async fn test_metrics_collection() {
546 let bucket = TokenBucket::new(100, 100);
547 let metrics = bucket.get_metrics();
548
549 assert_eq!(metrics.total_bytes.load(Ordering::Relaxed), 0);
551 assert_eq!(metrics.acquisition_count.load(Ordering::Relaxed), 0);
552
553 bucket.acquire(50).await;
555
556 assert!(metrics.total_bytes.load(Ordering::Relaxed) > 0);
558 assert!(metrics.acquisition_count.load(Ordering::Relaxed) > 0);
559 }
560}