claude_code_switcher/
credentials.rs

1//! Credential management module for Claude Code Switcher
2//!
3//! This module provides functionality to save and retrieve API keys for different AI providers.
4//! Credentials are stored in plain text since they're typically managed through environment variables.
5//!
6//! Version management strategy:
7//! - Current version: v2 (simplified from previous encryption-based approach)
8//! - Future versions should increment the version number when format changes are needed
9
10use anyhow::{Result, anyhow};
11use chrono::Utc;
12use inquire::{Confirm, Select, Text};
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::path::PathBuf;
16use uuid::Uuid;
17
18use crate::TemplateType;
19
20/// Current credential data format version
21pub const CURRENT_CREDENTIAL_VERSION: &str = "v2";
22
23/// Core credential data structure
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct CredentialData {
26    /// Data format version for compatibility
27    pub version: String,
28    /// Unique identifier for the credential
29    pub id: String,
30    /// User-friendly name for the credential
31    pub name: String,
32    /// API key in plain text
33    pub api_key: String,
34    /// Template type this credential is associated with
35    pub template_type: TemplateType,
36    /// Creation timestamp in UTC
37    pub created_at: String,
38    /// Last update timestamp in UTC
39    pub updated_at: String,
40    /// Optional metadata for future extensibility
41    pub metadata: Option<std::collections::HashMap<String, String>>,
42}
43
44impl Default for CredentialData {
45    fn default() -> Self {
46        let now = Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
47        Self {
48            version: CURRENT_CREDENTIAL_VERSION.to_string(),
49            id: Uuid::new_v4().to_string(),
50            name: String::new(),
51            api_key: String::new(),
52            template_type: TemplateType::KatCoder,
53            created_at: now.clone(),
54            updated_at: now,
55            metadata: None,
56        }
57    }
58}
59
60impl CredentialData {
61    /// Create a new credential
62    pub fn new(name: String, api_key: String, template_type: TemplateType) -> Self {
63        let now = Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
64        Self {
65            version: CURRENT_CREDENTIAL_VERSION.to_string(),
66            id: Uuid::new_v4().to_string(),
67            name,
68            api_key,
69            template_type,
70            created_at: now.clone(),
71            updated_at: now,
72            metadata: None,
73        }
74    }
75
76    /// Update the timestamp to current time
77    pub fn update_timestamp(&mut self) {
78        self.updated_at = Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
79    }
80
81    /// Get credential ID
82    pub fn id(&self) -> &str {
83        &self.id
84    }
85
86    /// Get credential name
87    pub fn name(&self) -> &str {
88        &self.name
89    }
90
91    /// Get API key
92    pub fn api_key(&self) -> &str {
93        &self.api_key
94    }
95
96    /// Get template type
97    pub fn template_type(&self) -> &TemplateType {
98        &self.template_type
99    }
100
101    /// Get creation timestamp
102    pub fn created_at(&self) -> &str {
103        &self.created_at
104    }
105
106    /// Get update timestamp
107    pub fn updated_at(&self) -> &str {
108        &self.updated_at
109    }
110
111    /// Get metadata
112    pub fn metadata(&self) -> Option<&std::collections::HashMap<String, String>> {
113        self.metadata.as_ref()
114    }
115
116    /// Update metadata
117    pub fn set_metadata(&mut self, metadata: std::collections::HashMap<String, String>) {
118        self.metadata = Some(metadata);
119        self.update_timestamp();
120    }
121}
122
123/// Result type for credential operations
124pub type SavedCredential = CredentialData;
125
126/// Storage backend for credential files
127pub struct SavedCredentialStore {
128    pub credentials_dir: PathBuf,
129}
130
131impl SavedCredentialStore {
132    /// Create a new credential store with default directory
133    pub fn new() -> Result<Self> {
134        let home_dir = dirs::home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
135        let credentials_dir = home_dir.join(".claude").join("credentials");
136
137        let store = Self { credentials_dir };
138        store.ensure_dir()?;
139        Ok(store)
140    }
141
142    /// Create a new credential store with custom directory (for backward compatibility)
143    pub fn new_with_dir(credentials_dir: PathBuf) -> Self {
144        Self { credentials_dir }
145    }
146
147    /// Ensure the credentials directory exists
148    pub fn ensure_dir(&self) -> Result<()> {
149        if !self.credentials_dir.exists() {
150            fs::create_dir_all(&self.credentials_dir)
151                .map_err(|e| anyhow!("Failed to create credentials directory: {}", e))?;
152        }
153        Ok(())
154    }
155
156    /// Get the file path for a credential
157    pub fn credential_path(&self, credential_id: &str) -> PathBuf {
158        self.credentials_dir.join(format!("{}.json", credential_id))
159    }
160
161    /// Save a credential to disk
162    pub fn save(&self, credential: &CredentialData) -> Result<()> {
163        self.ensure_dir()?;
164        let path = self.credential_path(&credential.id);
165
166        let content = serde_json::to_string_pretty(credential)
167            .map_err(|e| anyhow!("Failed to serialize credential: {}", e))?;
168
169        fs::write(&path, content)
170            .map_err(|e| anyhow!("Failed to write credential file {}: {}", path.display(), e))?;
171
172        Ok(())
173    }
174
175    /// Load a credential from disk
176    pub fn load(&self, credential_id: &str) -> Result<SavedCredential> {
177        let path = self.credential_path(credential_id);
178
179        if !path.exists() {
180            return Err(anyhow!("Credential '{}' not found", credential_id));
181        }
182
183        let content = fs::read_to_string(&path)
184            .map_err(|e| anyhow!("Failed to read credential file {}: {}", path.display(), e))?;
185
186        // Parse as current format
187        serde_json::from_str::<CredentialData>(&content)
188            .map_err(|e| anyhow!("Failed to parse credential file {}: {}", path.display(), e))
189    }
190
191    /// List all saved credentials
192    pub fn list(&self) -> Result<Vec<SavedCredential>> {
193        self.ensure_dir()?;
194
195        let mut credentials = Vec::new();
196
197        let entries = fs::read_dir(&self.credentials_dir)
198            .map_err(|e| anyhow!("Failed to read credentials directory: {}", e))?;
199
200        for entry in entries {
201            let entry = entry.map_err(|e| anyhow!("Failed to read directory entry: {}", e))?;
202            let path = entry.path();
203
204            if path.extension().and_then(|s| s.to_str()) == Some("json") {
205                let credential_id = path
206                    .file_stem()
207                    .and_then(|s| s.to_str())
208                    .ok_or_else(|| anyhow!("Invalid credential file name: {}", path.display()))?;
209
210                match self.load(credential_id) {
211                    Ok(credential) => credentials.push(credential),
212                    Err(e) => {
213                        // Log the error but continue loading other credentials
214                        eprintln!(
215                            "Warning: Failed to load credential '{}': {}",
216                            credential_id, e
217                        );
218                    }
219                }
220            }
221        }
222
223        // Sort by creation time (newest first)
224        credentials.sort_by(|a, b| b.created_at().cmp(a.created_at()));
225
226        Ok(credentials)
227    }
228
229    /// Delete a credential
230    pub fn delete(&self, credential_id: &str) -> Result<()> {
231        let path = self.credential_path(credential_id);
232
233        if !path.exists() {
234            return Err(anyhow!("Credential '{}' not found", credential_id));
235        }
236
237        fs::remove_file(&path)
238            .map_err(|e| anyhow!("Failed to delete credential file {}: {}", path.display(), e))?;
239
240        Ok(())
241    }
242
243    /// Check if a credential exists
244    pub fn exists(&self, credential_id: &str) -> bool {
245        self.credential_path(credential_id).exists()
246    }
247
248    /// Get all credential names
249    pub fn list_names(&self) -> Result<Vec<String>> {
250        let credentials = self.list()?;
251        Ok(credentials
252            .into_iter()
253            .map(|c| c.name().to_string())
254            .collect())
255    }
256
257    /// Find credentials by template type
258    pub fn find_by_template_type(
259        &self,
260        template_type: &TemplateType,
261    ) -> Result<Vec<SavedCredential>> {
262        let credentials = self.list()?;
263        Ok(credentials
264            .into_iter()
265            .filter(|c| c.template_type() == template_type)
266            .collect())
267    }
268}
269
270/// High-level credential management
271pub struct CredentialStore {
272    pub store: SavedCredentialStore,
273}
274
275impl CredentialStore {
276    /// Create a new credential store
277    pub fn new() -> Result<Self> {
278        Ok(Self {
279            store: SavedCredentialStore::new()?,
280        })
281    }
282
283    /// Create and save a new credential
284    pub fn create_credential(
285        &self,
286        name: String,
287        api_key: &str,
288        template_type: TemplateType,
289    ) -> Result<SavedCredential> {
290        let credential = CredentialData::new(name, api_key.to_string(), template_type);
291        self.store.save(&credential)?;
292        Ok(credential)
293    }
294
295    /// Get the API key from a credential
296    pub fn get_api_key(&self, credential: &SavedCredential) -> Result<String> {
297        Ok(credential.api_key().to_string())
298    }
299
300    /// Update credential name
301    pub fn update_name(&self, credential_id: &str, new_name: String) -> Result<()> {
302        let mut credential = self.store.load(credential_id)?;
303        credential.name = new_name;
304        credential.update_timestamp();
305        self.store.save(&credential)?;
306        Ok(())
307    }
308
309    /// Update credential metadata
310    pub fn update_metadata(
311        &self,
312        credential_id: &str,
313        metadata: std::collections::HashMap<String, String>,
314    ) -> Result<()> {
315        let mut credential = self.store.load(credential_id)?;
316        credential.set_metadata(metadata);
317        self.store.save(&credential)?;
318        Ok(())
319    }
320}
321
322impl crate::CredentialManager for CredentialStore {
323    fn save_credential(
324        &self,
325        name: String,
326        api_key: &str,
327        template_type: TemplateType,
328    ) -> Result<()> {
329        self.create_credential(name, api_key, template_type)?;
330        Ok(())
331    }
332
333    fn load_credentials(&self) -> Result<Vec<SavedCredential>> {
334        self.store.list()
335    }
336
337    fn delete_credential(&self, credential_id: &str) -> Result<()> {
338        self.store.delete(credential_id)
339    }
340
341    fn clear_credentials(&self) -> Result<()> {
342        let credentials = self.store.list()?;
343        for credential in credentials {
344            self.store.delete(&credential.id())?;
345        }
346        Ok(())
347    }
348}
349
350/// Helper function to select a credential from a list
351pub fn select_credential<'a>(
352    credentials: &'a [SavedCredential],
353    message: &str,
354) -> Result<&'a SavedCredential> {
355    let options: Vec<String> = credentials
356        .iter()
357        .map(|c| {
358            format!(
359                "{} ({} - {})",
360                c.name(),
361                c.template_type(),
362                mask_api_key(c.api_key())
363            )
364        })
365        .collect();
366
367    let selected = Select::new(message, options.clone())
368        .prompt()
369        .map_err(|e| anyhow!("Failed to select credential: {}", e))?;
370
371    let index = options.iter().position(|o| o == &selected).unwrap();
372    Ok(&credentials[index])
373}
374
375/// Prompt user to save a credential interactively
376pub fn prompt_save_credential(
377    api_key: &str,
378    template_type: TemplateType,
379) -> Result<Option<SavedCredential>> {
380    if let Ok(should_save) = Confirm::new("Would you like to save this API key for future use?")
381        .with_default(true)
382        .prompt()
383    {
384        if should_save {
385            let name = Text::new("Enter a name for this credential:")
386                .with_placeholder(&format!("{} API Key", template_type))
387                .prompt()
388                .map_err(|e| anyhow!("Failed to get credential name: {}", e))?;
389
390            let store = CredentialStore::new()?;
391            let credential = store.create_credential(name, api_key, template_type)?;
392
393            println!("✓ Credential saved successfully!");
394            return Ok(Some(credential));
395        }
396    }
397    Ok(None)
398}
399
400/// Get API key interactively with option to save
401pub fn get_api_key_interactively(template_type: TemplateType) -> Result<String> {
402    // Try to use saved credentials first
403    if let Ok(credential_store) = CredentialStore::new() {
404        if let Ok(credentials) = credential_store.store.find_by_template_type(&template_type) {
405            if !credentials.is_empty() {
406                println!("Found saved credentials for {}:", template_type);
407
408                for credential in &credentials {
409                    println!(
410                        "  • {}: {} ({})",
411                        STYLE_CYAN.apply_to(credential.name()),
412                        mask_api_key(credential.api_key()),
413                        credential.created_at()
414                    );
415                }
416
417                if let Ok(continue_use) = Confirm::new("Use one of these saved credentials?")
418                    .with_default(true)
419                    .prompt()
420                {
421                    if continue_use {
422                        if let Ok(selected) =
423                            select_credential(&credentials, "Select a credential:")
424                        {
425                            return credential_store.get_api_key(&selected);
426                        }
427                    }
428                }
429            }
430        }
431    }
432
433    // If no saved credentials or user chooses not to use them, prompt for API key
434    let prompt_text = format!("Enter your {} API key:", template_type);
435    let api_key = Text::new(&prompt_text)
436        .with_placeholder("sk-...")
437        .prompt()
438        .map_err(|e| anyhow!("Failed to get API key: {}", e))?;
439
440    // Offer to save the credential
441    if let Some(_) = prompt_save_credential(&api_key, template_type)? {
442        // Credential was saved
443    }
444
445    Ok(api_key)
446}
447
448/// Mask API key for display (show first 4 and last 4 characters)
449fn mask_api_key(api_key: &str) -> String {
450    if api_key.len() <= 8 {
451        "••••••••".to_string()
452    } else {
453        format!(
454            "{}{}{}",
455            &api_key[..4],
456            "•".repeat(api_key.len() - 8),
457            &api_key[api_key.len() - 4..]
458        )
459    }
460}
461
462use console::Style;
463const STYLE_CYAN: Style = Style::new().cyan();
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    fn create_test_store() -> CredentialStore {
469        let temp_dir = std::env::temp_dir().join("ccs_test");
470        let store = SavedCredentialStore {
471            credentials_dir: temp_dir,
472        };
473        CredentialStore { store }
474    }
475
476    #[test]
477    fn test_credential_creation() {
478        let credential = CredentialData::new(
479            "test".to_string(),
480            "test-key".to_string(),
481            TemplateType::KatCoder,
482        );
483
484        assert_eq!(credential.name(), "test");
485        assert_eq!(credential.api_key(), "test-key");
486        assert_eq!(credential.version, CURRENT_CREDENTIAL_VERSION);
487    }
488
489    #[test]
490    fn test_credential_save_and_load() {
491        let store = create_test_store();
492
493        let credential = store
494            .create_credential("test".to_string(), "test-key", TemplateType::KatCoder)
495            .unwrap();
496
497        let loaded = store.store.load(&credential.id()).unwrap();
498        assert_eq!(credential.name(), loaded.name());
499        assert_eq!(credential.api_key(), loaded.api_key());
500    }
501
502    #[test]
503    fn test_mask_api_key() {
504        assert_eq!(mask_api_key("sk-1234567890"), "sk-1•••••7890");
505        assert_eq!(mask_api_key("short"), "••••••••");
506    }
507}