claude_code_switcher/
credentials.rs

1//! Credential management module for Claude Code Switcher
2//!
3//! This module provides functionality to save and retrieve API keys for different AI providers.
4//! Credentials are stored in plain text since they're typically managed through environment variables.
5//!
6//! Version management strategy:
7//! - Current version: v2 (simplified from previous encryption-based approach)
8//! - Future versions should increment the version number when format changes are needed
9
10use anyhow::{Result, anyhow};
11use chrono::Utc;
12use inquire::{Confirm, Select, Text};
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::path::PathBuf;
16use uuid::Uuid;
17
18use crate::TemplateType;
19
20/// Current credential data format version
21pub const CURRENT_CREDENTIAL_VERSION: &str = "v2";
22
23/// Core credential data structure
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct CredentialData {
26    /// Data format version for compatibility
27    pub version: String,
28    /// Unique identifier for the credential
29    pub id: String,
30    /// User-friendly name for the credential
31    pub name: String,
32    /// API key in plain text
33    pub api_key: String,
34    /// Template type this credential is associated with
35    pub template_type: TemplateType,
36    /// Creation timestamp in UTC
37    pub created_at: String,
38    /// Last update timestamp in UTC
39    pub updated_at: String,
40    /// Optional metadata for future extensibility
41    pub metadata: Option<std::collections::HashMap<String, String>>,
42}
43
44impl Default for CredentialData {
45    fn default() -> Self {
46        let now = Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
47        Self {
48            version: CURRENT_CREDENTIAL_VERSION.to_string(),
49            id: Uuid::new_v4().to_string(),
50            name: String::new(),
51            api_key: String::new(),
52            template_type: TemplateType::KatCoder,
53            created_at: now.clone(),
54            updated_at: now,
55            metadata: None,
56        }
57    }
58}
59
60impl CredentialData {
61    /// Create a new credential
62    pub fn new(name: String, api_key: String, template_type: TemplateType) -> Self {
63        let now = Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
64        Self {
65            version: CURRENT_CREDENTIAL_VERSION.to_string(),
66            id: Uuid::new_v4().to_string(),
67            name,
68            api_key,
69            template_type,
70            created_at: now.clone(),
71            updated_at: now,
72            metadata: None,
73        }
74    }
75
76    /// Update the timestamp to current time
77    pub fn update_timestamp(&mut self) {
78        self.updated_at = Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
79    }
80
81    /// Get credential ID
82    pub fn id(&self) -> &str {
83        &self.id
84    }
85
86    /// Get credential name
87    pub fn name(&self) -> &str {
88        &self.name
89    }
90
91    /// Get API key
92    pub fn api_key(&self) -> &str {
93        &self.api_key
94    }
95
96    /// Get template type
97    pub fn template_type(&self) -> &TemplateType {
98        &self.template_type
99    }
100
101    /// Get creation timestamp
102    pub fn created_at(&self) -> &str {
103        &self.created_at
104    }
105
106    /// Get update timestamp
107    pub fn updated_at(&self) -> &str {
108        &self.updated_at
109    }
110
111    /// Get metadata
112    pub fn metadata(&self) -> Option<&std::collections::HashMap<String, String>> {
113        self.metadata.as_ref()
114    }
115
116    /// Update metadata
117    pub fn set_metadata(&mut self, metadata: std::collections::HashMap<String, String>) {
118        self.metadata = Some(metadata);
119        self.update_timestamp();
120    }
121
122    /// Get a specific metadata value
123    pub fn get_metadata(&self, key: &str) -> Option<String> {
124        self.metadata.as_ref()?.get(key).cloned()
125    }
126
127    /// Set a specific metadata value
128    pub fn set_metadata_value(&mut self, key: String, value: String) {
129        if let Some(ref mut metadata) = self.metadata {
130            metadata.insert(key, value);
131        } else {
132            let mut new_metadata = std::collections::HashMap::new();
133            new_metadata.insert(key, value);
134            self.metadata = Some(new_metadata);
135        }
136        self.update_timestamp();
137    }
138}
139
140/// Result type for credential operations
141pub type SavedCredential = CredentialData;
142
143/// Storage backend for credential files
144pub struct SavedCredentialStore {
145    pub credentials_dir: PathBuf,
146}
147
148impl SavedCredentialStore {
149    /// Create a new credential store with default directory
150    pub fn new() -> Result<Self> {
151        let home_dir = dirs::home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
152        let credentials_dir = home_dir.join(".claude").join("credentials");
153
154        let store = Self { credentials_dir };
155        store.ensure_dir()?;
156        Ok(store)
157    }
158
159    /// Create a new credential store with custom directory (for backward compatibility)
160    pub fn new_with_dir(credentials_dir: PathBuf) -> Self {
161        Self { credentials_dir }
162    }
163
164    /// Ensure the credentials directory exists
165    pub fn ensure_dir(&self) -> Result<()> {
166        if !self.credentials_dir.exists() {
167            fs::create_dir_all(&self.credentials_dir)
168                .map_err(|e| anyhow!("Failed to create credentials directory: {}", e))?;
169        }
170        Ok(())
171    }
172
173    /// Get the file path for a credential
174    pub fn credential_path(&self, credential_id: &str) -> PathBuf {
175        self.credentials_dir.join(format!("{}.json", credential_id))
176    }
177
178    /// Save a credential to disk
179    pub fn save(&self, credential: &CredentialData) -> Result<()> {
180        self.ensure_dir()?;
181        let path = self.credential_path(&credential.id);
182
183        let content = serde_json::to_string_pretty(credential)
184            .map_err(|e| anyhow!("Failed to serialize credential: {}", e))?;
185
186        fs::write(&path, content)
187            .map_err(|e| anyhow!("Failed to write credential file {}: {}", path.display(), e))?;
188
189        Ok(())
190    }
191
192    /// Load a credential from disk
193    pub fn load(&self, credential_id: &str) -> Result<SavedCredential> {
194        let path = self.credential_path(credential_id);
195
196        if !path.exists() {
197            return Err(anyhow!("Credential '{}' not found", credential_id));
198        }
199
200        let content = fs::read_to_string(&path)
201            .map_err(|e| anyhow!("Failed to read credential file {}: {}", path.display(), e))?;
202
203        // Parse as current format
204        serde_json::from_str::<CredentialData>(&content)
205            .map_err(|e| anyhow!("Failed to parse credential file {}: {}", path.display(), e))
206    }
207
208    /// List all saved credentials
209    pub fn list(&self) -> Result<Vec<SavedCredential>> {
210        self.ensure_dir()?;
211
212        let mut credentials = Vec::new();
213
214        let entries = fs::read_dir(&self.credentials_dir)
215            .map_err(|e| anyhow!("Failed to read credentials directory: {}", e))?;
216
217        for entry in entries {
218            let entry = entry.map_err(|e| anyhow!("Failed to read directory entry: {}", e))?;
219            let path = entry.path();
220
221            if path.extension().and_then(|s| s.to_str()) == Some("json") {
222                let credential_id = path
223                    .file_stem()
224                    .and_then(|s| s.to_str())
225                    .ok_or_else(|| anyhow!("Invalid credential file name: {}", path.display()))?;
226
227                match self.load(credential_id) {
228                    Ok(credential) => credentials.push(credential),
229                    Err(e) => {
230                        // Log the error but continue loading other credentials
231                        eprintln!(
232                            "Warning: Failed to load credential '{}': {}",
233                            credential_id, e
234                        );
235                    }
236                }
237            }
238        }
239
240        // Sort by creation time (newest first)
241        credentials.sort_by(|a, b| b.created_at().cmp(a.created_at()));
242
243        Ok(credentials)
244    }
245
246    /// Delete a credential
247    pub fn delete(&self, credential_id: &str) -> Result<()> {
248        let path = self.credential_path(credential_id);
249
250        if !path.exists() {
251            return Err(anyhow!("Credential '{}' not found", credential_id));
252        }
253
254        fs::remove_file(&path)
255            .map_err(|e| anyhow!("Failed to delete credential file {}: {}", path.display(), e))?;
256
257        Ok(())
258    }
259
260    /// Check if a credential exists
261    pub fn exists(&self, credential_id: &str) -> bool {
262        self.credential_path(credential_id).exists()
263    }
264
265    /// Get all credential names
266    pub fn list_names(&self) -> Result<Vec<String>> {
267        let credentials = self.list()?;
268        Ok(credentials
269            .into_iter()
270            .map(|c| c.name().to_string())
271            .collect())
272    }
273
274    /// Find credentials by template type
275    pub fn find_by_template_type(
276        &self,
277        template_type: &TemplateType,
278    ) -> Result<Vec<SavedCredential>> {
279        let credentials = self.list()?;
280        Ok(credentials
281            .into_iter()
282            .filter(|c| c.template_type() == template_type)
283            .collect())
284    }
285}
286
287/// High-level credential management
288pub struct CredentialStore {
289    pub store: SavedCredentialStore,
290}
291
292impl CredentialStore {
293    /// Create a new credential store
294    pub fn new() -> Result<Self> {
295        Ok(Self {
296            store: SavedCredentialStore::new()?,
297        })
298    }
299
300    /// Create and save a new credential
301    pub fn create_credential(
302        &self,
303        name: String,
304        api_key: &str,
305        template_type: TemplateType,
306    ) -> Result<SavedCredential> {
307        let credential = CredentialData::new(name, api_key.to_string(), template_type);
308        self.store.save(&credential)?;
309        Ok(credential)
310    }
311
312    /// Get the API key from a credential
313    pub fn get_api_key(&self, credential: &SavedCredential) -> Result<String> {
314        Ok(credential.api_key().to_string())
315    }
316
317    /// Check if API key already exists for this template type
318    pub fn has_api_key(&self, api_key: &str, template_type: &TemplateType) -> bool {
319        if let Ok(credentials) = self.store.find_by_template_type(template_type) {
320            for credential in credentials {
321                if credential.api_key() == api_key {
322                    return true;
323                }
324            }
325        }
326        false
327    }
328
329    /// Get saved endpoint IDs for a template type (from credential metadata)
330    pub fn get_endpoint_ids(&self, template_type: &TemplateType) -> Vec<(String, String)> {
331        let mut endpoint_ids = Vec::new();
332        if let Ok(credentials) = self.store.find_by_template_type(template_type) {
333            for credential in credentials {
334                if let Some(endpoint_id) = credential.get_metadata("endpoint_id") {
335                    let name = format!("{} - {}", credential.name(), endpoint_id);
336                    endpoint_ids.push((name, endpoint_id));
337                }
338            }
339        }
340        endpoint_ids
341    }
342
343    /// Save endpoint ID to credential metadata
344    pub fn save_endpoint_id(&self, credential_id: &str, endpoint_id: &str) -> Result<()> {
345        let mut credential = self.store.load(credential_id)?;
346        credential.set_metadata_value("endpoint_id".to_string(), endpoint_id.to_string());
347        self.store.save(&credential)?;
348        Ok(())
349    }
350
351    /// Check if endpoint ID exists
352    pub fn has_endpoint_id(&self, endpoint_id: &str, template_type: &TemplateType) -> bool {
353        if let Ok(credentials) = self.store.find_by_template_type(template_type) {
354            for credential in credentials {
355                if let Some(saved_endpoint) = credential.get_metadata("endpoint_id")
356                    && saved_endpoint == endpoint_id
357                {
358                    return true;
359                }
360            }
361        }
362        false
363    }
364
365    /// Update credential name
366    pub fn update_name(&self, credential_id: &str, new_name: String) -> Result<()> {
367        let mut credential = self.store.load(credential_id)?;
368        credential.name = new_name;
369        credential.update_timestamp();
370        self.store.save(&credential)?;
371        Ok(())
372    }
373
374    /// Update credential metadata
375    pub fn update_metadata(
376        &self,
377        credential_id: &str,
378        metadata: std::collections::HashMap<String, String>,
379    ) -> Result<()> {
380        let mut credential = self.store.load(credential_id)?;
381        credential.set_metadata(metadata);
382        self.store.save(&credential)?;
383        Ok(())
384    }
385}
386
387impl crate::CredentialManager for CredentialStore {
388    fn save_credential(
389        &self,
390        name: String,
391        api_key: &str,
392        template_type: TemplateType,
393    ) -> Result<()> {
394        self.create_credential(name, api_key, template_type)?;
395        Ok(())
396    }
397
398    fn load_credentials(&self) -> Result<Vec<SavedCredential>> {
399        self.store.list()
400    }
401
402    fn delete_credential(&self, credential_id: &str) -> Result<()> {
403        self.store.delete(credential_id)
404    }
405
406    fn clear_credentials(&self) -> Result<()> {
407        let credentials = self.store.list()?;
408        for credential in credentials {
409            self.store.delete(credential.id())?;
410        }
411        Ok(())
412    }
413}
414
415/// Helper function to select a credential from a list
416pub fn select_credential<'a>(
417    credentials: &'a [SavedCredential],
418    message: &str,
419) -> Result<&'a SavedCredential> {
420    let options: Vec<String> = credentials
421        .iter()
422        .map(|c| {
423            format!(
424                "{} ({} - {})",
425                c.name(),
426                c.template_type(),
427                mask_api_key(c.api_key())
428            )
429        })
430        .collect();
431
432    let selected = Select::new(message, options.clone())
433        .prompt()
434        .map_err(|e| anyhow!("Failed to select credential: {}", e))?;
435
436    let index = options.iter().position(|o| o == &selected).unwrap();
437    Ok(&credentials[index])
438}
439
440/// Prompt user to save a credential interactively
441pub fn prompt_save_credential(
442    api_key: &str,
443    template_type: TemplateType,
444) -> Result<Option<SavedCredential>> {
445    if let Ok(should_save) = Confirm::new("Would you like to save this API key for future use?")
446        .with_default(true)
447        .prompt()
448        && should_save
449    {
450        let name = Text::new("Enter a name for this credential:")
451            .with_placeholder(&format!("{} API Key", template_type))
452            .prompt()
453            .map_err(|e| anyhow!("Failed to get credential name: {}", e))?;
454
455        let store = CredentialStore::new()?;
456        let credential = store.create_credential(name, api_key, template_type)?;
457
458        println!("✓ Credential saved successfully!");
459        return Ok(Some(credential));
460    }
461    Ok(None)
462}
463
464/// Get API key interactively using simple selector
465pub fn get_api_key_interactively(template_type: TemplateType) -> Result<String> {
466    // First, try to get API key from environment variables
467    let env_var_name = crate::templates::get_env_var_name(&template_type);
468    if let Ok(api_key) = std::env::var(env_var_name)
469        && !api_key.trim().is_empty()
470    {
471        println!("✓ Using API key from environment variable {}", env_var_name);
472        return Ok(api_key);
473    }
474
475    // Get saved credentials
476    let credentials = if let Ok(credential_store) = CredentialStore::new() {
477        credential_store
478            .store
479            .find_by_template_type(&template_type)
480            .unwrap_or_default()
481    } else {
482        Vec::new()
483    };
484
485    // Use simple selector
486    let mut selector =
487        crate::simple_selector::SimpleCredentialSelector::new(credentials, template_type.clone());
488
489    match selector.run()? {
490        Some(api_key) => {
491            // Auto-save the credential if it's new
492            if let Ok(credential_store) = CredentialStore::new()
493                && !credential_store.has_api_key(&api_key, &template_type)
494            {
495                let default_name = format!("{} API Key", template_type);
496                if credential_store
497                    .create_credential(default_name, &api_key, template_type)
498                    .is_ok()
499                {
500                    println!("✓ API key saved automatically for future use.");
501                }
502            }
503            Ok(api_key)
504        }
505        None => Err(anyhow!("No API key selected")),
506    }
507}
508
509/// Mask API key for display (show first 4 and last 4 characters)
510fn mask_api_key(api_key: &str) -> String {
511    if api_key.len() <= 8 {
512        "••••••••".to_string()
513    } else {
514        format!(
515            "{}{}{}",
516            &api_key[..4],
517            "•".repeat(api_key.len() - 8),
518            &api_key[api_key.len() - 4..]
519        )
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    fn create_test_store() -> CredentialStore {
527        let temp_dir = std::env::temp_dir().join("ccs_test");
528        let store = SavedCredentialStore {
529            credentials_dir: temp_dir,
530        };
531        CredentialStore { store }
532    }
533
534    #[test]
535    fn test_credential_creation() {
536        let credential = CredentialData::new(
537            "test".to_string(),
538            "test-key".to_string(),
539            TemplateType::KatCoder,
540        );
541
542        assert_eq!(credential.name(), "test");
543        assert_eq!(credential.api_key(), "test-key");
544        assert_eq!(credential.version, CURRENT_CREDENTIAL_VERSION);
545    }
546
547    #[test]
548    fn test_credential_save_and_load() {
549        let store = create_test_store();
550
551        let credential = store
552            .create_credential("test".to_string(), "test-key", TemplateType::KatCoder)
553            .unwrap();
554
555        let loaded = store.store.load(&credential.id()).unwrap();
556        assert_eq!(credential.name(), loaded.name());
557        assert_eq!(credential.api_key(), loaded.api_key());
558    }
559
560    #[test]
561    fn test_mask_api_key() {
562        assert_eq!(mask_api_key("sk-1234567890"), "sk-1•••••7890");
563        assert_eq!(mask_api_key("short"), "••••••••");
564    }
565}