use std::time::Duration;
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::StdRng;
use url::Url;
use crate::config::RateLimitConfig;
use crate::fetcher::FetcherError;
use crate::fetcher::concurrency::{Pacer, PacerGuard};
use crate::fetcher::fetch::{ConditionalGet, FetchedPage, fetch_url_conditional};
use crate::fetcher::ssrf::SsrfLevel;
#[derive(Debug)]
enum Class {
Done(Box<FetchedPage>),
Fatal(FetcherError),
Backoff(FetcherError),
RetryAfter(Duration, FetcherError),
}
#[allow(clippy::too_many_arguments)]
pub async fn with_retries(
db: &crate::storage::Db,
pacer: &Pacer,
client: &reqwest::Client,
url: &Url,
level: SsrfLevel,
project_root: Option<&std::path::Path>,
har_recorder: Option<&std::sync::Arc<crate::fetcher::har::HarRecorder>>,
cond: &ConditionalGet,
crawl_delay: Option<Duration>,
cfg: &RateLimitConfig,
) -> Result<FetchedPage, FetcherError> {
let host = url
.host_str()
.ok_or(FetcherError::Ssrf(crate::fetcher::ssrf::SsrfError::NoHost))?
.to_string();
let _guard: PacerGuard<'_> = pacer.acquire(&host, crawl_delay).await;
let mut rng: StdRng = match cfg.jitter_seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_os_rng(),
};
let mut attempt: u8 = 0;
loop {
let result =
fetch_url_conditional(client, url, level, project_root, har_recorder, cond).await;
let class = classify(result, cfg);
match class {
Class::Done(page) => return Ok(*page),
Class::Fatal(err) => return Err(err),
Class::Backoff(err) => {
if attempt >= cfg.max_retries {
return Err(FetcherError::RetryExhausted {
attempts: attempt + 1,
last: Box::new(err),
});
}
let wait = compute_jittered_backoff(attempt, cfg, &mut rng);
tokio::time::sleep(wait).await;
attempt += 1;
}
Class::RetryAfter(d, err) => {
if d.as_secs() > cfg.deferred_retry_threshold_secs {
let task_id = uuid::Uuid::now_v7().to_string();
let params = crate::tasks::types::RetryParams {
url: url.to_string(),
attempt: 1,
wait_ms_initial: std::cmp::max(d.as_millis() as u64, 1_000),
max_attempts: cfg.max_retries.max(1),
parent_task_id: None,
};
let params_json = serde_json::to_string(¶ms)
.expect("RetryParams serialization is infallible");
crate::storage::tasks::insert(
db,
crate::storage::tasks::TaskInsert {
id: task_id.clone(),
kind: crate::storage::tasks::TaskKind::Retry,
params_json,
owner_pid: Some(std::process::id() as i64),
},
)
.await?;
tracing::info!(
target: "rover::fetcher::retry",
requested_secs = d.as_secs(),
threshold_secs = cfg.deferred_retry_threshold_secs,
task_id = %task_id,
url = url.as_str(),
"Retry-After exceeds deferral threshold; scheduling retry task"
);
let _ = err;
return Err(FetcherError::Deferred { task_id });
}
if attempt >= cfg.max_retries {
return Err(FetcherError::RetryExhausted {
attempts: attempt + 1,
last: Box::new(err),
});
}
let (capped, clamped) = compute_retry_after_wait(d, cfg);
if clamped {
tracing::warn!(
target: "rover::fetcher::retry",
requested_secs = d.as_secs(),
ceiling_secs = cfg.retry_after_ceiling.as_secs(),
"Retry-After exceeded ceiling; clamping"
);
}
tokio::time::sleep(capped).await;
attempt += 1;
}
}
}
}
fn classify(result: Result<FetchedPage, FetcherError>, _cfg: &RateLimitConfig) -> Class {
match result {
Ok(page) => {
if page.status == 304 || (200..300).contains(&page.status) {
return Class::Done(Box::new(page));
}
classify_non_2xx(page)
}
Err(e) => classify_err(e),
}
}
fn classify_non_2xx(page: FetchedPage) -> Class {
let status = page.status;
let retry_after = page.retry_after.as_deref().and_then(parse_retry_after);
let err = FetcherError::Status {
status,
url: page.final_url.to_string(),
};
match status {
429 | 503 => match retry_after {
Some(d) => Class::RetryAfter(d, err),
None => Class::Backoff(err),
},
500 | 502 | 504 => Class::Backoff(err),
s if (500..600).contains(&s) => Class::Backoff(err),
_ => Class::Fatal(err),
}
}
fn classify_err(e: FetcherError) -> Class {
match &e {
FetcherError::Http(re) => {
if crate::fetcher::dns::dial_blocked_cause(re).is_some() {
Class::Fatal(e)
} else if re.is_timeout() || re.is_connect() {
Class::Backoff(e)
} else {
Class::Fatal(e)
}
}
FetcherError::Ssrf(_)
| FetcherError::Url(_)
| FetcherError::Decode
| FetcherError::Storage(_)
| FetcherError::Status { .. }
| FetcherError::Dns { .. } => Class::Fatal(e),
FetcherError::Extract(_)
| FetcherError::RetryExhausted { .. }
| FetcherError::RateLimited { .. }
| FetcherError::RobotsDisallowed { .. }
| FetcherError::RobotsFetchFailed { .. }
| FetcherError::Deferred { .. }
| FetcherError::HeadlessFeatureNotCompiled
| FetcherError::HeadlessRendererUnavailable => Class::Fatal(e),
#[cfg(feature = "headless")]
FetcherError::Headless(_) => Class::Fatal(e),
}
}
pub(crate) fn compute_jittered_backoff(
attempt: u8,
cfg: &crate::config::RateLimitConfig,
rng: &mut StdRng,
) -> Duration {
let base = cfg
.initial_backoff
.saturating_mul(2u32.saturating_pow(attempt as u32));
let capped = base.min(cfg.max_backoff);
let jitter_ms = rng.random_range(0..=(capped.as_millis() as u64 / 2));
capped + Duration::from_millis(jitter_ms)
}
pub(crate) fn compute_retry_after_wait(
requested: Duration,
cfg: &crate::config::RateLimitConfig,
) -> (Duration, bool) {
let was_clamped = requested > cfg.retry_after_ceiling;
let wait = requested.min(cfg.retry_after_ceiling);
(wait, was_clamped)
}
pub fn parse_retry_after(value: &str) -> Option<Duration> {
let trimmed = value.trim();
if let Ok(secs) = trimmed.parse::<u64>() {
return Some(Duration::from_secs(secs));
}
if let Ok(t) = httpdate::parse_http_date(trimmed) {
let now = std::time::SystemTime::now();
if let Ok(d) = t.duration_since(now) {
return Some(d);
}
return Some(Duration::from_secs(0));
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg() -> RateLimitConfig {
RateLimitConfig {
requests_per_minute_per_domain: 6000,
per_domain_concurrency: 2,
global_concurrency: 8,
max_retries: 3,
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_secs(1),
retry_after_ceiling: Duration::from_secs(60),
jitter_seed: Some(0),
deferred_retry_threshold_secs: 30,
}
}
#[test]
fn parse_retry_after_seconds() {
assert_eq!(parse_retry_after("30"), Some(Duration::from_secs(30)));
assert_eq!(parse_retry_after(" 5 "), Some(Duration::from_secs(5)));
assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
}
#[test]
fn parse_retry_after_http_date_future() {
let t = std::time::SystemTime::now() + Duration::from_secs(3600);
let s = httpdate::fmt_http_date(t);
let d = parse_retry_after(&s).unwrap();
assert!(d.as_secs() > 3500 && d.as_secs() < 3700, "got {d:?}");
}
#[test]
fn parse_retry_after_http_date_past() {
let t = std::time::SystemTime::now() - Duration::from_secs(60);
let s = httpdate::fmt_http_date(t);
let d = parse_retry_after(&s).unwrap();
assert_eq!(d, Duration::from_secs(0));
}
#[test]
fn parse_retry_after_garbage_returns_none() {
assert_eq!(parse_retry_after("not a date or number"), None);
assert_eq!(parse_retry_after(""), None);
}
#[test]
fn classify_2xx_is_done() {
let page = FetchedPage {
final_url: Url::parse("https://example.com").unwrap(),
canonical_url: Url::parse("https://example.com").unwrap(),
status: 200,
content_type: None,
body: String::new(),
charset: crate::fetcher::charset::Detected::default(),
link_header: None,
etag: None,
last_modified: None,
cache_control: None,
expires: None,
retry_after: None,
};
match classify(Ok(page), &cfg()) {
Class::Done(_) => {}
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn classify_429_with_retry_after_is_retry_after() {
let page = FetchedPage {
final_url: Url::parse("https://example.com").unwrap(),
canonical_url: Url::parse("https://example.com").unwrap(),
status: 429,
content_type: None,
body: String::new(),
charset: crate::fetcher::charset::Detected::default(),
link_header: None,
etag: None,
last_modified: None,
cache_control: None,
expires: None,
retry_after: Some("3".to_string()),
};
match classify(Ok(page), &cfg()) {
Class::RetryAfter(d, _) => assert_eq!(d, Duration::from_secs(3)),
other => panic!("expected RetryAfter, got {other:?}"),
}
}
#[test]
fn classify_500_is_backoff() {
let page = FetchedPage {
final_url: Url::parse("https://example.com").unwrap(),
canonical_url: Url::parse("https://example.com").unwrap(),
status: 500,
content_type: None,
body: String::new(),
charset: crate::fetcher::charset::Detected::default(),
link_header: None,
etag: None,
last_modified: None,
cache_control: None,
expires: None,
retry_after: None,
};
assert!(matches!(classify(Ok(page), &cfg()), Class::Backoff(_)));
}
#[cfg(any(test, feature = "test-loopback"))]
#[tokio::test]
async fn long_retry_after_produces_deferred_error() {
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "120"))
.mount(&server)
.await;
let tmp = tempfile::tempdir().unwrap();
let db = crate::storage::Db::open(tmp.path().join("rover.db"))
.await
.unwrap();
let url = Url::parse(&server.uri()).unwrap();
let cfg = RateLimitConfig {
deferred_retry_threshold_secs: 30,
max_retries: 3,
..Default::default()
};
let pacer = Pacer::new(&cfg);
crate::fetcher::client::install_ring_provider();
let client = reqwest::Client::new();
let cond = ConditionalGet::default();
let res = with_retries(
&db,
&pacer,
&client,
&url,
SsrfLevel::Loopback,
None,
None,
&cond,
None,
&cfg,
)
.await;
match res {
Err(FetcherError::Deferred { task_id }) => {
let row = crate::storage::tasks::get(&db, &task_id)
.await
.unwrap()
.unwrap();
assert_eq!(row.kind, crate::storage::tasks::TaskKind::Retry);
}
other => panic!("expected Deferred, got {other:?}"),
}
}
#[test]
fn classify_404_is_fatal() {
let page = FetchedPage {
final_url: Url::parse("https://example.com").unwrap(),
canonical_url: Url::parse("https://example.com").unwrap(),
status: 404,
content_type: None,
body: String::new(),
charset: crate::fetcher::charset::Detected::default(),
link_header: None,
etag: None,
last_modified: None,
cache_control: None,
expires: None,
retry_after: None,
};
assert!(matches!(classify(Ok(page), &cfg()), Class::Fatal(_)));
}
}