Skip to main content

cloakpipe_core/
vault_sqlite.rs

1//! SQLite-backed vault — persistent entity<->pseudo-token mappings.
2//!
3//! Replaces the file-based vault with SQLite for:
4//! - Better concurrent access and crash safety (WAL mode)
5//! - Per-row AES-256-GCM encryption of sensitive values
6//! - Multi-user support via user_id scoping
7//! - Efficient lookups without loading entire vault into memory
8
9use crate::{EntityCategory, PseudoToken};
10use aes_gcm::{
11    aead::{Aead, KeyInit, OsRng},
12    Aes256Gcm, Nonce,
13};
14use anyhow::{bail, Context, Result};
15use rand::RngCore;
16use rusqlite::{params, Connection};
17use std::collections::HashMap;
18use zeroize::Zeroize;
19
20/// SQLite-backed vault with per-value encryption.
21pub struct SqliteVault {
22    conn: Connection,
23    cipher: Aes256Gcm,
24    /// In-memory caches for hot path performance
25    forward_cache: HashMap<String, PseudoToken>,
26    reverse_cache: HashMap<String, String>,
27}
28
29/// Bytes that are zeroized from memory when dropped.
30struct SensitiveBytes(Vec<u8>);
31
32impl Drop for SensitiveBytes {
33    fn drop(&mut self) {
34        self.0.zeroize();
35    }
36}
37
38impl SqliteVault {
39    /// Open or create a SQLite vault at the given path.
40    pub fn open(path: &str, key: Vec<u8>) -> Result<Self> {
41        if key.len() != 32 {
42            bail!("Vault key must be exactly 32 bytes (AES-256)");
43        }
44
45        let conn = Connection::open(path)
46            .with_context(|| format!("Failed to open vault database: {}", path))?;
47
48        // WAL mode for better concurrent read performance
49        conn.pragma_update(None, "journal_mode", "WAL")?;
50        conn.pragma_update(None, "foreign_keys", "ON")?;
51
52        Self::init_schema(&conn)?;
53
54        let cipher = Aes256Gcm::new_from_slice(&key)
55            .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
56
57        let _key_guard = SensitiveBytes(key);
58
59        let mut vault = Self {
60            conn,
61            cipher,
62            forward_cache: HashMap::new(),
63            reverse_cache: HashMap::new(),
64        };
65
66        vault.load_cache()?;
67        Ok(vault)
68    }
69
70    /// Create an ephemeral (in-memory) vault for testing.
71    pub fn ephemeral() -> Self {
72        let mut key = vec![0u8; 32];
73        OsRng.fill_bytes(&mut key);
74
75        let conn = Connection::open_in_memory().expect("Failed to open in-memory SQLite");
76        Self::init_schema(&conn).expect("Failed to init schema");
77
78        let cipher = Aes256Gcm::new_from_slice(&key)
79            .expect("Invalid key");
80
81        let _key_guard = SensitiveBytes(key);
82
83        Self {
84            conn,
85            cipher,
86            forward_cache: HashMap::new(),
87            reverse_cache: HashMap::new(),
88        }
89    }
90
91    fn init_schema(conn: &Connection) -> Result<()> {
92        conn.execute_batch(
93            "CREATE TABLE IF NOT EXISTS mappings (
94                id INTEGER PRIMARY KEY AUTOINCREMENT,
95                original_enc BLOB NOT NULL,
96                token TEXT NOT NULL UNIQUE,
97                category TEXT NOT NULL,
98                token_id INTEGER NOT NULL,
99                user_id TEXT,
100                created_at TEXT NOT NULL DEFAULT (datetime('now'))
101            );
102            CREATE INDEX IF NOT EXISTS idx_mappings_token ON mappings(token);
103            CREATE INDEX IF NOT EXISTS idx_mappings_user ON mappings(user_id);
104
105            CREATE TABLE IF NOT EXISTS counters (
106                category TEXT NOT NULL,
107                user_id TEXT,
108                counter INTEGER NOT NULL DEFAULT 0,
109                PRIMARY KEY (category, user_id)
110            );"
111        ).context("Failed to initialize vault schema")?;
112        Ok(())
113    }
114
115    /// Load all mappings into the in-memory cache.
116    fn load_cache(&mut self) -> Result<()> {
117        let mut stmt = self.conn.prepare(
118            "SELECT original_enc, token, category, token_id FROM mappings"
119        )?;
120
121        let rows = stmt.query_map([], |row| {
122            let enc: Vec<u8> = row.get(0)?;
123            let token: String = row.get(1)?;
124            let category: String = row.get(2)?;
125            let token_id: u32 = row.get(3)?;
126            Ok((enc, token, category, token_id))
127        })?;
128
129        for row in rows {
130            let (enc, token, category_str, token_id) = row?;
131            let original = self.decrypt_value(&enc)
132                .unwrap_or_else(|_| "[decrypt_failed]".to_string());
133            let category = Self::parse_category(&category_str);
134
135            let pseudo = PseudoToken {
136                token: token.clone(),
137                category,
138                id: token_id,
139            };
140
141            self.forward_cache.insert(original.clone(), pseudo);
142            self.reverse_cache.insert(token, original);
143        }
144
145        Ok(())
146    }
147
148    /// Get or create a pseudo-token for the given original value.
149    pub fn get_or_create(&mut self, original: &str, category: &EntityCategory) -> PseudoToken {
150        self.get_or_create_for_user(original, category, None)
151    }
152
153    /// Get or create a pseudo-token scoped to a user.
154    pub fn get_or_create_for_user(
155        &mut self,
156        original: &str,
157        category: &EntityCategory,
158        user_id: Option<&str>,
159    ) -> PseudoToken {
160        // Check cache first
161        if let Some(token) = self.forward_cache.get(original) {
162            return token.clone();
163        }
164
165        let prefix = Self::category_prefix(category);
166        let user_key = user_id.unwrap_or("");
167
168        // Increment counter
169        self.conn.execute(
170            "INSERT INTO counters (category, user_id, counter) VALUES (?1, ?2, 1)
171             ON CONFLICT(category, user_id) DO UPDATE SET counter = counter + 1",
172            params![prefix, user_key],
173        ).expect("Failed to update counter");
174
175        let counter: u32 = self.conn.query_row(
176            "SELECT counter FROM counters WHERE category = ?1 AND user_id = ?2",
177            params![prefix, user_key],
178            |row| row.get(0),
179        ).expect("Failed to read counter");
180
181        // Include user_id in token to avoid collisions across users
182        let token_str = if user_key.is_empty() {
183            format!("{}_{}", prefix, counter)
184        } else {
185            format!("{}_{}_{}", prefix, user_key, counter)
186        };
187
188        let token = PseudoToken {
189            token: token_str,
190            category: category.clone(),
191            id: counter,
192        };
193
194        // Encrypt and store
195        let encrypted = self.encrypt_value(original)
196            .expect("Failed to encrypt value");
197
198        self.conn.execute(
199            "INSERT INTO mappings (original_enc, token, category, token_id, user_id) VALUES (?1, ?2, ?3, ?4, ?5)",
200            params![encrypted, token.token, prefix, counter, user_id],
201        ).expect("Failed to insert mapping");
202
203        // Update caches
204        self.forward_cache.insert(original.to_string(), token.clone());
205        self.reverse_cache.insert(token.token.clone(), original.to_string());
206
207        token
208    }
209
210    /// Look up the original value for a pseudo-token.
211    pub fn lookup(&self, token: &str) -> Option<&str> {
212        self.reverse_cache.get(token).map(|s| s.as_str())
213    }
214
215    /// Get all reverse mappings.
216    pub fn reverse_mappings(&self) -> HashMap<String, String> {
217        self.reverse_cache.clone()
218    }
219
220    /// Save is a no-op for SQLite (writes are immediate).
221    pub fn save(&self) -> Result<()> {
222        Ok(())
223    }
224
225    /// Get vault statistics.
226    pub fn stats(&self) -> VaultStats {
227        let mut categories = HashMap::new();
228
229        let mut stmt = self.conn.prepare(
230            "SELECT category, counter FROM counters WHERE user_id = ''"
231        ).expect("Failed to prepare stats query");
232
233        let rows = stmt.query_map([], |row| {
234            let cat: String = row.get(0)?;
235            let count: u32 = row.get(1)?;
236            Ok((cat, count))
237        }).expect("Failed to query stats");
238
239        for (cat, count) in rows.flatten() {
240            categories.insert(cat, count);
241        }
242
243        VaultStats {
244            total_mappings: self.forward_cache.len(),
245            categories,
246        }
247    }
248
249    /// Get mappings for a specific user.
250    pub fn user_stats(&self, user_id: &str) -> Result<VaultStats> {
251        let count: usize = self.conn.query_row(
252            "SELECT COUNT(*) FROM mappings WHERE user_id = ?1",
253            params![user_id],
254            |row| row.get(0),
255        )?;
256
257        let mut categories = HashMap::new();
258        let mut stmt = self.conn.prepare(
259            "SELECT category, counter FROM counters WHERE user_id = ?1"
260        )?;
261
262        let rows = stmt.query_map(params![user_id], |row| {
263            let cat: String = row.get(0)?;
264            let cnt: u32 = row.get(1)?;
265            Ok((cat, cnt))
266        })?;
267
268        for (cat, cnt) in rows.flatten() {
269            categories.insert(cat, cnt);
270        }
271
272        Ok(VaultStats {
273            total_mappings: count,
274            categories,
275        })
276    }
277
278    fn encrypt_value(&self, plaintext: &str) -> Result<Vec<u8>> {
279        let mut nonce_bytes = [0u8; 12];
280        OsRng.fill_bytes(&mut nonce_bytes);
281        let nonce = Nonce::from_slice(&nonce_bytes);
282
283        let ciphertext = self.cipher
284            .encrypt(nonce, plaintext.as_bytes())
285            .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
286
287        let mut output = Vec::with_capacity(12 + ciphertext.len());
288        output.extend_from_slice(&nonce_bytes);
289        output.extend_from_slice(&ciphertext);
290        Ok(output)
291    }
292
293    fn decrypt_value(&self, data: &[u8]) -> Result<String> {
294        if data.len() < 12 {
295            bail!("Encrypted data too short");
296        }
297        let nonce = Nonce::from_slice(&data[..12]);
298        let ciphertext = &data[12..];
299
300        let plaintext = self.cipher
301            .decrypt(nonce, ciphertext)
302            .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or corrupted data"))?;
303
304        String::from_utf8(plaintext).context("Decrypted value is not valid UTF-8")
305    }
306
307    fn category_prefix(category: &EntityCategory) -> String {
308        match category {
309            EntityCategory::Person => "PERSON".into(),
310            EntityCategory::Organization => "ORG".into(),
311            EntityCategory::Location => "LOC".into(),
312            EntityCategory::Amount => "AMOUNT".into(),
313            EntityCategory::Percentage => "PCT".into(),
314            EntityCategory::Date => "DATE".into(),
315            EntityCategory::Email => "EMAIL".into(),
316            EntityCategory::PhoneNumber => "PHONE".into(),
317            EntityCategory::IpAddress => "IP".into(),
318            EntityCategory::Secret => "SECRET".into(),
319            EntityCategory::Url => "URL".into(),
320            EntityCategory::Project => "PROJECT".into(),
321            EntityCategory::Business => "BIZ".into(),
322            EntityCategory::Infra => "INFRA".into(),
323            EntityCategory::Custom(name) => name.to_uppercase(),
324        }
325    }
326
327    fn parse_category(s: &str) -> EntityCategory {
328        match s {
329            "PERSON" => EntityCategory::Person,
330            "ORG" => EntityCategory::Organization,
331            "LOC" => EntityCategory::Location,
332            "AMOUNT" => EntityCategory::Amount,
333            "PCT" => EntityCategory::Percentage,
334            "DATE" => EntityCategory::Date,
335            "EMAIL" => EntityCategory::Email,
336            "PHONE" => EntityCategory::PhoneNumber,
337            "IP" => EntityCategory::IpAddress,
338            "SECRET" => EntityCategory::Secret,
339            "URL" => EntityCategory::Url,
340            "PROJECT" => EntityCategory::Project,
341            "BIZ" => EntityCategory::Business,
342            "INFRA" => EntityCategory::Infra,
343            other => EntityCategory::Custom(other.to_string()),
344        }
345    }
346}
347
348#[derive(Debug, serde::Serialize)]
349pub struct VaultStats {
350    pub total_mappings: usize,
351    pub categories: HashMap<String, u32>,
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::EntityCategory;
358
359    #[test]
360    fn test_sqlite_vault_get_or_create() {
361        let mut vault = SqliteVault::ephemeral();
362        let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
363        let t2 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
364        assert_eq!(t1.token, t2.token);
365        assert_eq!(t1.token, "ORG_1");
366    }
367
368    #[test]
369    fn test_sqlite_vault_different_entities() {
370        let mut vault = SqliteVault::ephemeral();
371        let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
372        let t2 = vault.get_or_create("Infosys", &EntityCategory::Organization);
373        assert_ne!(t1.token, t2.token);
374        assert_eq!(t1.token, "ORG_1");
375        assert_eq!(t2.token, "ORG_2");
376    }
377
378    #[test]
379    fn test_sqlite_vault_lookup() {
380        let mut vault = SqliteVault::ephemeral();
381        vault.get_or_create("secret@example.com", &EntityCategory::Email);
382        assert_eq!(vault.lookup("EMAIL_1"), Some("secret@example.com"));
383        assert_eq!(vault.lookup("NONEXISTENT_99"), None);
384    }
385
386    #[test]
387    fn test_sqlite_vault_stats() {
388        let mut vault = SqliteVault::ephemeral();
389        vault.get_or_create("Alice", &EntityCategory::Person);
390        vault.get_or_create("Bob", &EntityCategory::Person);
391        vault.get_or_create("Acme Corp", &EntityCategory::Organization);
392        let stats = vault.stats();
393        assert_eq!(stats.total_mappings, 3);
394        assert_eq!(stats.categories.get("PERSON"), Some(&2));
395        assert_eq!(stats.categories.get("ORG"), Some(&1));
396    }
397
398    #[test]
399    fn test_sqlite_vault_persistence() {
400        let dir = tempfile::tempdir().unwrap();
401        let vault_path = dir.path().join("test.db");
402        let path_str = vault_path.to_str().unwrap();
403        let key = vec![0xAB; 32];
404
405        // Create and populate
406        {
407            let mut vault = SqliteVault::open(path_str, key.clone()).unwrap();
408            vault.get_or_create("Tata Motors", &EntityCategory::Organization);
409            vault.get_or_create("$1.2M", &EntityCategory::Amount);
410        }
411
412        // Reopen and verify
413        {
414            let vault = SqliteVault::open(path_str, key).unwrap();
415            assert_eq!(vault.lookup("ORG_1"), Some("Tata Motors"));
416            assert_eq!(vault.lookup("AMOUNT_1"), Some("$1.2M"));
417            assert_eq!(vault.stats().total_mappings, 2);
418        }
419    }
420
421    #[test]
422    fn test_sqlite_vault_wrong_key() {
423        let dir = tempfile::tempdir().unwrap();
424        let vault_path = dir.path().join("test.db");
425        let path_str = vault_path.to_str().unwrap();
426
427        {
428            let mut vault = SqliteVault::open(path_str, vec![0xAB; 32]).unwrap();
429            vault.get_or_create("secret", &EntityCategory::Secret);
430        }
431
432        // Reopen with wrong key — decryption should fail but not crash
433        let vault = SqliteVault::open(path_str, vec![0xCD; 32]).unwrap();
434        // The decrypted value will be "[decrypt_failed]"
435        assert_ne!(vault.lookup("SECRET_1"), Some("secret"));
436    }
437
438    #[test]
439    fn test_sqlite_vault_multi_user() {
440        let mut vault = SqliteVault::ephemeral();
441        let t1 = vault.get_or_create_for_user("Alice", &EntityCategory::Person, Some("user-1"));
442        let t2 = vault.get_or_create_for_user("Bob", &EntityCategory::Person, Some("user-2"));
443        assert_eq!(t1.token, "PERSON_user-1_1");
444        assert_eq!(t2.token, "PERSON_user-2_1");
445        // Each user gets their own counter
446        assert_eq!(t1.id, 1);
447        assert_eq!(t2.id, 1);
448    }
449}