Skip to main content

aether_auth/
encrypted_file.rs

1use crate::credential::{OAuthCredential, OAuthCredentialStorage};
2use crate::error::OAuthError;
3use age::scrypt::Identity;
4use age::secrecy::SecretString;
5use age::{Decryptor, Encryptor};
6use async_trait::async_trait;
7use dirs::home_dir;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::env::var;
11use std::io::{Read, Write};
12use std::iter::once;
13use std::path::{Path, PathBuf};
14use std::sync::Mutex;
15
16const DEFAULT_PASSWORD_ENV: &str = "AETHER_CREDENTIALS_PASSWORD";
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19struct CredentialStore {
20    credentials: HashMap<String, OAuthCredential>,
21}
22
23pub struct EncryptedFileOAuthCredentialStorage {
24    path: PathBuf,
25    passphrase: String,
26    write_guard: Mutex<()>,
27}
28
29impl EncryptedFileOAuthCredentialStorage {
30    pub fn new(path: PathBuf, passphrase: String) -> Self {
31        Self { path, passphrase, write_guard: Mutex::new(()) }
32    }
33
34    pub fn from_settings(path: Option<PathBuf>, password_env: Option<&str>) -> Result<Self, OAuthError> {
35        let path = path.map_or_else(default_path, Ok)?;
36        let env_var = password_env.unwrap_or(DEFAULT_PASSWORD_ENV);
37        let passphrase = var(env_var).ok().filter(|pass| !pass.is_empty()).ok_or_else(|| {
38            OAuthError::CredentialStore(format!(
39                "Encrypted file credential store requires a passphrase. \
40                     Set the {env_var} environment variable or configure a custom `passwordEnv` in settings."
41            ))
42        })?;
43
44        Ok(Self::new(path, passphrase))
45    }
46
47    fn encrypt(plaintext: &[u8], passphrase: &str) -> Result<Vec<u8>, OAuthError> {
48        let fail = |e| OAuthError::CredentialStore(format!("Encryption failed: {e}"));
49
50        let mut ciphertext = Vec::new();
51        let mut writer = Encryptor::with_user_passphrase(SecretString::from(passphrase))
52            .wrap_output(&mut ciphertext)
53            .map_err(fail)?;
54        writer.write_all(plaintext).map_err(fail)?;
55        writer.finish().map_err(fail)?;
56
57        Ok(ciphertext)
58    }
59
60    fn decrypt(ciphertext: &[u8], passphrase: &str) -> Result<Vec<u8>, OAuthError> {
61        let decryptor = Decryptor::new(ciphertext)
62            .map_err(|e| OAuthError::CredentialStore(format!("Invalid encrypted file: {e}")))?;
63
64        let mut reader =
65            decryptor.decrypt(once(&Identity::new(SecretString::from(passphrase)) as &dyn age::Identity)).map_err(
66                |e| OAuthError::CredentialStore(format!("Decryption failed — wrong passphrase or corrupted file: {e}")),
67            )?;
68
69        let mut plaintext = Vec::new();
70        reader
71            .read_to_end(&mut plaintext)
72            .map_err(|e| OAuthError::CredentialStore(format!("Decryption failed: {e}")))?;
73
74        Ok(plaintext)
75    }
76
77    fn load(&self) -> Result<CredentialStore, OAuthError> {
78        if !self.path.exists() {
79            return Ok(CredentialStore { credentials: HashMap::new() });
80        }
81
82        let bytes = std::fs::read(&self.path)?;
83        if bytes.is_empty() {
84            return Ok(CredentialStore { credentials: HashMap::new() });
85        }
86
87        let plaintext = Self::decrypt(&bytes, &self.passphrase)?;
88        serde_json::from_slice(&plaintext)
89            .map_err(|e| OAuthError::CredentialStore(format!("Invalid credential data: {e}")))
90    }
91
92    fn update(&self, mutate: impl FnOnce(&mut CredentialStore)) -> Result<(), OAuthError> {
93        let _guard = self
94            .write_guard
95            .lock()
96            .map_err(|_| OAuthError::CredentialStore("Failed to acquire write lock on credential store".to_string()))?;
97
98        let mut store = self.load()?;
99        mutate(&mut store);
100
101        let plaintext = serde_json::to_vec(&store)
102            .map_err(|e| OAuthError::CredentialStore(format!("Failed to serialize credentials: {e}")))?;
103
104        let ciphertext = Self::encrypt(&plaintext, &self.passphrase)?;
105        write_atomic(&self.path, &ciphertext)
106    }
107}
108
109fn default_path() -> Result<PathBuf, OAuthError> {
110    home_dir().map(|home| home.join(".aether").join("credentials.enc")).ok_or_else(|| {
111        OAuthError::CredentialStore(
112            "Could not determine the home directory for the encrypted credential file".to_string(),
113        )
114    })
115}
116
117fn write_atomic(path: &Path, data: &[u8]) -> Result<(), OAuthError> {
118    if let Some(parent) = path.parent() {
119        std::fs::create_dir_all(parent)?;
120    }
121
122    let temp_path = path.with_extension("tmp");
123
124    {
125        let mut file = std::fs::File::create(&temp_path)?;
126        file.write_all(data)?;
127        file.sync_all()?;
128    }
129
130    std::fs::rename(&temp_path, path)?;
131
132    Ok(())
133}
134
135#[async_trait]
136impl OAuthCredentialStorage for EncryptedFileOAuthCredentialStorage {
137    async fn load_credential(&self, key: &str) -> Result<Option<OAuthCredential>, OAuthError> {
138        let store = self.load()?;
139        Ok(store.credentials.get(key).cloned())
140    }
141
142    async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
143        self.update(|store| {
144            store.credentials.insert(key.to_string(), credential);
145        })
146    }
147
148    async fn delete_credential(&self, key: &str) -> Result<(), OAuthError> {
149        self.update(|store| {
150            store.credentials.remove(key);
151        })
152    }
153
154    fn has_credential(&self, key: &str) -> bool {
155        self.load().is_ok_and(|store| store.credentials.contains_key(key))
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[tokio::test]
164    async fn save_then_load_round_trips() {
165        let store = temp_store("correct-passphrase");
166        let cred = test_credential();
167
168        store.save_credential("server-1", cred.clone()).await.unwrap();
169        let loaded = store.load_credential("server-1").await.unwrap().unwrap();
170
171        assert_eq!(loaded.client_id, "client_1");
172        assert_eq!(loaded.access_token, "tok_abc");
173        assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
174        assert_eq!(loaded.granted_scopes, vec!["scope1"]);
175    }
176
177    #[tokio::test]
178    async fn load_returns_none_for_missing_key() {
179        let store = temp_store("pass");
180        assert!(store.load_credential("nonexistent").await.unwrap().is_none());
181    }
182
183    #[tokio::test]
184    async fn delete_removes_credential() {
185        let store = temp_store("pass");
186        store.save_credential("server-1", test_credential()).await.unwrap();
187        assert!(store.has_credential("server-1"));
188
189        store.delete_credential("server-1").await.unwrap();
190        assert!(!store.has_credential("server-1"));
191    }
192
193    #[tokio::test]
194    async fn wrong_passphrase_fails_to_load() {
195        let store = temp_store("correct-pass");
196        store.save_credential("server-1", test_credential()).await.unwrap();
197
198        let wrong_store = EncryptedFileOAuthCredentialStorage::new(store.path.clone(), "wrong-pass".to_string());
199
200        let err = wrong_store.load_credential("server-1").await.unwrap_err();
201        let msg = err.to_string();
202        assert!(msg.contains("Decryption failed"), "Expected decryption error, got: {msg}");
203    }
204
205    #[tokio::test]
206    async fn multiple_credentials_are_isolated() {
207        let store = temp_store("pass");
208
209        let cred_a = OAuthCredential {
210            client_id: "a".to_string(),
211            access_token: "token_a".to_string(),
212            refresh_token: None,
213            expires_at: None,
214            granted_scopes: vec![],
215        };
216        let cred_b = OAuthCredential {
217            client_id: "b".to_string(),
218            access_token: "token_b".to_string(),
219            refresh_token: None,
220            expires_at: None,
221            granted_scopes: vec![],
222        };
223
224        store.save_credential("server-a", cred_a).await.unwrap();
225        store.save_credential("server-b", cred_b).await.unwrap();
226
227        let loaded_a = store.load_credential("server-a").await.unwrap().unwrap();
228        let loaded_b = store.load_credential("server-b").await.unwrap().unwrap();
229
230        assert_eq!(loaded_a.access_token, "token_a");
231        assert_eq!(loaded_b.access_token, "token_b");
232    }
233
234    #[tokio::test]
235    async fn save_overwrites_existing_credential() {
236        let store = temp_store("pass");
237        let cred_v1 = OAuthCredential {
238            client_id: "c".to_string(),
239            access_token: "v1".to_string(),
240            refresh_token: None,
241            expires_at: None,
242            granted_scopes: vec![],
243        };
244        let cred_v2 = OAuthCredential {
245            client_id: "c".to_string(),
246            access_token: "v2".to_string(),
247            refresh_token: None,
248            expires_at: None,
249            granted_scopes: vec![],
250        };
251
252        store.save_credential("server", cred_v1).await.unwrap();
253        store.save_credential("server", cred_v2).await.unwrap();
254
255        let loaded = store.load_credential("server").await.unwrap().unwrap();
256        assert_eq!(loaded.access_token, "v2");
257    }
258
259    #[test]
260    fn encrypt_decrypt_round_trips() {
261        let plaintext = b"hello, world!";
262        let ciphertext = EncryptedFileOAuthCredentialStorage::encrypt(plaintext, "passphrase").unwrap();
263        let decrypted = EncryptedFileOAuthCredentialStorage::decrypt(&ciphertext, "passphrase").unwrap();
264        assert_eq!(decrypted, plaintext);
265    }
266
267    fn test_credential() -> OAuthCredential {
268        OAuthCredential {
269            client_id: "client_1".to_string(),
270            access_token: "tok_abc".to_string(),
271            refresh_token: Some("ref_xyz".to_string()),
272            expires_at: Some(9_999_999_999_999),
273            granted_scopes: vec!["scope1".to_string()],
274        }
275    }
276
277    fn temp_store(passphrase: &str) -> EncryptedFileOAuthCredentialStorage {
278        let dir = tempfile::tempdir().unwrap();
279        let path = dir.keep().join("creds.enc");
280        EncryptedFileOAuthCredentialStorage::new(path, passphrase.to_string())
281    }
282}