1use std::collections::HashMap;
22use std::path::{Path, PathBuf};
23
24use aes_gcm::aead::Aead;
25use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
26use rand::RngCore;
27use serde::{Deserialize, Serialize};
28use zeroize::Zeroizing;
29
30const VAULT_DIR: &str = "vault";
31const SECRETS_FILE: &str = "secrets.json";
32const KEYRING_LABEL: &str = "koi-vault-master";
33const NONCE_LEN: usize = 12;
34const MASTER_KEY_LEN: usize = 32;
35
36#[derive(Debug, thiserror::Error)]
39pub enum VaultError {
40 #[error("vault I/O error: {0}")]
41 Io(#[from] std::io::Error),
42 #[error("vault serialization error: {0}")]
43 Serialization(#[from] serde_json::Error),
44 #[error("vault encryption error: {0}")]
45 Encryption(String),
46 #[error("vault decryption error: {0}")]
47 Decryption(String),
48 #[error("vault master key error: {0}")]
49 MasterKey(String),
50}
51
52pub struct Vault {
54 vault_dir: PathBuf,
55 master_key: Zeroizing<[u8; MASTER_KEY_LEN]>,
56 backend_name: &'static str,
57}
58
59impl Vault {
60 pub fn open(data_dir: &Path) -> Result<Self, VaultError> {
65 let vault_dir = data_dir.join(VAULT_DIR);
66 std::fs::create_dir_all(&vault_dir)?;
67
68 let (master_key, backend_name) = if crate::tpm::is_available() {
69 match Self::load_or_create_keyring_master() {
70 Ok(key) => (key, "keyring"),
71 Err(e) => {
72 tracing::warn!("Keyring master key failed, falling back to machine-bound: {e}");
73 (Self::derive_machine_master()?, "machine-bound")
74 }
75 }
76 } else {
77 (Self::derive_machine_master()?, "machine-bound")
78 };
79
80 Ok(Self {
81 vault_dir,
82 master_key,
83 backend_name,
84 })
85 }
86
87 pub fn backend_name(&self) -> &'static str {
89 self.backend_name
90 }
91
92 pub fn store(&self, key: &str, value: &str) -> Result<(), VaultError> {
94 let mut secrets = self.load_secrets()?;
95 secrets
96 .entries
97 .insert(key.to_string(), self.encrypt(value)?);
98 self.save_secrets(&secrets)
99 }
100
101 pub fn retrieve(&self, key: &str) -> Result<Option<String>, VaultError> {
103 let secrets = self.load_secrets()?;
104 match secrets.entries.get(key) {
105 Some(entry) => Ok(Some(self.decrypt(entry)?)),
106 None => Ok(None),
107 }
108 }
109
110 pub fn delete(&self, key: &str) -> Result<(), VaultError> {
112 let mut secrets = self.load_secrets()?;
113 secrets.entries.remove(key);
114 self.save_secrets(&secrets)
115 }
116
117 pub fn list_keys(&self) -> Result<Vec<String>, VaultError> {
119 let secrets = self.load_secrets()?;
120 Ok(secrets.entries.keys().cloned().collect())
121 }
122
123 fn load_or_create_keyring_master() -> Result<Zeroizing<[u8; MASTER_KEY_LEN]>, VaultError> {
127 match crate::tpm::unseal_key_material(KEYRING_LABEL) {
128 Ok(data) if data.len() == MASTER_KEY_LEN => {
129 let mut key = Zeroizing::new([0u8; MASTER_KEY_LEN]);
130 key.copy_from_slice(&data);
131 Ok(key)
132 }
133 _ => {
134 let mut key = Zeroizing::new([0u8; MASTER_KEY_LEN]);
136 rand::rng().fill_bytes(key.as_mut());
137 crate::tpm::seal_key_material(KEYRING_LABEL, &*key)
138 .map_err(|e| VaultError::MasterKey(e.to_string()))?;
139 tracing::info!("Vault master key created and sealed in platform credential store");
140 Ok(key)
141 }
142 }
143 }
144
145 fn derive_machine_master() -> Result<Zeroizing<[u8; MASTER_KEY_LEN]>, VaultError> {
147 let machine_id = get_machine_id()
148 .map_err(|e| VaultError::MasterKey(format!("machine ID unavailable: {e}")))?;
149
150 let salt = sha2::Sha256::digest(format!("koi-vault-salt:{machine_id}").as_bytes());
151 let params = argon2::Params::new(65536, 3, 4, Some(MASTER_KEY_LEN))
152 .map_err(|e| VaultError::MasterKey(e.to_string()))?;
153 let argon2 =
154 argon2::Argon2::new(argon2::Algorithm::Argon2id, argon2::Version::V0x13, params);
155
156 let mut key = Zeroizing::new([0u8; MASTER_KEY_LEN]);
157 argon2
158 .hash_password_into(machine_id.as_bytes(), &salt[..16], key.as_mut())
159 .map_err(|e| VaultError::MasterKey(e.to_string()))?;
160 Ok(key)
161 }
162
163 fn encrypt(&self, plaintext: &str) -> Result<EncryptedEntry, VaultError> {
166 let cipher = Aes256Gcm::new_from_slice(&*self.master_key)
167 .map_err(|e| VaultError::Encryption(e.to_string()))?;
168
169 let mut nonce_bytes = [0u8; NONCE_LEN];
170 rand::rng().fill_bytes(&mut nonce_bytes);
171 let nonce = Nonce::from(nonce_bytes);
172
173 let ciphertext = cipher
174 .encrypt(&nonce, plaintext.as_bytes())
175 .map_err(|e| VaultError::Encryption(e.to_string()))?;
176
177 Ok(EncryptedEntry {
178 ciphertext,
179 nonce: nonce_bytes.to_vec(),
180 })
181 }
182
183 fn decrypt(&self, entry: &EncryptedEntry) -> Result<String, VaultError> {
184 let cipher = Aes256Gcm::new_from_slice(&*self.master_key)
185 .map_err(|e| VaultError::Decryption(e.to_string()))?;
186
187 let nonce_arr: [u8; NONCE_LEN] = entry
188 .nonce
189 .as_slice()
190 .try_into()
191 .map_err(|_| VaultError::Decryption("invalid nonce length".into()))?;
192 let nonce = Nonce::from(nonce_arr);
193
194 let plaintext = cipher
195 .decrypt(&nonce, entry.ciphertext.as_ref())
196 .map_err(|e| VaultError::Decryption(e.to_string()))?;
197
198 String::from_utf8(plaintext)
199 .map_err(|e| VaultError::Decryption(format!("not valid UTF-8: {e}")))
200 }
201
202 fn secrets_path(&self) -> PathBuf {
205 self.vault_dir.join(SECRETS_FILE)
206 }
207
208 fn load_secrets(&self) -> Result<SecretsFile, VaultError> {
209 let path = self.secrets_path();
210 if !path.exists() {
211 return Ok(SecretsFile {
212 version: 1,
213 entries: HashMap::new(),
214 });
215 }
216 let data = std::fs::read(&path)?;
217 Ok(serde_json::from_slice(&data)?)
218 }
219
220 fn save_secrets(&self, secrets: &SecretsFile) -> Result<(), VaultError> {
221 let data = serde_json::to_vec_pretty(secrets)?;
222 let path = self.secrets_path();
223 std::fs::write(&path, &data)?;
224
225 #[cfg(unix)]
226 {
227 use std::os::unix::fs::PermissionsExt;
228 let _ = std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600));
229 }
230
231 Ok(())
232 }
233}
234
235#[derive(Serialize, Deserialize)]
238struct SecretsFile {
239 version: u8,
240 entries: HashMap<String, EncryptedEntry>,
241}
242
243#[derive(Serialize, Deserialize)]
244struct EncryptedEntry {
245 ciphertext: Vec<u8>,
246 nonce: Vec<u8>,
247}
248
249use sha2::Digest;
252
253fn get_machine_id() -> Result<String, String> {
254 #[cfg(target_os = "linux")]
255 {
256 std::fs::read_to_string("/etc/machine-id")
257 .or_else(|_| std::fs::read_to_string("/var/lib/dbus/machine-id"))
258 .map(|s| s.trim().to_string())
259 .map_err(|e| e.to_string())
260 }
261
262 #[cfg(target_os = "windows")]
263 {
264 let output = std::process::Command::new("reg")
266 .args([
267 "query",
268 r"HKLM\SOFTWARE\Microsoft\Cryptography",
269 "/v",
270 "MachineGuid",
271 ])
272 .output()
273 .map_err(|e| e.to_string())?;
274 let stdout = String::from_utf8_lossy(&output.stdout);
275 stdout
276 .lines()
277 .find_map(|line| {
278 let parts: Vec<&str> = line.split_whitespace().collect();
279 if parts.len() >= 3 && parts[0] == "MachineGuid" {
280 Some(parts[2].to_string())
281 } else {
282 None
283 }
284 })
285 .ok_or_else(|| "MachineGuid not found in registry".to_string())
286 }
287
288 #[cfg(target_os = "macos")]
289 {
290 let output = std::process::Command::new("ioreg")
291 .args(["-rd1", "-c", "IOPlatformExpertDevice"])
292 .output()
293 .map_err(|e| e.to_string())?;
294 let stdout = String::from_utf8_lossy(&output.stdout);
295 stdout
296 .lines()
297 .find(|line| line.contains("IOPlatformUUID"))
298 .and_then(|line| line.split('"').nth(3))
299 .map(|s| s.to_string())
300 .ok_or_else(|| "IOPlatformUUID not found".to_string())
301 }
302}
303
304#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn round_trip_store_retrieve() {
312 let tmp = tempfile::tempdir().unwrap();
313 let vault = Vault::open(tmp.path()).unwrap();
314
315 vault.store("db-password", "s3cret!").unwrap();
316 assert_eq!(
317 vault.retrieve("db-password").unwrap(),
318 Some("s3cret!".to_string())
319 );
320
321 vault.store("api-key", "tok_abc123").unwrap();
322 let keys = vault.list_keys().unwrap();
323 assert!(keys.contains(&"db-password".to_string()));
324 assert!(keys.contains(&"api-key".to_string()));
325
326 vault.delete("db-password").unwrap();
327 assert_eq!(vault.retrieve("db-password").unwrap(), None);
328 }
329
330 #[test]
331 fn retrieve_missing_returns_none() {
332 let tmp = tempfile::tempdir().unwrap();
333 let vault = Vault::open(tmp.path()).unwrap();
334 assert_eq!(vault.retrieve("nonexistent").unwrap(), None);
335 }
336
337 #[test]
338 fn overwrite_replaces_value() {
339 let tmp = tempfile::tempdir().unwrap();
340 let vault = Vault::open(tmp.path()).unwrap();
341
342 vault.store("key", "v1").unwrap();
343 vault.store("key", "v2").unwrap();
344 assert_eq!(vault.retrieve("key").unwrap(), Some("v2".to_string()));
345 }
346
347 #[test]
348 fn persistence_across_open() {
349 let _ = koi_common::test::ensure_data_dir("koi-vault-persist-tests");
353 let tmp = tempfile::tempdir().unwrap();
354 {
355 let vault = Vault::open(tmp.path()).unwrap();
356 vault.store("persist-test", "hello").unwrap();
357 }
358 {
359 let vault = Vault::open(tmp.path()).unwrap();
360 assert_eq!(
361 vault.retrieve("persist-test").unwrap(),
362 Some("hello".to_string())
363 );
364 }
365 }
366}