Skip to main content

cloakpipe_core/
vault.rs

1//! Encrypted mapping vault — stores entity<->pseudo-token mappings.
2//!
3//! Security properties:
4//! - AES-256-GCM encrypted at rest
5//! - `zeroize` on all in-memory sensitive values when dropped
6//! - Persistent across sessions for consistent pseudonymization
7//! - Atomic file writes (write to .tmp, rename)
8
9use crate::{EntityCategory, PseudoToken};
10use aes_gcm::{
11    aead::{Aead, KeyInit, OsRng},
12    Aes256Gcm, Nonce,
13};
14use anyhow::{bail, Context, Result};
15use rand::RngCore;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use zeroize::Zeroize;
19
20/// The mapping vault that maintains entity<->token consistency.
21pub struct Vault {
22    /// Forward map: original value -> pseudo-token
23    forward: HashMap<String, PseudoToken>,
24    /// Reverse map: pseudo-token string -> original value
25    reverse: HashMap<String, SensitiveString>,
26    /// Next ID counter per category
27    counters: HashMap<String, u32>,
28    /// File path for persistence (None = ephemeral/in-memory only)
29    path: Option<String>,
30    /// Encryption key (zeroized on drop)
31    key: SensitiveBytes,
32}
33
34/// A string that is zeroized from memory when dropped.
35#[derive(Clone, Serialize, Deserialize)]
36pub struct SensitiveString(String);
37
38impl Drop for SensitiveString {
39    fn drop(&mut self) {
40        self.0.zeroize();
41    }
42}
43
44impl std::fmt::Debug for SensitiveString {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "[REDACTED]")
47    }
48}
49
50/// Bytes that are zeroized from memory when dropped.
51pub struct SensitiveBytes(Vec<u8>);
52
53impl Drop for SensitiveBytes {
54    fn drop(&mut self) {
55        self.0.zeroize();
56    }
57}
58
59/// Serializable vault data for persistence.
60#[derive(Serialize, Deserialize)]
61struct VaultData {
62    forward: Vec<(String, StoredToken)>,
63    counters: HashMap<String, u32>,
64}
65
66#[derive(Serialize, Deserialize)]
67struct StoredToken {
68    token: String,
69    category: EntityCategory,
70    id: u32,
71    original: String,
72}
73
74impl Vault {
75    /// Create or load a vault from the given path.
76    pub fn open(path: &str, key: Vec<u8>) -> Result<Self> {
77        if key.len() != 32 {
78            bail!("Vault key must be exactly 32 bytes (AES-256)");
79        }
80        if std::path::Path::new(path).exists() {
81            Self::load(path, &key)
82        } else {
83            Ok(Self {
84                forward: HashMap::new(),
85                reverse: HashMap::new(),
86                counters: HashMap::new(),
87                path: Some(path.to_string()),
88                key: SensitiveBytes(key),
89            })
90        }
91    }
92
93    /// Create an ephemeral (in-memory only) vault for testing.
94    pub fn ephemeral() -> Self {
95        let mut key = vec![0u8; 32];
96        rand::rngs::OsRng.fill_bytes(&mut key);
97        Self {
98            forward: HashMap::new(),
99            reverse: HashMap::new(),
100            counters: HashMap::new(),
101            path: None,
102            key: SensitiveBytes(key),
103        }
104    }
105
106    /// Get or create a pseudo-token for the given original value.
107    pub fn get_or_create(&mut self, original: &str, category: &EntityCategory) -> PseudoToken {
108        if let Some(token) = self.forward.get(original) {
109            return token.clone();
110        }
111
112        let prefix = Self::category_prefix(category);
113        let counter = self.counters.entry(prefix.clone()).or_insert(0);
114        *counter += 1;
115
116        let token = PseudoToken {
117            token: format!("{}_{}", prefix, counter),
118            category: category.clone(),
119            id: *counter,
120        };
121
122        self.forward.insert(original.to_string(), token.clone());
123        self.reverse.insert(
124            token.token.clone(),
125            SensitiveString(original.to_string()),
126        );
127
128        token
129    }
130
131    /// Look up the original value for a pseudo-token (for rehydration).
132    pub fn lookup(&self, token: &str) -> Option<&str> {
133        self.reverse.get(token).map(|s| s.0.as_str())
134    }
135
136    /// Get all reverse mappings (for rehydration).
137    pub fn reverse_mappings(&self) -> HashMap<String, String> {
138        self.reverse
139            .iter()
140            .map(|(k, v)| (k.clone(), v.0.clone()))
141            .collect()
142    }
143
144    /// Save the vault to disk (AES-256-GCM encrypted).
145    pub fn save(&self) -> Result<()> {
146        let path = match &self.path {
147            Some(p) => p,
148            None => return Ok(()), // ephemeral vault, nothing to save
149        };
150
151        let data = self.to_vault_data();
152        let json = serde_json::to_vec(&data).context("Failed to serialize vault")?;
153
154        let encrypted = self.encrypt(&json)?;
155
156        // Atomic write: write to .tmp, then rename
157        let tmp_path = format!("{}.tmp", path);
158        if let Some(parent) = std::path::Path::new(path).parent() {
159            std::fs::create_dir_all(parent).context("Failed to create vault directory")?;
160        }
161        std::fs::write(&tmp_path, &encrypted).context("Failed to write vault temp file")?;
162        std::fs::rename(&tmp_path, path).context("Failed to rename vault file")?;
163
164        Ok(())
165    }
166
167    /// Get vault statistics (safe to expose — no sensitive data).
168    pub fn stats(&self) -> VaultStats {
169        VaultStats {
170            total_mappings: self.forward.len(),
171            categories: self.counters.clone(),
172        }
173    }
174
175    fn category_prefix(category: &EntityCategory) -> String {
176        match category {
177            EntityCategory::Person => "PERSON".into(),
178            EntityCategory::Organization => "ORG".into(),
179            EntityCategory::Location => "LOC".into(),
180            EntityCategory::Amount => "AMOUNT".into(),
181            EntityCategory::Percentage => "PCT".into(),
182            EntityCategory::Date => "DATE".into(),
183            EntityCategory::Email => "EMAIL".into(),
184            EntityCategory::PhoneNumber => "PHONE".into(),
185            EntityCategory::IpAddress => "IP".into(),
186            EntityCategory::Secret => "SECRET".into(),
187            EntityCategory::Url => "URL".into(),
188            EntityCategory::Project => "PROJECT".into(),
189            EntityCategory::Business => "BIZ".into(),
190            EntityCategory::Infra => "INFRA".into(),
191            EntityCategory::Custom(name) => name.to_uppercase(),
192        }
193    }
194
195    fn load(path: &str, key: &[u8]) -> Result<Self> {
196        let encrypted = std::fs::read(path).context("Failed to read vault file")?;
197        let json = Self::decrypt_bytes(key, &encrypted)?;
198        let data: VaultData =
199            serde_json::from_slice(&json).context("Failed to deserialize vault")?;
200
201        let mut forward = HashMap::new();
202        let mut reverse = HashMap::new();
203
204        for (_original_key, stored) in &data.forward {
205            let token = PseudoToken {
206                token: stored.token.clone(),
207                category: stored.category.clone(),
208                id: stored.id,
209            };
210            forward.insert(stored.original.clone(), token.clone());
211            reverse.insert(
212                stored.token.clone(),
213                SensitiveString(stored.original.clone()),
214            );
215        }
216
217        Ok(Self {
218            forward,
219            reverse,
220            counters: data.counters,
221            path: Some(path.to_string()),
222            key: SensitiveBytes(key.to_vec()),
223        })
224    }
225
226    fn to_vault_data(&self) -> VaultData {
227        let forward: Vec<(String, StoredToken)> = self
228            .forward
229            .iter()
230            .map(|(original, token)| {
231                (
232                    original.clone(),
233                    StoredToken {
234                        token: token.token.clone(),
235                        category: token.category.clone(),
236                        id: token.id,
237                        original: original.clone(),
238                    },
239                )
240            })
241            .collect();
242
243        VaultData {
244            forward,
245            counters: self.counters.clone(),
246        }
247    }
248
249    /// Encrypt plaintext with AES-256-GCM. Output: 12-byte nonce || ciphertext.
250    fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
251        let cipher =
252            Aes256Gcm::new_from_slice(&self.key.0)
253                .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
254
255        let mut nonce_bytes = [0u8; 12];
256        OsRng.fill_bytes(&mut nonce_bytes);
257        let nonce = Nonce::from_slice(&nonce_bytes);
258
259        let ciphertext = cipher
260            .encrypt(nonce, plaintext)
261            .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
262
263        let mut output = Vec::with_capacity(12 + ciphertext.len());
264        output.extend_from_slice(&nonce_bytes);
265        output.extend_from_slice(&ciphertext);
266        Ok(output)
267    }
268
269    /// Decrypt ciphertext (12-byte nonce || ciphertext) with AES-256-GCM.
270    fn decrypt_bytes(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
271        if data.len() < 12 {
272            bail!("Vault data too short — corrupted or wrong format");
273        }
274
275        let cipher = Aes256Gcm::new_from_slice(key)
276            .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
277        let nonce = Nonce::from_slice(&data[..12]);
278        let ciphertext = &data[12..];
279
280        cipher
281            .decrypt(nonce, ciphertext)
282            .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or corrupted vault"))
283    }
284}
285
286#[derive(Debug, Serialize)]
287pub struct VaultStats {
288    pub total_mappings: usize,
289    pub categories: HashMap<String, u32>,
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::EntityCategory;
296
297    fn test_key() -> Vec<u8> {
298        vec![0xAB; 32]
299    }
300
301    #[test]
302    fn test_vault_get_or_create_consistency() {
303        let mut vault = Vault::ephemeral();
304        let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
305        let t2 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
306        assert_eq!(t1.token, t2.token);
307        assert_eq!(t1.token, "ORG_1");
308    }
309
310    #[test]
311    fn test_vault_different_entities() {
312        let mut vault = Vault::ephemeral();
313        let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
314        let t2 = vault.get_or_create("Infosys", &EntityCategory::Organization);
315        assert_ne!(t1.token, t2.token);
316        assert_eq!(t1.token, "ORG_1");
317        assert_eq!(t2.token, "ORG_2");
318    }
319
320    #[test]
321    fn test_vault_lookup() {
322        let mut vault = Vault::ephemeral();
323        vault.get_or_create("secret@example.com", &EntityCategory::Email);
324        assert_eq!(vault.lookup("EMAIL_1"), Some("secret@example.com"));
325        assert_eq!(vault.lookup("NONEXISTENT_99"), None);
326    }
327
328    #[test]
329    fn test_vault_stats() {
330        let mut vault = Vault::ephemeral();
331        vault.get_or_create("Alice", &EntityCategory::Person);
332        vault.get_or_create("Bob", &EntityCategory::Person);
333        vault.get_or_create("Acme Corp", &EntityCategory::Organization);
334        let stats = vault.stats();
335        assert_eq!(stats.total_mappings, 3);
336        assert_eq!(stats.categories.get("PERSON"), Some(&2));
337        assert_eq!(stats.categories.get("ORG"), Some(&1));
338    }
339
340    #[test]
341    fn test_vault_roundtrip_persistence() {
342        let dir = tempfile::tempdir().unwrap();
343        let vault_path = dir.path().join("test.vault");
344        let path_str = vault_path.to_str().unwrap();
345
346        // Create vault and add mappings
347        {
348            let mut vault = Vault::open(path_str, test_key()).unwrap();
349            vault.get_or_create("Tata Motors", &EntityCategory::Organization);
350            vault.get_or_create("$1.2M", &EntityCategory::Amount);
351            vault.save().unwrap();
352        }
353
354        // Load vault and verify mappings persisted
355        {
356            let vault = Vault::open(path_str, test_key()).unwrap();
357            assert_eq!(vault.lookup("ORG_1"), Some("Tata Motors"));
358            assert_eq!(vault.lookup("AMOUNT_1"), Some("$1.2M"));
359            assert_eq!(vault.stats().total_mappings, 2);
360        }
361    }
362
363    #[test]
364    fn test_vault_wrong_key_fails() {
365        let dir = tempfile::tempdir().unwrap();
366        let vault_path = dir.path().join("test.vault");
367        let path_str = vault_path.to_str().unwrap();
368
369        // Create with one key
370        {
371            let mut vault = Vault::open(path_str, test_key()).unwrap();
372            vault.get_or_create("secret", &EntityCategory::Secret);
373            vault.save().unwrap();
374        }
375
376        // Try to open with wrong key
377        let wrong_key = vec![0xCD; 32];
378        let result = Vault::open(path_str, wrong_key);
379        assert!(result.is_err());
380    }
381
382    #[test]
383    fn test_vault_invalid_key_length() {
384        let result = Vault::open("/tmp/test.vault", vec![0u8; 16]);
385        assert!(result.is_err());
386    }
387}