codex-helper-core 0.12.1

Core library for codex-helper.
Documentation
use axum::http::HeaderMap;
use rand::Rng;
use tokio::time::sleep;

use crate::config::ResolvedRetryConfig;
use crate::config::ResolvedRetryLayerConfig;
use crate::config::RetryStrategy;
use crate::logging::RetryInfo;

#[derive(Clone)]
pub(super) struct RetryLayerOptions {
    pub(super) max_attempts: u32,
    pub(super) base_backoff_ms: u64,
    pub(super) max_backoff_ms: u64,
    pub(super) jitter_ms: u64,
    pub(super) retry_status_ranges: Vec<(u16, u16)>,
    pub(super) retry_error_classes: Vec<String>,
    pub(super) strategy: RetryStrategy,
}

#[derive(Clone)]
pub(super) struct RetryPlan {
    pub(super) upstream: RetryLayerOptions,
    pub(super) provider: RetryLayerOptions,
    pub(super) never_status_ranges: Vec<(u16, u16)>,
    pub(super) never_error_classes: Vec<String>,
    pub(super) cloudflare_challenge_cooldown_secs: u64,
    pub(super) cloudflare_timeout_cooldown_secs: u64,
    pub(super) transport_cooldown_secs: u64,
    pub(super) cooldown_backoff_factor: u64,
    pub(super) cooldown_backoff_max_secs: u64,
}

pub(super) fn parse_status_ranges(spec: &str) -> Vec<(u16, u16)> {
    let mut out = Vec::new();
    for raw in spec.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()) {
        if let Some((a, b)) = raw.split_once('-') {
            let (Ok(start), Ok(end)) = (a.trim().parse::<u16>(), b.trim().parse::<u16>()) else {
                continue;
            };
            out.push((start.min(end), start.max(end)));
        } else if let Ok(code) = raw.parse::<u16>() {
            out.push((code, code));
        }
    }
    out
}

fn layer_options(cfg: &ResolvedRetryLayerConfig) -> RetryLayerOptions {
    let max_attempts = cfg.max_attempts.clamp(1, 8);
    let base_backoff_ms = cfg.backoff_ms;
    let max_backoff_ms = cfg.backoff_max_ms;
    let jitter_ms = cfg.jitter_ms;
    let retry_status_ranges = parse_status_ranges(cfg.on_status.as_str());
    let retry_error_classes = cfg.on_class.clone();
    let strategy = cfg.strategy;

    RetryLayerOptions {
        max_attempts,
        base_backoff_ms,
        max_backoff_ms,
        jitter_ms,
        retry_status_ranges,
        retry_error_classes,
        strategy,
    }
}

pub(super) fn retry_plan(cfg: &ResolvedRetryConfig) -> RetryPlan {
    let upstream = layer_options(&cfg.upstream);
    let provider = layer_options(&cfg.provider);
    let never_status_ranges = parse_status_ranges(cfg.never_on_status.as_str());
    let never_error_classes = cfg.never_on_class.clone();
    let cloudflare_challenge_cooldown_secs = cfg.cloudflare_challenge_cooldown_secs;
    let cloudflare_timeout_cooldown_secs = cfg.cloudflare_timeout_cooldown_secs;
    let transport_cooldown_secs = cfg.transport_cooldown_secs;
    let cooldown_backoff_factor = cfg.cooldown_backoff_factor.clamp(1, 16);
    let cooldown_backoff_max_secs = cfg.cooldown_backoff_max_secs.clamp(0, 24 * 60 * 60);

    RetryPlan {
        upstream,
        provider,
        never_status_ranges,
        never_error_classes,
        cloudflare_challenge_cooldown_secs,
        cloudflare_timeout_cooldown_secs,
        transport_cooldown_secs,
        cooldown_backoff_factor,
        cooldown_backoff_max_secs,
    }
}

pub(super) fn retry_info_for_chain(chain: &[String]) -> Option<RetryInfo> {
    let mut attempts = chain.len() as u32;
    if chain
        .last()
        .is_some_and(|s| s.starts_with("all_upstreams_avoided"))
    {
        attempts = attempts.saturating_sub(1);
    }

    if attempts <= 1 {
        return None;
    }
    Some(RetryInfo {
        attempts,
        upstream_chain: chain.to_vec(),
    })
}

pub(super) fn should_retry_status(opt: &RetryLayerOptions, status_code: u16) -> bool {
    opt.retry_status_ranges
        .iter()
        .any(|(a, b)| status_code >= *a && status_code <= *b)
}

pub(super) fn should_retry_class(opt: &RetryLayerOptions, class: Option<&str>) -> bool {
    let Some(c) = class else {
        return false;
    };
    opt.retry_error_classes.iter().any(|x| x == c)
}

pub(super) fn should_never_retry_status(plan: &RetryPlan, status_code: u16) -> bool {
    plan.never_status_ranges
        .iter()
        .any(|(a, b)| status_code >= *a && status_code <= *b)
}

pub(super) fn should_never_retry_class(plan: &RetryPlan, class: Option<&str>) -> bool {
    let Some(c) = class else {
        return false;
    };
    plan.never_error_classes.iter().any(|x| x == c)
}

