Skip to main content

cloud_lite_core_rs/
rate_limit.rs

1//! Rate limiting configuration and runtime for API requests.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6
7/// Configuration for per-API concurrency limiting.
8///
9/// Provides the building blocks for concurrency control. Provider crates
10/// supply their own defaults (e.g. GCP API quotas).
11///
12/// # Example
13///
14/// ```rust
15/// use cloud_lite_core::rate_limit::RateLimitConfig;
16///
17/// let config = RateLimitConfig::new(20)
18///     .with_api_limit("api.example.com", 10);
19///
20/// let disabled = RateLimitConfig::disabled();
21/// ```
22#[derive(Debug, Clone)]
23pub struct RateLimitConfig {
24    /// Concurrency limit for APIs not in `api_limits`.
25    pub default_limit: usize,
26    /// Per-API concurrency limits keyed by host (e.g. "compute.googleapis.com").
27    pub api_limits: HashMap<String, usize>,
28}
29
30impl RateLimitConfig {
31    /// Create a new config with the given default concurrency limit.
32    pub fn new(default_limit: usize) -> Self {
33        Self {
34            default_limit,
35            api_limits: HashMap::new(),
36        }
37    }
38
39    /// Create a config that effectively disables rate limiting.
40    pub fn disabled() -> Self {
41        Self {
42            default_limit: usize::MAX,
43            api_limits: HashMap::new(),
44        }
45    }
46
47    /// Override the default concurrency limit for unknown APIs.
48    pub fn with_default_limit(mut self, limit: usize) -> Self {
49        self.default_limit = limit;
50        self
51    }
52
53    /// Set or override the concurrency limit for a specific API host.
54    pub fn with_api_limit(mut self, host: &str, limit: usize) -> Self {
55        self.api_limits.insert(host.to_string(), limit);
56        self
57    }
58}
59
60/// Snapshot of rate limiting state for a single API.
61#[derive(Debug, Clone)]
62pub struct RateLimitStats {
63    /// API host name, or "default" for the fallback semaphore.
64    pub api: String,
65    /// Configured concurrency limit.
66    pub limit: usize,
67    /// Permits currently available.
68    pub available: usize,
69    /// Requests currently in flight (`limit - available`).
70    pub in_flight: usize,
71}
72
73/// Extract the host from a URL (e.g. "https://compute.googleapis.com/v1/..." -> "compute.googleapis.com").
74fn extract_host(url: &str) -> Option<&str> {
75    let after_scheme = url
76        .strip_prefix("https://")
77        .or_else(|| url.strip_prefix("http://"))?;
78    Some(after_scheme.split('/').next().unwrap_or(after_scheme))
79}
80
81/// Maximum permits tokio allows on a semaphore (`usize::MAX >> 3`).
82const MAX_SEMAPHORE_PERMITS: usize = Semaphore::MAX_PERMITS;
83
84/// Semaphore-based per-API concurrency limiter.
85pub struct RateLimiter {
86    default_limit: usize,
87    default_semaphore: Arc<Semaphore>,
88    api_limits: HashMap<String, usize>,
89    api_semaphores: HashMap<String, Arc<Semaphore>>,
90}
91
92impl RateLimiter {
93    /// Create a new rate limiter from the given configuration.
94    pub fn new(config: RateLimitConfig) -> Self {
95        let capped_default = config.default_limit.min(MAX_SEMAPHORE_PERMITS);
96        let default_semaphore = Arc::new(Semaphore::new(capped_default));
97        let api_semaphores = config
98            .api_limits
99            .iter()
100            .map(|(host, &limit)| {
101                (
102                    host.clone(),
103                    Arc::new(Semaphore::new(limit.min(MAX_SEMAPHORE_PERMITS))),
104                )
105            })
106            .collect();
107        let api_limits: HashMap<String, usize> = config
108            .api_limits
109            .into_iter()
110            .map(|(host, limit)| (host, limit.min(MAX_SEMAPHORE_PERMITS)))
111            .collect();
112        Self {
113            default_limit: capped_default,
114            default_semaphore,
115            api_limits,
116            api_semaphores,
117        }
118    }
119
120    /// Acquire a permit for the given URL, blocking until one is available.
121    pub async fn acquire(&self, url: &str) -> OwnedSemaphorePermit {
122        let semaphore = self.semaphore_for(url);
123        semaphore
124            .acquire_owned()
125            .await
126            .expect("rate limiter semaphore closed unexpectedly")
127    }
128
129    fn semaphore_for(&self, url: &str) -> Arc<Semaphore> {
130        if let Some(host) = extract_host(url)
131            && let Some(sem) = self.api_semaphores.get(host)
132        {
133            return Arc::clone(sem);
134        }
135        Arc::clone(&self.default_semaphore)
136    }
137
138    /// Get a snapshot of current rate limiting state.
139    pub fn stats(&self) -> Vec<RateLimitStats> {
140        let mut result = Vec::with_capacity(self.api_semaphores.len() + 1);
141
142        // Default
143        let available = self.default_semaphore.available_permits();
144        result.push(RateLimitStats {
145            api: "default".into(),
146            limit: self.default_limit,
147            available,
148            in_flight: self.default_limit.saturating_sub(available),
149        });
150
151        // Per-API
152        for (host, sem) in &self.api_semaphores {
153            let limit = self.api_limits[host];
154            let available = sem.available_permits();
155            result.push(RateLimitStats {
156                api: host.clone(),
157                limit,
158                available,
159                in_flight: limit.saturating_sub(available),
160            });
161        }
162
163        result
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn new_config_has_given_default() {
173        let config = RateLimitConfig::new(20);
174        assert_eq!(config.default_limit, 20);
175        assert!(config.api_limits.is_empty());
176    }
177
178    #[test]
179    fn disabled_config_uses_usize_max() {
180        let config = RateLimitConfig::disabled();
181        assert_eq!(config.default_limit, usize::MAX);
182        assert!(config.api_limits.is_empty());
183    }
184
185    #[test]
186    fn with_default_limit_overrides() {
187        let config = RateLimitConfig::new(20).with_default_limit(30);
188        assert_eq!(config.default_limit, 30);
189    }
190
191    #[test]
192    fn with_api_limit_adds_entry() {
193        let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
194        assert_eq!(config.api_limits.get("test.example.com"), Some(&5));
195        assert_eq!(config.default_limit, 20);
196    }
197
198    #[test]
199    fn extract_host_from_standard_url() {
200        assert_eq!(
201            extract_host("https://compute.googleapis.com/compute/v1/projects/foo"),
202            Some("compute.googleapis.com")
203        );
204    }
205
206    #[test]
207    fn extract_host_returns_none_for_garbage() {
208        assert_eq!(extract_host("not-a-url"), None);
209    }
210
211    #[test]
212    fn rate_limiter_uses_api_specific_semaphore() {
213        let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
214        let limiter = RateLimiter::new(config);
215        let stats = limiter.stats();
216        let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
217        assert_eq!(test_api.limit, 5);
218        assert_eq!(test_api.available, 5);
219        assert_eq!(test_api.in_flight, 0);
220    }
221
222    #[test]
223    fn rate_limiter_default_semaphore_in_stats() {
224        let config = RateLimitConfig::new(20);
225        let limiter = RateLimiter::new(config);
226        let stats = limiter.stats();
227        let default = stats.iter().find(|s| s.api == "default").unwrap();
228        assert_eq!(default.limit, 20);
229        assert_eq!(default.available, 20);
230    }
231
232    #[tokio::test]
233    async fn acquire_uses_correct_semaphore() {
234        let config = RateLimitConfig::new(100).with_api_limit("compute.googleapis.com", 2);
235        let limiter = RateLimiter::new(config);
236
237        let _p1 = limiter
238            .acquire("https://compute.googleapis.com/v1/foo")
239            .await;
240        let _p2 = limiter
241            .acquire("https://compute.googleapis.com/v1/bar")
242            .await;
243
244        let stats = limiter.stats();
245        let compute = stats
246            .iter()
247            .find(|s| s.api == "compute.googleapis.com")
248            .unwrap();
249        assert_eq!(compute.in_flight, 2);
250        assert_eq!(compute.available, 0);
251
252        let default = stats.iter().find(|s| s.api == "default").unwrap();
253        assert_eq!(default.in_flight, 0);
254    }
255
256    #[tokio::test]
257    async fn acquire_falls_back_to_default() {
258        let config = RateLimitConfig::new(3);
259        let limiter = RateLimiter::new(config);
260
261        let _p = limiter
262            .acquire("https://unknown.googleapis.com/v1/foo")
263            .await;
264
265        let stats = limiter.stats();
266        let default = stats.iter().find(|s| s.api == "default").unwrap();
267        assert_eq!(default.in_flight, 1);
268    }
269
270    #[tokio::test]
271    async fn permit_released_on_drop() {
272        let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 1);
273        let limiter = RateLimiter::new(config);
274
275        {
276            let _permit = limiter.acquire("https://test.example.com/v1/foo").await;
277            let stats = limiter.stats();
278            let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
279            assert_eq!(test_api.in_flight, 1);
280        }
281        // Permit dropped
282
283        let stats = limiter.stats();
284        let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
285        assert_eq!(test_api.in_flight, 0);
286    }
287}