use std::sync::Arc;
use async_trait::async_trait;
use lellm_core::{ChatResponse, LlmError, Message, ToolError};
#[derive(Debug)]
pub enum FallbackReason {
LlmError(LlmError),
ToolError(ToolError),
LoopDetected,
MaxIterationsReached,
}
pub struct FallbackContext {
pub reason: FallbackReason,
pub conversation: Arc<[Message]>,
pub attempt: usize,
pub max_attempts: usize,
}
#[derive(Debug, Clone)]
pub enum FallbackAction {
Retry,
RetryWithMessages(Vec<Message>),
SwitchProvider(String),
Complete(ChatResponse),
Abort,
}
#[async_trait]
pub trait FallbackStrategy: Send + Sync {
async fn handle(&self, ctx: &FallbackContext) -> FallbackAction;
}
pub struct DefaultFallback {
max_retries: usize,
}
impl DefaultFallback {
pub fn new(max_retries: usize) -> Self {
Self { max_retries }
}
fn is_retriable(error: &LlmError) -> bool {
match error {
LlmError::Timeout | LlmError::Network { .. } => true,
LlmError::ApiError { status, .. } => *status >= 500,
_ => false,
}
}
}
impl Default for DefaultFallback {
fn default() -> Self {
Self::new(3)
}
}
#[async_trait]
impl FallbackStrategy for DefaultFallback {
async fn handle(&self, ctx: &FallbackContext) -> FallbackAction {
match &ctx.reason {
FallbackReason::LlmError(error) => {
if Self::is_retriable(error) && ctx.attempt < self.max_retries {
FallbackAction::Retry
} else {
FallbackAction::Abort
}
}
FallbackReason::ToolError(_)
| FallbackReason::LoopDetected
| FallbackReason::MaxIterationsReached => FallbackAction::Abort,
}
}
}