/// Effective guardrail decision:
/// - `never_on_class` always wins.
/// - `never_on_status` is a guardrail for unclassified / client-ish errors, but is allowed to be
///   overridden by an explicit `on_class` match on either retry layer (e.g. Cloudflare/WAF HTML
///   challenge pages may return status 400).
pub(super) fn should_never_retry(plan: &RetryPlan, status_code: u16, class: Option<&str>) -> bool {
    if should_never_retry_class(plan, class) {
        return true;
    }
    if !should_never_retry_status(plan, status_code) {
        return false;
    }

    let class_is_explicitly_retryable =
        should_retry_class(&plan.upstream, class) || should_retry_class(&plan.provider, class);
    !class_is_explicitly_retryable
}

fn retry_after_ms(headers: &HeaderMap, opt: &RetryLayerOptions) -> Option<u64> {
    let raw = headers.get("retry-after")?.to_str().ok()?.trim();
    if raw.is_empty() {
        return None;
    }
    let seconds = raw.parse::<u64>().ok()?;
    let ms = seconds.saturating_mul(1000);
    let cap = opt.max_backoff_ms.max(opt.base_backoff_ms);
    Some(ms.min(cap))
}

pub(super) async fn backoff_sleep(opt: &RetryLayerOptions, attempt_index: u32) {
    if opt.base_backoff_ms == 0 {
        return;
    }
    let pow = 1u64 << attempt_index.min(20);
    let base = opt.base_backoff_ms.saturating_mul(pow);
    let capped = base.min(opt.max_backoff_ms.max(opt.base_backoff_ms));
    let jitter = if opt.jitter_ms == 0 {
        0
    } else {
        rand::thread_rng().gen_range(0..=opt.jitter_ms)
    };
    sleep(std::time::Duration::from_millis(
        capped.saturating_add(jitter),
    ))
    .await;
}

pub(super) async fn retry_sleep(
    opt: &RetryLayerOptions,
    attempt_index: u32,
    resp_headers: &HeaderMap,
) {
    if let Some(mut ms) = retry_after_ms(resp_headers, opt) {
        if opt.jitter_ms > 0 {
            let jitter = rand::thread_rng().gen_range(0..=opt.jitter_ms);
            let cap = opt.max_backoff_ms.max(opt.base_backoff_ms);
            ms = ms.saturating_add(jitter).min(cap);
        }
        if ms > 0 {
            sleep(std::time::Duration::from_millis(ms)).await;
        }
        return;
    }
    backoff_sleep(opt, attempt_index).await;
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::config::RetryProfileName;
    use axum::http::HeaderValue;
    use pretty_assertions::assert_eq;

    #[test]
    fn parse_status_ranges_accepts_single_codes_and_ranges() {
        assert_eq!(
            parse_status_ranges("429,500-599"),
            vec![(429, 429), (500, 599)]
        );
    }

    #[test]
    fn retry_after_ms_parses_seconds_and_caps() {
        let mut headers = HeaderMap::new();
        headers.insert("retry-after", HeaderValue::from_static("10"));
        let opt = RetryLayerOptions {
            max_attempts: 3,
            base_backoff_ms: 200,
            max_backoff_ms: 2_000,
            jitter_ms: 0,
            retry_status_ranges: vec![(429, 429)],
            retry_error_classes: Vec::new(),
            strategy: RetryStrategy::Failover,
        };
        assert_eq!(retry_after_ms(&headers, &opt), Some(2_000));
    }

    #[test]
    fn retry_info_attempts_excludes_all_upstreams_avoided_marker() {
        let chain = vec![
            "https://a.example/v1 (idx=0) status=502 class=-".to_string(),
            "https://b.example/v1 (idx=1) status=502 class=-".to_string(),
            "all_upstreams_avoided total=2".to_string(),
        ];
        let info = retry_info_for_chain(&chain).unwrap();
        assert_eq!(info.attempts, 2);
        assert_eq!(info.upstream_chain, chain);
    }

    #[test]
    fn retry_info_is_none_when_only_one_real_attempt_happened() {
        let chain = vec![
            "https://a.example/v1 (idx=0) status=502 class=-".to_string(),
            "all_upstreams_avoided total=1".to_string(),
        ];
        assert!(retry_info_for_chain(&chain).is_none());
    }

    #[test]
    fn should_never_retry_allows_on_class_to_override_never_on_status() {
        let resolved = RetryProfileName::Balanced.defaults();
        let plan = retry_plan(&resolved);

        // Default guardrail no longer blocks raw 400 by status alone.
        assert!(!should_never_retry(&plan, 400, None));

        // But Cloudflare/WAF challenge pages may still be retryable even if they are 400.
        assert!(!should_never_retry(
            &plan,
            400,
            Some("cloudflare_challenge")
        ));

        // Explicitly non-retryable client-side mistakes should remain blocked.
        assert!(should_never_retry(
            &plan,
            400,
            Some("client_error_non_retryable")
        ));
    }
}