Skip to main content

dbx_core/storage/encryption/
wos.rs

1//! Encrypted WOS (Write-Optimized Store) — transparent encryption wrapper.
2//!
3//! Wraps [`WosBackend`] to encrypt all values before storage and decrypt on read.
4//! Keys remain unencrypted to support range scans and ordered iteration.
5//!
6//! # Architecture
7//!
8//! ```text
9//! Application
10//!     │ plaintext value
11//!     ▼
12//! EncryptedWosBackend
13//!     │ encrypt(value) → ciphertext
14//!     ▼
15//! WosBackend (sled)
16//!     │ store ciphertext
17//!     ▼
18//! Disk
19//! ```
20//!
21//! # Security Properties
22//!
23//! - Values are encrypted with AEAD (confidentiality + integrity)
24//! - Keys are stored in plaintext (trade-off for range scan support)
25//! - AAD includes table name for cross-table attack prevention
26
27use crate::error::DbxResult;
28use crate::storage::StorageBackend;
29use crate::storage::encryption::EncryptionConfig;
30use crate::storage::wos::WosBackend;
31use std::ops::RangeBounds;
32use std::path::Path;
33
34/// Tier 3 with transparent encryption: sled-backed storage with AEAD encryption.
35///
36/// All values are encrypted before being written to sled and decrypted on read.
37/// Keys remain in plaintext to preserve ordered iteration and range scans.
38///
39pub struct EncryptedWosBackend {
40    inner: WosBackend,
41    encryption: EncryptionConfig,
42}
43
44impl EncryptedWosBackend {
45    /// Open an encrypted WOS at the given directory path.
46    pub fn open(path: &Path, encryption: EncryptionConfig) -> DbxResult<Self> {
47        let inner = WosBackend::open(path)?;
48        Ok(Self { inner, encryption })
49    }
50
51    /// Open a temporary encrypted WOS (for testing).
52    pub fn open_temporary(encryption: EncryptionConfig) -> DbxResult<Self> {
53        let inner = WosBackend::open_temporary()?;
54        Ok(Self { inner, encryption })
55    }
56
57    /// Get a reference to the encryption config.
58    pub fn encryption_config(&self) -> &EncryptionConfig {
59        &self.encryption
60    }
61
62    /// Re-key all data with a new encryption config.
63    ///
64    /// Reads all existing data, decrypts with the current key,
65    /// and re-encrypts with the new key.
66    ///
67    /// # Warning
68    ///
69    /// This operation is NOT atomic — if interrupted, some data may be
70    /// encrypted with the old key and some with the new key.
71    /// Always checkpoint/backup before re-keying.
72    pub fn rekey(&mut self, new_encryption: EncryptionConfig) -> DbxResult<usize> {
73        let table_names = self.inner.table_names()?;
74        let mut rekey_count = 0;
75
76        for table_name in &table_names {
77            // Read all entries with current key
78            let entries: Vec<(Vec<u8>, Vec<u8>)> = self
79                .inner
80                .scan(table_name, ..)?
81                .into_iter()
82                .filter_map(|(key, encrypted_value)| {
83                    // Decrypt with old key
84                    let aad = table_name.as_bytes();
85                    self.encryption
86                        .decrypt_with_aad(&encrypted_value, aad)
87                        .ok()
88                        .map(|plain| (key, plain))
89                })
90                .collect();
91
92            // Re-encrypt with new key and write back
93            for (key, plaintext) in &entries {
94                let aad = table_name.as_bytes();
95                let new_ciphertext = new_encryption.encrypt_with_aad(plaintext, aad)?;
96                self.inner.insert(table_name, key, &new_ciphertext)?;
97                rekey_count += 1;
98            }
99        }
100
101        self.encryption = new_encryption;
102        self.inner.flush()?;
103        Ok(rekey_count)
104    }
105}
106
107impl StorageBackend for EncryptedWosBackend {
108    fn insert(&self, table: &str, key: &[u8], value: &[u8]) -> DbxResult<()> {
109        let aad = table.as_bytes();
110        let encrypted = self.encryption.encrypt_with_aad(value, aad)?;
111        self.inner.insert(table, key, &encrypted)
112    }
113
114    fn insert_batch(&self, table: &str, rows: Vec<(Vec<u8>, Vec<u8>)>) -> DbxResult<()> {
115        let aad = table.as_bytes();
116        let encrypted_rows: Vec<(Vec<u8>, Vec<u8>)> = rows
117            .into_iter()
118            .map(|(key, value)| {
119                let encrypted = self.encryption.encrypt_with_aad(&value, aad)?;
120                Ok((key, encrypted))
121            })
122            .collect::<DbxResult<Vec<_>>>()?;
123
124        self.inner.insert_batch(table, encrypted_rows)
125    }
126
127    fn get(&self, table: &str, key: &[u8]) -> DbxResult<Option<Vec<u8>>> {
128        match self.inner.get(table, key)? {
129            Some(encrypted) => {
130                let aad = table.as_bytes();
131                let decrypted = self.encryption.decrypt_with_aad(&encrypted, aad)?;
132                Ok(Some(decrypted))
133            }
134            None => Ok(None),
135        }
136    }
137
138    fn delete(&self, table: &str, key: &[u8]) -> DbxResult<bool> {
139        self.inner.delete(table, key)
140    }
141
142    fn scan<R: RangeBounds<Vec<u8>> + Clone>(
143        &self,
144        table: &str,
145        range: R,
146    ) -> DbxResult<Vec<(Vec<u8>, Vec<u8>)>> {
147        let encrypted_entries = self.inner.scan(table, range)?;
148        let aad = table.as_bytes();
149
150        encrypted_entries
151            .into_iter()
152            .map(|(key, encrypted)| {
153                let decrypted = self.encryption.decrypt_with_aad(&encrypted, aad)?;
154                Ok((key, decrypted))
155            })
156            .collect()
157    }
158
159    fn scan_one<R: RangeBounds<Vec<u8>> + Clone>(
160        &self,
161        table: &str,
162        range: R,
163    ) -> DbxResult<Option<(Vec<u8>, Vec<u8>)>> {
164        let aad = table.as_bytes();
165        match self.inner.scan_one(table, range)? {
166            Some((key, encrypted)) => {
167                let decrypted = self.encryption.decrypt_with_aad(&encrypted, aad)?;
168                Ok(Some((key, decrypted)))
169            }
170            None => Ok(None),
171        }
172    }
173
174    fn flush(&self) -> DbxResult<()> {
175        self.inner.flush()
176    }
177
178    fn count(&self, table: &str) -> DbxResult<usize> {
179        self.inner.count(table)
180    }
181
182    fn table_names(&self) -> DbxResult<Vec<String>> {
183        self.inner.table_names()
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::storage::encryption::EncryptionAlgorithm;
191
192    fn encrypted_wos() -> EncryptedWosBackend {
193        let enc = EncryptionConfig::from_password("test-password");
194        EncryptedWosBackend::open_temporary(enc).unwrap()
195    }
196
197    #[test]
198    fn insert_and_get_round_trip() {
199        let wos = encrypted_wos();
200        wos.insert("users", b"key1", b"Alice").unwrap();
201        let result = wos.get("users", b"key1").unwrap();
202        assert_eq!(result, Some(b"Alice".to_vec()));
203    }
204
205    #[test]
206    fn get_nonexistent_returns_none() {
207        let wos = encrypted_wos();
208        assert_eq!(wos.get("users", b"missing").unwrap(), None);
209    }
210
211    #[test]
212    fn delete_existing() {
213        let wos = encrypted_wos();
214        wos.insert("users", b"key1", b"Alice").unwrap();
215        assert!(wos.delete("users", b"key1").unwrap());
216        assert_eq!(wos.get("users", b"key1").unwrap(), None);
217    }
218
219    #[test]
220    fn upsert_overwrites() {
221        let wos = encrypted_wos();
222        wos.insert("t", b"k", b"v1").unwrap();
223        wos.insert("t", b"k", b"v2").unwrap();
224        assert_eq!(wos.get("t", b"k").unwrap(), Some(b"v2".to_vec()));
225    }
226
227    #[test]
228    fn scan_all_decrypted() {
229        let wos = encrypted_wos();
230        wos.insert("t", b"a", b"1").unwrap();
231        wos.insert("t", b"b", b"2").unwrap();
232        wos.insert("t", b"c", b"3").unwrap();
233
234        let all = wos.scan("t", ..).unwrap();
235        assert_eq!(all.len(), 3);
236        assert_eq!(all[0], (b"a".to_vec(), b"1".to_vec()));
237        assert_eq!(all[1], (b"b".to_vec(), b"2".to_vec()));
238        assert_eq!(all[2], (b"c".to_vec(), b"3".to_vec()));
239    }
240
241    #[test]
242    fn count_accuracy() {
243        let wos = encrypted_wos();
244        assert_eq!(wos.count("t").unwrap(), 0);
245        wos.insert("t", b"a", b"1").unwrap();
246        wos.insert("t", b"b", b"2").unwrap();
247        assert_eq!(wos.count("t").unwrap(), 2);
248    }
249
250    #[test]
251    fn table_names_tracks_tables() {
252        let wos = encrypted_wos();
253        wos.insert("users", b"a", b"1").unwrap();
254        wos.insert("orders", b"b", b"2").unwrap();
255        let mut names = wos.table_names().unwrap();
256        names.sort();
257        assert_eq!(names, vec!["orders".to_string(), "users".to_string()]);
258    }
259
260    #[test]
261    fn wrong_password_cannot_decrypt() {
262        let enc1 = EncryptionConfig::from_password("correct");
263        let enc2 = EncryptionConfig::from_password("wrong");
264
265        let wos = EncryptedWosBackend::open_temporary(enc1).unwrap();
266        wos.insert("t", b"k", b"secret").unwrap();
267
268        // Read raw encrypted value from inner sled
269        let raw = wos.inner.get("t", b"k").unwrap().unwrap();
270
271        // Trying to decrypt with wrong key should fail
272        let result = enc2.decrypt_with_aad(&raw, b"t");
273        assert!(result.is_err());
274    }
275
276    #[test]
277    fn cross_table_aad_prevents_swap() {
278        let wos = encrypted_wos();
279        wos.insert("table_a", b"k", b"data_a").unwrap();
280
281        // Read raw encrypted value from table_a
282        let raw = wos.inner.get("table_a", b"k").unwrap().unwrap();
283
284        // Write it to table_b as if it were valid
285        wos.inner.insert("table_b", b"k", &raw).unwrap();
286
287        // Reading from table_b should fail (wrong AAD)
288        let result = wos.get("table_b", b"k");
289        assert!(result.is_err(), "Cross-table AAD should prevent decryption");
290    }
291
292    #[test]
293    fn insert_batch_encrypted() {
294        let wos = encrypted_wos();
295        let rows = vec![
296            (b"k1".to_vec(), b"v1".to_vec()),
297            (b"k2".to_vec(), b"v2".to_vec()),
298            (b"k3".to_vec(), b"v3".to_vec()),
299        ];
300        wos.insert_batch("t", rows).unwrap();
301
302        assert_eq!(wos.get("t", b"k1").unwrap(), Some(b"v1".to_vec()));
303        assert_eq!(wos.get("t", b"k2").unwrap(), Some(b"v2".to_vec()));
304        assert_eq!(wos.get("t", b"k3").unwrap(), Some(b"v3".to_vec()));
305    }
306
307    #[test]
308    fn rekey_preserves_data() {
309        let enc_old = EncryptionConfig::from_password("old-password");
310        let enc_new = EncryptionConfig::from_password("new-password")
311            .with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305);
312
313        let mut wos = EncryptedWosBackend::open_temporary(enc_old).unwrap();
314        wos.insert("users", b"alice", b"Alice Data").unwrap();
315        wos.insert("users", b"bob", b"Bob Data").unwrap();
316        wos.insert("orders", b"order1", b"Order Data").unwrap();
317
318        let rekeyed = wos.rekey(enc_new).unwrap();
319        assert_eq!(rekeyed, 3);
320
321        // Verify data still readable with new key
322        assert_eq!(
323            wos.get("users", b"alice").unwrap(),
324            Some(b"Alice Data".to_vec())
325        );
326        assert_eq!(
327            wos.get("users", b"bob").unwrap(),
328            Some(b"Bob Data".to_vec())
329        );
330        assert_eq!(
331            wos.get("orders", b"order1").unwrap(),
332            Some(b"Order Data".to_vec())
333        );
334    }
335
336    #[test]
337    fn flush_persists() {
338        let wos = encrypted_wos();
339        wos.insert("t", b"key", b"val").unwrap();
340        wos.flush().unwrap();
341        assert_eq!(wos.get("t", b"key").unwrap(), Some(b"val".to_vec()));
342    }
343
344    #[test]
345    fn multiple_tables_isolation() {
346        let wos = encrypted_wos();
347        wos.insert("t1", b"k", b"v1").unwrap();
348        wos.insert("t2", b"k", b"v2").unwrap();
349        assert_eq!(wos.get("t1", b"k").unwrap(), Some(b"v1".to_vec()));
350        assert_eq!(wos.get("t2", b"k").unwrap(), Some(b"v2".to_vec()));
351    }
352
353    #[test]
354    fn large_value_round_trip() {
355        let wos = encrypted_wos();
356        let large_value: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
357        wos.insert("t", b"big", &large_value).unwrap();
358        let result = wos.get("t", b"big").unwrap().unwrap();
359        assert_eq!(result, large_value);
360    }
361}