kumo 0.3.1

An async web crawling framework for Rust - Scrapy for Rust
Documentation
use std::{
    collections::HashMap,
    sync::Arc,
    time::{Duration, Instant},
};

use rand::Rng;
use tokio::sync::Mutex;

use crate::{
    frontier::Frontier,
    request::{CrawlRequest, FrontierRequest},
};

use super::{domain::domain_key, fingerprint::FingerprintPolicy, policy::PolitenessPolicy};

#[derive(Debug, Default)]
struct DomainState {
    in_flight: usize,
    next_available_at: Option<Instant>,
    robots_delay: Option<Duration>,
}

pub struct CrawlScheduler {
    frontier: Arc<dyn Frontier>,
    policy: PolitenessPolicy,
    fingerprint_policy: FingerprintPolicy,
    domains: Mutex<HashMap<String, DomainState>>,
}

pub(crate) enum SchedulerPoll {
    Ready(Box<FrontierRequest>),
    Pending(Duration),
    Empty,
}

enum CandidateState {
    Ready,
    Pending(Duration),
}

fn delay_with_jitter(base: Duration, jitter: Option<Duration>) -> Duration {
    let Some(jitter) = jitter else {
        return base;
    };
    if jitter.is_zero() {
        return base;
    }

    let extra = rand::rng().random_range(Duration::ZERO..=jitter);
    base.saturating_add(extra)
}

impl CrawlScheduler {
    pub fn new(frontier: impl Frontier + 'static, policy: PolitenessPolicy) -> Self {
        Self::from_arc(Arc::new(frontier), policy)
    }

    pub fn from_arc(frontier: Arc<dyn Frontier>, policy: PolitenessPolicy) -> Self {
        Self {
            frontier,
            policy,
            fingerprint_policy: FingerprintPolicy::default(),
            domains: Mutex::new(HashMap::new()),
        }
    }

    pub fn with_fingerprint_policy(mut self, policy: FingerprintPolicy) -> Self {
        self.fingerprint_policy = policy;
        self
    }

    pub async fn push_request(&self, request: CrawlRequest, depth: usize) -> bool {
        self.frontier
            .push_request(self.apply_fingerprint(request), depth)
            .await
    }

    pub async fn push_request_force(&self, queued: FrontierRequest) {
        self.frontier.push_request_force(queued).await;
    }

    pub async fn is_empty(&self) -> bool {
        self.frontier.is_empty().await
    }

    pub async fn flush(&self) -> Result<(), crate::error::KumoError> {
        self.frontier.flush().await
    }

    pub async fn try_next_ready(&self) -> Option<FrontierRequest> {
        match self.poll_next().await {
            SchedulerPoll::Ready(queued) => Some(*queued),
            SchedulerPoll::Pending(_) | SchedulerPoll::Empty => None,
        }
    }

    pub(crate) async fn poll_ready(&self) -> SchedulerPoll {
        self.poll_next().await
    }

    pub async fn next_ready(&self) -> Option<FrontierRequest> {
        loop {
            match self.poll_next().await {
                SchedulerPoll::Ready(queued) => return Some(*queued),
                SchedulerPoll::Pending(wait) => tokio::time::sleep(wait).await,
                SchedulerPoll::Empty => return None,
            }
        }
    }

    pub async fn finish(&self, queued: &FrontierRequest) {
        let Some(domain) = domain_key(queued.request().url()) else {
            return;
        };

        let mut domains = self.domains.lock().await;
        let state = domains.entry(domain.clone()).or_default();
        state.in_flight = state.in_flight.saturating_sub(1);
        let policy_delay = self.policy.policy_for(&domain).delay();
        let robots_delay = if self.policy.respects_robots_crawl_delay() {
            state.robots_delay
        } else {
            None
        };
        if let Some(delay) = [policy_delay, robots_delay].into_iter().flatten().max() {
            let delay = delay_with_jitter(delay, self.policy.jitter_range());
            state.next_available_at = Some(Instant::now() + delay);
        }
    }

    pub async fn observe_robots_crawl_delay(&self, url: &str, delay: Duration) {
        let Some(domain) = domain_key(url) else {
            return;
        };

        let mut domains = self.domains.lock().await;
        let state = domains.entry(domain).or_default();
        state.robots_delay = Some(delay);
    }

    async fn poll_next(&self) -> SchedulerPoll {
        let queued_len = self.frontier.len().await;
        if queued_len == 0 {
            return SchedulerPoll::Empty;
        };

        let mut deferred = Vec::new();
        let mut shortest_wait: Option<Duration> = None;

        for _ in 0..queued_len {
            let Some(queued) = self.frontier.pop_request().await else {
                break;
            };

            match self.classify_candidate(&queued).await {
                CandidateState::Ready => {
                    for item in deferred {
                        self.frontier.push_request_force(item).await;
                    }
                    return SchedulerPoll::Ready(Box::new(queued));
                }
                CandidateState::Pending(wait) => {
                    shortest_wait = Some(shortest_wait.map_or(wait, |current| current.min(wait)));
                    deferred.push(queued);
                }
            }
        }

        for item in deferred {
            self.frontier.push_request_force(item).await;
        }

        shortest_wait.map_or(SchedulerPoll::Empty, SchedulerPoll::Pending)
    }

    async fn classify_candidate(&self, queued: &FrontierRequest) -> CandidateState {
        if let Some(scheduled_at) = queued.scheduled_at()
            && let Ok(wait) = scheduled_at.duration_since(std::time::SystemTime::now())
        {
            return CandidateState::Pending(wait);
        }

        let Some(domain) = domain_key(queued.request().url()) else {
            return CandidateState::Ready;
        };

        let mut domains = self.domains.lock().await;
        let state = domains.entry(domain.clone()).or_default();
        let domain_policy = self.policy.policy_for(&domain);

        if state.in_flight >= domain_policy.concurrency() {
            CandidateState::Pending(Duration::from_millis(10))
        } else if let Some(next) = state.next_available_at {
            match next.checked_duration_since(Instant::now()) {
                Some(wait) => CandidateState::Pending(wait),
                None => {
                    state.in_flight += 1;
                    CandidateState::Ready
                }
            }
        } else {
            state.in_flight += 1;
            CandidateState::Ready
        }
    }

    fn apply_fingerprint(&self, request: CrawlRequest) -> CrawlRequest {
        if request.dont_filter_enabled() {
            return request;
        }

        let key = self
            .fingerprint_policy
            .fingerprint(request.url())
            .unwrap_or_else(|_| request.url().to_string());
        request.with_dedup_key(key)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn delay_with_jitter_keeps_delay_within_configured_range() {
        let base = Duration::from_millis(10);
        let jitter = Duration::from_millis(50);

        for _ in 0..100 {
            let delay = delay_with_jitter(base, Some(jitter));
            assert!(delay >= base);
            assert!(delay <= base + jitter);
        }
    }

    #[test]
    fn delay_with_jitter_returns_base_when_jitter_is_disabled() {
        let base = Duration::from_millis(10);

        assert_eq!(delay_with_jitter(base, None), base);
        assert_eq!(delay_with_jitter(base, Some(Duration::ZERO)), base);
    }
}