use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use std::time::Duration;
use tracing::{debug, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FallbackLevel {
Normal = 0,
Retry = 1,
Simplified = 2,
AlgorithmOnly = 3,
}
impl Default for FallbackLevel {
fn default() -> Self {
Self::Normal
}
}
impl From<u8> for FallbackLevel {
fn from(value: u8) -> Self {
match value {
0 => Self::Normal,
1 => Self::Retry,
2 => Self::Simplified,
_ => Self::AlgorithmOnly,
}
}
}
#[derive(Debug, Clone)]
pub struct FallbackConfig {
pub max_retries: usize,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub failures_before_escalate: usize,
pub successes_before_deescalate: usize,
}
impl Default for FallbackConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 10000,
backoff_multiplier: 2.0,
failures_before_escalate: 3,
successes_before_deescalate: 2,
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum FallbackError {
#[error("Network error: {0}")]
Network(String),
#[error("Rate limited")]
RateLimited,
#[error("Token limit exceeded")]
TokenLimitExceeded,
#[error("LLM unavailable: {0}")]
Unavailable(String),
#[error("Response parsing failed: {0}")]
ParseError(String),
#[error("All fallback strategies exhausted")]
Exhausted,
}
impl FallbackError {
pub fn is_retryable(&self) -> bool {
matches!(self, Self::Network(_) | Self::RateLimited)
}
pub fn needs_simplification(&self) -> bool {
matches!(self, Self::TokenLimitExceeded)
}
pub fn needs_algorithm_fallback(&self) -> bool {
matches!(self, Self::Unavailable(_) | Self::Exhausted)
}
}
#[derive(Debug, Clone, Default)]
pub struct FallbackStats {
pub total_attempts: usize,
pub successful: usize,
pub retried: usize,
pub simplified: usize,
pub algorithm_fallbacks: usize,
pub current_level: FallbackLevel,
}
pub struct FallbackManager {
config: FallbackConfig,
current_level: AtomicU8,
consecutive_failures: AtomicUsize,
consecutive_successes: AtomicUsize,
retry_attempts: AtomicUsize,
}
impl std::fmt::Debug for FallbackManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FallbackManager")
.field("config", &self.config)
.field("current_level", &self.current_level())
.field(
"consecutive_failures",
&self.consecutive_failures.load(Ordering::Relaxed),
)
.finish()
}
}
impl FallbackManager {
pub fn new(config: FallbackConfig) -> Self {
Self {
config,
current_level: AtomicU8::new(0),
consecutive_failures: AtomicUsize::new(0),
consecutive_successes: AtomicUsize::new(0),
retry_attempts: AtomicUsize::new(0),
}
}
pub fn with_defaults() -> Self {
Self::new(FallbackConfig::default())
}
pub fn current_level(&self) -> FallbackLevel {
self.current_level.load(Ordering::Relaxed).into()
}
pub fn is_algorithm_only(&self) -> bool {
self.current_level() == FallbackLevel::AlgorithmOnly
}
pub fn should_simplify(&self) -> bool {
matches!(
self.current_level(),
FallbackLevel::Simplified | FallbackLevel::AlgorithmOnly
)
}
pub fn retry_delay(&self, attempt: usize) -> Duration {
let delay = self.config.initial_delay_ms as f64
* self.config.backoff_multiplier.powi(attempt as i32);
let delay = delay.min(self.config.max_delay_ms as f64);
Duration::from_millis(delay as u64)
}
pub fn record_success(&self) {
self.consecutive_failures.store(0, Ordering::Relaxed);
let successes = self.consecutive_successes.fetch_add(1, Ordering::Relaxed) + 1;
if successes >= self.config.successes_before_deescalate {
let current = self.current_level.load(Ordering::Relaxed);
if current > 0 {
self.current_level.fetch_sub(1, Ordering::Relaxed);
debug!("Fallback level de-escalated to {:?}", self.current_level());
}
self.consecutive_successes.store(0, Ordering::Relaxed);
}
}
pub fn record_failure(&self, error: &FallbackError) -> FallbackAction {
self.consecutive_successes.store(0, Ordering::Relaxed);
let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
if failures >= self.config.failures_before_escalate {
self.escalate_level();
self.consecutive_failures.store(0, Ordering::Relaxed);
}
match error {
FallbackError::Network(_) | FallbackError::RateLimited => {
if self.retry_attempts.load(Ordering::Relaxed) < self.config.max_retries {
FallbackAction::Retry
} else {
FallbackAction::Escalate
}
}
FallbackError::TokenLimitExceeded => FallbackAction::Simplify,
FallbackError::Unavailable(_) | FallbackError::Exhausted => {
FallbackAction::UseAlgorithm
}
FallbackError::ParseError(_) => {
FallbackAction::UseDefault
}
}
}
fn escalate_level(&self) {
let current = self.current_level.load(Ordering::Relaxed);
if current < 3 {
self.current_level.fetch_add(1, Ordering::Relaxed);
warn!("Fallback level escalated to {:?}", self.current_level());
}
}
pub fn start_retry(&self) {
self.retry_attempts.fetch_add(1, Ordering::Relaxed);
}
pub fn reset_retry_count(&self) {
self.retry_attempts.store(0, Ordering::Relaxed);
}
pub fn reset(&self) {
self.current_level.store(0, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Relaxed);
self.consecutive_successes.store(0, Ordering::Relaxed);
self.retry_attempts.store(0, Ordering::Relaxed);
}
pub fn stats(&self) -> FallbackStats {
FallbackStats {
current_level: self.current_level(),
..Default::default()
}
}
pub fn config(&self) -> &FallbackConfig {
&self.config
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FallbackAction {
Retry,
Simplify,
Escalate,
UseAlgorithm,
UseDefault,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fallback_level_conversion() {
assert_eq!(FallbackLevel::from(0), FallbackLevel::Normal);
assert_eq!(FallbackLevel::from(1), FallbackLevel::Retry);
assert_eq!(FallbackLevel::from(2), FallbackLevel::Simplified);
assert_eq!(FallbackLevel::from(3), FallbackLevel::AlgorithmOnly);
assert_eq!(FallbackLevel::from(4), FallbackLevel::AlgorithmOnly);
}
#[test]
fn test_fallback_manager_creation() {
let manager = FallbackManager::with_defaults();
assert_eq!(manager.current_level(), FallbackLevel::Normal);
assert!(!manager.is_algorithm_only());
assert!(!manager.should_simplify());
}
#[test]
fn test_retry_delay() {
let manager = FallbackManager::with_defaults();
let d0 = manager.retry_delay(0);
let d1 = manager.retry_delay(1);
let d2 = manager.retry_delay(2);
assert!(d1 > d0);
assert!(d2 > d1);
}
#[test]
fn test_retry_delay_max() {
let config = FallbackConfig {
max_delay_ms: 5000,
..Default::default()
};
let manager = FallbackManager::new(config);
let delay = manager.retry_delay(10);
assert!(delay.as_millis() <= 5000);
}
#[test]
fn test_record_success() {
let manager = FallbackManager::with_defaults();
manager.current_level.store(1, Ordering::Relaxed);
for _ in 0..manager.config.successes_before_deescalate {
manager.record_success();
}
assert_eq!(manager.current_level(), FallbackLevel::Normal);
}
#[test]
fn test_record_failure_escalate() {
let manager = FallbackManager::with_defaults();
for _ in 0..manager.config.failures_before_escalate {
let action = manager.record_failure(&FallbackError::Network("test".to_string()));
assert!(matches!(
action,
FallbackAction::Retry | FallbackAction::Escalate
));
}
assert_eq!(manager.current_level(), FallbackLevel::Retry);
}
#[test]
fn test_record_failure_token_limit() {
let manager = FallbackManager::with_defaults();
let action = manager.record_failure(&FallbackError::TokenLimitExceeded);
assert_eq!(action, FallbackAction::Simplify);
}
#[test]
fn test_record_failure_unavailable() {
let manager = FallbackManager::with_defaults();
let action = manager.record_failure(&FallbackError::Unavailable("test".to_string()));
assert_eq!(action, FallbackAction::UseAlgorithm);
}
#[test]
fn test_reset() {
let manager = FallbackManager::with_defaults();
manager.current_level.store(3, Ordering::Relaxed);
manager.consecutive_failures.store(5, Ordering::Relaxed);
manager.reset();
assert_eq!(manager.current_level(), FallbackLevel::Normal);
assert_eq!(manager.consecutive_failures.load(Ordering::Relaxed), 0);
}
#[test]
fn test_error_retryable() {
assert!(FallbackError::Network("test".to_string()).is_retryable());
assert!(FallbackError::RateLimited.is_retryable());
assert!(!FallbackError::TokenLimitExceeded.is_retryable());
assert!(!FallbackError::Unavailable("test".to_string()).is_retryable());
}
#[test]
fn test_error_needs_simplification() {
assert!(FallbackError::TokenLimitExceeded.needs_simplification());
assert!(!FallbackError::Network("test".to_string()).needs_simplification());
}
}