Skip to main content

aster/core/
retry_logic.rs

1//! 上下文溢出自动恢复逻辑
2//!
3//! 解析上下文溢出错误,动态调整 max_tokens,自动重试
4
5use regex::Regex;
6use std::future::Future;
7use thiserror::Error;
8
9/// 最小输出 tokens
10const MIN_OUTPUT_TOKENS: u64 = 3000;
11
12/// 保留空间
13const RESERVE_BUFFER: u64 = 1000;
14
15/// 上下文溢出错误信息
16#[derive(Debug, Clone)]
17pub struct ContextOverflowError {
18    /// 输入 tokens
19    pub input_tokens: u64,
20    /// 最大 tokens
21    pub max_tokens: u64,
22    /// 上下文限制
23    pub context_limit: u64,
24}
25
26/// 溢出恢复错误
27#[derive(Debug, Error)]
28pub enum OverflowRecoveryError {
29    #[error("Not a context overflow error")]
30    NotOverflowError,
31    #[error("Cannot recover: input={input_tokens}, limit={context_limit}")]
32    CannotRecover {
33        input_tokens: u64,
34        context_limit: u64,
35    },
36    #[error("Max retries exceeded after {attempts} attempts")]
37    MaxRetriesExceeded { attempts: u32 },
38    #[error("Request failed: {0}")]
39    RequestFailed(String),
40}
41
42/// 解析上下文溢出错误
43///
44/// 错误格式示例:
45/// "input length and `max_tokens` exceed context limit: 195000 + 8192 > 200000"
46pub fn parse_context_overflow_error(status: u16, message: &str) -> Option<ContextOverflowError> {
47    // 检查是否为 400 错误
48    if status != 400 {
49        return None;
50    }
51
52    // 匹配错误消息模式
53    let pattern =
54        Regex::new(r"input length and `max_tokens` exceed context limit: (\d+) \+ (\d+) > (\d+)")
55            .ok()?;
56
57    let captures = pattern.captures(message)?;
58
59    let input_tokens: u64 = captures.get(1)?.as_str().parse().ok()?;
60    let max_tokens: u64 = captures.get(2)?.as_str().parse().ok()?;
61    let context_limit: u64 = captures.get(3)?.as_str().parse().ok()?;
62
63    Some(ContextOverflowError {
64        input_tokens,
65        max_tokens,
66        context_limit,
67    })
68}
69
70/// 计算调整后的 max_tokens
71///
72/// 策略:
73/// 1. 计算可用空间 = contextLimit - inputTokens - reserve
74/// 2. 如果可用空间 < MIN_OUTPUT_TOKENS,无法恢复
75/// 3. 否则,返回 max(MIN_OUTPUT_TOKENS, available, thinkingTokens + 1)
76pub fn calculate_adjusted_max_tokens(
77    overflow: &ContextOverflowError,
78    max_thinking_tokens: u64,
79) -> Option<u64> {
80    let available = overflow
81        .context_limit
82        .saturating_sub(overflow.input_tokens)
83        .saturating_sub(RESERVE_BUFFER);
84
85    // 如果可用空间不足最小要求,无法恢复
86    if available < MIN_OUTPUT_TOKENS {
87        return None;
88    }
89
90    // 计算调整后的值
91    let thinking = max_thinking_tokens + 1;
92    let adjusted = available.max(MIN_OUTPUT_TOKENS).max(thinking);
93
94    Some(adjusted)
95}
96
97/// 处理上下文溢出错误
98///
99/// 返回调整后的 max_tokens,如果无法恢复则返回错误
100pub fn handle_context_overflow(
101    status: u16,
102    message: &str,
103    max_thinking_tokens: u64,
104) -> Result<u64, OverflowRecoveryError> {
105    let overflow = parse_context_overflow_error(status, message)
106        .ok_or(OverflowRecoveryError::NotOverflowError)?;
107
108    let adjusted = calculate_adjusted_max_tokens(&overflow, max_thinking_tokens).ok_or(
109        OverflowRecoveryError::CannotRecover {
110            input_tokens: overflow.input_tokens,
111            context_limit: overflow.context_limit,
112        },
113    )?;
114
115    tracing::warn!(
116        "Context overflow detected. Adjusting max_tokens from {} to {}",
117        overflow.max_tokens,
118        adjusted
119    );
120    tracing::warn!(
121        "  Input: {}, Limit: {}, Available: {}",
122        overflow.input_tokens,
123        overflow.context_limit,
124        adjusted
125    );
126
127    Ok(adjusted)
128}
129
130/// 溢出恢复选项
131#[derive(Debug, Clone)]
132pub struct OverflowRecoveryOptions {
133    /// 初始 max_tokens
134    pub max_tokens: Option<u64>,
135    /// 最大思考 tokens
136    pub max_thinking_tokens: u64,
137    /// 最大重试次数
138    pub max_retries: u32,
139}
140
141impl Default for OverflowRecoveryOptions {
142    fn default() -> Self {
143        Self {
144            max_tokens: None,
145            max_thinking_tokens: 0,
146            max_retries: 3,
147        }
148    }
149}
150
151/// 请求错误信息
152pub struct RequestError {
153    pub status: u16,
154    pub message: String,
155}
156
157/// 执行带溢出恢复的请求
158pub async fn execute_with_overflow_recovery<T, E, F, Fut>(
159    execute_request: F,
160    options: OverflowRecoveryOptions,
161    mut on_retry: Option<impl FnMut(u32, u64)>,
162) -> Result<T, OverflowRecoveryError>
163where
164    F: Fn(Option<u64>) -> Fut,
165    Fut: Future<Output = Result<T, E>>,
166    E: Into<RequestError>,
167{
168    let mut current_max_tokens = options.max_tokens;
169
170    for attempt in 1..=options.max_retries {
171        match execute_request(current_max_tokens).await {
172            Ok(result) => return Ok(result),
173            Err(error) => {
174                let req_error: RequestError = error.into();
175
176                let overflow =
177                    match parse_context_overflow_error(req_error.status, &req_error.message) {
178                        Some(o) => o,
179                        None => {
180                            return Err(OverflowRecoveryError::RequestFailed(req_error.message));
181                        }
182                    };
183
184                if attempt >= options.max_retries {
185                    tracing::error!(
186                        "Context overflow recovery failed after {} attempts",
187                        options.max_retries
188                    );
189                    return Err(OverflowRecoveryError::MaxRetriesExceeded { attempts: attempt });
190                }
191
192                let adjusted =
193                    match calculate_adjusted_max_tokens(&overflow, options.max_thinking_tokens) {
194                        Some(a) => a,
195                        None => {
196                            return Err(OverflowRecoveryError::CannotRecover {
197                                input_tokens: overflow.input_tokens,
198                                context_limit: overflow.context_limit,
199                            });
200                        }
201                    };
202
203                tracing::warn!(
204                    "[Retry {}/{}] Context overflow detected. Adjusting max_tokens from {:?} to {}",
205                    attempt,
206                    options.max_retries,
207                    current_max_tokens,
208                    adjusted
209                );
210
211                current_max_tokens = Some(adjusted);
212
213                if let Some(ref mut callback) = on_retry {
214                    callback(attempt, adjusted);
215                }
216            }
217        }
218    }
219
220    Err(OverflowRecoveryError::MaxRetriesExceeded {
221        attempts: options.max_retries,
222    })
223}