Skip to main content

git_iris/
providers.rs

1//! LLM Provider configuration.
2//!
3//! Single source of truth for supported providers and their defaults.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt;
8use std::str::FromStr;
9
10/// Supported LLM providers
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum Provider {
14    #[default]
15    OpenAI,
16    Anthropic,
17    Google,
18}
19
20impl Provider {
21    /// All available providers
22    pub const ALL: &'static [Provider] = &[Provider::OpenAI, Provider::Anthropic, Provider::Google];
23
24    /// Provider name as used in config files and CLI
25    #[must_use]
26    pub const fn name(&self) -> &'static str {
27        match self {
28            Self::OpenAI => "openai",
29            Self::Anthropic => "anthropic",
30            Self::Google => "google",
31        }
32    }
33
34    /// Default model for complex analysis tasks
35    #[must_use]
36    pub const fn default_model(&self) -> &'static str {
37        match self {
38            Self::OpenAI => "gpt-5.4",
39            Self::Anthropic => "claude-opus-4-6",
40            Self::Google => "gemini-3-pro-preview",
41        }
42    }
43
44    /// Fast model for simple tasks (status updates, parsing)
45    #[must_use]
46    pub const fn default_fast_model(&self) -> &'static str {
47        match self {
48            Self::OpenAI => "gpt-5.4-mini",
49            Self::Anthropic => "claude-haiku-4-5-20251001",
50            Self::Google => "gemini-2.5-flash",
51        }
52    }
53
54    /// Context window size (max tokens)
55    #[must_use]
56    pub const fn context_window(&self) -> usize {
57        match self {
58            Self::OpenAI => 128_000,
59            Self::Anthropic => 200_000,
60            Self::Google => 1_000_000,
61        }
62    }
63
64    /// Environment variable name for the API key
65    #[must_use]
66    pub const fn api_key_env(&self) -> &'static str {
67        match self {
68            Self::OpenAI => "OPENAI_API_KEY",
69            Self::Anthropic => "ANTHROPIC_API_KEY",
70            Self::Google => "GOOGLE_API_KEY",
71        }
72    }
73
74    /// Valid API key prefixes for format validation
75    ///
76    /// Returns the expected prefixes for the provider's API keys.
77    /// `OpenAI` has multiple valid prefixes (sk-, sk-proj-).
78    #[must_use]
79    pub fn api_key_prefixes(&self) -> &'static [&'static str] {
80        match self {
81            Self::OpenAI => &["sk-", "sk-proj-"],
82            Self::Anthropic => &["sk-ant-"],
83            Self::Google => &[], // Google API keys don't have a consistent prefix
84        }
85    }
86
87    /// Expected API key prefix for basic format validation (primary prefix)
88    ///
89    /// Returns the primary expected prefix for display in error messages.
90    #[must_use]
91    pub const fn api_key_prefix(&self) -> Option<&'static str> {
92        match self {
93            Self::OpenAI => Some("sk-"),
94            Self::Anthropic => Some("sk-ant-"),
95            Self::Google => None,
96        }
97    }
98
99    /// Validate API key format
100    ///
101    /// Performs basic validation to catch obvious misconfigurations:
102    /// - Checks for expected prefix (`OpenAI`: `sk-` or `sk-proj-`, `Anthropic`: `sk-ant-`)
103    /// - Ensures key is not suspiciously short
104    ///
105    /// Returns `Ok(())` if valid, or a warning message if potentially invalid.
106    /// Note: A valid format doesn't guarantee the key works - it may still be
107    /// expired or revoked. This just catches typos.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error string when the key format is clearly invalid for the provider.
112    pub fn validate_api_key_format(&self, key: &str) -> Result<(), String> {
113        // Check minimum length (API keys are typically 30+ chars)
114        if key.len() < 20 {
115            return Err(format!(
116                "{} API key appears too short (got {} chars, expected 20+)",
117                self.name(),
118                key.len()
119            ));
120        }
121
122        // Check expected prefixes
123        let prefixes = self.api_key_prefixes();
124        if !prefixes.is_empty() && !prefixes.iter().any(|p| key.starts_with(p)) {
125            let expected = if prefixes.len() == 1 {
126                format!("'{}'", prefixes[0])
127            } else {
128                prefixes
129                    .iter()
130                    .map(|p| format!("'{p}'"))
131                    .collect::<Vec<_>>()
132                    .join(" or ")
133            };
134            return Err(format!(
135                "{} API key should start with {} (key has unexpected prefix)",
136                self.name(),
137                expected
138            ));
139        }
140
141        Ok(())
142    }
143
144    /// Get all provider names as strings
145    pub fn all_names() -> Vec<&'static str> {
146        Self::ALL.iter().map(Self::name).collect()
147    }
148}
149
150impl FromStr for Provider {
151    type Err = ProviderError;
152
153    fn from_str(s: &str) -> Result<Self, Self::Err> {
154        let lower = s.to_lowercase();
155        // Handle legacy/common aliases
156        let normalized = match lower.as_str() {
157            "claude" => "anthropic",
158            "gemini" => "google",
159            _ => &lower,
160        };
161
162        Self::ALL
163            .iter()
164            .find(|p| p.name() == normalized)
165            .copied()
166            .ok_or_else(|| ProviderError::Unknown(s.to_string()))
167    }
168}
169
170impl fmt::Display for Provider {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        write!(f, "{}", self.name())
173    }
174}
175
176/// Provider configuration error
177#[derive(Debug, thiserror::Error)]
178pub enum ProviderError {
179    #[error("Unknown provider: {0}. Supported: openai, anthropic, google")]
180    Unknown(String),
181    #[error("API key required for provider: {0}")]
182    MissingApiKey(String),
183}
184
185/// Per-provider configuration
186#[derive(Clone, Default, Serialize, Deserialize)]
187pub struct ProviderConfig {
188    /// API key (loaded from env or config)
189    #[serde(default, skip_serializing_if = "String::is_empty")]
190    pub api_key: String,
191    /// Primary model for complex analysis
192    #[serde(default, skip_serializing_if = "String::is_empty")]
193    pub model: String,
194    /// Fast model for simple tasks
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    pub fast_model: Option<String>,
197    /// Token limit override
198    #[serde(default, skip_serializing_if = "Option::is_none")]
199    pub token_limit: Option<usize>,
200    /// Additional provider-specific params
201    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
202    pub additional_params: HashMap<String, String>,
203}
204
205impl fmt::Debug for ProviderConfig {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        f.debug_struct("ProviderConfig")
208            .field(
209                "api_key",
210                if self.api_key.is_empty() {
211                    &"<empty>"
212                } else {
213                    &"[REDACTED]"
214                },
215            )
216            .field("model", &self.model)
217            .field("fast_model", &self.fast_model)
218            .field("token_limit", &self.token_limit)
219            .field("additional_params", &self.additional_params)
220            .finish()
221    }
222}
223
224impl ProviderConfig {
225    /// Create config with defaults for a provider
226    #[must_use]
227    pub fn with_defaults(provider: Provider) -> Self {
228        Self {
229            api_key: String::new(),
230            model: provider.default_model().to_string(),
231            fast_model: Some(provider.default_fast_model().to_string()),
232            token_limit: None,
233            additional_params: HashMap::new(),
234        }
235    }
236
237    /// Get effective model (configured or default)
238    #[must_use]
239    pub fn effective_model(&self, provider: Provider) -> &str {
240        if self.model.is_empty() {
241            provider.default_model()
242        } else {
243            &self.model
244        }
245    }
246
247    /// Get effective fast model (configured or default)
248    #[must_use]
249    pub fn effective_fast_model(&self, provider: Provider) -> &str {
250        self.fast_model
251            .as_deref()
252            .unwrap_or_else(|| provider.default_fast_model())
253    }
254
255    /// Get effective token limit (configured or default)
256    #[must_use]
257    pub fn effective_token_limit(&self, provider: Provider) -> usize {
258        self.token_limit
259            .unwrap_or_else(|| provider.context_window())
260    }
261
262    /// Check if this config has an API key set
263    #[must_use]
264    pub fn has_api_key(&self) -> bool {
265        !self.api_key.is_empty()
266    }
267
268    /// Get API key if set (non-empty), otherwise None
269    ///
270    /// This is the canonical way to extract an API key for passing to
271    /// provider builders. Returns `None` for empty strings, allowing
272    /// fallback to environment variables.
273    #[must_use]
274    pub fn api_key_if_set(&self) -> Option<&str> {
275        if self.api_key.is_empty() {
276            None
277        } else {
278            Some(&self.api_key)
279        }
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_provider_from_str() {
289        assert_eq!("openai".parse::<Provider>().ok(), Some(Provider::OpenAI));
290        assert_eq!(
291            "ANTHROPIC".parse::<Provider>().ok(),
292            Some(Provider::Anthropic)
293        );
294        assert_eq!("claude".parse::<Provider>().ok(), Some(Provider::Anthropic)); // Legacy alias
295        assert_eq!("gemini".parse::<Provider>().ok(), Some(Provider::Google)); // Common alias
296        assert!("invalid".parse::<Provider>().is_err());
297    }
298
299    #[test]
300    fn test_provider_defaults() {
301        assert_eq!(Provider::OpenAI.default_model(), "gpt-5.4");
302        assert_eq!(Provider::OpenAI.default_fast_model(), "gpt-5.4-mini");
303        assert_eq!(Provider::Anthropic.context_window(), 200_000);
304        assert_eq!(Provider::Google.api_key_env(), "GOOGLE_API_KEY");
305    }
306
307    #[test]
308    fn test_provider_config_defaults() {
309        let config = ProviderConfig::with_defaults(Provider::Anthropic);
310        assert_eq!(config.model, "claude-opus-4-6");
311        assert_eq!(
312            config.fast_model.as_deref(),
313            Some("claude-haiku-4-5-20251001")
314        );
315    }
316
317    #[test]
318    fn test_api_key_prefix() {
319        assert_eq!(Provider::OpenAI.api_key_prefix(), Some("sk-"));
320        assert_eq!(Provider::Anthropic.api_key_prefix(), Some("sk-ant-"));
321        assert_eq!(Provider::Google.api_key_prefix(), None);
322    }
323
324    #[test]
325    fn test_api_key_if_set() {
326        // Non-empty key returns Some
327        let mut config = ProviderConfig::with_defaults(Provider::OpenAI);
328        config.api_key = "sk-test-key-12345678901234567890".to_string();
329        assert_eq!(
330            config.api_key_if_set(),
331            Some("sk-test-key-12345678901234567890")
332        );
333
334        // Empty key returns None
335        config.api_key = String::new();
336        assert_eq!(config.api_key_if_set(), None);
337    }
338
339    #[test]
340    fn test_api_key_prefixes() {
341        // OpenAI accepts multiple prefixes
342        assert_eq!(Provider::OpenAI.api_key_prefixes(), &["sk-", "sk-proj-"]);
343        assert_eq!(Provider::Anthropic.api_key_prefixes(), &["sk-ant-"]);
344        assert!(Provider::Google.api_key_prefixes().is_empty());
345    }
346
347    #[test]
348    fn test_api_key_validation_valid_openai() {
349        // Valid OpenAI key format (starts with sk-, long enough)
350        let result = Provider::OpenAI.validate_api_key_format("sk-1234567890abcdefghijklmnop");
351        assert!(result.is_ok());
352    }
353
354    #[test]
355    fn test_api_key_validation_valid_openai_project_key() {
356        // Valid OpenAI project key format (starts with sk-proj-, long enough)
357        let result = Provider::OpenAI.validate_api_key_format("sk-proj-1234567890abcdefghijklmnop");
358        assert!(result.is_ok());
359    }
360
361    #[test]
362    fn test_api_key_validation_valid_anthropic() {
363        // Valid Anthropic key format (starts with sk-ant-, long enough)
364        let result =
365            Provider::Anthropic.validate_api_key_format("sk-ant-1234567890abcdefghijklmnop");
366        assert!(result.is_ok());
367    }
368
369    #[test]
370    fn test_api_key_validation_valid_google() {
371        // Google keys don't have a prefix requirement, just length
372        let result = Provider::Google.validate_api_key_format("AIzaSyA1234567890abcdefgh");
373        assert!(result.is_ok());
374    }
375
376    #[test]
377    fn test_api_key_validation_too_short() {
378        let result = Provider::OpenAI.validate_api_key_format("sk-short");
379        assert!(result.is_err());
380        assert!(result.expect_err("should be err").contains("too short"));
381    }
382
383    #[test]
384    fn test_api_key_validation_wrong_prefix_openai() {
385        // Long enough but wrong prefix
386        let result = Provider::OpenAI.validate_api_key_format("wrong-prefix-1234567890abcdef");
387        assert!(result.is_err());
388        let err = result.expect_err("should be err");
389        assert!(err.contains("should start with"));
390        // Error should mention valid prefixes
391        assert!(err.contains("'sk-'") || err.contains("'sk-proj-'"));
392        // Verify we don't expose the actual key prefix in error messages
393        assert!(!err.contains("wrong-"));
394    }
395
396    #[test]
397    fn test_api_key_validation_wrong_prefix_anthropic() {
398        // Has sk- but not sk-ant- (might be OpenAI key used for Anthropic)
399        let result = Provider::Anthropic.validate_api_key_format("sk-1234567890abcdefghijklmnop");
400        assert!(result.is_err());
401        let err = result.expect_err("should be err");
402        assert!(err.contains("sk-ant-"));
403        // Verify we don't expose the actual key content
404        assert!(err.contains("unexpected prefix"));
405    }
406}