Skip to main content

cloakpipe_core/
vault.rs

1//! Encrypted mapping vault — stores entity<->pseudo-token mappings.
2//!
3//! Security properties:
4//! - AES-256-GCM encrypted at rest
5//! - `zeroize` on all in-memory sensitive values when dropped
6//! - Persistent across sessions for consistent pseudonymization
7//! - Atomic file writes (write to .tmp, rename)
8
9use crate::{EntityCategory, PseudoToken, resolver::EntityResolver};
10use aes_gcm::{
11    aead::{Aead, KeyInit, OsRng},
12    Aes256Gcm, Nonce,
13};
14use anyhow::{bail, Context, Result};
15use rand::RngCore;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use zeroize::Zeroize;
19
20/// The mapping vault that maintains entity<->token consistency.
21pub struct Vault {
22    /// Forward map: original value -> pseudo-token
23    forward: HashMap<String, PseudoToken>,
24    /// Reverse map: pseudo-token string -> original value
25    reverse: HashMap<String, SensitiveString>,
26    /// Next ID counter per category
27    counters: HashMap<String, u32>,
28    /// File path for persistence (None = ephemeral/in-memory only)
29    path: Option<String>,
30    /// Encryption key (zeroized on drop)
31    key: SensitiveBytes,
32    /// Optional fuzzy entity resolver for merging variant spellings
33    resolver: Option<EntityResolver>,
34}
35
36/// A string that is zeroized from memory when dropped.
37#[derive(Clone, Serialize, Deserialize)]
38pub struct SensitiveString(String);
39
40impl Drop for SensitiveString {
41    fn drop(&mut self) {
42        self.0.zeroize();
43    }
44}
45
46impl std::fmt::Debug for SensitiveString {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "[REDACTED]")
49    }
50}
51
52/// Bytes that are zeroized from memory when dropped.
53pub struct SensitiveBytes(Vec<u8>);
54
55impl Drop for SensitiveBytes {
56    fn drop(&mut self) {
57        self.0.zeroize();
58    }
59}
60
61/// Serializable vault data for persistence.
62#[derive(Serialize, Deserialize)]
63struct VaultData {
64    forward: Vec<(String, StoredToken)>,
65    counters: HashMap<String, u32>,
66}
67
68#[derive(Serialize, Deserialize)]
69struct StoredToken {
70    token: String,
71    category: EntityCategory,
72    id: u32,
73    original: String,
74}
75
76impl Vault {
77    /// Create or load a vault from the given path.
78    pub fn open(path: &str, key: Vec<u8>) -> Result<Self> {
79        if key.len() != 32 {
80            bail!("Vault key must be exactly 32 bytes (AES-256)");
81        }
82        if std::path::Path::new(path).exists() {
83            Self::load(path, &key)
84        } else {
85            Ok(Self {
86                forward: HashMap::new(),
87                reverse: HashMap::new(),
88                counters: HashMap::new(),
89                path: Some(path.to_string()),
90                key: SensitiveBytes(key),
91                resolver: None,
92            })
93        }
94    }
95
96    /// Create an ephemeral (in-memory only) vault for testing.
97    pub fn ephemeral() -> Self {
98        let mut key = vec![0u8; 32];
99        rand::rngs::OsRng.fill_bytes(&mut key);
100        Self {
101            forward: HashMap::new(),
102            reverse: HashMap::new(),
103            counters: HashMap::new(),
104            path: None,
105            key: SensitiveBytes(key),
106            resolver: None,
107        }
108    }
109
110    /// Configure the fuzzy entity resolver on this vault.
111    pub fn set_resolver(&mut self, resolver: EntityResolver) {
112        self.resolver = Some(resolver);
113    }
114
115    /// Get or create a pseudo-token for the given original value.
116    ///
117    /// If a fuzzy entity resolver is configured, checks for similar existing
118    /// entries before creating a new token. This merges variant spellings
119    /// (e.g., "Rishi" and "Rishikesh") to the same token.
120    pub fn get_or_create(&mut self, original: &str, category: &EntityCategory) -> PseudoToken {
121        // 1. Exact match (existing behavior)
122        if let Some(token) = self.forward.get(original) {
123            return token.clone();
124        }
125
126        // 2. Fuzzy match via resolver (if configured)
127        if let Some(ref resolver) = self.resolver {
128            // Build a category map of existing entries for the resolver
129            let existing: HashMap<String, EntityCategory> = self
130                .forward
131                .iter()
132                .map(|(k, v)| (k.clone(), v.category.clone()))
133                .collect();
134
135            if let Some(canonical) = resolver.resolve(original, category, &existing) {
136                if let Some(token) = self.forward.get(&canonical) {
137                    let token = token.clone();
138                    // Store this variant as an alias → same token
139                    self.forward.insert(original.to_string(), token.clone());
140                    return token;
141                }
142            }
143        }
144
145        // 3. No match — create new token
146        let prefix = Self::category_prefix(category);
147        let counter = self.counters.entry(prefix.clone()).or_insert(0);
148        *counter += 1;
149
150        let token = PseudoToken {
151            token: format!("{}_{}", prefix, counter),
152            category: category.clone(),
153            id: *counter,
154        };
155
156        self.forward.insert(original.to_string(), token.clone());
157        self.reverse.insert(
158            token.token.clone(),
159            SensitiveString(original.to_string()),
160        );
161
162        token
163    }
164
165    /// Look up the original value for a pseudo-token (for rehydration).
166    pub fn lookup(&self, token: &str) -> Option<&str> {
167        self.reverse.get(token).map(|s| s.0.as_str())
168    }
169
170    /// Get all reverse mappings (for rehydration).
171    pub fn reverse_mappings(&self) -> HashMap<String, String> {
172        self.reverse
173            .iter()
174            .map(|(k, v)| (k.clone(), v.0.clone()))
175            .collect()
176    }
177
178    /// Save the vault to disk (AES-256-GCM encrypted).
179    pub fn save(&self) -> Result<()> {
180        let path = match &self.path {
181            Some(p) => p,
182            None => return Ok(()), // ephemeral vault, nothing to save
183        };
184
185        let data = self.to_vault_data();
186        let json = serde_json::to_vec(&data).context("Failed to serialize vault")?;
187
188        let encrypted = self.encrypt(&json)?;
189
190        // Atomic write: write to .tmp, then rename
191        let tmp_path = format!("{}.tmp", path);
192        if let Some(parent) = std::path::Path::new(path).parent() {
193            std::fs::create_dir_all(parent).context("Failed to create vault directory")?;
194        }
195        std::fs::write(&tmp_path, &encrypted).context("Failed to write vault temp file")?;
196        std::fs::rename(&tmp_path, path).context("Failed to rename vault file")?;
197
198        Ok(())
199    }
200
201    /// Get vault statistics (safe to expose — no sensitive data).
202    pub fn stats(&self) -> VaultStats {
203        VaultStats {
204            total_mappings: self.forward.len(),
205            categories: self.counters.clone(),
206        }
207    }
208
209    fn category_prefix(category: &EntityCategory) -> String {
210        match category {
211            EntityCategory::Person => "PERSON".into(),
212            EntityCategory::Organization => "ORG".into(),
213            EntityCategory::Location => "LOC".into(),
214            EntityCategory::Amount => "AMOUNT".into(),
215            EntityCategory::Percentage => "PCT".into(),
216            EntityCategory::Date => "DATE".into(),
217            EntityCategory::Email => "EMAIL".into(),
218            EntityCategory::PhoneNumber => "PHONE".into(),
219            EntityCategory::IpAddress => "IP".into(),
220            EntityCategory::Secret => "SECRET".into(),
221            EntityCategory::Url => "URL".into(),
222            EntityCategory::Project => "PROJECT".into(),
223            EntityCategory::Business => "BIZ".into(),
224            EntityCategory::Infra => "INFRA".into(),
225            EntityCategory::Custom(name) => name.to_uppercase(),
226        }
227    }
228
229    fn load(path: &str, key: &[u8]) -> Result<Self> {
230        let encrypted = std::fs::read(path).context("Failed to read vault file")?;
231        let json = Self::decrypt_bytes(key, &encrypted)?;
232        let data: VaultData =
233            serde_json::from_slice(&json).context("Failed to deserialize vault")?;
234
235        let mut forward = HashMap::new();
236        let mut reverse = HashMap::new();
237
238        for (_original_key, stored) in &data.forward {
239            let token = PseudoToken {
240                token: stored.token.clone(),
241                category: stored.category.clone(),
242                id: stored.id,
243            };
244            forward.insert(stored.original.clone(), token.clone());
245            reverse.insert(
246                stored.token.clone(),
247                SensitiveString(stored.original.clone()),
248            );
249        }
250
251        Ok(Self {
252            forward,
253            reverse,
254            counters: data.counters,
255            path: Some(path.to_string()),
256            key: SensitiveBytes(key.to_vec()),
257            resolver: None,
258        })
259    }
260
261    fn to_vault_data(&self) -> VaultData {
262        let forward: Vec<(String, StoredToken)> = self
263            .forward
264            .iter()
265            .map(|(original, token)| {
266                (
267                    original.clone(),
268                    StoredToken {
269                        token: token.token.clone(),
270                        category: token.category.clone(),
271                        id: token.id,
272                        original: original.clone(),
273                    },
274                )
275            })
276            .collect();
277
278        VaultData {
279            forward,
280            counters: self.counters.clone(),
281        }
282    }
283
284    /// Encrypt plaintext with AES-256-GCM. Output: 12-byte nonce || ciphertext.
285    fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
286        let cipher =
287            Aes256Gcm::new_from_slice(&self.key.0)
288                .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
289
290        let mut nonce_bytes = [0u8; 12];
291        OsRng.fill_bytes(&mut nonce_bytes);
292        let nonce = Nonce::from_slice(&nonce_bytes);
293
294        let ciphertext = cipher
295            .encrypt(nonce, plaintext)
296            .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
297
298        let mut output = Vec::with_capacity(12 + ciphertext.len());
299        output.extend_from_slice(&nonce_bytes);
300        output.extend_from_slice(&ciphertext);
301        Ok(output)
302    }
303
304    /// Decrypt ciphertext (12-byte nonce || ciphertext) with AES-256-GCM.
305    fn decrypt_bytes(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
306        if data.len() < 12 {
307            bail!("Vault data too short — corrupted or wrong format");
308        }
309
310        let cipher = Aes256Gcm::new_from_slice(key)
311            .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
312        let nonce = Nonce::from_slice(&data[..12]);
313        let ciphertext = &data[12..];
314
315        cipher
316            .decrypt(nonce, ciphertext)
317            .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or corrupted vault"))
318    }
319}
320
321#[derive(Debug, Serialize)]
322pub struct VaultStats {
323    pub total_mappings: usize,
324    pub categories: HashMap<String, u32>,
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::EntityCategory;
331
332    fn test_key() -> Vec<u8> {
333        vec![0xAB; 32]
334    }
335
336    #[test]
337    fn test_vault_get_or_create_consistency() {
338        let mut vault = Vault::ephemeral();
339        let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
340        let t2 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
341        assert_eq!(t1.token, t2.token);
342        assert_eq!(t1.token, "ORG_1");
343    }
344
345    #[test]
346    fn test_vault_different_entities() {
347        let mut vault = Vault::ephemeral();
348        let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
349        let t2 = vault.get_or_create("Infosys", &EntityCategory::Organization);
350        assert_ne!(t1.token, t2.token);
351        assert_eq!(t1.token, "ORG_1");
352        assert_eq!(t2.token, "ORG_2");
353    }
354
355    #[test]
356    fn test_vault_lookup() {
357        let mut vault = Vault::ephemeral();
358        vault.get_or_create("secret@example.com", &EntityCategory::Email);
359        assert_eq!(vault.lookup("EMAIL_1"), Some("secret@example.com"));
360        assert_eq!(vault.lookup("NONEXISTENT_99"), None);
361    }
362
363    #[test]
364    fn test_vault_stats() {
365        let mut vault = Vault::ephemeral();
366        vault.get_or_create("Alice", &EntityCategory::Person);
367        vault.get_or_create("Bob", &EntityCategory::Person);
368        vault.get_or_create("Acme Corp", &EntityCategory::Organization);
369        let stats = vault.stats();
370        assert_eq!(stats.total_mappings, 3);
371        assert_eq!(stats.categories.get("PERSON"), Some(&2));
372        assert_eq!(stats.categories.get("ORG"), Some(&1));
373    }
374
375    #[test]
376    fn test_vault_roundtrip_persistence() {
377        let dir = tempfile::tempdir().unwrap();
378        let vault_path = dir.path().join("test.vault");
379        let path_str = vault_path.to_str().unwrap();
380
381        // Create vault and add mappings
382        {
383            let mut vault = Vault::open(path_str, test_key()).unwrap();
384            vault.get_or_create("Tata Motors", &EntityCategory::Organization);
385            vault.get_or_create("$1.2M", &EntityCategory::Amount);
386            vault.save().unwrap();
387        }
388
389        // Load vault and verify mappings persisted
390        {
391            let vault = Vault::open(path_str, test_key()).unwrap();
392            assert_eq!(vault.lookup("ORG_1"), Some("Tata Motors"));
393            assert_eq!(vault.lookup("AMOUNT_1"), Some("$1.2M"));
394            assert_eq!(vault.stats().total_mappings, 2);
395        }
396    }
397
398    #[test]
399    fn test_vault_wrong_key_fails() {
400        let dir = tempfile::tempdir().unwrap();
401        let vault_path = dir.path().join("test.vault");
402        let path_str = vault_path.to_str().unwrap();
403
404        // Create with one key
405        {
406            let mut vault = Vault::open(path_str, test_key()).unwrap();
407            vault.get_or_create("secret", &EntityCategory::Secret);
408            vault.save().unwrap();
409        }
410
411        // Try to open with wrong key
412        let wrong_key = vec![0xCD; 32];
413        let result = Vault::open(path_str, wrong_key);
414        assert!(result.is_err());
415    }
416
417    #[test]
418    fn test_vault_invalid_key_length() {
419        let result = Vault::open("/tmp/test.vault", vec![0u8; 16]);
420        assert!(result.is_err());
421    }
422}