1use uuid::Uuid;
2
3use crate::errors::CoreError;
4use crate::models::{
5 ProviderAddInput, ProviderRecord, ProviderRemoveInput, ProviderSetActiveInput,
6 ProviderUpdateInput,
7};
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct CheckAuthInput {
11 pub engine: String,
12}
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct CheckAuthResult {
17 pub credential_detected: bool,
18 pub verified: bool,
19 pub method: String,
20 pub detail: String,
21}
22
23#[derive(sqlx::FromRow)]
24struct ProviderRow {
25 id: String,
26 name: String,
27 base_url: String,
28 api_key: String,
29 model_mapping: String,
30 is_active: i64,
31 created_at: String,
32 updated_at: String,
33}
34
35impl ProviderRow {
36 fn decrypt_api_key(&self) -> String {
37 match crate::crypto::decrypt_secret(&self.api_key) {
38 Ok(plaintext) => plaintext,
39 Err(e) => {
40 eprintln!(
41 "Failed to decrypt API key for provider {}: {e}. Returning empty string to avoid leaking ciphertext.",
42 self.id
43 );
44 String::new()
45 }
46 }
47 }
48
49 fn into_masked(self) -> ProviderRecord {
50 let decrypted = self.decrypt_api_key();
51 ProviderRecord {
52 id: self.id,
53 name: self.name,
54 base_url: self.base_url,
55 api_key: Some(mask_api_key(&decrypted)),
56 model_mapping: serde_json::from_str(&self.model_mapping).unwrap_or_default(),
57 is_active: self.is_active != 0,
58 created_at: self.created_at,
59 updated_at: self.updated_at,
60 }
61 }
62
63 fn into_internal(self) -> ProviderRecord {
64 let decrypted = self.decrypt_api_key();
65 ProviderRecord {
66 id: self.id,
67 name: self.name,
68 base_url: self.base_url,
69 api_key: Some(decrypted),
70 model_mapping: serde_json::from_str(&self.model_mapping).unwrap_or_default(),
71 is_active: self.is_active != 0,
72 created_at: self.created_at,
73 updated_at: self.updated_at,
74 }
75 }
76}
77
78fn mask_api_key(key: &str) -> String {
79 let n = key.chars().count();
84 if n <= 8 {
85 return "****".to_owned();
86 }
87 let head: String = key.chars().take(3).collect();
88 let tail: String = key.chars().skip(n - 3).collect();
89 format!("{head}***{tail}")
90}
91
92pub async fn list(db: &sqlx::SqlitePool) -> crate::Result<Vec<ProviderRecord>> {
93 let rows = sqlx::query_as!(
94 ProviderRow,
95 "SELECT id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at FROM providers ORDER BY created_at DESC"
96 )
97 .fetch_all(db)
98 .await?;
99 Ok(rows.into_iter().map(ProviderRow::into_masked).collect())
100}
101
102pub async fn get(
103 db: &sqlx::SqlitePool,
104 input: ProviderRemoveInput,
105) -> crate::Result<Option<ProviderRecord>> {
106 let row = sqlx::query_as!(
107 ProviderRow,
108 "SELECT id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at FROM providers WHERE id = ?1",
109 input.id
110 )
111 .fetch_optional(db)
112 .await?;
113 Ok(row.map(ProviderRow::into_masked))
114}
115
116pub async fn add(db: &sqlx::SqlitePool, input: ProviderAddInput) -> crate::Result<ProviderRecord> {
117 let id = format!("provider-{}", Uuid::new_v4());
118 let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
119 let mapping_json = serde_json::to_string(&input.model_mapping)?;
120 let encrypted_key = crate::crypto::encrypt_secret("").map_err(CoreError::Internal)?;
124
125 sqlx::query!(
126 "INSERT INTO providers (id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at)
127 VALUES (?1, ?2, ?3, ?4, ?5, 0, ?6, ?6)",
128 id,
129 input.name,
130 input.base_url,
131 encrypted_key,
132 mapping_json,
133 now
134 )
135 .execute(db)
136 .await?;
137
138 Ok(ProviderRecord {
139 id,
140 name: input.name,
141 base_url: input.base_url,
142 api_key: None,
143 model_mapping: input.model_mapping,
144 is_active: false,
145 created_at: now.clone(),
146 updated_at: now,
147 })
148}
149
150pub async fn update(
151 db: &sqlx::SqlitePool,
152 input: ProviderUpdateInput,
153) -> crate::Result<ProviderRecord> {
154 let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
155
156 let row = sqlx::query_as!(
157 ProviderRow,
158 "SELECT id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at FROM providers WHERE id = ?1",
159 input.id
160 )
161 .fetch_optional(db)
162 .await?
163 .ok_or_else(|| CoreError::NotFound(format!(
164 "provider '{}' not found. List current providers with `difflore providers list`.",
165 input.id
166 )))?;
167
168 let mut provider = row.into_internal();
169
170 if let Some(name) = input.name {
171 provider.name = name;
172 }
173 if let Some(base_url) = input.base_url {
174 provider.base_url = base_url;
175 }
176 if let Some(mm) = input.model_mapping {
177 provider.model_mapping = mm;
178 }
179 provider.updated_at = now;
180
181 let mapping_json = serde_json::to_string(&provider.model_mapping)?;
182 let encrypted_secret = crate::crypto::encrypt_secret("").map_err(CoreError::Internal)?;
185
186 let result = sqlx::query!(
187 "UPDATE providers SET name=?1, base_url=?2, api_key=?3, model_mapping=?4, updated_at=?5 WHERE id=?6",
188 provider.name,
189 provider.base_url,
190 encrypted_secret,
191 mapping_json,
192 provider.updated_at,
193 provider.id
194 )
195 .execute(db)
196 .await?;
197 if result.rows_affected() == 0 {
198 return Err(CoreError::NotFound(format!(
199 "provider '{}' not found — cannot update. List current providers with `difflore providers list`.",
200 provider.id
201 )));
202 }
203
204 Ok(ProviderRecord {
205 id: provider.id,
206 name: provider.name,
207 base_url: provider.base_url,
208 api_key: None,
209 model_mapping: provider.model_mapping,
210 is_active: provider.is_active,
211 created_at: provider.created_at,
212 updated_at: provider.updated_at,
213 })
214}
215
216pub async fn remove(db: &sqlx::SqlitePool, input: ProviderRemoveInput) -> crate::Result<()> {
217 let result = sqlx::query!("DELETE FROM providers WHERE id = ?1", input.id)
218 .execute(db)
219 .await?;
220 if result.rows_affected() == 0 {
221 return Err(CoreError::NotFound(format!(
222 "provider '{}' not found. List current providers with `difflore providers list`.",
223 input.id
224 )));
225 }
226 Ok(())
227}
228
229pub async fn set_active(db: &sqlx::SqlitePool, input: ProviderSetActiveInput) -> crate::Result<()> {
230 let mut tx = db.begin().await?;
231 sqlx::query!("UPDATE providers SET is_active = 0")
232 .execute(&mut *tx)
233 .await?;
234 if input.is_active {
235 let result = sqlx::query!("UPDATE providers SET is_active = 1 WHERE id = ?1", input.id)
236 .execute(&mut *tx)
237 .await?;
238 if result.rows_affected() == 0 {
239 return Err(CoreError::NotFound(format!(
243 "provider '{}' not found. List current providers with `difflore providers list`.",
244 input.id
245 )));
246 }
247 }
248 tx.commit().await?;
249 Ok(())
250}
251
252pub async fn check_auth(input: CheckAuthInput) -> crate::Result<CheckAuthResult> {
253 let home = dirs::home_dir()
254 .ok_or_else(|| CoreError::Internal("cannot resolve home directory".into()))?;
255
256 let (detected, method, detail) = match input.engine.as_str() {
257 "claude" => {
258 let path = home.join(".claude").join(".credentials.json");
259 let found = path.exists();
260 (
261 found,
262 "config_file".to_owned(),
263 if found {
264 "Credentials file detected".to_owned()
265 } else {
266 "No credentials file detected".to_owned()
267 },
268 )
269 }
270 "codex" => {
271 let found = crate::env::var(crate::env::OPENAI_API_KEY).is_some();
272 (
273 found,
274 "env_var".to_owned(),
275 if found {
276 "OPENAI_API_KEY environment variable detected".to_owned()
277 } else {
278 "OPENAI_API_KEY not found in environment".to_owned()
279 },
280 )
281 }
282 "gemini" => {
283 let path = home.join(".gemini").join("credentials.json");
284 let found = path.exists();
285 (
286 found,
287 "config_file".to_owned(),
288 if found {
289 "Credentials file detected".to_owned()
290 } else {
291 "No credentials file detected".to_owned()
292 },
293 )
294 }
295 "cursor" => {
296 let path = home.join(".cursor");
297 let found = path.exists();
298 (
299 found,
300 "cli_config".to_owned(),
301 if found {
302 "Config directory detected".to_owned()
303 } else {
304 "Config directory not found".to_owned()
305 },
306 )
307 }
308 other => (
309 false,
310 "unknown".to_owned(),
311 format!("unsupported engine: {other}"),
312 ),
313 };
314
315 Ok(CheckAuthResult {
316 credential_detected: detected,
317 verified: false,
318 method,
319 detail,
320 })
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn mask_api_key_table() {
329 let cases: &[(&str, &str)] = &[
330 ("", "****"),
331 ("short", "****"),
332 ("12345678", "****"),
333 ("sk-ant-1234567890abcd", "sk-***bcd"),
334 ("abcdefghijk", "abc***ijk"),
335 ];
336 for (input, expected) in cases {
337 assert_eq!(mask_api_key(input), *expected, "input: {input}");
338 }
339 }
340
341 #[tokio::test]
342 async fn check_auth_unknown_engine_reports_unsupported() {
343 let res = check_auth(CheckAuthInput {
344 engine: "bogus-engine".into(),
345 })
346 .await
347 .expect("check_auth should not error for unknown engine");
348 assert!(!res.credential_detected);
349 assert!(!res.verified);
350 assert_eq!(res.method, "unknown");
351 assert!(res.detail.contains("unsupported"));
352 }
353
354 #[tokio::test]
355 async fn check_auth_codex_method_is_env_var() {
356 let res = check_auth(CheckAuthInput {
359 engine: "codex".into(),
360 })
361 .await
362 .unwrap();
363 assert_eq!(res.method, "env_var");
364 assert!(!res.verified);
365 }
366
367 #[tokio::test]
368 async fn check_auth_claude_uses_config_file_method() {
369 let res = check_auth(CheckAuthInput {
370 engine: "claude".into(),
371 })
372 .await
373 .unwrap();
374 assert_eq!(res.method, "config_file");
375 assert!(!res.verified);
376 }
377}