codex-helper-core 0.15.0

Core library for codex-helper.
Documentation
use std::collections::HashSet;

use crate::config::RetryStrategy;
use crate::lb::LoadBalancer;
use crate::logging::log_retry_trace;

use super::attempt_selection::station_upstreams_exhausted;
use super::retry::{RetryLayerOptions, RetryPlan};

pub(super) fn cross_station_failover_enabled(
    strict_multi_config: bool,
    plan: &RetryPlan,
    provider_opt: &RetryLayerOptions,
) -> bool {
    strict_multi_config
        && plan.allow_cross_station_before_first_output
        && provider_opt.strategy == RetryStrategy::Failover
}

pub(super) fn provider_attempt_limit(
    cross_station_failover_enabled: bool,
    provider_max_attempts: u32,
) -> u32 {
    if cross_station_failover_enabled {
        provider_max_attempts
    } else {
        1
    }
}

pub(super) fn next_provider_load_balancer(
    lbs: &[LoadBalancer],
    tried_stations: &HashSet<String>,
) -> Option<LoadBalancer> {
    lbs.iter()
        .find(|lb| !tried_stations.contains(&lb.service.name))
        .cloned()
}

pub(super) fn station_loop_action_after_attempt(
    service_name: &str,
    request_id: u64,
    station_name: &str,
    upstream_total: usize,
    avoid_set: &HashSet<usize>,
) -> bool {
    if station_upstreams_exhausted(upstream_total, avoid_set) {
        log_same_station_failover_trace(
            service_name,
            request_id,
            station_name,
            upstream_total,
            avoid_set,
            true,
        );
        return true;
    }

    if !avoid_set.is_empty() {
        log_same_station_failover_trace(
            service_name,
            request_id,
            station_name,
            upstream_total,
            avoid_set,
            false,
        );
    }
    false
}

pub(super) struct CrossStationFailoverBlockedParams<'a> {
    pub(super) service_name: &'a str,
    pub(super) request_id: u64,
    pub(super) station_name: &'a str,
    pub(super) strict_multi_config: bool,
    pub(super) provider_attempt: u32,
    pub(super) cross_station_failover_enabled: bool,
    pub(super) provider_opt: &'a RetryLayerOptions,
    pub(super) provider_attempt_limit: u32,
    pub(super) allow_cross_station_before_first_output: bool,
}

pub(super) fn log_cross_station_failover_blocked(params: CrossStationFailoverBlockedParams<'_>) {
    let CrossStationFailoverBlockedParams {
        service_name,
        request_id,
        station_name,
        strict_multi_config,
        provider_attempt,
        cross_station_failover_enabled,
        provider_opt,
        provider_attempt_limit,
        allow_cross_station_before_first_output,
    } = params;

    if !(strict_multi_config
        && provider_attempt == 0
        && !cross_station_failover_enabled
        && provider_opt.max_attempts > 1)
    {
        return;
    }

    log_retry_trace(serde_json::json!({
        "event": "cross_station_failover_blocked",
        "service": service_name,
        "request_id": request_id,
        "station_name": station_name,
        "provider_strategy": if provider_opt.strategy == RetryStrategy::Failover { "failover" } else { "same_upstream" },
        "configured_provider_max_attempts": provider_opt.max_attempts,
        "effective_provider_max_attempts": provider_attempt_limit,
        "allow_cross_station_before_first_output": allow_cross_station_before_first_output,
    }));
}

fn sorted_avoid_indices(avoid: &HashSet<usize>) -> Vec<usize> {
    let mut indices = avoid.iter().copied().collect::<Vec<_>>();
    indices.sort_unstable();
    indices
}

