1use crate::credential::{OAuthCredential, OAuthCredentialStorage};
2use crate::error::OAuthError;
3use age::scrypt::Identity;
4use age::secrecy::SecretString;
5use age::{Decryptor, Encryptor};
6use async_trait::async_trait;
7use dirs::home_dir;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::env::var;
11use std::io::{Read, Write};
12use std::iter::once;
13use std::path::{Path, PathBuf};
14use std::sync::Mutex;
15
16const DEFAULT_PASSWORD_ENV: &str = "AETHER_CREDENTIALS_PASSWORD";
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19struct CredentialStore {
20 credentials: HashMap<String, OAuthCredential>,
21}
22
23pub struct EncryptedFileOAuthCredentialStorage {
24 path: PathBuf,
25 passphrase: String,
26 write_guard: Mutex<()>,
27}
28
29impl EncryptedFileOAuthCredentialStorage {
30 pub fn new(path: PathBuf, passphrase: String) -> Self {
31 Self { path, passphrase, write_guard: Mutex::new(()) }
32 }
33
34 pub fn from_settings(path: Option<PathBuf>, password_env: Option<&str>) -> Result<Self, OAuthError> {
35 let path = path.map_or_else(default_path, Ok)?;
36 let env_var = password_env.unwrap_or(DEFAULT_PASSWORD_ENV);
37 let passphrase = var(env_var).ok().filter(|pass| !pass.is_empty()).ok_or_else(|| {
38 OAuthError::CredentialStore(format!(
39 "Encrypted file credential store requires a passphrase. \
40 Set the {env_var} environment variable or configure a custom `passwordEnv` in settings."
41 ))
42 })?;
43
44 Ok(Self::new(path, passphrase))
45 }
46
47 fn encrypt(plaintext: &[u8], passphrase: &str) -> Result<Vec<u8>, OAuthError> {
48 let fail = |e| OAuthError::CredentialStore(format!("Encryption failed: {e}"));
49
50 let mut ciphertext = Vec::new();
51 let mut writer = Encryptor::with_user_passphrase(SecretString::from(passphrase))
52 .wrap_output(&mut ciphertext)
53 .map_err(fail)?;
54 writer.write_all(plaintext).map_err(fail)?;
55 writer.finish().map_err(fail)?;
56
57 Ok(ciphertext)
58 }
59
60 fn decrypt(ciphertext: &[u8], passphrase: &str) -> Result<Vec<u8>, OAuthError> {
61 let decryptor = Decryptor::new(ciphertext)
62 .map_err(|e| OAuthError::CredentialStore(format!("Invalid encrypted file: {e}")))?;
63
64 let mut reader =
65 decryptor.decrypt(once(&Identity::new(SecretString::from(passphrase)) as &dyn age::Identity)).map_err(
66 |e| OAuthError::CredentialStore(format!("Decryption failed — wrong passphrase or corrupted file: {e}")),
67 )?;
68
69 let mut plaintext = Vec::new();
70 reader
71 .read_to_end(&mut plaintext)
72 .map_err(|e| OAuthError::CredentialStore(format!("Decryption failed: {e}")))?;
73
74 Ok(plaintext)
75 }
76
77 fn load(&self) -> Result<CredentialStore, OAuthError> {
78 if !self.path.exists() {
79 return Ok(CredentialStore { credentials: HashMap::new() });
80 }
81
82 let bytes = std::fs::read(&self.path)?;
83 if bytes.is_empty() {
84 return Ok(CredentialStore { credentials: HashMap::new() });
85 }
86
87 let plaintext = Self::decrypt(&bytes, &self.passphrase)?;
88 serde_json::from_slice(&plaintext)
89 .map_err(|e| OAuthError::CredentialStore(format!("Invalid credential data: {e}")))
90 }
91
92 fn update(&self, mutate: impl FnOnce(&mut CredentialStore)) -> Result<(), OAuthError> {
93 let _guard = self
94 .write_guard
95 .lock()
96 .map_err(|_| OAuthError::CredentialStore("Failed to acquire write lock on credential store".to_string()))?;
97
98 let mut store = self.load()?;
99 mutate(&mut store);
100
101 let plaintext = serde_json::to_vec(&store)
102 .map_err(|e| OAuthError::CredentialStore(format!("Failed to serialize credentials: {e}")))?;
103
104 let ciphertext = Self::encrypt(&plaintext, &self.passphrase)?;
105 write_atomic(&self.path, &ciphertext)
106 }
107}
108
109fn default_path() -> Result<PathBuf, OAuthError> {
110 home_dir().map(|home| home.join(".aether").join("credentials.enc")).ok_or_else(|| {
111 OAuthError::CredentialStore(
112 "Could not determine the home directory for the encrypted credential file".to_string(),
113 )
114 })
115}
116
117fn write_atomic(path: &Path, data: &[u8]) -> Result<(), OAuthError> {
118 if let Some(parent) = path.parent() {
119 std::fs::create_dir_all(parent)?;
120 }
121
122 let temp_path = path.with_extension("tmp");
123
124 {
125 let mut file = std::fs::File::create(&temp_path)?;
126 file.write_all(data)?;
127 file.sync_all()?;
128 }
129
130 std::fs::rename(&temp_path, path)?;
131
132 Ok(())
133}
134
135#[async_trait]
136impl OAuthCredentialStorage for EncryptedFileOAuthCredentialStorage {
137 async fn load_credential(&self, key: &str) -> Result<Option<OAuthCredential>, OAuthError> {
138 let store = self.load()?;
139 Ok(store.credentials.get(key).cloned())
140 }
141
142 async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
143 self.update(|store| {
144 store.credentials.insert(key.to_string(), credential);
145 })
146 }
147
148 async fn delete_credential(&self, key: &str) -> Result<(), OAuthError> {
149 self.update(|store| {
150 store.credentials.remove(key);
151 })
152 }
153
154 fn has_credential(&self, key: &str) -> bool {
155 self.load().is_ok_and(|store| store.credentials.contains_key(key))
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[tokio::test]
164 async fn save_then_load_round_trips() {
165 let store = temp_store("correct-passphrase");
166 let cred = test_credential();
167
168 store.save_credential("server-1", cred.clone()).await.unwrap();
169 let loaded = store.load_credential("server-1").await.unwrap().unwrap();
170
171 assert_eq!(loaded.client_id, "client_1");
172 assert_eq!(loaded.access_token, "tok_abc");
173 assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
174 assert_eq!(loaded.granted_scopes, vec!["scope1"]);
175 }
176
177 #[tokio::test]
178 async fn load_returns_none_for_missing_key() {
179 let store = temp_store("pass");
180 assert!(store.load_credential("nonexistent").await.unwrap().is_none());
181 }
182
183 #[tokio::test]
184 async fn delete_removes_credential() {
185 let store = temp_store("pass");
186 store.save_credential("server-1", test_credential()).await.unwrap();
187 assert!(store.has_credential("server-1"));
188
189 store.delete_credential("server-1").await.unwrap();
190 assert!(!store.has_credential("server-1"));
191 }
192
193 #[tokio::test]
194 async fn wrong_passphrase_fails_to_load() {
195 let store = temp_store("correct-pass");
196 store.save_credential("server-1", test_credential()).await.unwrap();
197
198 let wrong_store = EncryptedFileOAuthCredentialStorage::new(store.path.clone(), "wrong-pass".to_string());
199
200 let err = wrong_store.load_credential("server-1").await.unwrap_err();
201 let msg = err.to_string();
202 assert!(msg.contains("Decryption failed"), "Expected decryption error, got: {msg}");
203 }
204
205 #[tokio::test]
206 async fn multiple_credentials_are_isolated() {
207 let store = temp_store("pass");
208
209 let cred_a = OAuthCredential {
210 client_id: "a".to_string(),
211 access_token: "token_a".to_string(),
212 refresh_token: None,
213 expires_at: None,
214 granted_scopes: vec![],
215 };
216 let cred_b = OAuthCredential {
217 client_id: "b".to_string(),
218 access_token: "token_b".to_string(),
219 refresh_token: None,
220 expires_at: None,
221 granted_scopes: vec![],
222 };
223
224 store.save_credential("server-a", cred_a).await.unwrap();
225 store.save_credential("server-b", cred_b).await.unwrap();
226
227 let loaded_a = store.load_credential("server-a").await.unwrap().unwrap();
228 let loaded_b = store.load_credential("server-b").await.unwrap().unwrap();
229
230 assert_eq!(loaded_a.access_token, "token_a");
231 assert_eq!(loaded_b.access_token, "token_b");
232 }
233
234 #[tokio::test]
235 async fn save_overwrites_existing_credential() {
236 let store = temp_store("pass");
237 let cred_v1 = OAuthCredential {
238 client_id: "c".to_string(),
239 access_token: "v1".to_string(),
240 refresh_token: None,
241 expires_at: None,
242 granted_scopes: vec![],
243 };
244 let cred_v2 = OAuthCredential {
245 client_id: "c".to_string(),
246 access_token: "v2".to_string(),
247 refresh_token: None,
248 expires_at: None,
249 granted_scopes: vec![],
250 };
251
252 store.save_credential("server", cred_v1).await.unwrap();
253 store.save_credential("server", cred_v2).await.unwrap();
254
255 let loaded = store.load_credential("server").await.unwrap().unwrap();
256 assert_eq!(loaded.access_token, "v2");
257 }
258
259 #[test]
260 fn encrypt_decrypt_round_trips() {
261 let plaintext = b"hello, world!";
262 let ciphertext = EncryptedFileOAuthCredentialStorage::encrypt(plaintext, "passphrase").unwrap();
263 let decrypted = EncryptedFileOAuthCredentialStorage::decrypt(&ciphertext, "passphrase").unwrap();
264 assert_eq!(decrypted, plaintext);
265 }
266
267 fn test_credential() -> OAuthCredential {
268 OAuthCredential {
269 client_id: "client_1".to_string(),
270 access_token: "tok_abc".to_string(),
271 refresh_token: Some("ref_xyz".to_string()),
272 expires_at: Some(9_999_999_999_999),
273 granted_scopes: vec!["scope1".to_string()],
274 }
275 }
276
277 fn temp_store(passphrase: &str) -> EncryptedFileOAuthCredentialStorage {
278 let dir = tempfile::tempdir().unwrap();
279 let path = dir.keep().join("creds.enc");
280 EncryptedFileOAuthCredentialStorage::new(path, passphrase.to_string())
281 }
282}