mofa_foundation/llm/
retry.rs1use super::provider::LLMProvider;
7use super::types::*;
8use std::sync::Arc;
9use tracing::{debug, info, warn};
10
11pub struct RetryExecutor {
19 provider: Arc<dyn LLMProvider>,
20 policy: LLMRetryPolicy,
21}
22
23impl RetryExecutor {
24 pub fn new(provider: Arc<dyn LLMProvider>, policy: LLMRetryPolicy) -> Self {
26 Self { provider, policy }
27 }
28
29 pub async fn chat(
31 &self,
32 mut request: ChatCompletionRequest,
33 ) -> LLMResult<ChatCompletionResponse> {
34 let max_attempts = self.policy.max_attempts.max(1);
35 let mut error_history = Vec::new();
36
37 for attempt in 0..max_attempts {
38 if attempt > 0 {
40 let delay = self.policy.backoff.delay(attempt - 1);
41 debug!(
42 "Retry attempt {}/{} after {}ms",
43 attempt + 1,
44 max_attempts,
45 delay.as_millis()
46 );
47 tokio::time::sleep(delay).await;
48 }
49
50 match self.provider.chat(request.clone()).await {
52 Ok(response) => {
53 if let Some(json_error) = self.validate_json_response(&request, &response) {
55 let error = LLMError::SerializationError(json_error.to_string());
56 if attempt < max_attempts - 1 && self.policy.should_retry_error(&error) {
57 warn!(
58 "JSON validation failed (attempt {}): {}",
59 attempt + 1,
60 json_error
61 );
62 error_history.push(error.clone());
63 request = self.prepare_retry_request(request, &error);
64 continue;
65 }
66 return Err(error);
67 }
68 if attempt > 0 {
70 info!("Request succeeded on attempt {}", attempt + 1);
71 }
72 return Ok(response);
73 }
74 Err(error) => {
75 if attempt < max_attempts - 1 && self.policy.should_retry_error(&error) {
77 warn!(
78 "Request failed (attempt {}): {}, retrying",
79 attempt + 1,
80 error
81 );
82 error_history.push(error.clone());
83 request = self.prepare_retry_request(request, &error);
84 continue;
85 }
86 if !error_history.is_empty() {
88 warn!(
89 "Request failed after {} attempts. Last error: {}",
90 attempt + 1,
91 error
92 );
93 }
94 return Err(error);
95 }
96 }
97 }
98
99 Err(LLMError::Other(
101 "Retry loop completed without result".into(),
102 ))
103 }
104
105 fn validate_json_response(
109 &self,
110 request: &ChatCompletionRequest,
111 response: &ChatCompletionResponse,
112 ) -> Option<JSONValidationError> {
113 let is_json_mode = request
115 .response_format
116 .as_ref()
117 .map(|rf| rf.format_type == "json_object" || rf.format_type == "json_schema")
118 .unwrap_or(false);
119
120 if !is_json_mode {
121 return None;
122 }
123
124 let content = response.content()?;
125 let trimmed = content.trim();
126
127 let content_to_parse = if trimmed.starts_with("```json") {
129 trimmed
130 .strip_prefix("```json")
131 .and_then(|s| s.strip_suffix("```"))
132 .map(|s| s.trim())
133 .unwrap_or(trimmed)
134 } else if trimmed.starts_with("```") {
135 trimmed
136 .strip_prefix("```")
137 .and_then(|s| s.strip_suffix("```"))
138 .map(|s| s.trim())
139 .unwrap_or(trimmed)
140 } else {
141 trimmed
142 };
143
144 match serde_json::from_str::<serde_json::Value>(content_to_parse) {
146 Ok(_) => None,
147 Err(e) => Some(JSONValidationError {
148 raw_content: content.to_string(),
149 parse_error: e.to_string(),
150 expected_schema: request
151 .response_format
152 .as_ref()
153 .and_then(|rf| rf.json_schema.clone()),
154 }),
155 }
156 }
157
158 fn prepare_retry_request(
160 &self,
161 mut request: ChatCompletionRequest,
162 error: &LLMError,
163 ) -> ChatCompletionRequest {
164 let strategy = self.policy.strategy_for_error(error);
165
166 match strategy {
167 RetryStrategy::NoRetry | RetryStrategy::DirectRetry => {
168 request
170 }
171 RetryStrategy::PromptRetry => {
172 let error_message = format!(
174 "Previous attempt failed with error: {}. The response must be valid JSON.",
175 error
176 );
177
178 if let Some(msg) = request.messages.iter_mut().find(|m| m.role == Role::System) {
180 msg.content = Some(MessageContent::Text(format!(
182 "{}\n\n[RETRY CONTEXT: {}. Please fix the JSON and try again.]",
183 msg.text_content().unwrap_or(""),
184 error_message
185 )));
186 } else {
187 request.messages.insert(
189 0,
190 ChatMessage::system(format!(
191 "[RETRY CONTEXT: {}. Please fix the JSON and try again.]",
192 error_message
193 )),
194 );
195 }
196 request
197 }
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 struct MockProvider {
208 responses: Vec<LLMResult<ChatCompletionResponse>>,
209 call_count: std::sync::atomic::AtomicUsize,
210 }
211
212 impl MockProvider {
213 fn new(responses: Vec<LLMResult<ChatCompletionResponse>>) -> Self {
214 Self {
215 responses,
216 call_count: std::sync::atomic::AtomicUsize::new(0),
217 }
218 }
219 }
220
221 #[async_trait::async_trait]
222 impl LLMProvider for MockProvider {
223 fn name(&self) -> &str {
224 "mock"
225 }
226
227 async fn chat(&self, _request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
228 let index = self
229 .call_count
230 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
231 if index < self.responses.len() {
232 self.responses[index].clone()
233 } else {
234 Err(LLMError::Other("Unexpected call".to_string()))
235 }
236 }
237 }
238
239 fn create_json_response(content: &str) -> ChatCompletionResponse {
240 ChatCompletionResponse {
241 id: "test".to_string(),
242 object: "chat.completion".to_string(),
243 created: 0,
244 model: "test-model".to_string(),
245 choices: vec![Choice {
246 index: 0,
247 message: ChatMessage::assistant(content),
248 finish_reason: Some(FinishReason::Stop),
249 logprobs: None,
250 }],
251 usage: None,
252 system_fingerprint: None,
253 }
254 }
255
256 fn create_json_request() -> ChatCompletionRequest {
257 let mut request = ChatCompletionRequest::new("test-model");
258 request.messages.push(ChatMessage::user("Return JSON"));
259 request.response_format = Some(ResponseFormat::json());
260 request
261 }
262
263 #[tokio::test]
264 async fn test_retry_success_on_second_attempt() {
265 let provider = Arc::new(MockProvider::new(vec![
266 Err(LLMError::NetworkError("Temporary failure".to_string())),
267 Ok(create_json_response(r#"{"status": "ok"}"#)),
268 ]));
269
270 let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
271 let request = create_json_request();
272
273 let result = executor.chat(request).await;
274 assert!(result.is_ok());
275 assert_eq!(result.unwrap().content().unwrap(), r#"{"status": "ok"}"#);
276 }
277
278 #[tokio::test]
279 async fn test_retry_json_validation_failure() {
280 let provider = Arc::new(MockProvider::new(vec![
281 Ok(create_json_response("Not valid JSON")),
282 Ok(create_json_response(r#"{"valid": "json"}"#)),
283 ]));
284
285 let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
286 let request = create_json_request();
287
288 let result = executor.chat(request).await;
289 assert!(result.is_ok());
290 assert_eq!(result.unwrap().content().unwrap(), r#"{"valid": "json"}"#);
291 }
292
293 #[tokio::test]
294 async fn test_retry_json_with_markdown_blocks() {
295 let provider = Arc::new(MockProvider::new(vec![Ok(create_json_response(
296 "```json\n{\"wrapped\": \"content\"}\n```",
297 ))]));
298
299 let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
300 let request = create_json_request();
301
302 let result = executor.chat(request).await;
303 assert!(result.is_ok());
304 }
305
306 #[tokio::test]
307 async fn test_no_retry_exhausted() {
308 let provider = Arc::new(MockProvider::new(vec![
309 Err(LLMError::NetworkError("Persistent failure".to_string())),
310 Err(LLMError::NetworkError("Still failing".to_string())),
311 Err(LLMError::NetworkError("Giving up".to_string())),
312 ]));
313
314 let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
315 let request = create_json_request();
316
317 let result = executor.chat(request).await;
318 assert!(result.is_err());
319 }
320
321 #[tokio::test]
322 async fn test_no_retry_policy() {
323 let provider = Arc::new(MockProvider::new(vec![Err(LLMError::NetworkError(
324 "Should not retry".to_string(),
325 ))]));
326
327 let executor = RetryExecutor::new(provider, LLMRetryPolicy::no_retry());
328 let request = create_json_request();
329
330 let result = executor.chat(request).await;
331 assert!(result.is_err());
332 }
333
334 #[tokio::test]
335 async fn test_prompt_retry_modifies_system_message() {
336 let mut request = create_json_request();
338 request
339 .messages
340 .insert(0, ChatMessage::system("You are a helpful assistant."));
341
342 assert_eq!(request.messages[0].role, Role::System);
344
345 let error = LLMError::SerializationError("Invalid JSON".to_string());
346
347 let provider = Arc::new(MockProvider::new(vec![Ok(create_json_response(
349 r#"{"ok": true}"#,
350 ))]));
351 let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
352 let modified_request = executor.prepare_retry_request(request.clone(), &error);
353
354 assert_eq!(modified_request.messages[0].role, Role::System);
356 let system_content = modified_request.messages[0].text_content().unwrap();
357 assert!(system_content.contains("RETRY CONTEXT"));
358 assert!(system_content.contains("Invalid JSON"));
359 }
360}