use super::super::error::ChatError;
use rand::Rng;
const BACKOFF_MAX_SHIFT: u64 = 10;
const JITTER_DIVISOR: u64 = 5;
const BASE_FAST_MS: u64 = 1_000;
const BASE_MEDIUM_MS: u64 = 2_000;
const BASE_SLOW_MS: u64 = 3_000;
const BASE_RATE_LIMIT_MS: u64 = 5_000;
const CAP_DEFAULT_MS: u64 = 30_000;
const CAP_ABNORMAL_MS: u64 = 20_000;
const CAP_RATE_LIMIT_MS: u64 = 60_000;
const CAP_RETRY_AFTER_MS: u64 = 120_000;
const MAX_ATTEMPTS_AGGRESSIVE: u32 = 5;
const MAX_ATTEMPTS_MODERATE: u32 = 4;
const MAX_ATTEMPTS_CONSERVATIVE: u32 = 3;
const MAX_ATTEMPTS_ONCE: u32 = 1;
const RETRY_AFTER_CAP_SECS: u64 = 120;
#[derive(Debug)]
pub(super) struct RetryPolicy {
pub(super) max_attempts: u32,
pub(super) base_ms: u64,
pub(super) cap_ms: u64,
}
impl RetryPolicy {
const fn network_transient() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_AGGRESSIVE,
base_ms: BASE_FAST_MS,
cap_ms: CAP_DEFAULT_MS,
}
}
const fn network_error() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_AGGRESSIVE,
base_ms: BASE_MEDIUM_MS,
cap_ms: CAP_DEFAULT_MS,
}
}
const fn server_overloaded() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_MODERATE,
base_ms: BASE_MEDIUM_MS,
cap_ms: CAP_DEFAULT_MS,
}
}
const fn server_error() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_CONSERVATIVE,
base_ms: BASE_SLOW_MS,
cap_ms: CAP_DEFAULT_MS,
}
}
fn rate_limit_with_retry_after(secs: u64) -> Self {
let wait_ms = secs.min(RETRY_AFTER_CAP_SECS) * 1_000;
Self {
max_attempts: MAX_ATTEMPTS_ONCE,
base_ms: wait_ms,
cap_ms: CAP_RETRY_AFTER_MS,
}
}
const fn rate_limit_blind() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_CONSERVATIVE,
base_ms: BASE_RATE_LIMIT_MS,
cap_ms: CAP_RATE_LIMIT_MS,
}
}
const fn abnormal_finish() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_CONSERVATIVE,
base_ms: BASE_MEDIUM_MS,
cap_ms: CAP_ABNORMAL_MS,
}
}
const fn fallback_overloaded() -> Self {
Self {
max_attempts: MAX_ATTEMPTS_CONSERVATIVE,
base_ms: BASE_SLOW_MS,
cap_ms: CAP_DEFAULT_MS,
}
}
}
pub(super) fn retry_policy_for(error: &ChatError) -> Option<RetryPolicy> {
match error {
ChatError::NetworkTimeout(_) | ChatError::StreamInterrupted(_) => {
Some(RetryPolicy::network_transient())
}
ChatError::NetworkError(_) => Some(RetryPolicy::network_error()),
ChatError::ApiServerError { status, .. } => match status {
503 | 504 | 529 => Some(RetryPolicy::server_overloaded()),
500 | 502 => Some(RetryPolicy::server_error()),
_ => None,
},
ChatError::ApiRateLimit {
retry_after_secs: Some(secs),
..
} => Some(RetryPolicy::rate_limit_with_retry_after(*secs)),
ChatError::ApiRateLimit {
retry_after_secs: None,
..
} => Some(RetryPolicy::rate_limit_blind()),
ChatError::AbnormalFinish(reason)
if matches!(reason.as_str(), "network_error" | "timeout" | "overloaded") =>
{
Some(RetryPolicy::abnormal_finish())
}
ChatError::Other(msg)
if msg.contains("访问量过大")
|| msg.contains("过载")
|| msg.contains("overloaded")
|| msg.contains("too busy")
|| msg.contains("1305") =>
{
Some(RetryPolicy::fallback_overloaded())
}
_ => None,
}
}
pub(super) fn backoff_delay_ms(attempt: u32, base_ms: u64, cap_ms: u64) -> u64 {
let shift = (attempt - 1).min(BACKOFF_MAX_SHIFT as u32) as u64;
let exp = base_ms.saturating_mul(1u64 << shift).min(cap_ms);
let jitter = rand::thread_rng().gen_range(0..=(exp / JITTER_DIVISOR));
exp + jitter
}