git_iris/
providers.rs

1//! LLM Provider configuration.
2//!
3//! Single source of truth for supported providers and their defaults.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt;
8use std::str::FromStr;
9
10/// Supported LLM providers
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum Provider {
14    #[default]
15    OpenAI,
16    Anthropic,
17    Google,
18}
19
20impl Provider {
21    /// All available providers
22    pub const ALL: &'static [Provider] = &[Provider::OpenAI, Provider::Anthropic, Provider::Google];
23
24    /// Provider name as used in config files and CLI
25    pub const fn name(&self) -> &'static str {
26        match self {
27            Self::OpenAI => "openai",
28            Self::Anthropic => "anthropic",
29            Self::Google => "google",
30        }
31    }
32
33    /// Default model for complex analysis tasks
34    pub const fn default_model(&self) -> &'static str {
35        match self {
36            Self::OpenAI => "gpt-5.1",
37            Self::Anthropic => "claude-sonnet-4-5-20250929",
38            Self::Google => "gemini-3-pro-preview",
39        }
40    }
41
42    /// Fast model for simple tasks (status updates, parsing)
43    pub const fn default_fast_model(&self) -> &'static str {
44        match self {
45            Self::OpenAI => "gpt-5.1-mini",
46            Self::Anthropic => "claude-haiku-4-5-20251001",
47            Self::Google => "gemini-2.5-flash",
48        }
49    }
50
51    /// Context window size (max tokens)
52    pub const fn context_window(&self) -> usize {
53        match self {
54            Self::OpenAI => 128_000,
55            Self::Anthropic => 200_000,
56            Self::Google => 1_000_000,
57        }
58    }
59
60    /// Environment variable name for the API key
61    pub const fn api_key_env(&self) -> &'static str {
62        match self {
63            Self::OpenAI => "OPENAI_API_KEY",
64            Self::Anthropic => "ANTHROPIC_API_KEY",
65            Self::Google => "GOOGLE_API_KEY",
66        }
67    }
68
69    /// Valid API key prefixes for format validation
70    ///
71    /// Returns the expected prefixes for the provider's API keys.
72    /// `OpenAI` has multiple valid prefixes (sk-, sk-proj-).
73    pub fn api_key_prefixes(&self) -> &'static [&'static str] {
74        match self {
75            Self::OpenAI => &["sk-", "sk-proj-"],
76            Self::Anthropic => &["sk-ant-"],
77            Self::Google => &[], // Google API keys don't have a consistent prefix
78        }
79    }
80
81    /// Expected API key prefix for basic format validation (primary prefix)
82    ///
83    /// Returns the primary expected prefix for display in error messages.
84    pub const fn api_key_prefix(&self) -> Option<&'static str> {
85        match self {
86            Self::OpenAI => Some("sk-"),
87            Self::Anthropic => Some("sk-ant-"),
88            Self::Google => None,
89        }
90    }
91
92    /// Validate API key format
93    ///
94    /// Performs basic validation to catch obvious misconfigurations:
95    /// - Checks for expected prefix (`OpenAI`: `sk-` or `sk-proj-`, `Anthropic`: `sk-ant-`)
96    /// - Ensures key is not suspiciously short
97    ///
98    /// Returns `Ok(())` if valid, or a warning message if potentially invalid.
99    /// Note: A valid format doesn't guarantee the key works - it may still be
100    /// expired or revoked. This just catches typos.
101    pub fn validate_api_key_format(&self, key: &str) -> Result<(), String> {
102        // Check minimum length (API keys are typically 30+ chars)
103        if key.len() < 20 {
104            return Err(format!(
105                "{} API key appears too short (got {} chars, expected 20+)",
106                self.name(),
107                key.len()
108            ));
109        }
110
111        // Check expected prefixes
112        let prefixes = self.api_key_prefixes();
113        if !prefixes.is_empty() && !prefixes.iter().any(|p| key.starts_with(p)) {
114            let expected = if prefixes.len() == 1 {
115                format!("'{}'", prefixes[0])
116            } else {
117                prefixes
118                    .iter()
119                    .map(|p| format!("'{p}'"))
120                    .collect::<Vec<_>>()
121                    .join(" or ")
122            };
123            return Err(format!(
124                "{} API key should start with {} (key has unexpected prefix)",
125                self.name(),
126                expected
127            ));
128        }
129
130        Ok(())
131    }
132
133    /// Get all provider names as strings
134    pub fn all_names() -> Vec<&'static str> {
135        Self::ALL.iter().map(Self::name).collect()
136    }
137}
138
139impl FromStr for Provider {
140    type Err = ProviderError;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        let lower = s.to_lowercase();
144        // Handle legacy "claude" alias
145        let normalized = if lower == "claude" {
146            "anthropic"
147        } else {
148            &lower
149        };
150
151        Self::ALL
152            .iter()
153            .find(|p| p.name() == normalized)
154            .copied()
155            .ok_or_else(|| ProviderError::Unknown(s.to_string()))
156    }
157}
158
159impl fmt::Display for Provider {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        write!(f, "{}", self.name())
162    }
163}
164
165/// Provider configuration error
166#[derive(Debug, thiserror::Error)]
167pub enum ProviderError {
168    #[error("Unknown provider: {0}. Supported: openai, anthropic, google")]
169    Unknown(String),
170    #[error("API key required for provider: {0}")]
171    MissingApiKey(String),
172}
173
174/// Per-provider configuration
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
176pub struct ProviderConfig {
177    /// API key (loaded from env or config)
178    #[serde(default, skip_serializing_if = "String::is_empty")]
179    pub api_key: String,
180    /// Primary model for complex analysis
181    #[serde(default, skip_serializing_if = "String::is_empty")]
182    pub model: String,
183    /// Fast model for simple tasks
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub fast_model: Option<String>,
186    /// Token limit override
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub token_limit: Option<usize>,
189    /// Additional provider-specific params
190    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
191    pub additional_params: HashMap<String, String>,
192}
193
194impl ProviderConfig {
195    /// Create config with defaults for a provider
196    pub fn with_defaults(provider: Provider) -> Self {
197        Self {
198            api_key: String::new(),
199            model: provider.default_model().to_string(),
200            fast_model: Some(provider.default_fast_model().to_string()),
201            token_limit: None,
202            additional_params: HashMap::new(),
203        }
204    }
205
206    /// Get effective model (configured or default)
207    pub fn effective_model(&self, provider: Provider) -> &str {
208        if self.model.is_empty() {
209            provider.default_model()
210        } else {
211            &self.model
212        }
213    }
214
215    /// Get effective fast model (configured or default)
216    pub fn effective_fast_model(&self, provider: Provider) -> &str {
217        self.fast_model
218            .as_deref()
219            .unwrap_or_else(|| provider.default_fast_model())
220    }
221
222    /// Get effective token limit (configured or default)
223    pub fn effective_token_limit(&self, provider: Provider) -> usize {
224        self.token_limit
225            .unwrap_or_else(|| provider.context_window())
226    }
227
228    /// Check if this config has an API key set
229    pub fn has_api_key(&self) -> bool {
230        !self.api_key.is_empty()
231    }
232
233    /// Get API key if set (non-empty), otherwise None
234    ///
235    /// This is the canonical way to extract an API key for passing to
236    /// provider builders. Returns `None` for empty strings, allowing
237    /// fallback to environment variables.
238    pub fn api_key_if_set(&self) -> Option<&str> {
239        if self.api_key.is_empty() {
240            None
241        } else {
242            Some(&self.api_key)
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_provider_from_str() {
253        assert_eq!("openai".parse::<Provider>().ok(), Some(Provider::OpenAI));
254        assert_eq!(
255            "ANTHROPIC".parse::<Provider>().ok(),
256            Some(Provider::Anthropic)
257        );
258        assert_eq!("claude".parse::<Provider>().ok(), Some(Provider::Anthropic)); // Legacy alias
259        assert!("invalid".parse::<Provider>().is_err());
260    }
261
262    #[test]
263    fn test_provider_defaults() {
264        assert_eq!(Provider::OpenAI.default_model(), "gpt-5.1");
265        assert_eq!(Provider::Anthropic.context_window(), 200_000);
266        assert_eq!(Provider::Google.api_key_env(), "GOOGLE_API_KEY");
267    }
268
269    #[test]
270    fn test_provider_config_defaults() {
271        let config = ProviderConfig::with_defaults(Provider::Anthropic);
272        assert_eq!(config.model, "claude-sonnet-4-5-20250929");
273        assert_eq!(
274            config.fast_model.as_deref(),
275            Some("claude-haiku-4-5-20251001")
276        );
277    }
278
279    #[test]
280    fn test_api_key_prefix() {
281        assert_eq!(Provider::OpenAI.api_key_prefix(), Some("sk-"));
282        assert_eq!(Provider::Anthropic.api_key_prefix(), Some("sk-ant-"));
283        assert_eq!(Provider::Google.api_key_prefix(), None);
284    }
285
286    #[test]
287    fn test_api_key_if_set() {
288        // Non-empty key returns Some
289        let mut config = ProviderConfig::with_defaults(Provider::OpenAI);
290        config.api_key = "sk-test-key-12345678901234567890".to_string();
291        assert_eq!(
292            config.api_key_if_set(),
293            Some("sk-test-key-12345678901234567890")
294        );
295
296        // Empty key returns None
297        config.api_key = String::new();
298        assert_eq!(config.api_key_if_set(), None);
299    }
300
301    #[test]
302    fn test_api_key_prefixes() {
303        // OpenAI accepts multiple prefixes
304        assert_eq!(Provider::OpenAI.api_key_prefixes(), &["sk-", "sk-proj-"]);
305        assert_eq!(Provider::Anthropic.api_key_prefixes(), &["sk-ant-"]);
306        assert!(Provider::Google.api_key_prefixes().is_empty());
307    }
308
309    #[test]
310    fn test_api_key_validation_valid_openai() {
311        // Valid OpenAI key format (starts with sk-, long enough)
312        let result = Provider::OpenAI.validate_api_key_format("sk-1234567890abcdefghijklmnop");
313        assert!(result.is_ok());
314    }
315
316    #[test]
317    fn test_api_key_validation_valid_openai_project_key() {
318        // Valid OpenAI project key format (starts with sk-proj-, long enough)
319        let result =
320            Provider::OpenAI.validate_api_key_format("sk-proj-1234567890abcdefghijklmnop");
321        assert!(result.is_ok());
322    }
323
324    #[test]
325    fn test_api_key_validation_valid_anthropic() {
326        // Valid Anthropic key format (starts with sk-ant-, long enough)
327        let result =
328            Provider::Anthropic.validate_api_key_format("sk-ant-1234567890abcdefghijklmnop");
329        assert!(result.is_ok());
330    }
331
332    #[test]
333    fn test_api_key_validation_valid_google() {
334        // Google keys don't have a prefix requirement, just length
335        let result = Provider::Google.validate_api_key_format("AIzaSyA1234567890abcdefgh");
336        assert!(result.is_ok());
337    }
338
339    #[test]
340    fn test_api_key_validation_too_short() {
341        let result = Provider::OpenAI.validate_api_key_format("sk-short");
342        assert!(result.is_err());
343        assert!(result.unwrap_err().contains("too short"));
344    }
345
346    #[test]
347    fn test_api_key_validation_wrong_prefix_openai() {
348        // Long enough but wrong prefix
349        let result = Provider::OpenAI.validate_api_key_format("wrong-prefix-1234567890abcdef");
350        assert!(result.is_err());
351        let err = result.unwrap_err();
352        assert!(err.contains("should start with"));
353        // Error should mention valid prefixes
354        assert!(err.contains("'sk-'") || err.contains("'sk-proj-'"));
355        // Verify we don't expose the actual key prefix in error messages
356        assert!(!err.contains("wrong-"));
357    }
358
359    #[test]
360    fn test_api_key_validation_wrong_prefix_anthropic() {
361        // Has sk- but not sk-ant- (might be OpenAI key used for Anthropic)
362        let result = Provider::Anthropic.validate_api_key_format("sk-1234567890abcdefghijklmnop");
363        assert!(result.is_err());
364        let err = result.unwrap_err();
365        assert!(err.contains("sk-ant-"));
366        // Verify we don't expose the actual key content
367        assert!(err.contains("unexpected prefix"));
368    }
369}