Skip to main content

mqdb_core/storage/
encrypted_backend.rs

1// Copyright 2025-2026 LabOverWire. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::backend::{BatchOperations, StorageBackend};
5use crate::error::{Error, Result};
6use ring::aead::{AES_256_GCM, Aad, LessSafeKey, Nonce, UnboundKey};
7use ring::pbkdf2;
8use ring::rand::{SecureRandom, SystemRandom};
9use std::num::NonZeroU32;
10use std::sync::Arc;
11
12const NONCE_LEN: usize = 12;
13const KEY_LEN: usize = 32;
14const SALT_LEN: usize = 32;
15const TAG_LEN: usize = 16;
16const PBKDF2_ITERATIONS: u32 = 600_000;
17
18const SALT_KEY: &[u8] = b"_crypto/salt";
19const CHECK_KEY: &[u8] = b"_crypto/check";
20const CHECK_PLAINTEXT: &[u8] = b"mqdb";
21
22pub struct EncryptedBackend {
23    inner: Arc<dyn StorageBackend>,
24    key: Arc<LessSafeKey>,
25}
26
27impl EncryptedBackend {
28    /// # Errors
29    /// Returns an error if key derivation or passphrase verification fails.
30    pub fn open(inner: Arc<dyn StorageBackend>, passphrase: &str) -> Result<Self> {
31        let existing_salt = inner.get(SALT_KEY)?;
32
33        let salt = if let Some(s) = existing_salt {
34            if s.len() != SALT_LEN {
35                return Err(Error::Internal("corrupt encryption salt".into()));
36            }
37            let mut arr = [0u8; SALT_LEN];
38            arr.copy_from_slice(&s);
39            arr
40        } else {
41            let new_salt = generate_salt()?;
42            inner.insert(SALT_KEY, &new_salt)?;
43            new_salt
44        };
45
46        let key = Arc::new(derive_key(passphrase, &salt)?);
47
48        let existing_check = inner.get(CHECK_KEY)?;
49        if let Some(encrypted_check) = existing_check {
50            decrypt(&key, CHECK_KEY, &encrypted_check)
51                .map_err(|_| Error::Internal("invalid passphrase".into()))?;
52        } else {
53            let encrypted = encrypt(&key, CHECK_KEY, CHECK_PLAINTEXT)?;
54            inner.insert(CHECK_KEY, &encrypted)?;
55        }
56
57        Ok(Self { inner, key })
58    }
59}
60
61impl StorageBackend for EncryptedBackend {
62    fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
63        match self.inner.get(key)? {
64            Some(encrypted) => Ok(Some(decrypt(&self.key, key, &encrypted)?)),
65            None => Ok(None),
66        }
67    }
68
69    fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
70        let encrypted = encrypt(&self.key, key, value)?;
71        self.inner.insert(key, &encrypted)
72    }
73
74    fn remove(&self, key: &[u8]) -> Result<()> {
75        self.inner.remove(key)
76    }
77
78    fn prefix_scan(&self, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
79        let raw = self.inner.prefix_scan(prefix)?;
80        decrypt_pairs(&self.key, raw)
81    }
82
83    fn prefix_count(&self, prefix: &[u8]) -> Result<usize> {
84        self.inner.prefix_count(prefix)
85    }
86
87    fn prefix_scan_keys(&self, prefix: &[u8]) -> Result<Vec<Vec<u8>>> {
88        self.inner.prefix_scan_keys(prefix)
89    }
90
91    fn prefix_scan_batch(
92        &self,
93        prefix: &[u8],
94        batch_size: usize,
95        after_key: Option<&[u8]>,
96    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
97        let raw = self
98            .inner
99            .prefix_scan_batch(prefix, batch_size, after_key)?;
100        decrypt_pairs(&self.key, raw)
101    }
102
103    fn range_scan(&self, start: &[u8], end: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
104        let raw = self.inner.range_scan(start, end)?;
105        decrypt_pairs(&self.key, raw)
106    }
107
108    fn batch(&self) -> Box<dyn BatchOperations> {
109        Box::new(EncryptedBatch {
110            inner: self.inner.batch(),
111            backend: Arc::clone(&self.inner),
112            key: Arc::clone(&self.key),
113            pending_expects: Vec::new(),
114        })
115    }
116
117    fn flush(&self) -> Result<()> {
118        self.inner.flush()
119    }
120}
121
122struct EncryptedBatch {
123    inner: Box<dyn BatchOperations>,
124    backend: Arc<dyn StorageBackend>,
125    key: Arc<LessSafeKey>,
126    pending_expects: Vec<(Vec<u8>, Vec<u8>)>,
127}
128
129impl BatchOperations for EncryptedBatch {
130    fn insert(&mut self, key: Vec<u8>, value: Vec<u8>) {
131        match encrypt(&self.key, &key, &value) {
132            Ok(encrypted) => self.inner.insert(key, encrypted),
133            Err(_) => self.inner.insert(key, value),
134        }
135    }
136
137    fn remove(&mut self, key: Vec<u8>) {
138        self.inner.remove(key);
139    }
140
141    fn expect_value(&mut self, key: Vec<u8>, expected_value: Vec<u8>) {
142        self.pending_expects.push((key, expected_value));
143    }
144
145    fn commit(mut self: Box<Self>) -> Result<()> {
146        for (key, expected_plaintext) in &self.pending_expects {
147            let stored = self.backend.get(key)?;
148            match stored {
149                Some(encrypted) => {
150                    let decrypted = decrypt(&self.key, key, &encrypted)?;
151                    if decrypted != *expected_plaintext {
152                        return Err(Error::Conflict(
153                            "optimistic lock failed: value was modified".into(),
154                        ));
155                    }
156                    self.inner.expect_value(key.clone(), encrypted);
157                }
158                None => {
159                    return Err(Error::Conflict(
160                        "optimistic lock failed: value was modified".into(),
161                    ));
162                }
163            }
164        }
165        self.inner.commit()
166    }
167}
168
169fn generate_salt() -> Result<[u8; SALT_LEN]> {
170    let rng = SystemRandom::new();
171    let mut salt = [0u8; SALT_LEN];
172    rng.fill(&mut salt)
173        .map_err(|_| Error::Internal("random generation failed".into()))?;
174    Ok(salt)
175}
176
177fn derive_key(passphrase: &str, salt: &[u8]) -> Result<LessSafeKey> {
178    let iterations =
179        NonZeroU32::new(PBKDF2_ITERATIONS).expect("PBKDF2_ITERATIONS is non-zero constant");
180    let mut key_bytes = [0u8; KEY_LEN];
181    pbkdf2::derive(
182        pbkdf2::PBKDF2_HMAC_SHA256,
183        iterations,
184        salt,
185        passphrase.as_bytes(),
186        &mut key_bytes,
187    );
188    let unbound = UnboundKey::new(&AES_256_GCM, &key_bytes)
189        .map_err(|_| Error::Internal("key construction failed".into()))?;
190    Ok(LessSafeKey::new(unbound))
191}
192
193fn encrypt(key: &LessSafeKey, storage_key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> {
194    let rng = SystemRandom::new();
195    let mut nonce_bytes = [0u8; NONCE_LEN];
196    rng.fill(&mut nonce_bytes)
197        .map_err(|_| Error::Internal("nonce generation failed".into()))?;
198    let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)
199        .map_err(|_| Error::Internal("nonce construction failed".into()))?;
200    let aad = Aad::from(storage_key);
201
202    let mut in_out = Vec::with_capacity(plaintext.len() + TAG_LEN);
203    in_out.extend_from_slice(plaintext);
204    key.seal_in_place_append_tag(nonce, aad, &mut in_out)
205        .map_err(|_| Error::Internal("encryption failed".into()))?;
206
207    let mut output = Vec::with_capacity(NONCE_LEN + in_out.len());
208    output.extend_from_slice(&nonce_bytes);
209    output.extend_from_slice(&in_out);
210    Ok(output)
211}
212
213fn decrypt(key: &LessSafeKey, storage_key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
214    if data.len() < NONCE_LEN + TAG_LEN + 1 {
215        return Err(Error::Internal("ciphertext too short".into()));
216    }
217    let (nonce_bytes, ciphertext) = data.split_at(NONCE_LEN);
218    let nonce = Nonce::try_assume_unique_for_key(nonce_bytes)
219        .map_err(|_| Error::Internal("nonce construction failed".into()))?;
220    let aad = Aad::from(storage_key);
221
222    let mut in_out = ciphertext.to_vec();
223    let plaintext = key
224        .open_in_place(nonce, aad, &mut in_out)
225        .map_err(|_| Error::Internal("decryption failed".into()))?;
226    Ok(plaintext.to_vec())
227}
228
229fn decrypt_pairs(
230    key: &LessSafeKey,
231    pairs: Vec<(Vec<u8>, Vec<u8>)>,
232) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
233    let mut decrypted = Vec::with_capacity(pairs.len());
234    for (k, v) in pairs {
235        let plaintext = decrypt(key, &k, &v)?;
236        decrypted.push((k, plaintext));
237    }
238    Ok(decrypted)
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::storage::MemoryBackend;
245
246    fn make_encrypted(passphrase: &str) -> (Arc<MemoryBackend>, EncryptedBackend) {
247        let memory = Arc::new(MemoryBackend::new());
248        let encrypted =
249            EncryptedBackend::open(Arc::clone(&memory) as Arc<dyn StorageBackend>, passphrase)
250                .unwrap();
251        (memory, encrypted)
252    }
253
254    #[test]
255    fn roundtrip() {
256        let (_mem, enc) = make_encrypted("test-passphrase");
257
258        enc.insert(b"users/1", b"alice").unwrap();
259        let val = enc.get(b"users/1").unwrap().unwrap();
260        assert_eq!(val, b"alice");
261    }
262
263    #[test]
264    fn stored_values_are_encrypted() {
265        let (mem, enc) = make_encrypted("test-passphrase");
266
267        enc.insert(b"users/1", b"alice").unwrap();
268        let raw = mem.get(b"users/1").unwrap().unwrap();
269        assert_ne!(raw, b"alice");
270        assert!(raw.len() > b"alice".len());
271    }
272
273    #[test]
274    fn get_missing_key() {
275        let (_mem, enc) = make_encrypted("test-passphrase");
276        assert_eq!(enc.get(b"nonexistent").unwrap(), None);
277    }
278
279    #[test]
280    fn remove_key() {
281        let (_mem, enc) = make_encrypted("test-passphrase");
282
283        enc.insert(b"key", b"value").unwrap();
284        enc.remove(b"key").unwrap();
285        assert_eq!(enc.get(b"key").unwrap(), None);
286    }
287
288    #[test]
289    fn prefix_scan_decrypts() {
290        let (_mem, enc) = make_encrypted("test-passphrase");
291
292        enc.insert(b"users/1", b"alice").unwrap();
293        enc.insert(b"users/2", b"bob").unwrap();
294        enc.insert(b"posts/1", b"hello").unwrap();
295
296        let results = enc.prefix_scan(b"users/").unwrap();
297        assert_eq!(results.len(), 2);
298        assert_eq!(results[0], (b"users/1".to_vec(), b"alice".to_vec()));
299        assert_eq!(results[1], (b"users/2".to_vec(), b"bob".to_vec()));
300    }
301
302    #[test]
303    fn range_scan_decrypts() {
304        let (_mem, enc) = make_encrypted("test-passphrase");
305
306        enc.insert(b"a", b"1").unwrap();
307        enc.insert(b"b", b"2").unwrap();
308        enc.insert(b"c", b"3").unwrap();
309        enc.insert(b"d", b"4").unwrap();
310
311        let results = enc.range_scan(b"b", b"d").unwrap();
312        assert_eq!(results.len(), 2);
313        assert_eq!(results[0], (b"b".to_vec(), b"2".to_vec()));
314        assert_eq!(results[1], (b"c".to_vec(), b"3".to_vec()));
315    }
316
317    #[test]
318    fn batch_insert_and_get() {
319        let (_mem, enc) = make_encrypted("test-passphrase");
320
321        let mut batch = enc.batch();
322        batch.insert(b"k1".to_vec(), b"v1".to_vec());
323        batch.insert(b"k2".to_vec(), b"v2".to_vec());
324        batch.commit().unwrap();
325
326        assert_eq!(enc.get(b"k1").unwrap(), Some(b"v1".to_vec()));
327        assert_eq!(enc.get(b"k2").unwrap(), Some(b"v2".to_vec()));
328    }
329
330    #[test]
331    fn batch_expect_value_success() {
332        let (_mem, enc) = make_encrypted("test-passphrase");
333
334        enc.insert(b"key", b"original").unwrap();
335
336        let mut batch = enc.batch();
337        batch.expect_value(b"key".to_vec(), b"original".to_vec());
338        batch.insert(b"key".to_vec(), b"updated".to_vec());
339        batch.commit().unwrap();
340
341        assert_eq!(enc.get(b"key").unwrap(), Some(b"updated".to_vec()));
342    }
343
344    #[test]
345    fn batch_expect_value_failure() {
346        let (_mem, enc) = make_encrypted("test-passphrase");
347
348        enc.insert(b"key", b"original").unwrap();
349
350        let mut batch = enc.batch();
351        batch.expect_value(b"key".to_vec(), b"wrong".to_vec());
352        batch.insert(b"key".to_vec(), b"updated".to_vec());
353
354        let result = batch.commit();
355        assert!(result.is_err());
356        assert_eq!(enc.get(b"key").unwrap(), Some(b"original".to_vec()));
357    }
358
359    #[test]
360    fn wrong_passphrase_rejected() {
361        let memory = Arc::new(MemoryBackend::new());
362        let _enc = EncryptedBackend::open(
363            Arc::clone(&memory) as Arc<dyn StorageBackend>,
364            "correct-passphrase",
365        )
366        .unwrap();
367
368        let result = EncryptedBackend::open(memory as Arc<dyn StorageBackend>, "wrong-passphrase");
369        assert!(result.is_err());
370    }
371
372    #[test]
373    fn correct_passphrase_reopens() {
374        let memory = Arc::new(MemoryBackend::new());
375
376        {
377            let enc = EncryptedBackend::open(
378                Arc::clone(&memory) as Arc<dyn StorageBackend>,
379                "my-passphrase",
380            )
381            .unwrap();
382            enc.insert(b"secret", b"data").unwrap();
383        }
384
385        let enc =
386            EncryptedBackend::open(memory as Arc<dyn StorageBackend>, "my-passphrase").unwrap();
387        assert_eq!(enc.get(b"secret").unwrap(), Some(b"data".to_vec()));
388    }
389
390    #[test]
391    fn batch_remove() {
392        let (_mem, enc) = make_encrypted("test-passphrase");
393
394        enc.insert(b"key", b"value").unwrap();
395
396        let mut batch = enc.batch();
397        batch.remove(b"key".to_vec());
398        batch.commit().unwrap();
399
400        assert_eq!(enc.get(b"key").unwrap(), None);
401    }
402}