use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub default_limit: usize,
pub api_limits: HashMap<String, usize>,
}
impl RateLimitConfig {
pub fn new(default_limit: usize) -> Self {
Self {
default_limit,
api_limits: HashMap::new(),
}
}
pub fn disabled() -> Self {
Self {
default_limit: usize::MAX,
api_limits: HashMap::new(),
}
}
pub fn with_default_limit(mut self, limit: usize) -> Self {
self.default_limit = limit;
self
}
pub fn with_api_limit(mut self, host: &str, limit: usize) -> Self {
self.api_limits.insert(host.to_string(), limit);
self
}
}
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub api: String,
pub limit: usize,
pub available: usize,
pub in_flight: usize,
}
fn extract_host(url: &str) -> Option<&str> {
let after_scheme = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))?;
Some(after_scheme.split('/').next().unwrap_or(after_scheme))
}
const MAX_SEMAPHORE_PERMITS: usize = Semaphore::MAX_PERMITS;
pub struct RateLimiter {
default_limit: usize,
default_semaphore: Arc<Semaphore>,
api_limits: HashMap<String, usize>,
api_semaphores: HashMap<String, Arc<Semaphore>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let capped_default = config.default_limit.min(MAX_SEMAPHORE_PERMITS);
let default_semaphore = Arc::new(Semaphore::new(capped_default));
let api_semaphores = config
.api_limits
.iter()
.map(|(host, &limit)| {
(
host.clone(),
Arc::new(Semaphore::new(limit.min(MAX_SEMAPHORE_PERMITS))),
)
})
.collect();
let api_limits: HashMap<String, usize> = config
.api_limits
.into_iter()
.map(|(host, limit)| (host, limit.min(MAX_SEMAPHORE_PERMITS)))
.collect();
Self {
default_limit: capped_default,
default_semaphore,
api_limits,
api_semaphores,
}
}
pub async fn acquire(&self, url: &str) -> OwnedSemaphorePermit {
let semaphore = self.semaphore_for(url);
semaphore
.acquire_owned()
.await
.expect("rate limiter semaphore closed unexpectedly")
}
fn semaphore_for(&self, url: &str) -> Arc<Semaphore> {
if let Some(host) = extract_host(url)
&& let Some(sem) = self.api_semaphores.get(host)
{
return Arc::clone(sem);
}
Arc::clone(&self.default_semaphore)
}
pub fn stats(&self) -> Vec<RateLimitStats> {
let mut result = Vec::with_capacity(self.api_semaphores.len() + 1);
let available = self.default_semaphore.available_permits();
result.push(RateLimitStats {
api: "default".into(),
limit: self.default_limit,
available,
in_flight: self.default_limit.saturating_sub(available),
});
for (host, sem) in &self.api_semaphores {
let limit = self.api_limits[host];
let available = sem.available_permits();
result.push(RateLimitStats {
api: host.clone(),
limit,
available,
in_flight: limit.saturating_sub(available),
});
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_config_has_given_default() {
let config = RateLimitConfig::new(20);
assert_eq!(config.default_limit, 20);
assert!(config.api_limits.is_empty());
}
#[test]
fn disabled_config_uses_usize_max() {
let config = RateLimitConfig::disabled();
assert_eq!(config.default_limit, usize::MAX);
assert!(config.api_limits.is_empty());
}
#[test]
fn with_default_limit_overrides() {
let config = RateLimitConfig::new(20).with_default_limit(30);
assert_eq!(config.default_limit, 30);
}
#[test]
fn with_api_limit_adds_entry() {
let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
assert_eq!(config.api_limits.get("test.example.com"), Some(&5));
assert_eq!(config.default_limit, 20);
}
#[test]
fn extract_host_from_standard_url() {
assert_eq!(
extract_host("https://compute.googleapis.com/compute/v1/projects/foo"),
Some("compute.googleapis.com")
);
}
#[test]
fn extract_host_returns_none_for_garbage() {
assert_eq!(extract_host("not-a-url"), None);
}
#[test]
fn rate_limiter_uses_api_specific_semaphore() {
let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
let limiter = RateLimiter::new(config);
let stats = limiter.stats();
let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
assert_eq!(test_api.limit, 5);
assert_eq!(test_api.available, 5);
assert_eq!(test_api.in_flight, 0);
}
#[test]
fn rate_limiter_default_semaphore_in_stats() {
let config = RateLimitConfig::new(20);
let limiter = RateLimiter::new(config);
let stats = limiter.stats();
let default = stats.iter().find(|s| s.api == "default").unwrap();
assert_eq!(default.limit, 20);
assert_eq!(default.available, 20);
}
#[tokio::test]
async fn acquire_uses_correct_semaphore() {
let config = RateLimitConfig::new(100).with_api_limit("compute.googleapis.com", 2);
let limiter = RateLimiter::new(config);
let _p1 = limiter
.acquire("https://compute.googleapis.com/v1/foo")
.await;
let _p2 = limiter
.acquire("https://compute.googleapis.com/v1/bar")
.await;
let stats = limiter.stats();
let compute = stats
.iter()
.find(|s| s.api == "compute.googleapis.com")
.unwrap();
assert_eq!(compute.in_flight, 2);
assert_eq!(compute.available, 0);
let default = stats.iter().find(|s| s.api == "default").unwrap();
assert_eq!(default.in_flight, 0);
}
#[tokio::test]
async fn acquire_falls_back_to_default() {
let config = RateLimitConfig::new(3);
let limiter = RateLimiter::new(config);
let _p = limiter
.acquire("https://unknown.googleapis.com/v1/foo")
.await;
let stats = limiter.stats();
let default = stats.iter().find(|s| s.api == "default").unwrap();
assert_eq!(default.in_flight, 1);
}
#[tokio::test]
async fn permit_released_on_drop() {
let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 1);
let limiter = RateLimiter::new(config);
{
let _permit = limiter.acquire("https://test.example.com/v1/foo").await;
let stats = limiter.stats();
let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
assert_eq!(test_api.in_flight, 1);
}
let stats = limiter.stats();
let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
assert_eq!(test_api.in_flight, 0);
}
}