cloud_lite_core_rs/
rate_limit.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6
7#[derive(Debug, Clone)]
23pub struct RateLimitConfig {
24 pub default_limit: usize,
26 pub api_limits: HashMap<String, usize>,
28}
29
30impl RateLimitConfig {
31 pub fn new(default_limit: usize) -> Self {
33 Self {
34 default_limit,
35 api_limits: HashMap::new(),
36 }
37 }
38
39 pub fn disabled() -> Self {
41 Self {
42 default_limit: usize::MAX,
43 api_limits: HashMap::new(),
44 }
45 }
46
47 pub fn with_default_limit(mut self, limit: usize) -> Self {
49 self.default_limit = limit;
50 self
51 }
52
53 pub fn with_api_limit(mut self, host: &str, limit: usize) -> Self {
55 self.api_limits.insert(host.to_string(), limit);
56 self
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct RateLimitStats {
63 pub api: String,
65 pub limit: usize,
67 pub available: usize,
69 pub in_flight: usize,
71}
72
73fn extract_host(url: &str) -> Option<&str> {
75 let after_scheme = url
76 .strip_prefix("https://")
77 .or_else(|| url.strip_prefix("http://"))?;
78 Some(after_scheme.split('/').next().unwrap_or(after_scheme))
79}
80
81const MAX_SEMAPHORE_PERMITS: usize = Semaphore::MAX_PERMITS;
83
84pub struct RateLimiter {
86 default_limit: usize,
87 default_semaphore: Arc<Semaphore>,
88 api_limits: HashMap<String, usize>,
89 api_semaphores: HashMap<String, Arc<Semaphore>>,
90}
91
92impl RateLimiter {
93 pub fn new(config: RateLimitConfig) -> Self {
95 let capped_default = config.default_limit.min(MAX_SEMAPHORE_PERMITS);
96 let default_semaphore = Arc::new(Semaphore::new(capped_default));
97 let api_semaphores = config
98 .api_limits
99 .iter()
100 .map(|(host, &limit)| {
101 (
102 host.clone(),
103 Arc::new(Semaphore::new(limit.min(MAX_SEMAPHORE_PERMITS))),
104 )
105 })
106 .collect();
107 let api_limits: HashMap<String, usize> = config
108 .api_limits
109 .into_iter()
110 .map(|(host, limit)| (host, limit.min(MAX_SEMAPHORE_PERMITS)))
111 .collect();
112 Self {
113 default_limit: capped_default,
114 default_semaphore,
115 api_limits,
116 api_semaphores,
117 }
118 }
119
120 pub async fn acquire(&self, url: &str) -> OwnedSemaphorePermit {
122 let semaphore = self.semaphore_for(url);
123 semaphore
124 .acquire_owned()
125 .await
126 .expect("rate limiter semaphore closed unexpectedly")
127 }
128
129 fn semaphore_for(&self, url: &str) -> Arc<Semaphore> {
130 if let Some(host) = extract_host(url)
131 && let Some(sem) = self.api_semaphores.get(host)
132 {
133 return Arc::clone(sem);
134 }
135 Arc::clone(&self.default_semaphore)
136 }
137
138 pub fn stats(&self) -> Vec<RateLimitStats> {
140 let mut result = Vec::with_capacity(self.api_semaphores.len() + 1);
141
142 let available = self.default_semaphore.available_permits();
144 result.push(RateLimitStats {
145 api: "default".into(),
146 limit: self.default_limit,
147 available,
148 in_flight: self.default_limit.saturating_sub(available),
149 });
150
151 for (host, sem) in &self.api_semaphores {
153 let limit = self.api_limits[host];
154 let available = sem.available_permits();
155 result.push(RateLimitStats {
156 api: host.clone(),
157 limit,
158 available,
159 in_flight: limit.saturating_sub(available),
160 });
161 }
162
163 result
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn new_config_has_given_default() {
173 let config = RateLimitConfig::new(20);
174 assert_eq!(config.default_limit, 20);
175 assert!(config.api_limits.is_empty());
176 }
177
178 #[test]
179 fn disabled_config_uses_usize_max() {
180 let config = RateLimitConfig::disabled();
181 assert_eq!(config.default_limit, usize::MAX);
182 assert!(config.api_limits.is_empty());
183 }
184
185 #[test]
186 fn with_default_limit_overrides() {
187 let config = RateLimitConfig::new(20).with_default_limit(30);
188 assert_eq!(config.default_limit, 30);
189 }
190
191 #[test]
192 fn with_api_limit_adds_entry() {
193 let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
194 assert_eq!(config.api_limits.get("test.example.com"), Some(&5));
195 assert_eq!(config.default_limit, 20);
196 }
197
198 #[test]
199 fn extract_host_from_standard_url() {
200 assert_eq!(
201 extract_host("https://compute.googleapis.com/compute/v1/projects/foo"),
202 Some("compute.googleapis.com")
203 );
204 }
205
206 #[test]
207 fn extract_host_returns_none_for_garbage() {
208 assert_eq!(extract_host("not-a-url"), None);
209 }
210
211 #[test]
212 fn rate_limiter_uses_api_specific_semaphore() {
213 let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
214 let limiter = RateLimiter::new(config);
215 let stats = limiter.stats();
216 let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
217 assert_eq!(test_api.limit, 5);
218 assert_eq!(test_api.available, 5);
219 assert_eq!(test_api.in_flight, 0);
220 }
221
222 #[test]
223 fn rate_limiter_default_semaphore_in_stats() {
224 let config = RateLimitConfig::new(20);
225 let limiter = RateLimiter::new(config);
226 let stats = limiter.stats();
227 let default = stats.iter().find(|s| s.api == "default").unwrap();
228 assert_eq!(default.limit, 20);
229 assert_eq!(default.available, 20);
230 }
231
232 #[tokio::test]
233 async fn acquire_uses_correct_semaphore() {
234 let config = RateLimitConfig::new(100).with_api_limit("compute.googleapis.com", 2);
235 let limiter = RateLimiter::new(config);
236
237 let _p1 = limiter
238 .acquire("https://compute.googleapis.com/v1/foo")
239 .await;
240 let _p2 = limiter
241 .acquire("https://compute.googleapis.com/v1/bar")
242 .await;
243
244 let stats = limiter.stats();
245 let compute = stats
246 .iter()
247 .find(|s| s.api == "compute.googleapis.com")
248 .unwrap();
249 assert_eq!(compute.in_flight, 2);
250 assert_eq!(compute.available, 0);
251
252 let default = stats.iter().find(|s| s.api == "default").unwrap();
253 assert_eq!(default.in_flight, 0);
254 }
255
256 #[tokio::test]
257 async fn acquire_falls_back_to_default() {
258 let config = RateLimitConfig::new(3);
259 let limiter = RateLimiter::new(config);
260
261 let _p = limiter
262 .acquire("https://unknown.googleapis.com/v1/foo")
263 .await;
264
265 let stats = limiter.stats();
266 let default = stats.iter().find(|s| s.api == "default").unwrap();
267 assert_eq!(default.in_flight, 1);
268 }
269
270 #[tokio::test]
271 async fn permit_released_on_drop() {
272 let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 1);
273 let limiter = RateLimiter::new(config);
274
275 {
276 let _permit = limiter.acquire("https://test.example.com/v1/foo").await;
277 let stats = limiter.stats();
278 let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
279 assert_eq!(test_api.in_flight, 1);
280 }
281 let stats = limiter.stats();
284 let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
285 assert_eq!(test_api.in_flight, 0);
286 }
287}