cloud_disk_sync/core/
rate_limit.rs1use super::traits::RateLimiter;
2use crate::error::Result;
3use async_trait::async_trait;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7use tokio::sync::Semaphore;
8
9pub struct TokenBucketRateLimiter {
11 capacity: u64,
12 tokens: AtomicU64,
13 refill_rate: f64, last_refill: parking_lot::Mutex<Instant>,
15 semaphore: Arc<Semaphore>,
16}
17
18#[cfg(test)]
19mod tests {
20 use super::{SlidingWindowRateLimiter, TokenBucketRateLimiter};
21 use crate::core::traits::RateLimiter;
22 use std::time::Duration;
23
24 #[tokio::test]
25 async fn test_token_bucket_acquire() {
26 let limiter = TokenBucketRateLimiter::new(2, 10.0);
27 assert!(limiter.try_acquire());
28 assert!(limiter.try_acquire());
29 assert!(!limiter.try_acquire());
30 limiter.acquire().await.unwrap();
31 }
32
33 #[tokio::test]
34 async fn test_sliding_window_acquire() {
35 let limiter = SlidingWindowRateLimiter::new(Duration::from_millis(100), 1);
36 assert!(limiter.try_acquire());
37 assert!(!limiter.try_acquire());
38 limiter.acquire().await.unwrap();
39 }
40}
41impl TokenBucketRateLimiter {
42 pub fn new(capacity: u64, requests_per_second: f64) -> Self {
43 Self {
44 capacity,
45 tokens: AtomicU64::new(capacity),
46 refill_rate: requests_per_second,
47 last_refill: parking_lot::Mutex::new(Instant::now()),
48 semaphore: Arc::new(Semaphore::new(capacity as usize)),
49 }
50 }
51
52 fn refill_tokens(&self) {
53 let mut last_refill = self.last_refill.lock();
54 let now = Instant::now();
55 let elapsed = now.duration_since(*last_refill);
56
57 if elapsed.as_secs_f64() > 0.0 {
58 let new_tokens = (elapsed.as_secs_f64() * self.refill_rate) as u64;
59 if new_tokens > 0 {
60 let current = self.tokens.load(Ordering::Relaxed);
61 let new_total = (current + new_tokens).min(self.capacity);
62 self.tokens.store(new_total, Ordering::Relaxed);
63 *last_refill = now;
64 }
65 }
66 }
67}
68
69#[async_trait]
70impl RateLimiter for TokenBucketRateLimiter {
71 async fn acquire(&self) -> Result<()> {
72 self.refill_tokens();
73
74 loop {
75 let current = self.tokens.load(Ordering::Relaxed);
76 if current == 0 {
77 tokio::time::sleep(Duration::from_secs_f64(1.0 / self.refill_rate)).await;
78 self.refill_tokens();
79 continue;
80 }
81
82 if self
83 .tokens
84 .compare_exchange(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
85 .is_ok()
86 {
87 break;
88 }
89 }
90
91 Ok(())
92 }
93
94 fn current_rate(&self) -> f64 {
95 self.refill_rate
96 }
97
98 fn set_rate(&mut self, requests_per_second: f64) {
99 self.refill_rate = requests_per_second;
100 }
101
102 fn try_acquire(&self) -> bool {
103 self.refill_tokens();
104
105 let current = self.tokens.load(Ordering::Relaxed);
106 if current == 0 {
107 return false;
108 }
109
110 self.tokens
111 .compare_exchange(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
112 .is_ok()
113 }
114}
115
116pub struct SlidingWindowRateLimiter {
118 pub(crate) window_size: Duration,
119 pub(crate) max_requests: u64,
120 pub(crate) requests: Mutex<Vec<Instant>>,
121}
122
123impl SlidingWindowRateLimiter {
124 pub fn new(window_size: Duration, max_requests: u64) -> Self {
125 Self {
126 window_size,
127 max_requests,
128 requests: Mutex::new(Vec::new()),
129 }
130 }
131
132 fn cleanup_old_requests(&self) {
133 let mut requests = self.requests.lock().unwrap();
134 let cutoff = Instant::now() - self.window_size;
135 requests.retain(|&time| time > cutoff);
136 }
137}
138
139#[async_trait]
140impl RateLimiter for SlidingWindowRateLimiter {
141 async fn acquire<'a>(&'a self) -> Result<()>
142 where
143 Self: 'a,
144 {
145 loop {
146 self.cleanup_old_requests();
147 let wait_time_opt = {
148 let requests = self.requests.lock().unwrap();
149 if requests.len() < self.max_requests as usize {
150 None
151 } else {
152 let oldest = *requests.first().unwrap();
153 Some(self.window_size - oldest.elapsed())
154 }
155 };
156 if let Some(wait_time) = wait_time_opt {
157 if wait_time > Duration::ZERO {
158 tokio::time::sleep(wait_time).await;
159 continue;
160 }
161 }
162 let mut requests = self.requests.lock().unwrap();
163 requests.push(Instant::now());
164 return Ok(());
165 }
166 }
167
168 fn current_rate(&self) -> f64 {
169 self.cleanup_old_requests();
170 let requests = self.requests.lock().unwrap();
171 requests.len() as f64 / self.window_size.as_secs_f64()
172 }
173
174 fn set_rate(&mut self, requests_per_second: f64) {
175 }
178
179 fn try_acquire(&self) -> bool {
180 self.cleanup_old_requests();
181 let mut requests = self.requests.lock().unwrap();
182 if requests.len() < self.max_requests as usize {
183 requests.push(Instant::now());
184 true
185 } else {
186 false
187 }
188 }
189}