pub(super) fn log_same_station_failover_trace(
    service_name: &str,
    request_id: u64,
    station_name: &str,
    upstream_total: usize,
    avoid_set: &HashSet<usize>,
    exhausted: bool,
) {
    let event = if exhausted {
        "same_station_exhausted"
    } else {
        "same_station_failover"
    };
    log_retry_trace(serde_json::json!({
        "event": event,
        "service": service_name,
        "request_id": request_id,
        "station_name": station_name,
        "upstream_total": upstream_total,
        "avoided_indices": sorted_avoid_indices(avoid_set),
        "next_action": if exhausted {
            "consider_next_station"
        } else {
            "retry_another_upstream_within_station"
        },
    }));
}

#[cfg(test)]
mod tests {
    use std::collections::{HashMap, HashSet};
    use std::sync::{Arc, Mutex};

    use super::*;
    use crate::config::{RetryStrategy, ServiceConfig, UpstreamAuth, UpstreamConfig};
    use crate::lb::LbState;

    fn test_load_balancer(name: &str) -> LoadBalancer {
        LoadBalancer::new(
            Arc::new(ServiceConfig {
                name: name.to_string(),
                alias: None,
                enabled: true,
                level: 1,
                upstreams: vec![UpstreamConfig {
                    base_url: format!("https://{name}.example/v1"),
                    auth: UpstreamAuth::default(),
                    tags: HashMap::new(),
                    supported_models: HashMap::new(),
                    model_mapping: HashMap::new(),
                }],
            }),
            Arc::new(Mutex::new(HashMap::<String, LbState>::new())),
        )
    }

    #[test]
    fn next_provider_load_balancer_skips_tried_station_names() {
        let lbs = vec![test_load_balancer("a"), test_load_balancer("b")];
        let tried = HashSet::from([String::from("a")]);

        let next = next_provider_load_balancer(&lbs, &tried).expect("next");

        assert_eq!(next.service.name, "b");
    }

    #[test]
    fn provider_attempt_limit_respects_cross_station_flag() {
        assert_eq!(provider_attempt_limit(false, 4), 1);
        assert_eq!(provider_attempt_limit(true, 4), 4);
    }

    #[test]
    fn station_loop_action_ignores_out_of_range_avoids() {
        assert!(!station_loop_action_after_attempt(
            "codex",
            1,
            "alpha",
            2,
            &HashSet::from([0usize, 99usize])
        ));
        assert!(station_loop_action_after_attempt(
            "codex",
            1,
            "alpha",
            2,
            &HashSet::from([0usize, 1usize, 99usize])
        ));
    }

    #[test]
    fn cross_station_failover_enabled_requires_failover_strategy_and_guardrail() {
        let provider_opt = RetryLayerOptions {
            max_attempts: 3,
            base_backoff_ms: 0,
            max_backoff_ms: 0,
            jitter_ms: 0,
            retry_status_ranges: Vec::new(),
            retry_error_classes: Vec::new(),
            strategy: RetryStrategy::Failover,
        };
        let mut plan = RetryPlan {
            upstream: provider_opt.clone(),
            route: provider_opt.clone(),
            allow_cross_station_before_first_output: true,
            never_status_ranges: Vec::new(),
            never_error_classes: Vec::new(),
            cloudflare_challenge_cooldown_secs: 0,
            cloudflare_timeout_cooldown_secs: 0,
            transport_cooldown_secs: 0,
            cooldown_backoff_factor: 1,
            cooldown_backoff_max_secs: 0,
        };

        assert!(cross_station_failover_enabled(true, &plan, &provider_opt));
        assert!(!cross_station_failover_enabled(false, &plan, &provider_opt));

        plan.allow_cross_station_before_first_output = false;
        assert!(!cross_station_failover_enabled(true, &plan, &provider_opt));

        let same_upstream_provider = RetryLayerOptions {
            strategy: RetryStrategy::SameUpstream,
            ..provider_opt.clone()
        };
        assert!(!cross_station_failover_enabled(
            true,
            &RetryPlan {
                allow_cross_station_before_first_output: true,
                route: provider_opt.clone(),
                ..plan
            },
            &same_upstream_provider,
        ));
    }
}