1use crate::{EntityCategory, PseudoToken};
10use aes_gcm::{
11 aead::{Aead, KeyInit, OsRng},
12 Aes256Gcm, Nonce,
13};
14use anyhow::{bail, Context, Result};
15use rand::RngCore;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use zeroize::Zeroize;
19
20pub struct Vault {
22 forward: HashMap<String, PseudoToken>,
24 reverse: HashMap<String, SensitiveString>,
26 counters: HashMap<String, u32>,
28 path: Option<String>,
30 key: SensitiveBytes,
32}
33
34#[derive(Clone, Serialize, Deserialize)]
36pub struct SensitiveString(String);
37
38impl Drop for SensitiveString {
39 fn drop(&mut self) {
40 self.0.zeroize();
41 }
42}
43
44impl std::fmt::Debug for SensitiveString {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "[REDACTED]")
47 }
48}
49
50pub struct SensitiveBytes(Vec<u8>);
52
53impl Drop for SensitiveBytes {
54 fn drop(&mut self) {
55 self.0.zeroize();
56 }
57}
58
59#[derive(Serialize, Deserialize)]
61struct VaultData {
62 forward: Vec<(String, StoredToken)>,
63 counters: HashMap<String, u32>,
64}
65
66#[derive(Serialize, Deserialize)]
67struct StoredToken {
68 token: String,
69 category: EntityCategory,
70 id: u32,
71 original: String,
72}
73
74impl Vault {
75 pub fn open(path: &str, key: Vec<u8>) -> Result<Self> {
77 if key.len() != 32 {
78 bail!("Vault key must be exactly 32 bytes (AES-256)");
79 }
80 if std::path::Path::new(path).exists() {
81 Self::load(path, &key)
82 } else {
83 Ok(Self {
84 forward: HashMap::new(),
85 reverse: HashMap::new(),
86 counters: HashMap::new(),
87 path: Some(path.to_string()),
88 key: SensitiveBytes(key),
89 })
90 }
91 }
92
93 pub fn ephemeral() -> Self {
95 let mut key = vec![0u8; 32];
96 rand::rngs::OsRng.fill_bytes(&mut key);
97 Self {
98 forward: HashMap::new(),
99 reverse: HashMap::new(),
100 counters: HashMap::new(),
101 path: None,
102 key: SensitiveBytes(key),
103 }
104 }
105
106 pub fn get_or_create(&mut self, original: &str, category: &EntityCategory) -> PseudoToken {
108 if let Some(token) = self.forward.get(original) {
109 return token.clone();
110 }
111
112 let prefix = Self::category_prefix(category);
113 let counter = self.counters.entry(prefix.clone()).or_insert(0);
114 *counter += 1;
115
116 let token = PseudoToken {
117 token: format!("{}_{}", prefix, counter),
118 category: category.clone(),
119 id: *counter,
120 };
121
122 self.forward.insert(original.to_string(), token.clone());
123 self.reverse.insert(
124 token.token.clone(),
125 SensitiveString(original.to_string()),
126 );
127
128 token
129 }
130
131 pub fn lookup(&self, token: &str) -> Option<&str> {
133 self.reverse.get(token).map(|s| s.0.as_str())
134 }
135
136 pub fn reverse_mappings(&self) -> HashMap<String, String> {
138 self.reverse
139 .iter()
140 .map(|(k, v)| (k.clone(), v.0.clone()))
141 .collect()
142 }
143
144 pub fn save(&self) -> Result<()> {
146 let path = match &self.path {
147 Some(p) => p,
148 None => return Ok(()), };
150
151 let data = self.to_vault_data();
152 let json = serde_json::to_vec(&data).context("Failed to serialize vault")?;
153
154 let encrypted = self.encrypt(&json)?;
155
156 let tmp_path = format!("{}.tmp", path);
158 if let Some(parent) = std::path::Path::new(path).parent() {
159 std::fs::create_dir_all(parent).context("Failed to create vault directory")?;
160 }
161 std::fs::write(&tmp_path, &encrypted).context("Failed to write vault temp file")?;
162 std::fs::rename(&tmp_path, path).context("Failed to rename vault file")?;
163
164 Ok(())
165 }
166
167 pub fn stats(&self) -> VaultStats {
169 VaultStats {
170 total_mappings: self.forward.len(),
171 categories: self.counters.clone(),
172 }
173 }
174
175 fn category_prefix(category: &EntityCategory) -> String {
176 match category {
177 EntityCategory::Person => "PERSON".into(),
178 EntityCategory::Organization => "ORG".into(),
179 EntityCategory::Location => "LOC".into(),
180 EntityCategory::Amount => "AMOUNT".into(),
181 EntityCategory::Percentage => "PCT".into(),
182 EntityCategory::Date => "DATE".into(),
183 EntityCategory::Email => "EMAIL".into(),
184 EntityCategory::PhoneNumber => "PHONE".into(),
185 EntityCategory::IpAddress => "IP".into(),
186 EntityCategory::Secret => "SECRET".into(),
187 EntityCategory::Url => "URL".into(),
188 EntityCategory::Project => "PROJECT".into(),
189 EntityCategory::Business => "BIZ".into(),
190 EntityCategory::Infra => "INFRA".into(),
191 EntityCategory::Custom(name) => name.to_uppercase(),
192 }
193 }
194
195 fn load(path: &str, key: &[u8]) -> Result<Self> {
196 let encrypted = std::fs::read(path).context("Failed to read vault file")?;
197 let json = Self::decrypt_bytes(key, &encrypted)?;
198 let data: VaultData =
199 serde_json::from_slice(&json).context("Failed to deserialize vault")?;
200
201 let mut forward = HashMap::new();
202 let mut reverse = HashMap::new();
203
204 for (_original_key, stored) in &data.forward {
205 let token = PseudoToken {
206 token: stored.token.clone(),
207 category: stored.category.clone(),
208 id: stored.id,
209 };
210 forward.insert(stored.original.clone(), token.clone());
211 reverse.insert(
212 stored.token.clone(),
213 SensitiveString(stored.original.clone()),
214 );
215 }
216
217 Ok(Self {
218 forward,
219 reverse,
220 counters: data.counters,
221 path: Some(path.to_string()),
222 key: SensitiveBytes(key.to_vec()),
223 })
224 }
225
226 fn to_vault_data(&self) -> VaultData {
227 let forward: Vec<(String, StoredToken)> = self
228 .forward
229 .iter()
230 .map(|(original, token)| {
231 (
232 original.clone(),
233 StoredToken {
234 token: token.token.clone(),
235 category: token.category.clone(),
236 id: token.id,
237 original: original.clone(),
238 },
239 )
240 })
241 .collect();
242
243 VaultData {
244 forward,
245 counters: self.counters.clone(),
246 }
247 }
248
249 fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
251 let cipher =
252 Aes256Gcm::new_from_slice(&self.key.0)
253 .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
254
255 let mut nonce_bytes = [0u8; 12];
256 OsRng.fill_bytes(&mut nonce_bytes);
257 let nonce = Nonce::from_slice(&nonce_bytes);
258
259 let ciphertext = cipher
260 .encrypt(nonce, plaintext)
261 .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
262
263 let mut output = Vec::with_capacity(12 + ciphertext.len());
264 output.extend_from_slice(&nonce_bytes);
265 output.extend_from_slice(&ciphertext);
266 Ok(output)
267 }
268
269 fn decrypt_bytes(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
271 if data.len() < 12 {
272 bail!("Vault data too short — corrupted or wrong format");
273 }
274
275 let cipher = Aes256Gcm::new_from_slice(key)
276 .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
277 let nonce = Nonce::from_slice(&data[..12]);
278 let ciphertext = &data[12..];
279
280 cipher
281 .decrypt(nonce, ciphertext)
282 .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or corrupted vault"))
283 }
284}
285
286#[derive(Debug, Serialize)]
287pub struct VaultStats {
288 pub total_mappings: usize,
289 pub categories: HashMap<String, u32>,
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::EntityCategory;
296
297 fn test_key() -> Vec<u8> {
298 vec![0xAB; 32]
299 }
300
301 #[test]
302 fn test_vault_get_or_create_consistency() {
303 let mut vault = Vault::ephemeral();
304 let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
305 let t2 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
306 assert_eq!(t1.token, t2.token);
307 assert_eq!(t1.token, "ORG_1");
308 }
309
310 #[test]
311 fn test_vault_different_entities() {
312 let mut vault = Vault::ephemeral();
313 let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
314 let t2 = vault.get_or_create("Infosys", &EntityCategory::Organization);
315 assert_ne!(t1.token, t2.token);
316 assert_eq!(t1.token, "ORG_1");
317 assert_eq!(t2.token, "ORG_2");
318 }
319
320 #[test]
321 fn test_vault_lookup() {
322 let mut vault = Vault::ephemeral();
323 vault.get_or_create("secret@example.com", &EntityCategory::Email);
324 assert_eq!(vault.lookup("EMAIL_1"), Some("secret@example.com"));
325 assert_eq!(vault.lookup("NONEXISTENT_99"), None);
326 }
327
328 #[test]
329 fn test_vault_stats() {
330 let mut vault = Vault::ephemeral();
331 vault.get_or_create("Alice", &EntityCategory::Person);
332 vault.get_or_create("Bob", &EntityCategory::Person);
333 vault.get_or_create("Acme Corp", &EntityCategory::Organization);
334 let stats = vault.stats();
335 assert_eq!(stats.total_mappings, 3);
336 assert_eq!(stats.categories.get("PERSON"), Some(&2));
337 assert_eq!(stats.categories.get("ORG"), Some(&1));
338 }
339
340 #[test]
341 fn test_vault_roundtrip_persistence() {
342 let dir = tempfile::tempdir().unwrap();
343 let vault_path = dir.path().join("test.vault");
344 let path_str = vault_path.to_str().unwrap();
345
346 {
348 let mut vault = Vault::open(path_str, test_key()).unwrap();
349 vault.get_or_create("Tata Motors", &EntityCategory::Organization);
350 vault.get_or_create("$1.2M", &EntityCategory::Amount);
351 vault.save().unwrap();
352 }
353
354 {
356 let vault = Vault::open(path_str, test_key()).unwrap();
357 assert_eq!(vault.lookup("ORG_1"), Some("Tata Motors"));
358 assert_eq!(vault.lookup("AMOUNT_1"), Some("$1.2M"));
359 assert_eq!(vault.stats().total_mappings, 2);
360 }
361 }
362
363 #[test]
364 fn test_vault_wrong_key_fails() {
365 let dir = tempfile::tempdir().unwrap();
366 let vault_path = dir.path().join("test.vault");
367 let path_str = vault_path.to_str().unwrap();
368
369 {
371 let mut vault = Vault::open(path_str, test_key()).unwrap();
372 vault.get_or_create("secret", &EntityCategory::Secret);
373 vault.save().unwrap();
374 }
375
376 let wrong_key = vec![0xCD; 32];
378 let result = Vault::open(path_str, wrong_key);
379 assert!(result.is_err());
380 }
381
382 #[test]
383 fn test_vault_invalid_key_length() {
384 let result = Vault::open("/tmp/test.vault", vec![0u8; 16]);
385 assert!(result.is_err());
386 }
387}