aster/core/
retry_logic.rs1use regex::Regex;
6use std::future::Future;
7use thiserror::Error;
8
9const MIN_OUTPUT_TOKENS: u64 = 3000;
11
12const RESERVE_BUFFER: u64 = 1000;
14
15#[derive(Debug, Clone)]
17pub struct ContextOverflowError {
18 pub input_tokens: u64,
20 pub max_tokens: u64,
22 pub context_limit: u64,
24}
25
26#[derive(Debug, Error)]
28pub enum OverflowRecoveryError {
29 #[error("Not a context overflow error")]
30 NotOverflowError,
31 #[error("Cannot recover: input={input_tokens}, limit={context_limit}")]
32 CannotRecover {
33 input_tokens: u64,
34 context_limit: u64,
35 },
36 #[error("Max retries exceeded after {attempts} attempts")]
37 MaxRetriesExceeded { attempts: u32 },
38 #[error("Request failed: {0}")]
39 RequestFailed(String),
40}
41
42pub fn parse_context_overflow_error(status: u16, message: &str) -> Option<ContextOverflowError> {
47 if status != 400 {
49 return None;
50 }
51
52 let pattern =
54 Regex::new(r"input length and `max_tokens` exceed context limit: (\d+) \+ (\d+) > (\d+)")
55 .ok()?;
56
57 let captures = pattern.captures(message)?;
58
59 let input_tokens: u64 = captures.get(1)?.as_str().parse().ok()?;
60 let max_tokens: u64 = captures.get(2)?.as_str().parse().ok()?;
61 let context_limit: u64 = captures.get(3)?.as_str().parse().ok()?;
62
63 Some(ContextOverflowError {
64 input_tokens,
65 max_tokens,
66 context_limit,
67 })
68}
69
70pub fn calculate_adjusted_max_tokens(
77 overflow: &ContextOverflowError,
78 max_thinking_tokens: u64,
79) -> Option<u64> {
80 let available = overflow
81 .context_limit
82 .saturating_sub(overflow.input_tokens)
83 .saturating_sub(RESERVE_BUFFER);
84
85 if available < MIN_OUTPUT_TOKENS {
87 return None;
88 }
89
90 let thinking = max_thinking_tokens + 1;
92 let adjusted = available.max(MIN_OUTPUT_TOKENS).max(thinking);
93
94 Some(adjusted)
95}
96
97pub fn handle_context_overflow(
101 status: u16,
102 message: &str,
103 max_thinking_tokens: u64,
104) -> Result<u64, OverflowRecoveryError> {
105 let overflow = parse_context_overflow_error(status, message)
106 .ok_or(OverflowRecoveryError::NotOverflowError)?;
107
108 let adjusted = calculate_adjusted_max_tokens(&overflow, max_thinking_tokens).ok_or(
109 OverflowRecoveryError::CannotRecover {
110 input_tokens: overflow.input_tokens,
111 context_limit: overflow.context_limit,
112 },
113 )?;
114
115 tracing::warn!(
116 "Context overflow detected. Adjusting max_tokens from {} to {}",
117 overflow.max_tokens,
118 adjusted
119 );
120 tracing::warn!(
121 " Input: {}, Limit: {}, Available: {}",
122 overflow.input_tokens,
123 overflow.context_limit,
124 adjusted
125 );
126
127 Ok(adjusted)
128}
129
130#[derive(Debug, Clone)]
132pub struct OverflowRecoveryOptions {
133 pub max_tokens: Option<u64>,
135 pub max_thinking_tokens: u64,
137 pub max_retries: u32,
139}
140
141impl Default for OverflowRecoveryOptions {
142 fn default() -> Self {
143 Self {
144 max_tokens: None,
145 max_thinking_tokens: 0,
146 max_retries: 3,
147 }
148 }
149}
150
151pub struct RequestError {
153 pub status: u16,
154 pub message: String,
155}
156
157pub async fn execute_with_overflow_recovery<T, E, F, Fut>(
159 execute_request: F,
160 options: OverflowRecoveryOptions,
161 mut on_retry: Option<impl FnMut(u32, u64)>,
162) -> Result<T, OverflowRecoveryError>
163where
164 F: Fn(Option<u64>) -> Fut,
165 Fut: Future<Output = Result<T, E>>,
166 E: Into<RequestError>,
167{
168 let mut current_max_tokens = options.max_tokens;
169
170 for attempt in 1..=options.max_retries {
171 match execute_request(current_max_tokens).await {
172 Ok(result) => return Ok(result),
173 Err(error) => {
174 let req_error: RequestError = error.into();
175
176 let overflow =
177 match parse_context_overflow_error(req_error.status, &req_error.message) {
178 Some(o) => o,
179 None => {
180 return Err(OverflowRecoveryError::RequestFailed(req_error.message));
181 }
182 };
183
184 if attempt >= options.max_retries {
185 tracing::error!(
186 "Context overflow recovery failed after {} attempts",
187 options.max_retries
188 );
189 return Err(OverflowRecoveryError::MaxRetriesExceeded { attempts: attempt });
190 }
191
192 let adjusted =
193 match calculate_adjusted_max_tokens(&overflow, options.max_thinking_tokens) {
194 Some(a) => a,
195 None => {
196 return Err(OverflowRecoveryError::CannotRecover {
197 input_tokens: overflow.input_tokens,
198 context_limit: overflow.context_limit,
199 });
200 }
201 };
202
203 tracing::warn!(
204 "[Retry {}/{}] Context overflow detected. Adjusting max_tokens from {:?} to {}",
205 attempt,
206 options.max_retries,
207 current_max_tokens,
208 adjusted
209 );
210
211 current_max_tokens = Some(adjusted);
212
213 if let Some(ref mut callback) = on_retry {
214 callback(attempt, adjusted);
215 }
216 }
217 }
218 }
219
220 Err(OverflowRecoveryError::MaxRetriesExceeded {
221 attempts: options.max_retries,
222 })
223}