cloud-lite-core-rs 0.1.1

Shared utilities for cloud-lite provider crates
Documentation
//! Rate limiting configuration and runtime for API requests.

use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};

/// Configuration for per-API concurrency limiting.
///
/// Provides the building blocks for concurrency control. Provider crates
/// supply their own defaults (e.g. GCP API quotas).
///
/// # Example
///
/// ```rust
/// use cloud_lite_core::rate_limit::RateLimitConfig;
///
/// let config = RateLimitConfig::new(20)
///     .with_api_limit("api.example.com", 10);
///
/// let disabled = RateLimitConfig::disabled();
/// ```
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
    /// Concurrency limit for APIs not in `api_limits`.
    pub default_limit: usize,
    /// Per-API concurrency limits keyed by host (e.g. "compute.googleapis.com").
    pub api_limits: HashMap<String, usize>,
}

impl RateLimitConfig {
    /// Create a new config with the given default concurrency limit.
    pub fn new(default_limit: usize) -> Self {
        Self {
            default_limit,
            api_limits: HashMap::new(),
        }
    }

    /// Create a config that effectively disables rate limiting.
    pub fn disabled() -> Self {
        Self {
            default_limit: usize::MAX,
            api_limits: HashMap::new(),
        }
    }

    /// Override the default concurrency limit for unknown APIs.
    pub fn with_default_limit(mut self, limit: usize) -> Self {
        self.default_limit = limit;
        self
    }

    /// Set or override the concurrency limit for a specific API host.
    pub fn with_api_limit(mut self, host: &str, limit: usize) -> Self {
        self.api_limits.insert(host.to_string(), limit);
        self
    }
}

/// Snapshot of rate limiting state for a single API.
#[derive(Debug, Clone)]
pub struct RateLimitStats {
    /// API host name, or "default" for the fallback semaphore.
    pub api: String,
    /// Configured concurrency limit.
    pub limit: usize,
    /// Permits currently available.
    pub available: usize,
    /// Requests currently in flight (`limit - available`).
    pub in_flight: usize,
}

/// Extract the host from a URL (e.g. "https://compute.googleapis.com/v1/..." -> "compute.googleapis.com").
fn extract_host(url: &str) -> Option<&str> {
    let after_scheme = url
        .strip_prefix("https://")
        .or_else(|| url.strip_prefix("http://"))?;
    Some(after_scheme.split('/').next().unwrap_or(after_scheme))
}

/// Maximum permits tokio allows on a semaphore (`usize::MAX >> 3`).
const MAX_SEMAPHORE_PERMITS: usize = Semaphore::MAX_PERMITS;

/// Semaphore-based per-API concurrency limiter.
pub struct RateLimiter {
    default_limit: usize,
    default_semaphore: Arc<Semaphore>,
    api_limits: HashMap<String, usize>,
    api_semaphores: HashMap<String, Arc<Semaphore>>,
}

impl RateLimiter {
    /// Create a new rate limiter from the given configuration.
    pub fn new(config: RateLimitConfig) -> Self {
        let capped_default = config.default_limit.min(MAX_SEMAPHORE_PERMITS);
        let default_semaphore = Arc::new(Semaphore::new(capped_default));
        let api_semaphores = config
            .api_limits
            .iter()
            .map(|(host, &limit)| {
                (
                    host.clone(),
                    Arc::new(Semaphore::new(limit.min(MAX_SEMAPHORE_PERMITS))),
                )
            })
            .collect();
        let api_limits: HashMap<String, usize> = config
            .api_limits
            .into_iter()
            .map(|(host, limit)| (host, limit.min(MAX_SEMAPHORE_PERMITS)))
            .collect();
        Self {
            default_limit: capped_default,
            default_semaphore,
            api_limits,
            api_semaphores,
        }
    }

    /// Acquire a permit for the given URL, blocking until one is available.
    pub async fn acquire(&self, url: &str) -> OwnedSemaphorePermit {
        let semaphore = self.semaphore_for(url);
        semaphore
            .acquire_owned()
            .await
            .expect("rate limiter semaphore closed unexpectedly")
    }

    fn semaphore_for(&self, url: &str) -> Arc<Semaphore> {
        if let Some(host) = extract_host(url)
            && let Some(sem) = self.api_semaphores.get(host)
        {
            return Arc::clone(sem);
        }
        Arc::clone(&self.default_semaphore)
    }

