1use crate::{EntityCategory, PseudoToken};
10use aes_gcm::{
11 aead::{Aead, KeyInit, OsRng},
12 Aes256Gcm, Nonce,
13};
14use anyhow::{bail, Context, Result};
15use rand::RngCore;
16use rusqlite::{params, Connection};
17use std::collections::HashMap;
18use zeroize::Zeroize;
19
20pub struct SqliteVault {
22 conn: Connection,
23 cipher: Aes256Gcm,
24 forward_cache: HashMap<String, PseudoToken>,
26 reverse_cache: HashMap<String, String>,
27}
28
29struct SensitiveBytes(Vec<u8>);
31
32impl Drop for SensitiveBytes {
33 fn drop(&mut self) {
34 self.0.zeroize();
35 }
36}
37
38impl SqliteVault {
39 pub fn open(path: &str, key: Vec<u8>) -> Result<Self> {
41 if key.len() != 32 {
42 bail!("Vault key must be exactly 32 bytes (AES-256)");
43 }
44
45 let conn = Connection::open(path)
46 .with_context(|| format!("Failed to open vault database: {}", path))?;
47
48 conn.pragma_update(None, "journal_mode", "WAL")?;
50 conn.pragma_update(None, "foreign_keys", "ON")?;
51
52 Self::init_schema(&conn)?;
53
54 let cipher = Aes256Gcm::new_from_slice(&key)
55 .map_err(|_| anyhow::anyhow!("Invalid AES-256-GCM key"))?;
56
57 let _key_guard = SensitiveBytes(key);
58
59 let mut vault = Self {
60 conn,
61 cipher,
62 forward_cache: HashMap::new(),
63 reverse_cache: HashMap::new(),
64 };
65
66 vault.load_cache()?;
67 Ok(vault)
68 }
69
70 pub fn ephemeral() -> Self {
72 let mut key = vec![0u8; 32];
73 OsRng.fill_bytes(&mut key);
74
75 let conn = Connection::open_in_memory().expect("Failed to open in-memory SQLite");
76 Self::init_schema(&conn).expect("Failed to init schema");
77
78 let cipher = Aes256Gcm::new_from_slice(&key)
79 .expect("Invalid key");
80
81 let _key_guard = SensitiveBytes(key);
82
83 Self {
84 conn,
85 cipher,
86 forward_cache: HashMap::new(),
87 reverse_cache: HashMap::new(),
88 }
89 }
90
91 fn init_schema(conn: &Connection) -> Result<()> {
92 conn.execute_batch(
93 "CREATE TABLE IF NOT EXISTS mappings (
94 id INTEGER PRIMARY KEY AUTOINCREMENT,
95 original_enc BLOB NOT NULL,
96 token TEXT NOT NULL UNIQUE,
97 category TEXT NOT NULL,
98 token_id INTEGER NOT NULL,
99 user_id TEXT,
100 created_at TEXT NOT NULL DEFAULT (datetime('now'))
101 );
102 CREATE INDEX IF NOT EXISTS idx_mappings_token ON mappings(token);
103 CREATE INDEX IF NOT EXISTS idx_mappings_user ON mappings(user_id);
104
105 CREATE TABLE IF NOT EXISTS counters (
106 category TEXT NOT NULL,
107 user_id TEXT,
108 counter INTEGER NOT NULL DEFAULT 0,
109 PRIMARY KEY (category, user_id)
110 );"
111 ).context("Failed to initialize vault schema")?;
112 Ok(())
113 }
114
115 fn load_cache(&mut self) -> Result<()> {
117 let mut stmt = self.conn.prepare(
118 "SELECT original_enc, token, category, token_id FROM mappings"
119 )?;
120
121 let rows = stmt.query_map([], |row| {
122 let enc: Vec<u8> = row.get(0)?;
123 let token: String = row.get(1)?;
124 let category: String = row.get(2)?;
125 let token_id: u32 = row.get(3)?;
126 Ok((enc, token, category, token_id))
127 })?;
128
129 for row in rows {
130 let (enc, token, category_str, token_id) = row?;
131 let original = self.decrypt_value(&enc)
132 .unwrap_or_else(|_| "[decrypt_failed]".to_string());
133 let category = Self::parse_category(&category_str);
134
135 let pseudo = PseudoToken {
136 token: token.clone(),
137 category,
138 id: token_id,
139 };
140
141 self.forward_cache.insert(original.clone(), pseudo);
142 self.reverse_cache.insert(token, original);
143 }
144
145 Ok(())
146 }
147
148 pub fn get_or_create(&mut self, original: &str, category: &EntityCategory) -> PseudoToken {
150 self.get_or_create_for_user(original, category, None)
151 }
152
153 pub fn get_or_create_for_user(
155 &mut self,
156 original: &str,
157 category: &EntityCategory,
158 user_id: Option<&str>,
159 ) -> PseudoToken {
160 if let Some(token) = self.forward_cache.get(original) {
162 return token.clone();
163 }
164
165 let prefix = Self::category_prefix(category);
166 let user_key = user_id.unwrap_or("");
167
168 self.conn.execute(
170 "INSERT INTO counters (category, user_id, counter) VALUES (?1, ?2, 1)
171 ON CONFLICT(category, user_id) DO UPDATE SET counter = counter + 1",
172 params![prefix, user_key],
173 ).expect("Failed to update counter");
174
175 let counter: u32 = self.conn.query_row(
176 "SELECT counter FROM counters WHERE category = ?1 AND user_id = ?2",
177 params![prefix, user_key],
178 |row| row.get(0),
179 ).expect("Failed to read counter");
180
181 let token_str = if user_key.is_empty() {
183 format!("{}_{}", prefix, counter)
184 } else {
185 format!("{}_{}_{}", prefix, user_key, counter)
186 };
187
188 let token = PseudoToken {
189 token: token_str,
190 category: category.clone(),
191 id: counter,
192 };
193
194 let encrypted = self.encrypt_value(original)
196 .expect("Failed to encrypt value");
197
198 self.conn.execute(
199 "INSERT INTO mappings (original_enc, token, category, token_id, user_id) VALUES (?1, ?2, ?3, ?4, ?5)",
200 params![encrypted, token.token, prefix, counter, user_id],
201 ).expect("Failed to insert mapping");
202
203 self.forward_cache.insert(original.to_string(), token.clone());
205 self.reverse_cache.insert(token.token.clone(), original.to_string());
206
207 token
208 }
209
210 pub fn lookup(&self, token: &str) -> Option<&str> {
212 self.reverse_cache.get(token).map(|s| s.as_str())
213 }
214
215 pub fn reverse_mappings(&self) -> HashMap<String, String> {
217 self.reverse_cache.clone()
218 }
219
220 pub fn save(&self) -> Result<()> {
222 Ok(())
223 }
224
225 pub fn stats(&self) -> VaultStats {
227 let mut categories = HashMap::new();
228
229 let mut stmt = self.conn.prepare(
230 "SELECT category, counter FROM counters WHERE user_id = ''"
231 ).expect("Failed to prepare stats query");
232
233 let rows = stmt.query_map([], |row| {
234 let cat: String = row.get(0)?;
235 let count: u32 = row.get(1)?;
236 Ok((cat, count))
237 }).expect("Failed to query stats");
238
239 for (cat, count) in rows.flatten() {
240 categories.insert(cat, count);
241 }
242
243 VaultStats {
244 total_mappings: self.forward_cache.len(),
245 categories,
246 }
247 }
248
249 pub fn user_stats(&self, user_id: &str) -> Result<VaultStats> {
251 let count: usize = self.conn.query_row(
252 "SELECT COUNT(*) FROM mappings WHERE user_id = ?1",
253 params![user_id],
254 |row| row.get(0),
255 )?;
256
257 let mut categories = HashMap::new();
258 let mut stmt = self.conn.prepare(
259 "SELECT category, counter FROM counters WHERE user_id = ?1"
260 )?;
261
262 let rows = stmt.query_map(params![user_id], |row| {
263 let cat: String = row.get(0)?;
264 let cnt: u32 = row.get(1)?;
265 Ok((cat, cnt))
266 })?;
267
268 for (cat, cnt) in rows.flatten() {
269 categories.insert(cat, cnt);
270 }
271
272 Ok(VaultStats {
273 total_mappings: count,
274 categories,
275 })
276 }
277
278 fn encrypt_value(&self, plaintext: &str) -> Result<Vec<u8>> {
279 let mut nonce_bytes = [0u8; 12];
280 OsRng.fill_bytes(&mut nonce_bytes);
281 let nonce = Nonce::from_slice(&nonce_bytes);
282
283 let ciphertext = self.cipher
284 .encrypt(nonce, plaintext.as_bytes())
285 .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
286
287 let mut output = Vec::with_capacity(12 + ciphertext.len());
288 output.extend_from_slice(&nonce_bytes);
289 output.extend_from_slice(&ciphertext);
290 Ok(output)
291 }
292
293 fn decrypt_value(&self, data: &[u8]) -> Result<String> {
294 if data.len() < 12 {
295 bail!("Encrypted data too short");
296 }
297 let nonce = Nonce::from_slice(&data[..12]);
298 let ciphertext = &data[12..];
299
300 let plaintext = self.cipher
301 .decrypt(nonce, ciphertext)
302 .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or corrupted data"))?;
303
304 String::from_utf8(plaintext).context("Decrypted value is not valid UTF-8")
305 }
306
307 fn category_prefix(category: &EntityCategory) -> String {
308 match category {
309 EntityCategory::Person => "PERSON".into(),
310 EntityCategory::Organization => "ORG".into(),
311 EntityCategory::Location => "LOC".into(),
312 EntityCategory::Amount => "AMOUNT".into(),
313 EntityCategory::Percentage => "PCT".into(),
314 EntityCategory::Date => "DATE".into(),
315 EntityCategory::Email => "EMAIL".into(),
316 EntityCategory::PhoneNumber => "PHONE".into(),
317 EntityCategory::IpAddress => "IP".into(),
318 EntityCategory::Secret => "SECRET".into(),
319 EntityCategory::Url => "URL".into(),
320 EntityCategory::Project => "PROJECT".into(),
321 EntityCategory::Business => "BIZ".into(),
322 EntityCategory::Infra => "INFRA".into(),
323 EntityCategory::Custom(name) => name.to_uppercase(),
324 }
325 }
326
327 fn parse_category(s: &str) -> EntityCategory {
328 match s {
329 "PERSON" => EntityCategory::Person,
330 "ORG" => EntityCategory::Organization,
331 "LOC" => EntityCategory::Location,
332 "AMOUNT" => EntityCategory::Amount,
333 "PCT" => EntityCategory::Percentage,
334 "DATE" => EntityCategory::Date,
335 "EMAIL" => EntityCategory::Email,
336 "PHONE" => EntityCategory::PhoneNumber,
337 "IP" => EntityCategory::IpAddress,
338 "SECRET" => EntityCategory::Secret,
339 "URL" => EntityCategory::Url,
340 "PROJECT" => EntityCategory::Project,
341 "BIZ" => EntityCategory::Business,
342 "INFRA" => EntityCategory::Infra,
343 other => EntityCategory::Custom(other.to_string()),
344 }
345 }
346}
347
348#[derive(Debug, serde::Serialize)]
349pub struct VaultStats {
350 pub total_mappings: usize,
351 pub categories: HashMap<String, u32>,
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::EntityCategory;
358
359 #[test]
360 fn test_sqlite_vault_get_or_create() {
361 let mut vault = SqliteVault::ephemeral();
362 let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
363 let t2 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
364 assert_eq!(t1.token, t2.token);
365 assert_eq!(t1.token, "ORG_1");
366 }
367
368 #[test]
369 fn test_sqlite_vault_different_entities() {
370 let mut vault = SqliteVault::ephemeral();
371 let t1 = vault.get_or_create("Tata Motors", &EntityCategory::Organization);
372 let t2 = vault.get_or_create("Infosys", &EntityCategory::Organization);
373 assert_ne!(t1.token, t2.token);
374 assert_eq!(t1.token, "ORG_1");
375 assert_eq!(t2.token, "ORG_2");
376 }
377
378 #[test]
379 fn test_sqlite_vault_lookup() {
380 let mut vault = SqliteVault::ephemeral();
381 vault.get_or_create("secret@example.com", &EntityCategory::Email);
382 assert_eq!(vault.lookup("EMAIL_1"), Some("secret@example.com"));
383 assert_eq!(vault.lookup("NONEXISTENT_99"), None);
384 }
385
386 #[test]
387 fn test_sqlite_vault_stats() {
388 let mut vault = SqliteVault::ephemeral();
389 vault.get_or_create("Alice", &EntityCategory::Person);
390 vault.get_or_create("Bob", &EntityCategory::Person);
391 vault.get_or_create("Acme Corp", &EntityCategory::Organization);
392 let stats = vault.stats();
393 assert_eq!(stats.total_mappings, 3);
394 assert_eq!(stats.categories.get("PERSON"), Some(&2));
395 assert_eq!(stats.categories.get("ORG"), Some(&1));
396 }
397
398 #[test]
399 fn test_sqlite_vault_persistence() {
400 let dir = tempfile::tempdir().unwrap();
401 let vault_path = dir.path().join("test.db");
402 let path_str = vault_path.to_str().unwrap();
403 let key = vec![0xAB; 32];
404
405 {
407 let mut vault = SqliteVault::open(path_str, key.clone()).unwrap();
408 vault.get_or_create("Tata Motors", &EntityCategory::Organization);
409 vault.get_or_create("$1.2M", &EntityCategory::Amount);
410 }
411
412 {
414 let vault = SqliteVault::open(path_str, key).unwrap();
415 assert_eq!(vault.lookup("ORG_1"), Some("Tata Motors"));
416 assert_eq!(vault.lookup("AMOUNT_1"), Some("$1.2M"));
417 assert_eq!(vault.stats().total_mappings, 2);
418 }
419 }
420
421 #[test]
422 fn test_sqlite_vault_wrong_key() {
423 let dir = tempfile::tempdir().unwrap();
424 let vault_path = dir.path().join("test.db");
425 let path_str = vault_path.to_str().unwrap();
426
427 {
428 let mut vault = SqliteVault::open(path_str, vec![0xAB; 32]).unwrap();
429 vault.get_or_create("secret", &EntityCategory::Secret);
430 }
431
432 let vault = SqliteVault::open(path_str, vec![0xCD; 32]).unwrap();
434 assert_ne!(vault.lookup("SECRET_1"), Some("secret"));
436 }
437
438 #[test]
439 fn test_sqlite_vault_multi_user() {
440 let mut vault = SqliteVault::ephemeral();
441 let t1 = vault.get_or_create_for_user("Alice", &EntityCategory::Person, Some("user-1"));
442 let t2 = vault.get_or_create_for_user("Bob", &EntityCategory::Person, Some("user-2"));
443 assert_eq!(t1.token, "PERSON_user-1_1");
444 assert_eq!(t2.token, "PERSON_user-2_1");
445 assert_eq!(t1.id, 1);
447 assert_eq!(t2.id, 1);
448 }
449}