auto_commit/api/
provider.rs

1use std::fmt;
2
3/// Supported LLM providers
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Provider {
6    OpenAi,
7    DeepSeek,
8    Gemini,
9}
10
11impl Provider {
12    /// Detect provider from environment variables (priority order)
13    pub fn detect() -> Option<(Self, String)> {
14        // Priority: OPENAI_API_KEY > DEEPSEEK_API_KEY > GEMINI_API_KEY
15        if let Ok(key) = std::env::var("OPENAI_API_KEY") {
16            if !key.is_empty() {
17                return Some((Self::OpenAi, key));
18            }
19        }
20        if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
21            if !key.is_empty() {
22                return Some((Self::DeepSeek, key));
23            }
24        }
25        if let Ok(key) = std::env::var("GEMINI_API_KEY") {
26            if !key.is_empty() {
27                return Some((Self::Gemini, key));
28            }
29        }
30        None
31    }
32
33    /// Get the base URL for API requests
34    pub fn base_url(&self) -> &'static str {
35        match self {
36            Self::OpenAi => "https://api.openai.com",
37            Self::DeepSeek => "https://api.deepseek.com",
38            Self::Gemini => "https://generativelanguage.googleapis.com",
39        }
40    }
41
42    /// Get the default model for this provider
43    pub fn default_model(&self) -> &'static str {
44        match self {
45            Self::OpenAi => "gpt-4o-mini",
46            Self::DeepSeek => "deepseek-chat",
47            Self::Gemini => "gemini-2.0-flash",
48        }
49    }
50
51    /// Check if this provider uses OpenAI-compatible API
52    pub fn is_openai_compatible(&self) -> bool {
53        matches!(self, Self::OpenAi | Self::DeepSeek)
54    }
55}
56
57impl fmt::Display for Provider {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        match self {
60            Self::OpenAi => write!(f, "OpenAI"),
61            Self::DeepSeek => write!(f, "DeepSeek"),
62            Self::Gemini => write!(f, "Gemini"),
63        }
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use serial_test::serial;
71
72    fn clear_env_keys() {
73        std::env::remove_var("OPENAI_API_KEY");
74        std::env::remove_var("DEEPSEEK_API_KEY");
75        std::env::remove_var("GEMINI_API_KEY");
76    }
77
78    #[test]
79    #[serial]
80    fn test_detect_openai() {
81        clear_env_keys();
82        std::env::set_var("OPENAI_API_KEY", "sk-test");
83
84        let result = Provider::detect();
85        assert!(result.is_some());
86        let (provider, key) = result.unwrap();
87        assert_eq!(provider, Provider::OpenAi);
88        assert_eq!(key, "sk-test");
89
90        clear_env_keys();
91    }
92
93    #[test]
94    #[serial]
95    fn test_detect_deepseek() {
96        clear_env_keys();
97        std::env::set_var("DEEPSEEK_API_KEY", "sk-deepseek");
98
99        let result = Provider::detect();
100        assert!(result.is_some());
101        let (provider, key) = result.unwrap();
102        assert_eq!(provider, Provider::DeepSeek);
103        assert_eq!(key, "sk-deepseek");
104
105        clear_env_keys();
106    }
107
108    #[test]
109    #[serial]
110    fn test_detect_gemini() {
111        clear_env_keys();
112        std::env::set_var("GEMINI_API_KEY", "AIza-test");
113
114        let result = Provider::detect();
115        assert!(result.is_some());
116        let (provider, key) = result.unwrap();
117        assert_eq!(provider, Provider::Gemini);
118        assert_eq!(key, "AIza-test");
119
120        clear_env_keys();
121    }
122
123    #[test]
124    #[serial]
125    fn test_detect_priority() {
126        clear_env_keys();
127        // Set all keys - OpenAI should win
128        std::env::set_var("OPENAI_API_KEY", "openai-key");
129        std::env::set_var("DEEPSEEK_API_KEY", "deepseek-key");
130        std::env::set_var("GEMINI_API_KEY", "gemini-key");
131
132        let result = Provider::detect();
133        assert!(result.is_some());
134        let (provider, _) = result.unwrap();
135        assert_eq!(provider, Provider::OpenAi);
136
137        clear_env_keys();
138    }
139
140    #[test]
141    #[serial]
142    fn test_detect_none() {
143        clear_env_keys();
144
145        let result = Provider::detect();
146        assert!(result.is_none());
147    }
148
149    #[test]
150    fn test_base_url() {
151        assert_eq!(Provider::OpenAi.base_url(), "https://api.openai.com");
152        assert_eq!(Provider::DeepSeek.base_url(), "https://api.deepseek.com");
153        assert!(Provider::Gemini.base_url().contains("googleapis.com"));
154    }
155
156    #[test]
157    fn test_openai_compatible() {
158        assert!(Provider::OpenAi.is_openai_compatible());
159        assert!(Provider::DeepSeek.is_openai_compatible());
160        assert!(!Provider::Gemini.is_openai_compatible());
161    }
162}