1use crate::{EntityCategory, PseudoToken, resolver::EntityResolver};
10use aes_gcm::{
11 aead::{Aead, KeyInit, OsRng},
12 Aes256Gcm, Nonce,
13};
14use anyhow::{bail, Context, Result};
15use rand::RngCore;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use zeroize::Zeroize;
19
20pub struct Vault {
22 forward: HashMap<String, PseudoToken>,
24 reverse: HashMap<String, SensitiveString>,
26 counters: HashMap<String, u32>,
28 path: Option<String>,
30 key: SensitiveBytes,
32 resolver: Option<EntityResolver>,
34}
35
36#[derive(Clone, Serialize, Deserialize)]
38pub struct SensitiveString(String);
39
40impl Drop for SensitiveString {
41 fn drop(&mut self) {
42 self.0.zeroize();
43 }
44}
45
46impl std::fmt::Debug for SensitiveString {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "[REDACTED]")
49 }
50}
51
52pub struct SensitiveBytes(Vec<u8>);
54
55impl Drop for SensitiveBytes {
56 fn drop(&mut self) {
57 self.0.zeroize();
58 }
59}
60
61#[derive(Serialize, Deserialize)]
63struct VaultData {
64 forward: Vec<(String, StoredToken)>,
65 counters: HashMap<String, u32>,
66}
67
68#[derive(Serialize, Deserialize)]
69struct StoredToken {
70 token: String,
71 category: EntityCategory,
72 id: u32,
73 original: String,
74}
75
76impl Vault {
77 pub fn open(path: &str, key: Vec<u8>) -> Result<Self> {
79 if key.len() != 32 {
80 bail!("Vault key must be exactly 32 bytes (AES-256)");
81 }
82 if std::path::Path::new(path).exists() {
83 Self::load(path, &key)
84 } else {
85 Ok(Self {
86 forward: HashMap::new(),
87 reverse: HashMap::new(),
88 counters: HashMap::new(),
89 path: Some(path.to_string()),
90 key: SensitiveBytes(key),
91 resolver: None,
92 })
93 }
94 }
95
96 pub fn ephemeral() -> Self {
98 let mut key = vec![0u8; 32];
99 rand::rngs::OsRng.fill_bytes(&mut key);
100 Self {
101 forward: HashMap::new(),
102 reverse: HashMap::new(),
103 counters: HashMap::new(),
104 path: None,
105 key: SensitiveBytes(key),
106 resolver: None,
107 }
108 }
109
110 pub fn set_resolver(&mut self, resolver: EntityResolver) {
112 self.resolver = Some(resolver);
113 }
114
115 pub fn get_or_create(&mut self, original: &str, category: &EntityCategory) -> PseudoToken {
121 if let Some(token) = self.forward.get(original) {
123 return token.clone();
124 }
125
126 if let Some(ref resolver) = self.resolver {
128 let existing: HashMap<String, EntityCategory> = self
130 .forward
131 .iter()
132 .map(|(k, v)| (k.clone(), v.category.clone()))
133 .collect();
134
135 if let Some(canonical) = resolver.resolve(original, category, &existing) {
136 if let Some(token) = self.forward.get(&canonical) {
137 let token = token.clone();
138 self.forward.insert(original.to_string(), token.clone());
140 return token;
141 }
142 }
143 }
144
145 let prefix = Self::category_prefix(category);
147 let counter = self.counters.entry(prefix.clone()).or_insert(0);
148 *counter += 1;
149
150 let token = PseudoToken {
151 token: format!("{}_{}", prefix, counter),
152 category: category.clone(),
153 id: *counter,
154 };
155
156 self.forward.insert(original.to_string(), token.clone());
157 self.reverse.insert(
158 token.token.clone(),
159 SensitiveString(original.to_string()),
160 );
161
162 token
163 }
164
165 pub fn lookup(&self, token: &str) -> Option<&str> {
167 self.reverse.get(token).map(|s| s.0.as_str())
168 }
169
170 pub fn reverse_mappings(&self) -> HashMap<String, String> {
172 self.reverse
173 .iter()
174 .map(|(k, v)| (k.clone(), v.0.clone()))
175 .collect()
176 }
177
178 pub fn save(&self) -> Result<()> {
180 let path = match &self.path {
181 Some(p) => p,
182 None => return Ok(()), };
184
185 let data = self.to_vault_data();
186 let json = serde_json::to_vec(&data).context("Failed to serialize vault")?;
187
188 let encrypted = self.encrypt(&json)?;
189
190 let tmp_path = format!("{}.tmp", path);
192 if let Some(parent) = std::path::Path::new(path).parent() {
193 std::fs::create_dir_all(parent).context("Failed to create vault directory")?;
194 }
195 std::fs::write(&tmp_path, &encrypted).context("Failed to write vault temp file")?;
196 std::fs::rename(&tmp_path, path).context("Failed to rename vault file")?;
197
198 Ok(())
199 }
200
201 pub fn stats(&self) -> VaultStats {
203 VaultStats {
204 total_mappings: self.forward.len(),
205 categories: self.counters.clone(),
206 }
207 }
208
209 fn category_prefix(category: &EntityCategory) -> String {
210 match category {
211 EntityCategory::Person => "PERSON".into(),
212 EntityCategory::Organization => "ORG".into(),
213 EntityCategory::Location => "LOC".into(),
214 EntityCategory::Amount => "AMOUNT".into(),
215 EntityCategory::Percentage => "PCT".into(),
216 EntityCategory::Date => "DATE".into(),
217 EntityCategory::Email => "EMAIL".into(),
218 EntityCategory::PhoneNumber => "PHONE".into(),
219 EntityCategory::IpAddress => "IP".into(),
220 EntityCategory::Secret => "SECRET".into(),
221 EntityCategory::Url => "URL".into(),
222 EntityCategory::Project => "PROJECT".into(),
223 EntityCategory::Business => "BIZ".into(),
224 EntityCategory::Infra => "INFRA".into(),
225 EntityCategory::Custom(name) => name.to_uppercase(),
226 }
227 }
228
229 fn load(path: &str, key: &[u8]) -> Result<Self> {
230 let encrypted = std::fs::read(path).context("Failed to read vault file")?;
231 let json = Self::decrypt_bytes(key, &encrypted)?;
232 let data: VaultData =
233 serde_json::from_slice(&json).context("Failed to deserialize vault")?;
234
235 let mut forward = HashMap::new();
236 let mut reverse = HashMap::new();
237
238 for (_original_key, stored) in &data.forward {
239 let token = PseudoToken {
240 token: stored.token.clone(),
241 category: stored.category.clone(),
242 id: stored.id,
243 };
244 forward.insert(stored.original.clone(), token.clone());
245 reverse.insert(
246 stored.token.clone(),
247 SensitiveString(stored.original.clone()),
248 );
249 }
250
251 Ok(Self {
252 forward,
253 reverse,
254 counters: data.counters,
255 path: Some(path.to_string()),
256 key: SensitiveBytes(key.to_vec()),
257 resolver: None,
258 })
259 }
260
261 fn to_vault_data(&self) -> VaultData {
262 let forward: Vec<(String, StoredToken)> = self
263 .forward
264 .iter()
265 .map(|(original, token)| {
266 (
267 original.clone(),
268 StoredToken {
269 token: token.token.clone(),
270 category: token.category.clone(),
271 id: token.id,
272 original: original.clone(),
273 },
274 )
275 })
276 .collect();
277
278 VaultData {
279 forward,
280 counters: self.counters.clone(),
281 }
282 }
283
284 fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
286 let cipher =
287 Aes256Gcm::new_from_slice(&self.key.0)
288 .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
289
290 let mut nonce_bytes = [0u8; 12];
291 OsRng.fill_bytes(&mut nonce_bytes);
292 let nonce = Nonce::from_slice(&nonce_bytes);
293
294 let ciphertext = cipher
295 .encrypt(nonce, plaintext)
296 .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
297
298 let mut output = Vec::with_capacity(12 + ciphertext.len());
299 output.extend_from_slice(&nonce_bytes);
300 output.extend_from_slice(&ciphertext);
301 Ok(output)
302 }
303
304 fn decrypt_bytes(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
306 if data.len() < 12 {
307 bail!("Vault data too short — corrupted or wrong format");
308 }
309
310 let cipher = Aes256Gcm::new_from_slice(key)
311 .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
312 let nonce = Nonce::from_slice(&data[..12]);
313 let ciphertext = &data[12..];
314
315 cipher
316 .decrypt(nonce, ciphertext)
317 .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or corrupted vault"))
318 }
319}
320
321#[derive(Debug, Serialize)]
322pub struct VaultStats {
323 pub total_mappings: usize,
324 pub categories: HashMap<String, u32>,
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::EntityCategory;
331
332 fn test_key() -> Vec<u8> {
333 vec![0xAB; 32]
334 }
335
336 #[test]
337 fn test_vault_get_or_create_consistency() {
338 let mut vault = Vault::ephemeral();
339 let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
340 let t2 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
341 assert_eq!(t1.token, t2.token);
342 assert_eq!(t1.token, "ORG_1");
343 }
344
345 #[test]
346 fn test_vault_different_entities() {
347 let mut vault = Vault::ephemeral();
348 let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
349 let t2 = vault.get_or_create("Infosys", &EntityCategory::Organization);
350 assert_ne!(t1.token, t2.token);
351 assert_eq!(t1.token, "ORG_1");
352 assert_eq!(t2.token, "ORG_2");
353 }
354
355 #[test]
356 fn test_vault_lookup() {
357 let mut vault = Vault::ephemeral();
358 vault.get_or_create("secret@example.com", &EntityCategory::Email);
359 assert_eq!(vault.lookup("EMAIL_1"), Some("secret@example.com"));
360 assert_eq!(vault.lookup("NONEXISTENT_99"), None);
361 }
362
363 #[test]
364 fn test_vault_stats() {
365 let mut vault = Vault::ephemeral();
366 vault.get_or_create("Alice", &EntityCategory::Person);
367 vault.get_or_create("Bob", &EntityCategory::Person);
368 vault.get_or_create("Acme Corp", &EntityCategory::Organization);
369 let stats = vault.stats();
370 assert_eq!(stats.total_mappings, 3);
371 assert_eq!(stats.categories.get("PERSON"), Some(&2));
372 assert_eq!(stats.categories.get("ORG"), Some(&1));
373 }
374
375 #[test]
376 fn test_vault_roundtrip_persistence() {
377 let dir = tempfile::tempdir().unwrap();
378 let vault_path = dir.path().join("test.vault");
379 let path_str = vault_path.to_str().unwrap();
380
381 {
383 let mut vault = Vault::open(path_str, test_key()).unwrap();
384 vault.get_or_create("Tata Motors", &EntityCategory::Organization);
385 vault.get_or_create("$1.2M", &EntityCategory::Amount);
386 vault.save().unwrap();
387 }
388
389 {
391 let vault = Vault::open(path_str, test_key()).unwrap();
392 assert_eq!(vault.lookup("ORG_1"), Some("Tata Motors"));
393 assert_eq!(vault.lookup("AMOUNT_1"), Some("$1.2M"));
394 assert_eq!(vault.stats().total_mappings, 2);
395 }
396 }
397
398 #[test]
399 fn test_vault_wrong_key_fails() {
400 let dir = tempfile::tempdir().unwrap();
401 let vault_path = dir.path().join("test.vault");
402 let path_str = vault_path.to_str().unwrap();
403
404 {
406 let mut vault = Vault::open(path_str, test_key()).unwrap();
407 vault.get_or_create("secret", &EntityCategory::Secret);
408 vault.save().unwrap();
409 }
410
411 let wrong_key = vec![0xCD; 32];
413 let result = Vault::open(path_str, wrong_key);
414 assert!(result.is_err());
415 }
416
417 #[test]
418 fn test_vault_invalid_key_length() {
419 let result = Vault::open("/tmp/test.vault", vec![0u8; 16]);
420 assert!(result.is_err());
421 }
422}