Skip to main content

magi_core/
provider.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-04-05
4
5use crate::error::ProviderError;
6use std::sync::Arc;
7use std::time::Duration;
8
9/// Configuration for LLM completion requests.
10///
11/// Controls parameters like token limits and sampling temperature.
12/// Marked `#[non_exhaustive]` to allow adding fields in future versions
13/// without breaking downstream crates.
14#[non_exhaustive]
15#[derive(Debug, Clone)]
16pub struct CompletionConfig {
17    /// Maximum number of tokens in the LLM response.
18    pub max_tokens: u32,
19    /// Sampling temperature (0.0 = deterministic).
20    pub temperature: f64,
21}
22
23impl Default for CompletionConfig {
24    fn default() -> Self {
25        Self {
26            max_tokens: 4096,
27            temperature: 0.0,
28        }
29    }
30}
31
32/// Abstraction for LLM backends.
33///
34/// Any LLM provider (Claude, Gemini, OpenAI, local models) implements this
35/// trait. Uses `async-trait` because native async traits in Rust do not yet
36/// support `dyn Trait` dispatch, which is required for `Arc<dyn LlmProvider>`
37/// with `tokio::spawn`.
38///
39/// The `Send + Sync` bounds are required because `Arc<dyn LlmProvider>` is
40/// shared across `tokio::spawn` tasks.
41#[async_trait::async_trait]
42pub trait LlmProvider: Send + Sync {
43    /// Sends a completion request to the LLM provider.
44    ///
45    /// # Parameters
46    /// - `system_prompt`: The system-level instruction for the LLM.
47    /// - `user_prompt`: The user's input content.
48    /// - `config`: Completion parameters (max_tokens, temperature).
49    ///
50    /// # Returns
51    /// The LLM's text response, or a `ProviderError` on failure.
52    async fn complete(
53        &self,
54        system_prompt: &str,
55        user_prompt: &str,
56        config: &CompletionConfig,
57    ) -> Result<String, ProviderError>;
58
59    /// Returns the provider's name (e.g., "claude", "claude-cli", "openai").
60    fn name(&self) -> &str;
61
62    /// Returns the model identifier (e.g., "claude-sonnet-4-6").
63    fn model(&self) -> &str;
64}
65
66/// Opt-in retry wrapper for any `LlmProvider`.
67///
68/// Wraps an inner provider and retries transient errors (timeout, network,
69/// HTTP 500/429) up to `max_retries` times with exponential backoff starting
70/// from `base_delay`. Non-retryable errors (auth, process, nested session,
71/// other HTTP status codes) are returned immediately.
72///
73/// Implements `LlmProvider` itself, making it transparent to consumers.
74pub struct RetryProvider {
75    inner: Arc<dyn LlmProvider>,
76    /// Maximum number of retry attempts after the first failure.
77    pub max_retries: u32,
78    /// Delay between retry attempts.
79    pub base_delay: Duration,
80}
81
82impl RetryProvider {
83    /// Creates a new `RetryProvider` with default settings (3 retries, 1s delay).
84    ///
85    /// # Parameters
86    /// - `inner`: The provider to wrap with retry logic.
87    pub fn new(inner: Arc<dyn LlmProvider>) -> Self {
88        Self {
89            inner,
90            max_retries: 3,
91            base_delay: Duration::from_secs(1),
92        }
93    }
94
95    /// Creates a new `RetryProvider` with custom retry settings.
96    ///
97    /// # Parameters
98    /// - `inner`: The provider to wrap with retry logic.
99    /// - `max_retries`: Maximum retry attempts after the initial failure.
100    /// - `base_delay`: Initial delay between retries; doubles on each subsequent attempt.
101    pub fn with_config(
102        inner: Arc<dyn LlmProvider>,
103        max_retries: u32,
104        base_delay: Duration,
105    ) -> Self {
106        Self {
107            inner,
108            max_retries,
109            base_delay,
110        }
111    }
112}
113
114/// Determines whether a `ProviderError` is transient and should be retried.
115///
116/// Retryable errors:
117/// - `Timeout`: Provider did not respond in time.
118/// - `Network`: DNS, connection refused, etc.
119/// - `Http` with status 500 (server error) or 429 (rate limit).
120///
121/// Non-retryable errors:
122/// - `Auth`: Invalid credentials won't become valid on retry.
123/// - `Process`: CLI subprocess failure.
124/// - `NestedSession`: Structural environment issue.
125/// - `Http` with other status codes (e.g., 400, 403, 404).
126fn is_retryable(error: &ProviderError) -> bool {
127    match error {
128        ProviderError::Timeout { .. } | ProviderError::Network { .. } => true,
129        ProviderError::Http { status, .. } => *status == 500 || *status == 429,
130        _ => false,
131    }
132}
133
134#[async_trait::async_trait]
135impl LlmProvider for RetryProvider {
136    async fn complete(
137        &self,
138        system_prompt: &str,
139        user_prompt: &str,
140        config: &CompletionConfig,
141    ) -> Result<String, ProviderError> {
142        let mut last_error = None;
143        let mut delay = self.base_delay;
144        for attempt in 0..=self.max_retries {
145            match self
146                .inner
147                .complete(system_prompt, user_prompt, config)
148                .await
149            {
150                Ok(response) => return Ok(response),
151                Err(err) => {
152                    if !is_retryable(&err) || attempt == self.max_retries {
153                        return Err(err);
154                    }
155                    last_error = Some(err);
156                    tokio::time::sleep(delay).await;
157                    delay = delay.saturating_mul(2);
158                }
159            }
160        }
161        Err(last_error.expect("at least one attempt must have been made"))
162    }
163
164    fn name(&self) -> &str {
165        self.inner.name()
166    }
167
168    fn model(&self) -> &str {
169        self.inner.model()
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::sync::Arc;
177    use std::sync::atomic::{AtomicU32, Ordering};
178    use std::time::Duration;
179
180    /// Manual mock provider for testing.
181    struct MockProvider {
182        provider_name: String,
183        provider_model: String,
184        responses: std::sync::Mutex<Vec<Result<String, ProviderError>>>,
185        call_count: AtomicU32,
186    }
187
188    impl MockProvider {
189        fn new(name: &str, model: &str) -> Self {
190            Self {
191                provider_name: name.to_string(),
192                provider_model: model.to_string(),
193                responses: std::sync::Mutex::new(Vec::new()),
194                call_count: AtomicU32::new(0),
195            }
196        }
197
198        fn with_responses(
199            name: &str,
200            model: &str,
201            responses: Vec<Result<String, ProviderError>>,
202        ) -> Self {
203            // Reverse so we can pop from the end (FIFO order)
204            let mut reversed = responses;
205            reversed.reverse();
206            Self {
207                provider_name: name.to_string(),
208                provider_model: model.to_string(),
209                responses: std::sync::Mutex::new(reversed),
210                call_count: AtomicU32::new(0),
211            }
212        }
213
214        fn call_count(&self) -> u32 {
215            self.call_count.load(Ordering::SeqCst)
216        }
217    }
218
219    #[async_trait::async_trait]
220    impl LlmProvider for MockProvider {
221        async fn complete(
222            &self,
223            _system_prompt: &str,
224            _user_prompt: &str,
225            _config: &CompletionConfig,
226        ) -> Result<String, ProviderError> {
227            self.call_count.fetch_add(1, Ordering::SeqCst);
228            let mut responses = self.responses.lock().unwrap();
229            if let Some(result) = responses.pop() {
230                result
231            } else {
232                Ok("default response".to_string())
233            }
234        }
235
236        fn name(&self) -> &str {
237            &self.provider_name
238        }
239
240        fn model(&self) -> &str {
241            &self.provider_model
242        }
243    }
244
245    // -- CompletionConfig tests --
246
247    /// CompletionConfig::default has max_tokens=4096, temperature=0.0.
248    #[test]
249    fn test_completion_config_default_values() {
250        let config = CompletionConfig::default();
251        assert_eq!(config.max_tokens, 4096);
252        assert!((config.temperature - 0.0).abs() < f64::EPSILON);
253    }
254
255    /// CompletionConfig is #[non_exhaustive] — verify Default works and fields accessible.
256    #[test]
257    fn test_completion_config_is_non_exhaustive() {
258        let config = CompletionConfig::default();
259        assert_eq!(config.max_tokens, 4096);
260        assert!((config.temperature).abs() < f64::EPSILON);
261    }
262
263    // -- RetryProvider delegation tests --
264
265    /// RetryProvider wraps inner provider and delegates name().
266    #[tokio::test]
267    async fn test_retry_provider_delegates_name() {
268        let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
269        let retry = RetryProvider::new(mock);
270        assert_eq!(retry.name(), "test-provider");
271    }
272
273    /// RetryProvider wraps inner provider and delegates model().
274    #[tokio::test]
275    async fn test_retry_provider_delegates_model() {
276        let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
277        let retry = RetryProvider::new(mock);
278        assert_eq!(retry.model(), "test-model");
279    }
280
281    // -- RetryProvider retry behavior --
282
283    /// RetryProvider retries on ProviderError::Timeout up to max_retries.
284    #[tokio::test]
285    async fn test_retry_provider_retries_on_timeout() {
286        let mock = Arc::new(MockProvider::with_responses(
287            "p",
288            "m",
289            vec![
290                Err(ProviderError::Timeout {
291                    message: "t1".into(),
292                }),
293                Err(ProviderError::Timeout {
294                    message: "t2".into(),
295                }),
296                Ok("success".into()),
297            ],
298        ));
299        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
300        let config = CompletionConfig::default();
301        let result = retry.complete("sys", "usr", &config).await;
302        assert!(result.is_ok());
303        assert_eq!(result.unwrap(), "success");
304        assert_eq!(mock.call_count(), 3);
305    }
306
307    /// RetryProvider retries on ProviderError::Http with status 500.
308    #[tokio::test]
309    async fn test_retry_provider_retries_on_http_500() {
310        let mock = Arc::new(MockProvider::with_responses(
311            "p",
312            "m",
313            vec![
314                Err(ProviderError::Http {
315                    status: 500,
316                    body: "err".into(),
317                }),
318                Ok("ok".into()),
319            ],
320        ));
321        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
322        let config = CompletionConfig::default();
323        let result = retry.complete("sys", "usr", &config).await;
324        assert!(result.is_ok());
325        assert_eq!(mock.call_count(), 2);
326    }
327
328    /// RetryProvider retries on ProviderError::Http with status 429.
329    #[tokio::test]
330    async fn test_retry_provider_retries_on_http_429() {
331        let mock = Arc::new(MockProvider::with_responses(
332            "p",
333            "m",
334            vec![
335                Err(ProviderError::Http {
336                    status: 429,
337                    body: "rate limit".into(),
338                }),
339                Ok("ok".into()),
340            ],
341        ));
342        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
343        let config = CompletionConfig::default();
344        let result = retry.complete("sys", "usr", &config).await;
345        assert!(result.is_ok());
346        assert_eq!(mock.call_count(), 2);
347    }
348
349    /// RetryProvider retries on ProviderError::Network.
350    #[tokio::test]
351    async fn test_retry_provider_retries_on_network() {
352        let mock = Arc::new(MockProvider::with_responses(
353            "p",
354            "m",
355            vec![
356                Err(ProviderError::Network {
357                    message: "dns".into(),
358                }),
359                Ok("ok".into()),
360            ],
361        ));
362        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
363        let config = CompletionConfig::default();
364        let result = retry.complete("sys", "usr", &config).await;
365        assert!(result.is_ok());
366        assert_eq!(mock.call_count(), 2);
367    }
368
369    /// RetryProvider does NOT retry on ProviderError::Auth.
370    #[tokio::test]
371    async fn test_retry_provider_does_not_retry_on_auth() {
372        let mock = Arc::new(MockProvider::with_responses(
373            "p",
374            "m",
375            vec![Err(ProviderError::Auth {
376                message: "bad key".into(),
377            })],
378        ));
379        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
380        let config = CompletionConfig::default();
381        let result = retry.complete("sys", "usr", &config).await;
382        assert!(result.is_err());
383        assert_eq!(mock.call_count(), 1);
384    }
385
386    /// RetryProvider does NOT retry on ProviderError::Process.
387    #[tokio::test]
388    async fn test_retry_provider_does_not_retry_on_process() {
389        let mock = Arc::new(MockProvider::with_responses(
390            "p",
391            "m",
392            vec![Err(ProviderError::Process {
393                exit_code: Some(1),
394                stderr: "fail".into(),
395            })],
396        ));
397        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
398        let config = CompletionConfig::default();
399        let result = retry.complete("sys", "usr", &config).await;
400        assert!(result.is_err());
401        assert_eq!(mock.call_count(), 1);
402    }
403
404    /// RetryProvider does NOT retry on ProviderError::NestedSession.
405    #[tokio::test]
406    async fn test_retry_provider_does_not_retry_on_nested_session() {
407        let mock = Arc::new(MockProvider::with_responses(
408            "p",
409            "m",
410            vec![Err(ProviderError::NestedSession)],
411        ));
412        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
413        let config = CompletionConfig::default();
414        let result = retry.complete("sys", "usr", &config).await;
415        assert!(result.is_err());
416        assert_eq!(mock.call_count(), 1);
417    }
418
419    /// RetryProvider does NOT retry on ProviderError::Http with 4xx (except 429).
420    #[tokio::test]
421    async fn test_retry_provider_does_not_retry_on_http_4xx() {
422        let mock = Arc::new(MockProvider::with_responses(
423            "p",
424            "m",
425            vec![Err(ProviderError::Http {
426                status: 403,
427                body: "forbidden".into(),
428            })],
429        ));
430        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
431        let config = CompletionConfig::default();
432        let result = retry.complete("sys", "usr", &config).await;
433        assert!(result.is_err());
434        assert_eq!(mock.call_count(), 1);
435    }
436
437    /// RetryProvider returns last error after exhausting retries.
438    #[tokio::test]
439    async fn test_retry_provider_returns_last_error_after_exhausting_retries() {
440        let mock = Arc::new(MockProvider::with_responses(
441            "p",
442            "m",
443            vec![
444                Err(ProviderError::Timeout {
445                    message: "t1".into(),
446                }),
447                Err(ProviderError::Timeout {
448                    message: "t2".into(),
449                }),
450                Err(ProviderError::Timeout {
451                    message: "t3".into(),
452                }),
453            ],
454        ));
455        // max_retries=2 means 1 initial + 2 retries = 3 total attempts
456        let retry = RetryProvider::with_config(mock.clone(), 2, Duration::from_millis(1));
457        let config = CompletionConfig::default();
458        let result = retry.complete("sys", "usr", &config).await;
459        assert!(result.is_err());
460        assert_eq!(mock.call_count(), 3);
461        match result.unwrap_err() {
462            ProviderError::Timeout { message } => assert_eq!(message, "t3"),
463            other => panic!("expected Timeout, got: {other}"),
464        }
465    }
466
467    /// RetryProvider returns success on first successful retry.
468    #[tokio::test]
469    async fn test_retry_provider_returns_success_on_first_retry() {
470        let mock = Arc::new(MockProvider::with_responses(
471            "p",
472            "m",
473            vec![
474                Err(ProviderError::Timeout {
475                    message: "t1".into(),
476                }),
477                Ok("recovered".into()),
478            ],
479        ));
480        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
481        let config = CompletionConfig::default();
482        let result = retry.complete("sys", "usr", &config).await;
483        assert!(result.is_ok());
484        assert_eq!(result.unwrap(), "recovered");
485        assert_eq!(mock.call_count(), 2);
486    }
487
488    /// RetryProvider default config: 3 retries, 1s delay.
489    #[test]
490    fn test_retry_provider_default_config() {
491        let mock = Arc::new(MockProvider::new("p", "m"));
492        let retry = RetryProvider::new(mock);
493        assert_eq!(retry.max_retries, 3);
494        assert_eq!(retry.base_delay, Duration::from_secs(1));
495    }
496}