Skip to main content

magi_core/
provider.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-04-05
4
5use crate::error::ProviderError;
6use std::sync::Arc;
7use std::time::Duration;
8
9/// Configuration for LLM completion requests.
10///
11/// Controls parameters like token limits and sampling temperature.
12/// Marked `#[non_exhaustive]` to allow adding fields in future versions
13/// without breaking downstream crates.
14#[non_exhaustive]
15#[derive(Debug, Clone)]
16pub struct CompletionConfig {
17    /// Maximum number of tokens in the LLM response.
18    pub max_tokens: u32,
19    /// Sampling temperature (0.0 = deterministic).
20    pub temperature: f64,
21}
22
23impl Default for CompletionConfig {
24    fn default() -> Self {
25        Self {
26            max_tokens: 4096,
27            temperature: 0.0,
28        }
29    }
30}
31
32/// Abstraction for LLM backends.
33///
34/// Any LLM provider (Claude, Gemini, OpenAI, local models) implements this
35/// trait. Uses `async-trait` because native async traits in Rust do not yet
36/// support `dyn Trait` dispatch, which is required for `Arc<dyn LlmProvider>`
37/// with `tokio::spawn`.
38///
39/// The `Send + Sync` bounds are required because `Arc<dyn LlmProvider>` is
40/// shared across `tokio::spawn` tasks.
41#[async_trait::async_trait]
42pub trait LlmProvider: Send + Sync {
43    /// Sends a completion request to the LLM provider.
44    ///
45    /// # Parameters
46    /// - `system_prompt`: The system-level instruction for the LLM.
47    /// - `user_prompt`: The user's input content.
48    /// - `config`: Completion parameters (max_tokens, temperature).
49    ///
50    /// # Returns
51    /// The LLM's text response, or a `ProviderError` on failure.
52    async fn complete(
53        &self,
54        system_prompt: &str,
55        user_prompt: &str,
56        config: &CompletionConfig,
57    ) -> Result<String, ProviderError>;
58
59    /// Returns the provider's name (e.g., "claude", "claude-cli", "openai").
60    fn name(&self) -> &str;
61
62    /// Returns the model identifier (e.g., "claude-sonnet-4-6").
63    fn model(&self) -> &str;
64}
65
66/// Resolves a short Claude model alias to a full model identifier.
67///
68/// Used by both `ClaudeProvider` (HTTP) and `ClaudeCliProvider` (subprocess).
69/// Other providers (Gemini, OpenAI) should implement their own alias resolvers.
70///
71/// # Aliases
72///
73/// - `"sonnet"` → `"claude-sonnet-4-6"`
74/// - `"opus"` → `"claude-opus-4-7"`
75/// - `"haiku"` → `"claude-haiku-4-5-20251001"`
76/// - Any string containing `"claude-"` passes through as-is
77///
78/// # Errors
79///
80/// Returns `ProviderError::Auth` if the alias is unknown.
81///
82/// # Examples
83///
84/// ```
85/// use magi_core::provider::resolve_claude_alias;
86///
87/// assert_eq!(resolve_claude_alias("opus").unwrap(), "claude-opus-4-7");
88/// assert_eq!(resolve_claude_alias("claude-custom").unwrap(), "claude-custom");
89/// assert!(resolve_claude_alias("unknown").is_err());
90/// ```
91pub fn resolve_claude_alias(model: &str) -> Result<String, ProviderError> {
92    match model {
93        "sonnet" => Ok("claude-sonnet-4-6".to_string()),
94        "opus" => Ok("claude-opus-4-7".to_string()),
95        "haiku" => Ok("claude-haiku-4-5-20251001".to_string()),
96        m if m.contains("claude-") => Ok(m.to_string()),
97        _ => Err(ProviderError::Auth {
98            message: format!("unknown model alias: {model}"),
99        }),
100    }
101}
102
103/// Opt-in retry wrapper for any `LlmProvider`.
104///
105/// Wraps an inner provider and retries transient errors (timeout, network,
106/// HTTP 500/429) up to `max_retries` times with exponential backoff starting
107/// from `base_delay`. Non-retryable errors (auth, process, nested session,
108/// other HTTP status codes) are returned immediately.
109///
110/// Implements `LlmProvider` itself, making it transparent to consumers.
111pub struct RetryProvider {
112    inner: Arc<dyn LlmProvider>,
113    /// Maximum number of retry attempts after the first failure.
114    pub max_retries: u32,
115    /// Delay between retry attempts.
116    pub base_delay: Duration,
117}
118
119impl RetryProvider {
120    /// Creates a new `RetryProvider` with default settings (3 retries, 1s delay).
121    ///
122    /// # Parameters
123    /// - `inner`: The provider to wrap with retry logic.
124    pub fn new(inner: Arc<dyn LlmProvider>) -> Self {
125        Self {
126            inner,
127            max_retries: 3,
128            base_delay: Duration::from_secs(1),
129        }
130    }
131
132    /// Creates a new `RetryProvider` with custom retry settings.
133    ///
134    /// # Parameters
135    /// - `inner`: The provider to wrap with retry logic.
136    /// - `max_retries`: Maximum retry attempts after the initial failure.
137    /// - `base_delay`: Initial delay between retries; doubles on each subsequent attempt.
138    pub fn with_config(
139        inner: Arc<dyn LlmProvider>,
140        max_retries: u32,
141        base_delay: Duration,
142    ) -> Self {
143        Self {
144            inner,
145            max_retries,
146            base_delay,
147        }
148    }
149}
150
151/// Determines whether a `ProviderError` is transient and should be retried.
152///
153/// Retryable errors:
154/// - `Timeout`: Provider did not respond in time.
155/// - `Network`: DNS, connection refused, etc.
156/// - `Http` with status 500 (server error) or 429 (rate limit).
157///
158/// Non-retryable errors:
159/// - `Auth`: Invalid credentials won't become valid on retry.
160/// - `Process`: CLI subprocess failure.
161/// - `NestedSession`: Structural environment issue.
162/// - `Http` with other status codes (e.g., 400, 403, 404).
163fn is_retryable(error: &ProviderError) -> bool {
164    match error {
165        ProviderError::Timeout { .. } | ProviderError::Network { .. } => true,
166        ProviderError::Http { status, .. } => *status == 500 || *status == 429,
167        _ => false,
168    }
169}
170
171#[async_trait::async_trait]
172impl LlmProvider for RetryProvider {
173    async fn complete(
174        &self,
175        system_prompt: &str,
176        user_prompt: &str,
177        config: &CompletionConfig,
178    ) -> Result<String, ProviderError> {
179        let mut last_error = None;
180        let mut delay = self.base_delay;
181        for attempt in 0..=self.max_retries {
182            match self
183                .inner
184                .complete(system_prompt, user_prompt, config)
185                .await
186            {
187                Ok(response) => return Ok(response),
188                Err(err) => {
189                    if !is_retryable(&err) || attempt == self.max_retries {
190                        return Err(err);
191                    }
192                    last_error = Some(err);
193                    tokio::time::sleep(delay).await;
194                    delay = delay.saturating_mul(2);
195                }
196            }
197        }
198        Err(last_error.expect("at least one attempt must have been made"))
199    }
200
201    fn name(&self) -> &str {
202        self.inner.name()
203    }
204
205    fn model(&self) -> &str {
206        self.inner.model()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use std::sync::Arc;
214    use std::sync::atomic::{AtomicU32, Ordering};
215    use std::time::Duration;
216
217    /// Manual mock provider for testing.
218    struct MockProvider {
219        provider_name: String,
220        provider_model: String,
221        responses: std::sync::Mutex<Vec<Result<String, ProviderError>>>,
222        call_count: AtomicU32,
223    }
224
225    impl MockProvider {
226        fn new(name: &str, model: &str) -> Self {
227            Self {
228                provider_name: name.to_string(),
229                provider_model: model.to_string(),
230                responses: std::sync::Mutex::new(Vec::new()),
231                call_count: AtomicU32::new(0),
232            }
233        }
234
235        fn with_responses(
236            name: &str,
237            model: &str,
238            responses: Vec<Result<String, ProviderError>>,
239        ) -> Self {
240            // Reverse so we can pop from the end (FIFO order)
241            let mut reversed = responses;
242            reversed.reverse();
243            Self {
244                provider_name: name.to_string(),
245                provider_model: model.to_string(),
246                responses: std::sync::Mutex::new(reversed),
247                call_count: AtomicU32::new(0),
248            }
249        }
250
251        fn call_count(&self) -> u32 {
252            self.call_count.load(Ordering::SeqCst)
253        }
254    }
255
256    #[async_trait::async_trait]
257    impl LlmProvider for MockProvider {
258        async fn complete(
259            &self,
260            _system_prompt: &str,
261            _user_prompt: &str,
262            _config: &CompletionConfig,
263        ) -> Result<String, ProviderError> {
264            self.call_count.fetch_add(1, Ordering::SeqCst);
265            let mut responses = self.responses.lock().unwrap();
266            if let Some(result) = responses.pop() {
267                result
268            } else {
269                Ok("default response".to_string())
270            }
271        }
272
273        fn name(&self) -> &str {
274            &self.provider_name
275        }
276
277        fn model(&self) -> &str {
278            &self.provider_model
279        }
280    }
281
282    // -- CompletionConfig tests --
283
284    /// CompletionConfig::default has max_tokens=4096, temperature=0.0.
285    #[test]
286    fn test_completion_config_default_values() {
287        let config = CompletionConfig::default();
288        assert_eq!(config.max_tokens, 4096);
289        assert!((config.temperature - 0.0).abs() < f64::EPSILON);
290    }
291
292    /// CompletionConfig is #[non_exhaustive] — verify Default works and fields accessible.
293    #[test]
294    fn test_completion_config_is_non_exhaustive() {
295        let config = CompletionConfig::default();
296        assert_eq!(config.max_tokens, 4096);
297        assert!((config.temperature).abs() < f64::EPSILON);
298    }
299
300    // -- RetryProvider delegation tests --
301
302    /// RetryProvider wraps inner provider and delegates name().
303    #[tokio::test]
304    async fn test_retry_provider_delegates_name() {
305        let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
306        let retry = RetryProvider::new(mock);
307        assert_eq!(retry.name(), "test-provider");
308    }
309
310    /// RetryProvider wraps inner provider and delegates model().
311    #[tokio::test]
312    async fn test_retry_provider_delegates_model() {
313        let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
314        let retry = RetryProvider::new(mock);
315        assert_eq!(retry.model(), "test-model");
316    }
317
318    // -- RetryProvider retry behavior --
319
320    /// RetryProvider retries on ProviderError::Timeout up to max_retries.
321    #[tokio::test]
322    async fn test_retry_provider_retries_on_timeout() {
323        let mock = Arc::new(MockProvider::with_responses(
324            "p",
325            "m",
326            vec![
327                Err(ProviderError::Timeout {
328                    message: "t1".into(),
329                }),
330                Err(ProviderError::Timeout {
331                    message: "t2".into(),
332                }),
333                Ok("success".into()),
334            ],
335        ));
336        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
337        let config = CompletionConfig::default();
338        let result = retry.complete("sys", "usr", &config).await;
339        assert!(result.is_ok());
340        assert_eq!(result.unwrap(), "success");
341        assert_eq!(mock.call_count(), 3);
342    }
343
344    /// RetryProvider retries on ProviderError::Http with status 500.
345    #[tokio::test]
346    async fn test_retry_provider_retries_on_http_500() {
347        let mock = Arc::new(MockProvider::with_responses(
348            "p",
349            "m",
350            vec![
351                Err(ProviderError::Http {
352                    status: 500,
353                    body: "err".into(),
354                }),
355                Ok("ok".into()),
356            ],
357        ));
358        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
359        let config = CompletionConfig::default();
360        let result = retry.complete("sys", "usr", &config).await;
361        assert!(result.is_ok());
362        assert_eq!(mock.call_count(), 2);
363    }
364
365    /// RetryProvider retries on ProviderError::Http with status 429.
366    #[tokio::test]
367    async fn test_retry_provider_retries_on_http_429() {
368        let mock = Arc::new(MockProvider::with_responses(
369            "p",
370            "m",
371            vec![
372                Err(ProviderError::Http {
373                    status: 429,
374                    body: "rate limit".into(),
375                }),
376                Ok("ok".into()),
377            ],
378        ));
379        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
380        let config = CompletionConfig::default();
381        let result = retry.complete("sys", "usr", &config).await;
382        assert!(result.is_ok());
383        assert_eq!(mock.call_count(), 2);
384    }
385
386    /// RetryProvider retries on ProviderError::Network.
387    #[tokio::test]
388    async fn test_retry_provider_retries_on_network() {
389        let mock = Arc::new(MockProvider::with_responses(
390            "p",
391            "m",
392            vec![
393                Err(ProviderError::Network {
394                    message: "dns".into(),
395                }),
396                Ok("ok".into()),
397            ],
398        ));
399        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
400        let config = CompletionConfig::default();
401        let result = retry.complete("sys", "usr", &config).await;
402        assert!(result.is_ok());
403        assert_eq!(mock.call_count(), 2);
404    }
405
406    /// RetryProvider does NOT retry on ProviderError::Auth.
407    #[tokio::test]
408    async fn test_retry_provider_does_not_retry_on_auth() {
409        let mock = Arc::new(MockProvider::with_responses(
410            "p",
411            "m",
412            vec![Err(ProviderError::Auth {
413                message: "bad key".into(),
414            })],
415        ));
416        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
417        let config = CompletionConfig::default();
418        let result = retry.complete("sys", "usr", &config).await;
419        assert!(result.is_err());
420        assert_eq!(mock.call_count(), 1);
421    }
422
423    /// RetryProvider does NOT retry on ProviderError::Process.
424    #[tokio::test]
425    async fn test_retry_provider_does_not_retry_on_process() {
426        let mock = Arc::new(MockProvider::with_responses(
427            "p",
428            "m",
429            vec![Err(ProviderError::Process {
430                exit_code: Some(1),
431                stderr: "fail".into(),
432            })],
433        ));
434        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
435        let config = CompletionConfig::default();
436        let result = retry.complete("sys", "usr", &config).await;
437        assert!(result.is_err());
438        assert_eq!(mock.call_count(), 1);
439    }
440
441    /// RetryProvider does NOT retry on ProviderError::NestedSession.
442    #[tokio::test]
443    async fn test_retry_provider_does_not_retry_on_nested_session() {
444        let mock = Arc::new(MockProvider::with_responses(
445            "p",
446            "m",
447            vec![Err(ProviderError::NestedSession)],
448        ));
449        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
450        let config = CompletionConfig::default();
451        let result = retry.complete("sys", "usr", &config).await;
452        assert!(result.is_err());
453        assert_eq!(mock.call_count(), 1);
454    }
455
456    /// RetryProvider does NOT retry on ProviderError::Http with 4xx (except 429).
457    #[tokio::test]
458    async fn test_retry_provider_does_not_retry_on_http_4xx() {
459        let mock = Arc::new(MockProvider::with_responses(
460            "p",
461            "m",
462            vec![Err(ProviderError::Http {
463                status: 403,
464                body: "forbidden".into(),
465            })],
466        ));
467        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
468        let config = CompletionConfig::default();
469        let result = retry.complete("sys", "usr", &config).await;
470        assert!(result.is_err());
471        assert_eq!(mock.call_count(), 1);
472    }
473
474    /// RetryProvider returns last error after exhausting retries.
475    #[tokio::test]
476    async fn test_retry_provider_returns_last_error_after_exhausting_retries() {
477        let mock = Arc::new(MockProvider::with_responses(
478            "p",
479            "m",
480            vec![
481                Err(ProviderError::Timeout {
482                    message: "t1".into(),
483                }),
484                Err(ProviderError::Timeout {
485                    message: "t2".into(),
486                }),
487                Err(ProviderError::Timeout {
488                    message: "t3".into(),
489                }),
490            ],
491        ));
492        // max_retries=2 means 1 initial + 2 retries = 3 total attempts
493        let retry = RetryProvider::with_config(mock.clone(), 2, Duration::from_millis(1));
494        let config = CompletionConfig::default();
495        let result = retry.complete("sys", "usr", &config).await;
496        assert!(result.is_err());
497        assert_eq!(mock.call_count(), 3);
498        match result.unwrap_err() {
499            ProviderError::Timeout { message } => assert_eq!(message, "t3"),
500            other => panic!("expected Timeout, got: {other}"),
501        }
502    }
503
504    /// RetryProvider returns success on first successful retry.
505    #[tokio::test]
506    async fn test_retry_provider_returns_success_on_first_retry() {
507        let mock = Arc::new(MockProvider::with_responses(
508            "p",
509            "m",
510            vec![
511                Err(ProviderError::Timeout {
512                    message: "t1".into(),
513                }),
514                Ok("recovered".into()),
515            ],
516        ));
517        let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
518        let config = CompletionConfig::default();
519        let result = retry.complete("sys", "usr", &config).await;
520        assert!(result.is_ok());
521        assert_eq!(result.unwrap(), "recovered");
522        assert_eq!(mock.call_count(), 2);
523    }
524
525    /// RetryProvider default config: 3 retries, 1s delay.
526    #[test]
527    fn test_retry_provider_default_config() {
528        let mock = Arc::new(MockProvider::new("p", "m"));
529        let retry = RetryProvider::new(mock);
530        assert_eq!(retry.max_retries, 3);
531        assert_eq!(retry.base_delay, Duration::from_secs(1));
532    }
533
534    #[test]
535    fn test_resolve_claude_alias_opus_returns_claude_opus_4_7() {
536        let result = resolve_claude_alias("opus").unwrap();
537        assert_eq!(result, "claude-opus-4-7");
538    }
539
540    #[test]
541    fn test_resolve_claude_alias_sonnet_returns_claude_sonnet_4_6() {
542        let result = resolve_claude_alias("sonnet").unwrap();
543        assert_eq!(result, "claude-sonnet-4-6");
544    }
545
546    #[test]
547    fn test_resolve_claude_alias_haiku_returns_claude_haiku_4_5_20251001() {
548        let result = resolve_claude_alias("haiku").unwrap();
549        assert_eq!(result, "claude-haiku-4-5-20251001");
550    }
551
552    /// Consumers who pinned "claude-opus-4-6" from v0.1.x get the string passed through
553    /// unchanged — backward compatibility for callers that already resolved the alias.
554    #[test]
555    fn test_resolve_claude_alias_accepts_literal_claude_opus_4_6_passthrough() {
556        // Consumers may have pinned the string "claude-opus-4-6" from v0.1.x;
557        // the resolver must pass any string containing "claude-" through unchanged.
558        assert_eq!(
559            resolve_claude_alias("claude-opus-4-6").unwrap(),
560            "claude-opus-4-6"
561        );
562    }
563}