cloud_disk_sync/core/
rate_limit.rs

1use super::traits::RateLimiter;
2use crate::error::Result;
3use async_trait::async_trait;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7use tokio::sync::Semaphore;
8
9/// 令牌桶算法实现
10pub struct TokenBucketRateLimiter {
11    capacity: u64,
12    tokens: AtomicU64,
13    refill_rate: f64, // tokens per second
14    last_refill: parking_lot::Mutex<Instant>,
15    semaphore: Arc<Semaphore>,
16}
17
18#[cfg(test)]
19mod tests {
20    use super::{SlidingWindowRateLimiter, TokenBucketRateLimiter};
21    use crate::core::traits::RateLimiter;
22    use std::time::Duration;
23
24    #[tokio::test]
25    async fn test_token_bucket_acquire() {
26        let limiter = TokenBucketRateLimiter::new(2, 10.0);
27        assert!(limiter.try_acquire());
28        assert!(limiter.try_acquire());
29        assert!(!limiter.try_acquire());
30        limiter.acquire().await.unwrap();
31    }
32
33    #[tokio::test]
34    async fn test_sliding_window_acquire() {
35        let limiter = SlidingWindowRateLimiter::new(Duration::from_millis(100), 1);
36        assert!(limiter.try_acquire());
37        assert!(!limiter.try_acquire());
38        limiter.acquire().await.unwrap();
39    }
40}
41impl TokenBucketRateLimiter {
42    pub fn new(capacity: u64, requests_per_second: f64) -> Self {
43        Self {
44            capacity,
45            tokens: AtomicU64::new(capacity),
46            refill_rate: requests_per_second,
47            last_refill: parking_lot::Mutex::new(Instant::now()),
48            semaphore: Arc::new(Semaphore::new(capacity as usize)),
49        }
50    }
51
52    fn refill_tokens(&self) {
53        let mut last_refill = self.last_refill.lock();
54        let now = Instant::now();
55        let elapsed = now.duration_since(*last_refill);
56
57        if elapsed.as_secs_f64() > 0.0 {
58            let new_tokens = (elapsed.as_secs_f64() * self.refill_rate) as u64;
59            if new_tokens > 0 {
60                let current = self.tokens.load(Ordering::Relaxed);
61                let new_total = (current + new_tokens).min(self.capacity);
62                self.tokens.store(new_total, Ordering::Relaxed);
63                *last_refill = now;
64            }
65        }
66    }
67}
68
69#[async_trait]
70impl RateLimiter for TokenBucketRateLimiter {
71    async fn acquire(&self) -> Result<()> {
72        self.refill_tokens();
73
74        loop {
75            let current = self.tokens.load(Ordering::Relaxed);
76            if current == 0 {
77                tokio::time::sleep(Duration::from_secs_f64(1.0 / self.refill_rate)).await;
78                self.refill_tokens();
79                continue;
80            }
81
82            if self
83                .tokens
84                .compare_exchange(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
85                .is_ok()
86            {
87                break;
88            }
89        }
90
91        Ok(())
92    }
93
94    fn current_rate(&self) -> f64 {
95        self.refill_rate
96    }
97
98    fn set_rate(&mut self, requests_per_second: f64) {
99        self.refill_rate = requests_per_second;
100    }
101
102    fn try_acquire(&self) -> bool {
103        self.refill_tokens();
104
105        let current = self.tokens.load(Ordering::Relaxed);
106        if current == 0 {
107            return false;
108        }
109
110        self.tokens
111            .compare_exchange(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
112            .is_ok()
113    }
114}
115
116/// 滑动窗口限流器
117pub struct SlidingWindowRateLimiter {
118    pub(crate) window_size: Duration,
119    pub(crate) max_requests: u64,
120    pub(crate) requests: Mutex<Vec<Instant>>,
121}
122
123impl SlidingWindowRateLimiter {
124    pub fn new(window_size: Duration, max_requests: u64) -> Self {
125        Self {
126            window_size,
127            max_requests,
128            requests: Mutex::new(Vec::new()),
129        }
130    }
131
132    fn cleanup_old_requests(&self) {
133        let mut requests = self.requests.lock().unwrap();
134        let cutoff = Instant::now() - self.window_size;
135        requests.retain(|&time| time > cutoff);
136    }
137}
138
139#[async_trait]
140impl RateLimiter for SlidingWindowRateLimiter {
141    async fn acquire<'a>(&'a self) -> Result<()>
142    where
143        Self: 'a,
144    {
145        loop {
146            self.cleanup_old_requests();
147            let wait_time_opt = {
148                let requests = self.requests.lock().unwrap();
149                if requests.len() < self.max_requests as usize {
150                    None
151                } else {
152                    let oldest = *requests.first().unwrap();
153                    Some(self.window_size - oldest.elapsed())
154                }
155            };
156            if let Some(wait_time) = wait_time_opt {
157                if wait_time > Duration::ZERO {
158                    tokio::time::sleep(wait_time).await;
159                    continue;
160                }
161            }
162            let mut requests = self.requests.lock().unwrap();
163            requests.push(Instant::now());
164            return Ok(());
165        }
166    }
167
168    fn current_rate(&self) -> f64 {
169        self.cleanup_old_requests();
170        let requests = self.requests.lock().unwrap();
171        requests.len() as f64 / self.window_size.as_secs_f64()
172    }
173
174    fn set_rate(&mut self, requests_per_second: f64) {
175        // 调整窗口大小或最大请求数
176        // 这里简化处理
177    }
178
179    fn try_acquire(&self) -> bool {
180        self.cleanup_old_requests();
181        let mut requests = self.requests.lock().unwrap();
182        if requests.len() < self.max_requests as usize {
183            requests.push(Instant::now());
184            true
185        } else {
186            false
187        }
188    }
189}