Skip to main content

lellm_agent/tools/
fallback.rs

1//! 降级策略 — 可注入的 Fallback 回调。
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use lellm_core::{ChatResponse, LlmError, Message, ToolError};
7
8/// 降级原因
9#[derive(Debug)]
10pub enum FallbackReason {
11    LlmError(LlmError),
12    ToolError(ToolError),
13    LoopDetected,
14    MaxIterationsReached,
15}
16
17/// Fallback 上下文
18pub struct FallbackContext {
19    pub reason: FallbackReason,
20    pub conversation: Arc<[Message]>,
21    pub attempt: usize,
22    pub max_attempts: usize,
23}
24
25/// Fallback 动作
26#[derive(Debug, Clone)]
27pub enum FallbackAction {
28    Retry,
29    RetryWithMessages(Vec<Message>),
30    SwitchProvider(String),
31    Complete(ChatResponse),
32    Abort,
33}
34
35/// Fallback 策略 trait
36#[async_trait]
37pub trait FallbackStrategy: Send + Sync {
38    async fn handle(&self, ctx: &FallbackContext) -> FallbackAction;
39}
40
41/// 默认 fallback 策略
42pub struct DefaultFallback {
43    max_retries: usize,
44}
45
46impl DefaultFallback {
47    pub fn new(max_retries: usize) -> Self {
48        Self { max_retries }
49    }
50
51    /// 判断错误是否可重试
52    fn is_retriable(error: &LlmError) -> bool {
53        match error {
54            LlmError::Timeout | LlmError::Network { .. } => true,
55            LlmError::ApiError { status, .. } => *status >= 500,
56            _ => false,
57        }
58    }
59}
60
61impl Default for DefaultFallback {
62    fn default() -> Self {
63        Self::new(3)
64    }
65}
66
67#[async_trait]
68impl FallbackStrategy for DefaultFallback {
69    async fn handle(&self, ctx: &FallbackContext) -> FallbackAction {
70        match &ctx.reason {
71            FallbackReason::LlmError(error) => {
72                if Self::is_retriable(error) && ctx.attempt < self.max_retries {
73                    FallbackAction::Retry
74                } else {
75                    FallbackAction::Abort
76                }
77            }
78            FallbackReason::ToolError(_)
79            | FallbackReason::LoopDetected
80            | FallbackReason::MaxIterationsReached => FallbackAction::Abort,
81        }
82    }
83}