dbx_core/storage/encryption/
wos.rs1use crate::error::DbxResult;
28use crate::storage::StorageBackend;
29use crate::storage::encryption::EncryptionConfig;
30use crate::storage::wos::WosBackend;
31use std::ops::RangeBounds;
32use std::path::Path;
33
34pub struct EncryptedWosBackend {
40 inner: WosBackend,
41 encryption: EncryptionConfig,
42}
43
44impl EncryptedWosBackend {
45 pub fn open(path: &Path, encryption: EncryptionConfig) -> DbxResult<Self> {
47 let inner = WosBackend::open(path)?;
48 Ok(Self { inner, encryption })
49 }
50
51 pub fn open_temporary(encryption: EncryptionConfig) -> DbxResult<Self> {
53 let inner = WosBackend::open_temporary()?;
54 Ok(Self { inner, encryption })
55 }
56
57 pub fn encryption_config(&self) -> &EncryptionConfig {
59 &self.encryption
60 }
61
62 pub fn rekey(&mut self, new_encryption: EncryptionConfig) -> DbxResult<usize> {
73 let table_names = self.inner.table_names()?;
74 let mut rekey_count = 0;
75
76 for table_name in &table_names {
77 let entries: Vec<(Vec<u8>, Vec<u8>)> = self
79 .inner
80 .scan(table_name, ..)?
81 .into_iter()
82 .filter_map(|(key, encrypted_value)| {
83 let aad = table_name.as_bytes();
85 self.encryption
86 .decrypt_with_aad(&encrypted_value, aad)
87 .ok()
88 .map(|plain| (key, plain))
89 })
90 .collect();
91
92 for (key, plaintext) in &entries {
94 let aad = table_name.as_bytes();
95 let new_ciphertext = new_encryption.encrypt_with_aad(plaintext, aad)?;
96 self.inner.insert(table_name, key, &new_ciphertext)?;
97 rekey_count += 1;
98 }
99 }
100
101 self.encryption = new_encryption;
102 self.inner.flush()?;
103 Ok(rekey_count)
104 }
105}
106
107impl StorageBackend for EncryptedWosBackend {
108 fn insert(&self, table: &str, key: &[u8], value: &[u8]) -> DbxResult<()> {
109 let aad = table.as_bytes();
110 let encrypted = self.encryption.encrypt_with_aad(value, aad)?;
111 self.inner.insert(table, key, &encrypted)
112 }
113
114 fn insert_batch(&self, table: &str, rows: Vec<(Vec<u8>, Vec<u8>)>) -> DbxResult<()> {
115 let aad = table.as_bytes();
116 let encrypted_rows: Vec<(Vec<u8>, Vec<u8>)> = rows
117 .into_iter()
118 .map(|(key, value)| {
119 let encrypted = self.encryption.encrypt_with_aad(&value, aad)?;
120 Ok((key, encrypted))
121 })
122 .collect::<DbxResult<Vec<_>>>()?;
123
124 self.inner.insert_batch(table, encrypted_rows)
125 }
126
127 fn get(&self, table: &str, key: &[u8]) -> DbxResult<Option<Vec<u8>>> {
128 match self.inner.get(table, key)? {
129 Some(encrypted) => {
130 let aad = table.as_bytes();
131 let decrypted = self.encryption.decrypt_with_aad(&encrypted, aad)?;
132 Ok(Some(decrypted))
133 }
134 None => Ok(None),
135 }
136 }
137
138 fn delete(&self, table: &str, key: &[u8]) -> DbxResult<bool> {
139 self.inner.delete(table, key)
140 }
141
142 fn scan<R: RangeBounds<Vec<u8>> + Clone>(
143 &self,
144 table: &str,
145 range: R,
146 ) -> DbxResult<Vec<(Vec<u8>, Vec<u8>)>> {
147 let encrypted_entries = self.inner.scan(table, range)?;
148 let aad = table.as_bytes();
149
150 encrypted_entries
151 .into_iter()
152 .map(|(key, encrypted)| {
153 let decrypted = self.encryption.decrypt_with_aad(&encrypted, aad)?;
154 Ok((key, decrypted))
155 })
156 .collect()
157 }
158
159 fn scan_one<R: RangeBounds<Vec<u8>> + Clone>(
160 &self,
161 table: &str,
162 range: R,
163 ) -> DbxResult<Option<(Vec<u8>, Vec<u8>)>> {
164 let aad = table.as_bytes();
165 match self.inner.scan_one(table, range)? {
166 Some((key, encrypted)) => {
167 let decrypted = self.encryption.decrypt_with_aad(&encrypted, aad)?;
168 Ok(Some((key, decrypted)))
169 }
170 None => Ok(None),
171 }
172 }
173
174 fn flush(&self) -> DbxResult<()> {
175 self.inner.flush()
176 }
177
178 fn count(&self, table: &str) -> DbxResult<usize> {
179 self.inner.count(table)
180 }
181
182 fn table_names(&self) -> DbxResult<Vec<String>> {
183 self.inner.table_names()
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::storage::encryption::EncryptionAlgorithm;
191
192 fn encrypted_wos() -> EncryptedWosBackend {
193 let enc = EncryptionConfig::from_password("test-password");
194 EncryptedWosBackend::open_temporary(enc).unwrap()
195 }
196
197 #[test]
198 fn insert_and_get_round_trip() {
199 let wos = encrypted_wos();
200 wos.insert("users", b"key1", b"Alice").unwrap();
201 let result = wos.get("users", b"key1").unwrap();
202 assert_eq!(result, Some(b"Alice".to_vec()));
203 }
204
205 #[test]
206 fn get_nonexistent_returns_none() {
207 let wos = encrypted_wos();
208 assert_eq!(wos.get("users", b"missing").unwrap(), None);
209 }
210
211 #[test]
212 fn delete_existing() {
213 let wos = encrypted_wos();
214 wos.insert("users", b"key1", b"Alice").unwrap();
215 assert!(wos.delete("users", b"key1").unwrap());
216 assert_eq!(wos.get("users", b"key1").unwrap(), None);
217 }
218
219 #[test]
220 fn upsert_overwrites() {
221 let wos = encrypted_wos();
222 wos.insert("t", b"k", b"v1").unwrap();
223 wos.insert("t", b"k", b"v2").unwrap();
224 assert_eq!(wos.get("t", b"k").unwrap(), Some(b"v2".to_vec()));
225 }
226
227 #[test]
228 fn scan_all_decrypted() {
229 let wos = encrypted_wos();
230 wos.insert("t", b"a", b"1").unwrap();
231 wos.insert("t", b"b", b"2").unwrap();
232 wos.insert("t", b"c", b"3").unwrap();
233
234 let all = wos.scan("t", ..).unwrap();
235 assert_eq!(all.len(), 3);
236 assert_eq!(all[0], (b"a".to_vec(), b"1".to_vec()));
237 assert_eq!(all[1], (b"b".to_vec(), b"2".to_vec()));
238 assert_eq!(all[2], (b"c".to_vec(), b"3".to_vec()));
239 }
240
241 #[test]
242 fn count_accuracy() {
243 let wos = encrypted_wos();
244 assert_eq!(wos.count("t").unwrap(), 0);
245 wos.insert("t", b"a", b"1").unwrap();
246 wos.insert("t", b"b", b"2").unwrap();
247 assert_eq!(wos.count("t").unwrap(), 2);
248 }
249
250 #[test]
251 fn table_names_tracks_tables() {
252 let wos = encrypted_wos();
253 wos.insert("users", b"a", b"1").unwrap();
254 wos.insert("orders", b"b", b"2").unwrap();
255 let mut names = wos.table_names().unwrap();
256 names.sort();
257 assert_eq!(names, vec!["orders".to_string(), "users".to_string()]);
258 }
259
260 #[test]
261 fn wrong_password_cannot_decrypt() {
262 let enc1 = EncryptionConfig::from_password("correct");
263 let enc2 = EncryptionConfig::from_password("wrong");
264
265 let wos = EncryptedWosBackend::open_temporary(enc1).unwrap();
266 wos.insert("t", b"k", b"secret").unwrap();
267
268 let raw = wos.inner.get("t", b"k").unwrap().unwrap();
270
271 let result = enc2.decrypt_with_aad(&raw, b"t");
273 assert!(result.is_err());
274 }
275
276 #[test]
277 fn cross_table_aad_prevents_swap() {
278 let wos = encrypted_wos();
279 wos.insert("table_a", b"k", b"data_a").unwrap();
280
281 let raw = wos.inner.get("table_a", b"k").unwrap().unwrap();
283
284 wos.inner.insert("table_b", b"k", &raw).unwrap();
286
287 let result = wos.get("table_b", b"k");
289 assert!(result.is_err(), "Cross-table AAD should prevent decryption");
290 }
291
292 #[test]
293 fn insert_batch_encrypted() {
294 let wos = encrypted_wos();
295 let rows = vec![
296 (b"k1".to_vec(), b"v1".to_vec()),
297 (b"k2".to_vec(), b"v2".to_vec()),
298 (b"k3".to_vec(), b"v3".to_vec()),
299 ];
300 wos.insert_batch("t", rows).unwrap();
301
302 assert_eq!(wos.get("t", b"k1").unwrap(), Some(b"v1".to_vec()));
303 assert_eq!(wos.get("t", b"k2").unwrap(), Some(b"v2".to_vec()));
304 assert_eq!(wos.get("t", b"k3").unwrap(), Some(b"v3".to_vec()));
305 }
306
307 #[test]
308 fn rekey_preserves_data() {
309 let enc_old = EncryptionConfig::from_password("old-password");
310 let enc_new = EncryptionConfig::from_password("new-password")
311 .with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305);
312
313 let mut wos = EncryptedWosBackend::open_temporary(enc_old).unwrap();
314 wos.insert("users", b"alice", b"Alice Data").unwrap();
315 wos.insert("users", b"bob", b"Bob Data").unwrap();
316 wos.insert("orders", b"order1", b"Order Data").unwrap();
317
318 let rekeyed = wos.rekey(enc_new).unwrap();
319 assert_eq!(rekeyed, 3);
320
321 assert_eq!(
323 wos.get("users", b"alice").unwrap(),
324 Some(b"Alice Data".to_vec())
325 );
326 assert_eq!(
327 wos.get("users", b"bob").unwrap(),
328 Some(b"Bob Data".to_vec())
329 );
330 assert_eq!(
331 wos.get("orders", b"order1").unwrap(),
332 Some(b"Order Data".to_vec())
333 );
334 }
335
336 #[test]
337 fn flush_persists() {
338 let wos = encrypted_wos();
339 wos.insert("t", b"key", b"val").unwrap();
340 wos.flush().unwrap();
341 assert_eq!(wos.get("t", b"key").unwrap(), Some(b"val".to_vec()));
342 }
343
344 #[test]
345 fn multiple_tables_isolation() {
346 let wos = encrypted_wos();
347 wos.insert("t1", b"k", b"v1").unwrap();
348 wos.insert("t2", b"k", b"v2").unwrap();
349 assert_eq!(wos.get("t1", b"k").unwrap(), Some(b"v1".to_vec()));
350 assert_eq!(wos.get("t2", b"k").unwrap(), Some(b"v2".to_vec()));
351 }
352
353 #[test]
354 fn large_value_round_trip() {
355 let wos = encrypted_wos();
356 let large_value: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
357 wos.insert("t", b"big", &large_value).unwrap();
358 let result = wos.get("t", b"big").unwrap().unwrap();
359 assert_eq!(result, large_value);
360 }
361}