    /// Get a snapshot of current rate limiting state.
    pub fn stats(&self) -> Vec<RateLimitStats> {
        let mut result = Vec::with_capacity(self.api_semaphores.len() + 1);

        // Default
        let available = self.default_semaphore.available_permits();
        result.push(RateLimitStats {
            api: "default".into(),
            limit: self.default_limit,
            available,
            in_flight: self.default_limit.saturating_sub(available),
        });

        // Per-API
        for (host, sem) in &self.api_semaphores {
            let limit = self.api_limits[host];
            let available = sem.available_permits();
            result.push(RateLimitStats {
                api: host.clone(),
                limit,
                available,
                in_flight: limit.saturating_sub(available),
            });
        }

        result
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn new_config_has_given_default() {
        let config = RateLimitConfig::new(20);
        assert_eq!(config.default_limit, 20);
        assert!(config.api_limits.is_empty());
    }

    #[test]
    fn disabled_config_uses_usize_max() {
        let config = RateLimitConfig::disabled();
        assert_eq!(config.default_limit, usize::MAX);
        assert!(config.api_limits.is_empty());
    }

    #[test]
    fn with_default_limit_overrides() {
        let config = RateLimitConfig::new(20).with_default_limit(30);
        assert_eq!(config.default_limit, 30);
    }

    #[test]
    fn with_api_limit_adds_entry() {
        let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
        assert_eq!(config.api_limits.get("test.example.com"), Some(&5));
        assert_eq!(config.default_limit, 20);
    }

    #[test]
    fn extract_host_from_standard_url() {
        assert_eq!(
            extract_host("https://compute.googleapis.com/compute/v1/projects/foo"),
            Some("compute.googleapis.com")
        );
    }

    #[test]
    fn extract_host_returns_none_for_garbage() {
        assert_eq!(extract_host("not-a-url"), None);
    }

    #[test]
    fn rate_limiter_uses_api_specific_semaphore() {
        let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 5);
        let limiter = RateLimiter::new(config);
        let stats = limiter.stats();
        let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
        assert_eq!(test_api.limit, 5);
        assert_eq!(test_api.available, 5);
        assert_eq!(test_api.in_flight, 0);
    }

    #[test]
    fn rate_limiter_default_semaphore_in_stats() {
        let config = RateLimitConfig::new(20);
        let limiter = RateLimiter::new(config);
        let stats = limiter.stats();
        let default = stats.iter().find(|s| s.api == "default").unwrap();
        assert_eq!(default.limit, 20);
        assert_eq!(default.available, 20);
    }

    #[tokio::test]
    async fn acquire_uses_correct_semaphore() {
        let config = RateLimitConfig::new(100).with_api_limit("compute.googleapis.com", 2);
        let limiter = RateLimiter::new(config);

        let _p1 = limiter
            .acquire("https://compute.googleapis.com/v1/foo")
            .await;
        let _p2 = limiter
            .acquire("https://compute.googleapis.com/v1/bar")
            .await;

        let stats = limiter.stats();
        let compute = stats
            .iter()
            .find(|s| s.api == "compute.googleapis.com")
            .unwrap();
        assert_eq!(compute.in_flight, 2);
        assert_eq!(compute.available, 0);

        let default = stats.iter().find(|s| s.api == "default").unwrap();
        assert_eq!(default.in_flight, 0);
    }

    #[tokio::test]
    async fn acquire_falls_back_to_default() {
        let config = RateLimitConfig::new(3);
        let limiter = RateLimiter::new(config);

        let _p = limiter
            .acquire("https://unknown.googleapis.com/v1/foo")
            .await;

        let stats = limiter.stats();
        let default = stats.iter().find(|s| s.api == "default").unwrap();
        assert_eq!(default.in_flight, 1);
    }

    #[tokio::test]
    async fn permit_released_on_drop() {
        let config = RateLimitConfig::new(20).with_api_limit("test.example.com", 1);
        let limiter = RateLimiter::new(config);

        {
            let _permit = limiter.acquire("https://test.example.com/v1/foo").await;
            let stats = limiter.stats();
            let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
            assert_eq!(test_api.in_flight, 1);
        }
        // Permit dropped

        let stats = limiter.stats();
        let test_api = stats.iter().find(|s| s.api == "test.example.com").unwrap();
        assert_eq!(test_api.in_flight, 0);
    }
}