halldyll-core 0.1.0

Core scraping engine for Halldyll - high-performance async web scraper for AI agents
Documentation
//! Throttle - Rate limiting per domain

use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;

/// Throttling state for a domain
#[derive(Debug)]
pub struct ThrottleState {
    /// Last request
    last_request: Instant,
    /// Minimum delay between requests
    min_delay: Duration,
    /// Observed average latency
    avg_latency_ms: f64,
    /// Number of 429s received
    rate_limit_count: u32,
    /// In forced pause?
    paused_until: Option<Instant>,
}

impl ThrottleState {
    /// New state
    pub fn new(min_delay: Duration) -> Self {
        Self {
            last_request: Instant::now() - min_delay, // Allows an immediate request
            min_delay,
            avg_latency_ms: 0.0,
            rate_limit_count: 0,
            paused_until: None,
        }
    }

    /// Time until next allowed request
    pub fn time_until_next(&self) -> Duration {
        // Check for forced pause
        if let Some(paused_until) = self.paused_until {
            if paused_until > Instant::now() {
                return paused_until - Instant::now();
            }
        }

        let elapsed = self.last_request.elapsed();
        if elapsed >= self.min_delay {
            Duration::ZERO
        } else {
            self.min_delay - elapsed
        }
    }

    /// Can we make a request now?
    pub fn can_request_now(&self) -> bool {
        self.time_until_next() == Duration::ZERO
    }

    /// Marks a request as completed
    pub fn mark_request(&mut self) {
        self.last_request = Instant::now();
    }

    /// Updates the average latency
    pub fn update_latency(&mut self, latency_ms: u64) {
        // Exponential moving average
        const ALPHA: f64 = 0.3;
        self.avg_latency_ms = ALPHA * (latency_ms as f64) + (1.0 - ALPHA) * self.avg_latency_ms;
    }

    /// Signals a rate limit (429/503)
    pub fn signal_rate_limit(&mut self, pause_duration: Duration) {
        self.rate_limit_count += 1;
        self.paused_until = Some(Instant::now() + pause_duration);
        
        // Increase the minimum delay
        self.min_delay = Duration::from_millis(
            (self.min_delay.as_millis() as f64 * 1.5) as u64
        );
    }

    /// Adaptive delay based on latency
    pub fn adaptive_delay(&self) -> Duration {
        // Delay = max(min_delay, average_latency * 2)
        let latency_based = Duration::from_millis((self.avg_latency_ms * 2.0) as u64);
        std::cmp::max(self.min_delay, latency_based)
    }
}

/// Per-domain throttler
pub struct DomainThrottler {
    /// State per domain
    states: RwLock<HashMap<String, ThrottleState>>,
    /// Concurrency semaphore per domain
    semaphores: RwLock<HashMap<String, std::sync::Arc<Semaphore>>>,
    /// Default delay
    default_delay: Duration,
    /// Max concurrency per domain
    max_concurrent_per_domain: usize,
    /// Global max concurrency
    global_semaphore: Semaphore,
    /// Adaptive delay enabled?
    adaptive: bool,
    /// Pause on rate limit
    rate_limit_pause: Duration,
}

impl DomainThrottler {
    /// New throttler
    pub fn new(
        default_delay_ms: u64,
        max_concurrent_per_domain: usize,
        max_concurrent_total: usize,
        adaptive: bool,
        rate_limit_pause_ms: u64,
    ) -> Self {
        Self {
            states: RwLock::new(HashMap::new()),
            semaphores: RwLock::new(HashMap::new()),
            default_delay: Duration::from_millis(default_delay_ms),
            max_concurrent_per_domain,
            global_semaphore: Semaphore::new(max_concurrent_total),
            adaptive,
            rate_limit_pause: Duration::from_millis(rate_limit_pause_ms),
        }
    }

    /// Extracts the domain from a URL
    fn domain(url: &url::Url) -> String {
        url.host_str().unwrap_or("").to_string()
    }

    /// Retrieves or creates a semaphore for a domain
    fn get_or_create_semaphore(&self, domain: &str) -> std::sync::Arc<Semaphore> {
        {
            let semaphores = self.semaphores.read().unwrap();
            if let Some(sem) = semaphores.get(domain) {
                return sem.clone();
            }
        }

        let mut semaphores = self.semaphores.write().unwrap();
        semaphores
            .entry(domain.to_string())
            .or_insert_with(|| std::sync::Arc::new(Semaphore::new(self.max_concurrent_per_domain)))
            .clone()
    }

    /// Retrieves or creates the state for a domain
    fn get_or_create_state(&self, domain: &str, crawl_delay: Option<Duration>) -> ThrottleState {
        let delay = crawl_delay.unwrap_or(self.default_delay);
        
        {
            let states = self.states.read().unwrap();
            if let Some(state) = states.get(domain) {
                return ThrottleState {
                    last_request: state.last_request,
                    min_delay: delay,
                    avg_latency_ms: state.avg_latency_ms,
                    rate_limit_count: state.rate_limit_count,
                    paused_until: state.paused_until,
                };
            }
        }

        ThrottleState::new(delay)
    }

    /// Waits for the green light to make a request
    pub async fn acquire(&self, url: &url::Url, crawl_delay: Option<Duration>) {
        let domain = Self::domain(url);
        
        // Acquire the global semaphore
        let _global_permit = self.global_semaphore.acquire().await.unwrap();
        
        // Acquire the domain semaphore
        let domain_sem = self.get_or_create_semaphore(&domain);
        let _domain_permit = domain_sem.acquire().await.unwrap();

        // Wait for the delay
        let state = self.get_or_create_state(&domain, crawl_delay);
        let wait_time = if self.adaptive {
            state.adaptive_delay()
        } else {
            state.time_until_next()
        };

        if wait_time > Duration::ZERO {
            tokio::time::sleep(wait_time).await;
        }

        // Mark the request
        let mut states = self.states.write().unwrap();
        let state = states
            .entry(domain)
            .or_insert_with(|| ThrottleState::new(self.default_delay));
        state.mark_request();
    }

    /// Signals a completed request with its latency
    pub fn release(&self, url: &url::Url, latency_ms: u64, was_rate_limited: bool) {
        let domain = Self::domain(url);
        let mut states = self.states.write().unwrap();
        
        if let Some(state) = states.get_mut(&domain) {
            state.update_latency(latency_ms);
            
            if was_rate_limited {
                state.signal_rate_limit(self.rate_limit_pause);
            }
        }
    }

    /// Stats for a domain
    pub fn get_stats(&self, url: &url::Url) -> Option<DomainStats> {
        let domain = Self::domain(url);
        let states = self.states.read().unwrap();
        
        states.get(&domain).map(|state| DomainStats {
            avg_latency_ms: state.avg_latency_ms,
            rate_limit_count: state.rate_limit_count,
            min_delay_ms: state.min_delay.as_millis() as u64,
        })
    }
}

/// Stats for a domain
#[derive(Debug, Clone)]
pub struct DomainStats {
    /// Average latency in milliseconds
    pub avg_latency_ms: f64,
    /// Number of rate limit responses (429)
    pub rate_limit_count: u32,
    /// Minimum delay between requests in milliseconds
    pub min_delay_ms: u64,
}