Skip to main content

difflore_core/domain/
providers.rs

1use uuid::Uuid;
2
3use crate::errors::CoreError;
4use crate::models::{
5    ProviderAddInput, ProviderRecord, ProviderRemoveInput, ProviderSetActiveInput,
6    ProviderUpdateInput,
7};
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct CheckAuthInput {
11    pub engine: String,
12}
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct CheckAuthResult {
17    pub credential_detected: bool,
18    pub verified: bool,
19    pub method: String,
20    pub detail: String,
21}
22
23#[derive(sqlx::FromRow)]
24struct ProviderRow {
25    id: String,
26    name: String,
27    base_url: String,
28    api_key: String,
29    model_mapping: String,
30    is_active: i64,
31    created_at: String,
32    updated_at: String,
33}
34
35impl ProviderRow {
36    fn decrypt_api_key(&self) -> String {
37        match crate::crypto::decrypt_secret(&self.api_key) {
38            Ok(plaintext) => plaintext,
39            Err(e) => {
40                eprintln!(
41                    "Failed to decrypt API key for provider {}: {e}. Returning empty string to avoid leaking ciphertext.",
42                    self.id
43                );
44                String::new()
45            }
46        }
47    }
48
49    fn into_masked(self) -> ProviderRecord {
50        let decrypted = self.decrypt_api_key();
51        ProviderRecord {
52            id: self.id,
53            name: self.name,
54            base_url: self.base_url,
55            api_key: Some(mask_api_key(&decrypted)),
56            model_mapping: serde_json::from_str(&self.model_mapping).unwrap_or_default(),
57            is_active: self.is_active != 0,
58            created_at: self.created_at,
59            updated_at: self.updated_at,
60        }
61    }
62
63    fn into_internal(self) -> ProviderRecord {
64        let decrypted = self.decrypt_api_key();
65        ProviderRecord {
66            id: self.id,
67            name: self.name,
68            base_url: self.base_url,
69            api_key: Some(decrypted),
70            model_mapping: serde_json::from_str(&self.model_mapping).unwrap_or_default(),
71            is_active: self.is_active != 0,
72            created_at: self.created_at,
73            updated_at: self.updated_at,
74        }
75    }
76}
77
78fn mask_api_key(key: &str) -> String {
79    // Count chars, not bytes. Real API keys are ASCII so byte vs char
80    // length usually agrees, but a user-pasted key with stray emoji
81    // or a non-ASCII paste artefact would otherwise panic on the
82    // `&key[..3]` byte slice when 3 falls inside a multi-byte rune.
83    let n = key.chars().count();
84    if n <= 8 {
85        return "****".to_owned();
86    }
87    let head: String = key.chars().take(3).collect();
88    let tail: String = key.chars().skip(n - 3).collect();
89    format!("{head}***{tail}")
90}
91
92pub async fn list(db: &sqlx::SqlitePool) -> crate::Result<Vec<ProviderRecord>> {
93    let rows = sqlx::query_as!(
94        ProviderRow,
95        "SELECT id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at FROM providers ORDER BY created_at DESC"
96    )
97    .fetch_all(db)
98    .await?;
99    Ok(rows.into_iter().map(ProviderRow::into_masked).collect())
100}
101
102pub async fn get(
103    db: &sqlx::SqlitePool,
104    input: ProviderRemoveInput,
105) -> crate::Result<Option<ProviderRecord>> {
106    let row = sqlx::query_as!(
107        ProviderRow,
108        "SELECT id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at FROM providers WHERE id = ?1",
109        input.id
110    )
111    .fetch_optional(db)
112    .await?;
113    Ok(row.map(ProviderRow::into_masked))
114}
115
116pub async fn add(db: &sqlx::SqlitePool, input: ProviderAddInput) -> crate::Result<ProviderRecord> {
117    let id = format!("provider-{}", Uuid::new_v4());
118    let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
119    let mapping_json = serde_json::to_string(&input.model_mapping)?;
120    // BYOK has been removed from the local CLI. Provider rows now only
121    // describe an agent-cli sentinel (`agent-cli://<tool>`); the column
122    // stays for back-compat with older DBs but is always written empty.
123    let encrypted_key = crate::crypto::encrypt_secret("").map_err(CoreError::Internal)?;
124
125    sqlx::query!(
126        "INSERT INTO providers (id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at)
127         VALUES (?1, ?2, ?3, ?4, ?5, 0, ?6, ?6)",
128        id,
129        input.name,
130        input.base_url,
131        encrypted_key,
132        mapping_json,
133        now
134    )
135    .execute(db)
136    .await?;
137
138    Ok(ProviderRecord {
139        id,
140        name: input.name,
141        base_url: input.base_url,
142        api_key: None,
143        model_mapping: input.model_mapping,
144        is_active: false,
145        created_at: now.clone(),
146        updated_at: now,
147    })
148}
149
150pub async fn update(
151    db: &sqlx::SqlitePool,
152    input: ProviderUpdateInput,
153) -> crate::Result<ProviderRecord> {
154    let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
155
156    let row = sqlx::query_as!(
157        ProviderRow,
158        "SELECT id, name, base_url, api_key, model_mapping, is_active, created_at, updated_at FROM providers WHERE id = ?1",
159        input.id
160    )
161    .fetch_optional(db)
162    .await?
163    .ok_or_else(|| CoreError::NotFound(format!(
164        "provider '{}' not found. List current providers with `difflore providers list`.",
165        input.id
166    )))?;
167
168    let mut provider = row.into_internal();
169
170    if let Some(name) = input.name {
171        provider.name = name;
172    }
173    if let Some(base_url) = input.base_url {
174        provider.base_url = base_url;
175    }
176    if let Some(mm) = input.model_mapping {
177        provider.model_mapping = mm;
178    }
179    provider.updated_at = now;
180
181    let mapping_json = serde_json::to_string(&provider.model_mapping)?;
182    // BYOK has been removed; the api_key column is left in place for
183    // older schemas but always overwritten with an encrypted empty string.
184    let encrypted_secret = crate::crypto::encrypt_secret("").map_err(CoreError::Internal)?;
185
186    let result = sqlx::query!(
187        "UPDATE providers SET name=?1, base_url=?2, api_key=?3, model_mapping=?4, updated_at=?5 WHERE id=?6",
188        provider.name,
189        provider.base_url,
190        encrypted_secret,
191        mapping_json,
192        provider.updated_at,
193        provider.id
194    )
195    .execute(db)
196    .await?;
197    if result.rows_affected() == 0 {
198        return Err(CoreError::NotFound(format!(
199            "provider '{}' not found — cannot update. List current providers with `difflore providers list`.",
200            provider.id
201        )));
202    }
203
204    Ok(ProviderRecord {
205        id: provider.id,
206        name: provider.name,
207        base_url: provider.base_url,
208        api_key: None,
209        model_mapping: provider.model_mapping,
210        is_active: provider.is_active,
211        created_at: provider.created_at,
212        updated_at: provider.updated_at,
213    })
214}
215
216pub async fn remove(db: &sqlx::SqlitePool, input: ProviderRemoveInput) -> crate::Result<()> {
217    let result = sqlx::query!("DELETE FROM providers WHERE id = ?1", input.id)
218        .execute(db)
219        .await?;
220    if result.rows_affected() == 0 {
221        return Err(CoreError::NotFound(format!(
222            "provider '{}' not found. List current providers with `difflore providers list`.",
223            input.id
224        )));
225    }
226    Ok(())
227}
228
229pub async fn set_active(db: &sqlx::SqlitePool, input: ProviderSetActiveInput) -> crate::Result<()> {
230    let mut tx = db.begin().await?;
231    sqlx::query!("UPDATE providers SET is_active = 0")
232        .execute(&mut *tx)
233        .await?;
234    if input.is_active {
235        let result = sqlx::query!("UPDATE providers SET is_active = 1 WHERE id = ?1", input.id)
236            .execute(&mut *tx)
237            .await?;
238        if result.rows_affected() == 0 {
239            // Match the actionable message style used by `remove`: name
240            // the bad id and tell the user where to look up the right
241            // ones, instead of bare "provider".
242            return Err(CoreError::NotFound(format!(
243                "provider '{}' not found. List current providers with `difflore providers list`.",
244                input.id
245            )));
246        }
247    }
248    tx.commit().await?;
249    Ok(())
250}
251
252pub async fn check_auth(input: CheckAuthInput) -> crate::Result<CheckAuthResult> {
253    let home = dirs::home_dir()
254        .ok_or_else(|| CoreError::Internal("cannot resolve home directory".into()))?;
255
256    let (detected, method, detail) = match input.engine.as_str() {
257        "claude" => {
258            let path = home.join(".claude").join(".credentials.json");
259            let found = path.exists();
260            (
261                found,
262                "config_file".to_owned(),
263                if found {
264                    "Credentials file detected".to_owned()
265                } else {
266                    "No credentials file detected".to_owned()
267                },
268            )
269        }
270        "codex" => {
271            let found = crate::env::var(crate::env::OPENAI_API_KEY).is_some();
272            (
273                found,
274                "env_var".to_owned(),
275                if found {
276                    "OPENAI_API_KEY environment variable detected".to_owned()
277                } else {
278                    "OPENAI_API_KEY not found in environment".to_owned()
279                },
280            )
281        }
282        "gemini" => {
283            let path = home.join(".gemini").join("credentials.json");
284            let found = path.exists();
285            (
286                found,
287                "config_file".to_owned(),
288                if found {
289                    "Credentials file detected".to_owned()
290                } else {
291                    "No credentials file detected".to_owned()
292                },
293            )
294        }
295        "cursor" => {
296            let path = home.join(".cursor");
297            let found = path.exists();
298            (
299                found,
300                "cli_config".to_owned(),
301                if found {
302                    "Config directory detected".to_owned()
303                } else {
304                    "Config directory not found".to_owned()
305                },
306            )
307        }
308        other => (
309            false,
310            "unknown".to_owned(),
311            format!("unsupported engine: {other}"),
312        ),
313    };
314
315    Ok(CheckAuthResult {
316        credential_detected: detected,
317        verified: false,
318        method,
319        detail,
320    })
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn mask_api_key_table() {
329        let cases: &[(&str, &str)] = &[
330            ("", "****"),
331            ("short", "****"),
332            ("12345678", "****"),
333            ("sk-ant-1234567890abcd", "sk-***bcd"),
334            ("abcdefghijk", "abc***ijk"),
335        ];
336        for (input, expected) in cases {
337            assert_eq!(mask_api_key(input), *expected, "input: {input}");
338        }
339    }
340
341    #[tokio::test]
342    async fn check_auth_unknown_engine_reports_unsupported() {
343        let res = check_auth(CheckAuthInput {
344            engine: "bogus-engine".into(),
345        })
346        .await
347        .expect("check_auth should not error for unknown engine");
348        assert!(!res.credential_detected);
349        assert!(!res.verified);
350        assert_eq!(res.method, "unknown");
351        assert!(res.detail.contains("unsupported"));
352    }
353
354    #[tokio::test]
355    async fn check_auth_codex_method_is_env_var() {
356        // Just validate the method/shape — don't mutate env to avoid races with
357        // other parallel tests.
358        let res = check_auth(CheckAuthInput {
359            engine: "codex".into(),
360        })
361        .await
362        .unwrap();
363        assert_eq!(res.method, "env_var");
364        assert!(!res.verified);
365    }
366
367    #[tokio::test]
368    async fn check_auth_claude_uses_config_file_method() {
369        let res = check_auth(CheckAuthInput {
370            engine: "claude".into(),
371        })
372        .await
373        .unwrap();
374        assert_eq!(res.method, "config_file");
375        assert!(!res.verified);
376    }
377}