Skip to main content

pulith_fetch/rate/
bandwidth.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
3use std::time::{Duration, Instant};
4use tokio::time::sleep;
5
6/// Token bucket algorithm implementation for bandwidth limiting.
7///
8/// This implementation uses atomic operations for thread safety and
9/// provides async token acquisition with proper rate limiting.
10///
11/// The adaptive version can dynamically adjust the refill rate based on
12/// network conditions and congestion control algorithms.
13pub struct TokenBucket {
14    tokens: AtomicU64,
15    capacity: u64,
16    refill_rate: AtomicU64, // Changed to AtomicU64 for dynamic adjustment
17    last_refill: Arc<AtomicInstant>,
18    // Adaptive rate limiting fields
19    adaptive_config: Arc<AdaptiveConfig>,
20    metrics: Arc<RateMetrics>,
21    congestion_state: AtomicU8, // 0: normal, 1: congestion, 2: recovery
22}
23
24/// Configuration for adaptive rate limiting
25#[derive(Debug, Clone)]
26pub struct AdaptiveConfig {
27    /// Minimum refill rate (bytes per second)
28    pub min_rate: u64,
29    /// Maximum refill rate (bytes per second)
30    pub max_rate: u64,
31    /// Target utilization threshold (0.0 to 1.0)
32    pub target_utilization: f64,
33    /// Congestion detection threshold (0.0 to 1.0)
34    pub congestion_threshold: f64,
35    /// Recovery factor when congestion is detected (0.0 to 1.0)
36    pub recovery_factor: f64,
37    /// Growth factor when increasing rate (1.0 to 2.0)
38    pub growth_factor: f64,
39    /// Measurement window for rate adjustments (in milliseconds)
40    pub measurement_window_ms: u64,
41}
42
43impl Default for AdaptiveConfig {
44    fn default() -> Self {
45        Self {
46            min_rate: 1024,              // 1KB/s minimum
47            max_rate: 100 * 1024 * 1024, // 100MB/s maximum
48            target_utilization: 0.8,
49            congestion_threshold: 0.95,
50            recovery_factor: 0.5,
51            growth_factor: 1.1,
52            measurement_window_ms: 1000, // 1 second window
53        }
54    }
55}
56
57/// Metrics for tracking rate limiting performance
58#[derive(Debug)]
59pub struct RateMetrics {
60    /// Total bytes acquired
61    pub total_bytes: AtomicU64,
62    /// Total wait time in microseconds
63    pub total_wait_time_us: AtomicU64,
64    /// Number of acquisitions
65    pub acquisition_count: AtomicU64,
66    /// Number of times tokens were not immediately available
67    pub wait_count: AtomicU64,
68    /// Last measurement timestamp
69    pub last_measurement: AtomicU64, // Unix timestamp in milliseconds
70    /// Bytes acquired in current measurement window
71    pub window_bytes: AtomicU64,
72}
73
74impl Default for RateMetrics {
75    fn default() -> Self {
76        let now = std::time::SystemTime::now()
77            .duration_since(std::time::UNIX_EPOCH)
78            .unwrap()
79            .as_millis() as u64;
80        Self {
81            total_bytes: AtomicU64::new(0),
82            total_wait_time_us: AtomicU64::new(0),
83            acquisition_count: AtomicU64::new(0),
84            wait_count: AtomicU64::new(0),
85            last_measurement: AtomicU64::new(now),
86            window_bytes: AtomicU64::new(0),
87        }
88    }
89}
90
91impl RateMetrics {
92    pub fn record_acquisition(&self, bytes: u64, wait_time_us: u64) {
93        self.total_bytes.fetch_add(bytes, Ordering::Relaxed);
94        self.total_wait_time_us
95            .fetch_add(wait_time_us, Ordering::Relaxed);
96        self.acquisition_count.fetch_add(1, Ordering::Relaxed);
97        if wait_time_us > 0 {
98            self.wait_count.fetch_add(1, Ordering::Relaxed);
99        }
100        self.window_bytes.fetch_add(bytes, Ordering::Relaxed);
101    }
102
103    pub fn get_throughput(&self) -> f64 {
104        let total_bytes = self.total_bytes.load(Ordering::Relaxed);
105        let total_wait_us = self.total_wait_time_us.load(Ordering::Relaxed);
106        let count = self.acquisition_count.load(Ordering::Relaxed);
107
108        if count == 0 {
109            return 0.0;
110        }
111
112        // Calculate effective throughput considering wait times
113        let total_time_s = total_wait_us as f64 / 1_000_000.0;
114        if total_time_s > 0.0 {
115            total_bytes as f64 / total_time_s
116        } else {
117            0.0
118        }
119    }
120
121    pub fn get_utilization(&self, current_rate: u64) -> f64 {
122        let window_bytes = self.window_bytes.load(Ordering::Relaxed);
123        let window_duration_s = 1.0; // 1 second window
124        let expected_bytes = current_rate as f64 * window_duration_s;
125
126        if expected_bytes > 0.0 {
127            window_bytes as f64 / expected_bytes
128        } else {
129            0.0
130        }
131    }
132
133    pub fn reset_window(&self) {
134        self.window_bytes.store(0, Ordering::Relaxed);
135        let now = std::time::SystemTime::now()
136            .duration_since(std::time::UNIX_EPOCH)
137            .unwrap()
138            .as_millis() as u64;
139        self.last_measurement.store(now, Ordering::Relaxed);
140    }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq)]
144enum CongestionState {
145    Normal = 0,
146    Congestion = 1,
147    Recovery = 2,
148}
149
150/// Atomic wrapper for Instant to use in atomic operations.
151#[derive(Debug)]
152struct AtomicInstant {
153    instant: std::sync::Mutex<Instant>,
154}
155
156impl AtomicInstant {
157    fn new(instant: Instant) -> Self {
158        Self {
159            instant: std::sync::Mutex::new(instant),
160        }
161    }
162
163    fn get(&self) -> Instant {
164        *self.instant.lock().unwrap()
165    }
166
167    fn set(&self, instant: Instant) {
168        *self.instant.lock().unwrap() = instant;
169    }
170}
171
172impl TokenBucket {
173    /// Create a new TokenBucket with the specified capacity and refill rate.
174    ///
175    /// # Arguments
176    ///
177    /// * `capacity` - Maximum number of tokens the bucket can hold (in bytes)
178    /// * `refill_rate` - Rate at which tokens are refilled (in bytes per second)
179    pub fn new(capacity: u64, refill_rate: u64) -> Self {
180        let now = Instant::now();
181        Self {
182            tokens: AtomicU64::new(capacity),
183            capacity,
184            refill_rate: AtomicU64::new(refill_rate),
185            last_refill: Arc::new(AtomicInstant::new(now)),
186            adaptive_config: Arc::new(AdaptiveConfig::default()),
187            metrics: Arc::new(RateMetrics::default()),
188            congestion_state: AtomicU8::new(CongestionState::Normal as u8),
189        }
190    }
191
192    /// Create a new adaptive TokenBucket with custom configuration.
193    ///
194    /// # Arguments
195    ///
196    /// * `capacity` - Maximum number of tokens the bucket can hold (in bytes)
197    /// * `refill_rate` - Initial rate at which tokens are refilled (in bytes per second)
198    /// * `config` - Adaptive configuration for rate adjustments
199    pub fn new_adaptive(capacity: u64, refill_rate: u64, config: AdaptiveConfig) -> Self {
200        let now = Instant::now();
201        Self {
202            tokens: AtomicU64::new(capacity),
203            capacity,
204            refill_rate: AtomicU64::new(refill_rate),
205            last_refill: Arc::new(AtomicInstant::new(now)),
206            adaptive_config: Arc::new(config),
207            metrics: Arc::new(RateMetrics::default()),
208            congestion_state: AtomicU8::new(CongestionState::Normal as u8),
209        }
210    }
211
212    /// Acquire the specified number of tokens, waiting if necessary.
213    ///
214    /// This method will block until enough tokens are available.
215    ///
216    /// # Arguments
217    ///
218    /// * `bytes` - Number of tokens (bytes) to acquire
219    pub async fn acquire(&self, bytes: usize) {
220        let tokens_needed = bytes as u64;
221        let start_time = Instant::now();
222
223        loop {
224            // Refill tokens based on elapsed time
225            self.refill();
226
227            // Try to acquire tokens
228            let current_tokens = self.tokens.load(Ordering::Relaxed);
229            if current_tokens >= tokens_needed {
230                // Successfully acquire tokens
231                if self
232                    .tokens
233                    .compare_exchange_weak(
234                        current_tokens,
235                        current_tokens - tokens_needed,
236                        Ordering::Relaxed,
237                        Ordering::Relaxed,
238                    )
239                    .is_ok()
240                {
241                    let wait_time_us = start_time.elapsed().as_micros() as u64;
242                    self.metrics.record_acquisition(tokens_needed, wait_time_us);
243
244                    return;
245                }
246                // If compare_exchange failed, retry the loop
247                continue;
248            }
249
250            // Not enough tokens, calculate wait time based on deficit
251            let deficit = tokens_needed - current_tokens;
252            let current_rate = self.refill_rate.load(Ordering::Relaxed);
253
254            // Calculate exact time needed for deficit tokens
255            let wait_time = Duration::from_secs_f64(deficit as f64 / current_rate as f64);
256
257            // Wait for tokens to be available
258            sleep(wait_time).await;
259        }
260    }
261
262    /// Try to acquire tokens without waiting.
263    ///
264    /// Returns true if tokens were acquired, false otherwise.
265    ///
266    /// # Arguments
267    ///
268    /// * `bytes` - Number of tokens (bytes) to acquire
269    pub fn try_acquire(&self, bytes: usize) -> bool {
270        let tokens_needed = bytes as u64;
271
272        // Refill tokens based on elapsed time
273        self.refill();
274
275        // Try to acquire tokens
276        let current_tokens = self.tokens.load(Ordering::Relaxed);
277        if current_tokens >= tokens_needed {
278            // Successfully acquire tokens
279            if self
280                .tokens
281                .compare_exchange_weak(
282                    current_tokens,
283                    current_tokens - tokens_needed,
284                    Ordering::Relaxed,
285                    Ordering::Relaxed,
286                )
287                .is_ok()
288            {
289                return true;
290            }
291        }
292
293        false
294    }
295
296    /// Refill tokens based on elapsed time since last refill.
297    fn refill(&self) {
298        let now = Instant::now();
299        let last_refill = self.last_refill.get();
300        let elapsed = now.duration_since(last_refill);
301
302        if elapsed.as_secs_f64() > 0.0 {
303            let current_rate = self.refill_rate.load(Ordering::Relaxed);
304            let tokens_to_add = (current_rate as f64 * elapsed.as_secs_f64()) as u64;
305            let current_tokens = self.tokens.load(Ordering::Relaxed);
306            let new_tokens = (current_tokens + tokens_to_add).min(self.capacity);
307
308            self.tokens.store(new_tokens, Ordering::Relaxed);
309            self.last_refill.set(now);
310        }
311    }
312
313    /// Get the current number of tokens in the bucket.
314    pub fn available_tokens(&self) -> u64 {
315        self.refill();
316        self.tokens.load(Ordering::Relaxed)
317    }
318
319    /// Get the current refill rate.
320    pub fn current_rate(&self) -> u64 {
321        self.refill_rate.load(Ordering::Relaxed)
322    }
323
324    /// Check network conditions and adjust the refill rate accordingly.
325    pub fn check_and_adjust_rate(&self) {
326        let config = &self.adaptive_config;
327        let current_rate = self.refill_rate.load(Ordering::Relaxed);
328        let utilization = self.metrics.get_utilization(current_rate);
329
330        // Check if we should adjust the rate based on measurement window
331        let now = std::time::SystemTime::now()
332            .duration_since(std::time::UNIX_EPOCH)
333            .unwrap()
334            .as_millis() as u64;
335        let last_measurement = self.metrics.last_measurement.load(Ordering::Relaxed);
336
337        if now - last_measurement >= config.measurement_window_ms {
338            self.adjust_rate_based_on_conditions(utilization);
339            self.metrics.reset_window();
340        }
341    }
342
343    /// Adjust the rate based on current utilization and congestion state.
344    fn adjust_rate_based_on_conditions(&self, utilization: f64) {
345        let config = &self.adaptive_config;
346        let current_rate = self.refill_rate.load(Ordering::Relaxed);
347        let current_state = self.congestion_state.load(Ordering::Relaxed);
348
349        match current_state {
350            0 => {
351                // Normal state
352                if utilization > config.congestion_threshold {
353                    // Enter congestion state
354                    self.congestion_state
355                        .store(CongestionState::Congestion as u8, Ordering::Relaxed);
356                    let new_rate = (current_rate as f64 * config.recovery_factor)
357                        .max(config.min_rate as f64) as u64;
358                    self.refill_rate.store(new_rate, Ordering::Relaxed);
359                } else if utilization < config.target_utilization {
360                    // Increase rate gradually
361                    let new_rate = (current_rate as f64 * config.growth_factor)
362                        .min(config.max_rate as f64) as u64;
363                    self.refill_rate.store(new_rate, Ordering::Relaxed);
364                }
365            }
366            1 => {
367                // Congestion state
368                if utilization < config.target_utilization {
369                    // Enter recovery state
370                    self.congestion_state
371                        .store(CongestionState::Recovery as u8, Ordering::Relaxed);
372                } else {
373                    // Stay in congestion, further reduce rate
374                    let new_rate = (current_rate as f64 * config.recovery_factor)
375                        .max(config.min_rate as f64) as u64;
376                    self.refill_rate.store(new_rate, Ordering::Relaxed);
377                }
378            }
379            2 => {
380                // Recovery state
381                if utilization < config.congestion_threshold {
382                    // Back to normal state
383                    self.congestion_state
384                        .store(CongestionState::Normal as u8, Ordering::Relaxed);
385                    let new_rate = (current_rate as f64 * config.growth_factor)
386                        .min(config.max_rate as f64) as u64;
387                    self.refill_rate.store(new_rate, Ordering::Relaxed);
388                } else {
389                    // Still congested, go back to congestion state
390                    self.congestion_state
391                        .store(CongestionState::Congestion as u8, Ordering::Relaxed);
392                }
393            }
394            _ => {}
395        }
396    }
397
398    /// Get current metrics for monitoring.
399    pub fn get_metrics(&self) -> &RateMetrics {
400        &self.metrics
401    }
402
403    /// Force a rate adjustment (useful for testing or manual control).
404    pub fn set_rate(&self, new_rate: u64) {
405        let config = &self.adaptive_config;
406        let clamped_rate = new_rate.clamp(config.min_rate, config.max_rate);
407        self.refill_rate.store(clamped_rate, Ordering::Relaxed);
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use tokio::time::{Duration, sleep};
415
416    #[tokio::test]
417    async fn test_token_bucket_basic() {
418        // Create a bucket with 100 bytes capacity and 50 bytes/second refill rate
419        let bucket = TokenBucket::new(100, 50);
420
421        // Should be able to acquire 50 bytes immediately
422        bucket.acquire(50).await;
423        assert!(bucket.available_tokens() <= 50);
424
425        // Should be able to acquire another 50 bytes immediately
426        bucket.acquire(50).await;
427        assert_eq!(bucket.available_tokens(), 0);
428
429        // Wait a bit to ensure no immediate refill
430        tokio::time::sleep(Duration::from_millis(10)).await;
431
432        // Acquiring more should require waiting
433        let start = Instant::now();
434        bucket.acquire(25).await;
435        let elapsed = start.elapsed();
436
437        // Should have waited at least 0.5 seconds (25 bytes at 50 bytes/second)
438        assert!(elapsed >= Duration::from_millis(450));
439        assert!(elapsed <= Duration::from_millis(550));
440    }
441
442    #[tokio::test]
443    async fn test_token_bucket_refill() {
444        let bucket = TokenBucket::new(100, 100);
445
446        // Acquire all tokens
447        bucket.acquire(100).await;
448        assert_eq!(bucket.available_tokens(), 0);
449
450        // Wait for refill
451        sleep(Duration::from_millis(100)).await;
452
453        // Should have some tokens available
454        let available = bucket.available_tokens();
455        assert!(available > 5); // Should have at least 10 bytes (100 bytes/s * 0.1s)
456        assert!(available <= 15); // Allow some tolerance
457    }
458
459    #[tokio::test]
460    async fn test_token_bucket_concurrent() {
461        let bucket = Arc::new(TokenBucket::new(1000, 100));
462        let mut handles = vec![];
463
464        // Spawn 10 concurrent tasks each trying to acquire 100 bytes
465        for _ in 0..10 {
466            let bucket_clone = Arc::clone(&bucket);
467            let handle = tokio::spawn(async move {
468                bucket_clone.acquire(100).await;
469            });
470            handles.push(handle);
471        }
472
473        // Wait for all tasks to complete
474        for handle in handles {
475            handle.await.unwrap();
476        }
477
478        // All tokens should be consumed
479        assert_eq!(bucket.available_tokens(), 0);
480    }
481
482    #[tokio::test]
483    async fn test_adaptive_rate_limiting() {
484        let config = AdaptiveConfig {
485            min_rate: 10,
486            max_rate: 1000,
487            target_utilization: 0.8,
488            congestion_threshold: 0.9,
489            recovery_factor: 0.5,
490            growth_factor: 1.2,
491            measurement_window_ms: 100,
492        };
493
494        let bucket = TokenBucket::new_adaptive(100, 100, config);
495
496        // Initially should have the configured rate
497        assert_eq!(bucket.current_rate(), 100);
498
499        // Force a rate adjustment
500        bucket.set_rate(500);
501        assert_eq!(bucket.current_rate(), 500);
502
503        // Test that rate is clamped to max
504        bucket.set_rate(2000);
505        assert_eq!(bucket.current_rate(), 1000);
506
507        // Test that rate is clamped to min
508        bucket.set_rate(5);
509        assert_eq!(bucket.current_rate(), 10);
510    }
511
512    #[tokio::test]
513    async fn test_congestion_detection() {
514        let config = AdaptiveConfig {
515            min_rate: 10,
516            max_rate: 1000,
517            target_utilization: 0.5,
518            congestion_threshold: 0.8,
519            recovery_factor: 0.5,
520            growth_factor: 1.2,
521            measurement_window_ms: 50,
522        };
523
524        let bucket = TokenBucket::new_adaptive(1000, 100, config);
525
526        // Simulate high utilization by acquiring many tokens quickly
527        for _ in 0..20 {
528            bucket.acquire(10).await;
529        }
530
531        // Wait for measurement window
532        sleep(Duration::from_millis(60)).await;
533
534        // Manually trigger rate adjustment
535        bucket.check_and_adjust_rate();
536
537        // Acquire more tokens
538        bucket.acquire(10).await;
539
540        // Rate should have been adjusted due to congestion
541        assert!(bucket.current_rate() < 100);
542    }
543
544    #[tokio::test]
545    async fn test_metrics_collection() {
546        let bucket = TokenBucket::new(100, 100);
547        let metrics = bucket.get_metrics();
548
549        // Initially no metrics
550        assert_eq!(metrics.total_bytes.load(Ordering::Relaxed), 0);
551        assert_eq!(metrics.acquisition_count.load(Ordering::Relaxed), 0);
552
553        // Acquire tokens to update metrics
554        bucket.acquire(50).await;
555
556        // Metrics should now show the acquisition
557        assert!(metrics.total_bytes.load(Ordering::Relaxed) > 0);
558        assert!(metrics.acquisition_count.load(Ordering::Relaxed) > 0);
559    }
560}