auto_commit/api/
provider.rs1use std::fmt;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Provider {
6 OpenAi,
7 DeepSeek,
8 Gemini,
9}
10
11impl Provider {
12 pub fn detect() -> Option<(Self, String)> {
14 if let Ok(key) = std::env::var("OPENAI_API_KEY") {
16 if !key.is_empty() {
17 return Some((Self::OpenAi, key));
18 }
19 }
20 if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
21 if !key.is_empty() {
22 return Some((Self::DeepSeek, key));
23 }
24 }
25 if let Ok(key) = std::env::var("GEMINI_API_KEY") {
26 if !key.is_empty() {
27 return Some((Self::Gemini, key));
28 }
29 }
30 None
31 }
32
33 pub fn base_url(&self) -> &'static str {
35 match self {
36 Self::OpenAi => "https://api.openai.com",
37 Self::DeepSeek => "https://api.deepseek.com",
38 Self::Gemini => "https://generativelanguage.googleapis.com",
39 }
40 }
41
42 pub fn default_model(&self) -> &'static str {
44 match self {
45 Self::OpenAi => "gpt-4o-mini",
46 Self::DeepSeek => "deepseek-chat",
47 Self::Gemini => "gemini-2.0-flash",
48 }
49 }
50
51 pub fn is_openai_compatible(&self) -> bool {
53 matches!(self, Self::OpenAi | Self::DeepSeek)
54 }
55}
56
57impl fmt::Display for Provider {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 match self {
60 Self::OpenAi => write!(f, "OpenAI"),
61 Self::DeepSeek => write!(f, "DeepSeek"),
62 Self::Gemini => write!(f, "Gemini"),
63 }
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use serial_test::serial;
71
72 fn clear_env_keys() {
73 std::env::remove_var("OPENAI_API_KEY");
74 std::env::remove_var("DEEPSEEK_API_KEY");
75 std::env::remove_var("GEMINI_API_KEY");
76 }
77
78 #[test]
79 #[serial]
80 fn test_detect_openai() {
81 clear_env_keys();
82 std::env::set_var("OPENAI_API_KEY", "sk-test");
83
84 let result = Provider::detect();
85 assert!(result.is_some());
86 let (provider, key) = result.unwrap();
87 assert_eq!(provider, Provider::OpenAi);
88 assert_eq!(key, "sk-test");
89
90 clear_env_keys();
91 }
92
93 #[test]
94 #[serial]
95 fn test_detect_deepseek() {
96 clear_env_keys();
97 std::env::set_var("DEEPSEEK_API_KEY", "sk-deepseek");
98
99 let result = Provider::detect();
100 assert!(result.is_some());
101 let (provider, key) = result.unwrap();
102 assert_eq!(provider, Provider::DeepSeek);
103 assert_eq!(key, "sk-deepseek");
104
105 clear_env_keys();
106 }
107
108 #[test]
109 #[serial]
110 fn test_detect_gemini() {
111 clear_env_keys();
112 std::env::set_var("GEMINI_API_KEY", "AIza-test");
113
114 let result = Provider::detect();
115 assert!(result.is_some());
116 let (provider, key) = result.unwrap();
117 assert_eq!(provider, Provider::Gemini);
118 assert_eq!(key, "AIza-test");
119
120 clear_env_keys();
121 }
122
123 #[test]
124 #[serial]
125 fn test_detect_priority() {
126 clear_env_keys();
127 std::env::set_var("OPENAI_API_KEY", "openai-key");
129 std::env::set_var("DEEPSEEK_API_KEY", "deepseek-key");
130 std::env::set_var("GEMINI_API_KEY", "gemini-key");
131
132 let result = Provider::detect();
133 assert!(result.is_some());
134 let (provider, _) = result.unwrap();
135 assert_eq!(provider, Provider::OpenAi);
136
137 clear_env_keys();
138 }
139
140 #[test]
141 #[serial]
142 fn test_detect_none() {
143 clear_env_keys();
144
145 let result = Provider::detect();
146 assert!(result.is_none());
147 }
148
149 #[test]
150 fn test_base_url() {
151 assert_eq!(Provider::OpenAi.base_url(), "https://api.openai.com");
152 assert_eq!(Provider::DeepSeek.base_url(), "https://api.deepseek.com");
153 assert!(Provider::Gemini.base_url().contains("googleapis.com"));
154 }
155
156 #[test]
157 fn test_openai_compatible() {
158 assert!(Provider::OpenAi.is_openai_compatible());
159 assert!(Provider::DeepSeek.is_openai_compatible());
160 assert!(!Provider::Gemini.is_openai_compatible());
161 }
162}