use std::sync::Arc;
use async_trait::async_trait;
use lellm_core::{LlmError, Message};
pub struct FallbackContext<'a> {
pub error: &'a LlmError,
pub attempt: usize,
pub iterations: usize,
pub conversation: Arc<[Message]>,
}
#[derive(Debug, Clone)]
pub enum FallbackAction {
Retry,
Abort,
}
#[async_trait]
pub trait FallbackStrategy: Send + Sync {
async fn handle(&self, ctx: &FallbackContext) -> FallbackAction;
}
pub struct DefaultFallback {
max_retries: usize,
}
impl DefaultFallback {
pub fn new(max_retries: usize) -> Self {
Self { max_retries }
}
fn is_retriable(error: &LlmError) -> bool {
match error {
LlmError::Timeout { .. } | LlmError::Network { .. } => true,
LlmError::Provider {
status: Some(s), ..
} => *s >= 500,
_ => false,
}
}
}
impl Default for DefaultFallback {
fn default() -> Self {
Self::new(3)
}
}
#[async_trait]
impl FallbackStrategy for DefaultFallback {
async fn handle(&self, ctx: &FallbackContext) -> FallbackAction {
if Self::is_retriable(&ctx.error) && ctx.attempt < self.max_retries {
FallbackAction::Retry
} else {
FallbackAction::Abort
}
}
}