kumo 0.3.8

An async web crawling framework for Rust - Scrapy for Rust
Documentation
use std::{
    sync::{
        Arc,
        atomic::{AtomicUsize, Ordering},
    },
    time::Duration,
};

use bytes::Bytes;
use kumo::{
    error::KumoError,
    extract::Response,
    fetch::{CachingFetcher, Fetcher, MockFetcher},
    middleware::FetchRequest,
};
use reqwest::Method;

fn req(url: &str) -> FetchRequest {
    FetchRequest::new(url, 0)
}

fn post_req(url: &str) -> FetchRequest {
    let mut request = FetchRequest::new(url, 0);
    request.method = Method::POST;
    request.body = Some(br#"{"page":1}"#.to_vec());
    request
}

#[derive(Clone)]
enum SequenceBody {
    Text(&'static str),
    Bytes(&'static [u8]),
}

#[derive(Clone)]
struct SequenceResponse {
    status: u16,
    body: SequenceBody,
}

struct SequenceFetcher {
    calls: Arc<AtomicUsize>,
    responses: Vec<SequenceResponse>,
}

impl SequenceFetcher {
    fn new(responses: Vec<SequenceResponse>) -> Self {
        Self {
            calls: Arc::new(AtomicUsize::new(0)),
            responses,
        }
    }

    fn calls(&self) -> Arc<AtomicUsize> {
        self.calls.clone()
    }
}

#[async_trait::async_trait]
impl Fetcher for SequenceFetcher {
    async fn fetch(&self, request: &FetchRequest) -> Result<Response, KumoError> {
        let index = self.calls.fetch_add(1, Ordering::SeqCst);
        let response = self
            .responses
            .get(index)
            .or_else(|| self.responses.last())
            .expect("sequence fetcher needs at least one response");

        Ok(match response.body {
            SequenceBody::Text(body) => {
                Response::from_parts(request.url().to_string(), response.status, body)
            }
            SequenceBody::Bytes(body) => Response::from_bytes(
                request.url().to_string(),
                response.status,
                Bytes::from(body),
            ),
        })
    }
}

#[tokio::test]
async fn first_request_fetches_from_inner() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = MockFetcher::new().with_response("https://example.com", 200, "<h1>Hello</h1>");
    let cf = CachingFetcher::new(inner, tmp.path()).unwrap();

    let res = cf.fetch(&req("https://example.com")).await.unwrap();
    assert_eq!(res.status(), 200);
    assert_eq!(res.text(), Some("<h1>Hello</h1>"));
}

#[tokio::test]
async fn second_request_served_from_cache() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = MockFetcher::new().with_response("https://example.com", 200, "from network");
    let cf = CachingFetcher::new(inner, tmp.path()).unwrap();

    cf.fetch(&req("https://example.com")).await.unwrap();
    let res2 = cf.fetch(&req("https://example.com")).await.unwrap();
    assert_eq!(res2.text(), Some("from network"));
}

#[tokio::test]
async fn cache_file_is_created_after_fetch() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = MockFetcher::new().with_response("https://example.com", 200, "body");
    let cf = CachingFetcher::new(inner, tmp.path()).unwrap();

    cf.fetch(&req("https://example.com")).await.unwrap();

    let files: Vec<_> = std::fs::read_dir(tmp.path()).unwrap().collect();
    assert_eq!(files.len(), 1);
}

#[tokio::test]
async fn expired_entry_is_refetched() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = SequenceFetcher::new(vec![
        SequenceResponse {
            status: 200,
            body: SequenceBody::Text("cached"),
        },
        SequenceResponse {
            status: 200,
            body: SequenceBody::Text("refetched"),
        },
    ]);
    let cf = CachingFetcher::new(inner, tmp.path())
        .unwrap()
        .ttl(Duration::from_secs(0));

    cf.fetch(&req("https://example.com")).await.unwrap();
    let res = cf.fetch(&req("https://example.com")).await.unwrap();
    assert_eq!(res.status(), 200);
    assert_eq!(res.text(), Some("refetched"));
}

#[tokio::test]
async fn cached_response_preserves_status_on_hit() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = SequenceFetcher::new(vec![
        SequenceResponse {
            status: 404,
            body: SequenceBody::Text("missing"),
        },
        SequenceResponse {
            status: 200,
            body: SequenceBody::Text("live"),
        },
    ]);
    let calls = inner.calls();
    let cf = CachingFetcher::new(inner, tmp.path()).unwrap();

    cf.fetch(&req("https://example.com/missing")).await.unwrap();
    let cached = cf.fetch(&req("https://example.com/missing")).await.unwrap();

    assert_eq!(calls.load(Ordering::SeqCst), 1);
    assert_eq!(cached.status(), 404);
    assert_eq!(cached.text(), Some("missing"));
}

#[tokio::test]
async fn non_get_requests_bypass_cache() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = SequenceFetcher::new(vec![
        SequenceResponse {
            status: 200,
            body: SequenceBody::Text("post 1"),
        },
        SequenceResponse {
            status: 200,
            body: SequenceBody::Text("post 2"),
        },
    ]);
    let calls = inner.calls();
    let cf = CachingFetcher::new(inner, tmp.path()).unwrap();

    let first = cf
        .fetch(&post_req("https://example.com/search"))
        .await
        .unwrap();
    let second = cf
        .fetch(&post_req("https://example.com/search"))
        .await
        .unwrap();

    assert_eq!(calls.load(Ordering::SeqCst), 2);
    assert_eq!(first.text(), Some("post 1"));
    assert_eq!(second.text(), Some("post 2"));
    assert_eq!(std::fs::read_dir(tmp.path()).unwrap().count(), 0);
}

#[tokio::test]
async fn binary_responses_bypass_cache_writes() {
    let tmp = tempfile::TempDir::new().unwrap();
    let inner = SequenceFetcher::new(vec![
        SequenceResponse {
            status: 200,
            body: SequenceBody::Bytes(b"\x89PNG first"),
        },
        SequenceResponse {
            status: 200,
            body: SequenceBody::Bytes(b"\x89PNG second"),
        },
    ]);
    let calls = inner.calls();
    let cf = CachingFetcher::new(inner, tmp.path()).unwrap();

    let first = cf
        .fetch(&req("https://example.com/image.png"))
        .await
        .unwrap();
    let second = cf
        .fetch(&req("https://example.com/image.png"))
        .await
        .unwrap();

    assert_eq!(calls.load(Ordering::SeqCst), 2);
    assert_eq!(first.bytes(), b"\x89PNG first");
    assert_eq!(second.bytes(), b"\x89PNG second");
    assert_eq!(std::fs::read_dir(tmp.path()).unwrap().count(), 0);
}