Skip to main content

mofa_foundation/llm/
retry.rs

1//! Retry mechanism for LLM calls with intelligent error handling
2//!
3//! This module provides a retry executor that handles transient failures in LLM calls,
4//! with special support for JSON mode validation failures.
5
6use super::provider::LLMProvider;
7use super::types::*;
8use std::sync::Arc;
9use tracing::{debug, info, warn};
10
11/// Retry executor for LLM calls
12///
13/// Wraps an LLM provider with retry logic, supporting:
14/// - Configurable retry strategies (NoRetry, DirectRetry, PromptRetry)
15/// - Exponential backoff with jitter
16/// - JSON validation for JSON mode requests
17/// - Error-specific retry strategies
18pub struct RetryExecutor {
19    provider: Arc<dyn LLMProvider>,
20    policy: LLMRetryPolicy,
21}
22
23impl RetryExecutor {
24    /// Create a new retry executor
25    pub fn new(provider: Arc<dyn LLMProvider>, policy: LLMRetryPolicy) -> Self {
26        Self { provider, policy }
27    }
28
29    /// Execute a chat completion request with retry logic
30    pub async fn chat(
31        &self,
32        mut request: ChatCompletionRequest,
33    ) -> LLMResult<ChatCompletionResponse> {
34        let max_attempts = self.policy.max_attempts.max(1);
35        let mut error_history = Vec::new();
36
37        for attempt in 0..max_attempts {
38            // Apply backoff delay if this is a retry attempt
39            if attempt > 0 {
40                let delay = self.policy.backoff.delay(attempt - 1);
41                debug!(
42                    "Retry attempt {}/{} after {}ms",
43                    attempt + 1,
44                    max_attempts,
45                    delay.as_millis()
46                );
47                tokio::time::sleep(delay).await;
48            }
49
50            // Try to execute the request
51            match self.provider.chat(request.clone()).await {
52                Ok(response) => {
53                    // Validate JSON if in JSON mode
54                    if let Some(json_error) = self.validate_json_response(&request, &response) {
55                        let error = LLMError::SerializationError(json_error.to_string());
56                        if attempt < max_attempts - 1 && self.policy.should_retry_error(&error) {
57                            warn!(
58                                "JSON validation failed (attempt {}): {}",
59                                attempt + 1,
60                                json_error
61                            );
62                            error_history.push(error.clone());
63                            request = self.prepare_retry_request(request, &error);
64                            continue;
65                        }
66                        return Err(error);
67                    }
68                    // Success
69                    if attempt > 0 {
70                        info!("Request succeeded on attempt {}", attempt + 1);
71                    }
72                    return Ok(response);
73                }
74                Err(error) => {
75                    // Check if we should retry this error
76                    if attempt < max_attempts - 1 && self.policy.should_retry_error(&error) {
77                        warn!(
78                            "Request failed (attempt {}): {}, retrying",
79                            attempt + 1,
80                            error
81                        );
82                        error_history.push(error.clone());
83                        request = self.prepare_retry_request(request, &error);
84                        continue;
85                    }
86                    // No more retries or non-retryable error
87                    if !error_history.is_empty() {
88                        warn!(
89                            "Request failed after {} attempts. Last error: {}",
90                            attempt + 1,
91                            error
92                        );
93                    }
94                    return Err(error);
95                }
96            }
97        }
98
99        // This should not be reached, but handle it for completeness
100        Err(LLMError::Other(
101            "Retry loop completed without result".into(),
102        ))
103    }
104
105    /// Validate JSON response when JSON mode is enabled
106    ///
107    /// Returns `Some(JSONValidationError)` if validation fails, `None` otherwise.
108    fn validate_json_response(
109        &self,
110        request: &ChatCompletionRequest,
111        response: &ChatCompletionResponse,
112    ) -> Option<JSONValidationError> {
113        // Check if JSON mode is enabled
114        let is_json_mode = request
115            .response_format
116            .as_ref()
117            .map(|rf| rf.format_type == "json_object" || rf.format_type == "json_schema")
118            .unwrap_or(false);
119
120        if !is_json_mode {
121            return None;
122        }
123
124        let content = response.content()?;
125        let trimmed = content.trim();
126
127        // Handle markdown code blocks: ```json ... ```
128        let content_to_parse = if trimmed.starts_with("```json") {
129            trimmed
130                .strip_prefix("```json")
131                .and_then(|s| s.strip_suffix("```"))
132                .map(|s| s.trim())
133                .unwrap_or(trimmed)
134        } else if trimmed.starts_with("```") {
135            trimmed
136                .strip_prefix("```")
137                .and_then(|s| s.strip_suffix("```"))
138                .map(|s| s.trim())
139                .unwrap_or(trimmed)
140        } else {
141            trimmed
142        };
143
144        // Try to parse as JSON
145        match serde_json::from_str::<serde_json::Value>(content_to_parse) {
146            Ok(_) => None,
147            Err(e) => Some(JSONValidationError {
148                raw_content: content.to_string(),
149                parse_error: e.to_string(),
150                expected_schema: request
151                    .response_format
152                    .as_ref()
153                    .and_then(|rf| rf.json_schema.clone()),
154            }),
155        }
156    }
157
158    /// Prepare request for retry based on the error and strategy
159    fn prepare_retry_request(
160        &self,
161        mut request: ChatCompletionRequest,
162        error: &LLMError,
163    ) -> ChatCompletionRequest {
164        let strategy = self.policy.strategy_for_error(error);
165
166        match strategy {
167            RetryStrategy::NoRetry | RetryStrategy::DirectRetry => {
168                // No modification needed
169                request
170            }
171            RetryStrategy::PromptRetry => {
172                // Append error context to system prompt
173                let error_message = format!(
174                    "Previous attempt failed with error: {}. The response must be valid JSON.",
175                    error
176                );
177
178                // Find or create system message
179                if let Some(msg) = request.messages.iter_mut().find(|m| m.role == Role::System) {
180                    // Append to existing system message
181                    msg.content = Some(MessageContent::Text(format!(
182                        "{}\n\n[RETRY CONTEXT: {}. Please fix the JSON and try again.]",
183                        msg.text_content().unwrap_or(""),
184                        error_message
185                    )));
186                } else {
187                    // No system message exists, insert one at the beginning
188                    request.messages.insert(
189                        0,
190                        ChatMessage::system(format!(
191                            "[RETRY CONTEXT: {}. Please fix the JSON and try again.]",
192                            error_message
193                        )),
194                    );
195                }
196                request
197            }
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    // Mock provider for testing
207    struct MockProvider {
208        responses: Vec<LLMResult<ChatCompletionResponse>>,
209        call_count: std::sync::atomic::AtomicUsize,
210    }
211
212    impl MockProvider {
213        fn new(responses: Vec<LLMResult<ChatCompletionResponse>>) -> Self {
214            Self {
215                responses,
216                call_count: std::sync::atomic::AtomicUsize::new(0),
217            }
218        }
219    }
220
221    #[async_trait::async_trait]
222    impl LLMProvider for MockProvider {
223        fn name(&self) -> &str {
224            "mock"
225        }
226
227        async fn chat(&self, _request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
228            let index = self
229                .call_count
230                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
231            if index < self.responses.len() {
232                self.responses[index].clone()
233            } else {
234                Err(LLMError::Other("Unexpected call".to_string()))
235            }
236        }
237    }
238
239    fn create_json_response(content: &str) -> ChatCompletionResponse {
240        ChatCompletionResponse {
241            id: "test".to_string(),
242            object: "chat.completion".to_string(),
243            created: 0,
244            model: "test-model".to_string(),
245            choices: vec![Choice {
246                index: 0,
247                message: ChatMessage::assistant(content),
248                finish_reason: Some(FinishReason::Stop),
249                logprobs: None,
250            }],
251            usage: None,
252            system_fingerprint: None,
253        }
254    }
255
256    fn create_json_request() -> ChatCompletionRequest {
257        let mut request = ChatCompletionRequest::new("test-model");
258        request.messages.push(ChatMessage::user("Return JSON"));
259        request.response_format = Some(ResponseFormat::json());
260        request
261    }
262
263    #[tokio::test]
264    async fn test_retry_success_on_second_attempt() {
265        let provider = Arc::new(MockProvider::new(vec![
266            Err(LLMError::NetworkError("Temporary failure".to_string())),
267            Ok(create_json_response(r#"{"status": "ok"}"#)),
268        ]));
269
270        let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
271        let request = create_json_request();
272
273        let result = executor.chat(request).await;
274        assert!(result.is_ok());
275        assert_eq!(result.unwrap().content().unwrap(), r#"{"status": "ok"}"#);
276    }
277
278    #[tokio::test]
279    async fn test_retry_json_validation_failure() {
280        let provider = Arc::new(MockProvider::new(vec![
281            Ok(create_json_response("Not valid JSON")),
282            Ok(create_json_response(r#"{"valid": "json"}"#)),
283        ]));
284
285        let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
286        let request = create_json_request();
287
288        let result = executor.chat(request).await;
289        assert!(result.is_ok());
290        assert_eq!(result.unwrap().content().unwrap(), r#"{"valid": "json"}"#);
291    }
292
293    #[tokio::test]
294    async fn test_retry_json_with_markdown_blocks() {
295        let provider = Arc::new(MockProvider::new(vec![Ok(create_json_response(
296            "```json\n{\"wrapped\": \"content\"}\n```",
297        ))]));
298
299        let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
300        let request = create_json_request();
301
302        let result = executor.chat(request).await;
303        assert!(result.is_ok());
304    }
305
306    #[tokio::test]
307    async fn test_no_retry_exhausted() {
308        let provider = Arc::new(MockProvider::new(vec![
309            Err(LLMError::NetworkError("Persistent failure".to_string())),
310            Err(LLMError::NetworkError("Still failing".to_string())),
311            Err(LLMError::NetworkError("Giving up".to_string())),
312        ]));
313
314        let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
315        let request = create_json_request();
316
317        let result = executor.chat(request).await;
318        assert!(result.is_err());
319    }
320
321    #[tokio::test]
322    async fn test_no_retry_policy() {
323        let provider = Arc::new(MockProvider::new(vec![Err(LLMError::NetworkError(
324            "Should not retry".to_string(),
325        ))]));
326
327        let executor = RetryExecutor::new(provider, LLMRetryPolicy::no_retry());
328        let request = create_json_request();
329
330        let result = executor.chat(request).await;
331        assert!(result.is_err());
332    }
333
334    #[tokio::test]
335    async fn test_prompt_retry_modifies_system_message() {
336        // Create a request with an existing system message
337        let mut request = create_json_request();
338        request
339            .messages
340            .insert(0, ChatMessage::system("You are a helpful assistant."));
341
342        // Check that system message is present
343        assert_eq!(request.messages[0].role, Role::System);
344
345        let error = LLMError::SerializationError("Invalid JSON".to_string());
346
347        // Create executor and prepare retry request
348        let provider = Arc::new(MockProvider::new(vec![Ok(create_json_response(
349            r#"{"ok": true}"#,
350        ))]));
351        let executor = RetryExecutor::new(provider, LLMRetryPolicy::default());
352        let modified_request = executor.prepare_retry_request(request.clone(), &error);
353
354        // Check that system message was modified
355        assert_eq!(modified_request.messages[0].role, Role::System);
356        let system_content = modified_request.messages[0].text_content().unwrap();
357        assert!(system_content.contains("RETRY CONTEXT"));
358        assert!(system_content.contains("Invalid JSON"));
359    }
360}