1use super::backend::{BatchOperations, StorageBackend};
5use crate::error::{Error, Result};
6use ring::aead::{AES_256_GCM, Aad, LessSafeKey, Nonce, UnboundKey};
7use ring::pbkdf2;
8use ring::rand::{SecureRandom, SystemRandom};
9use std::num::NonZeroU32;
10use std::sync::Arc;
11
12const NONCE_LEN: usize = 12;
13const KEY_LEN: usize = 32;
14const SALT_LEN: usize = 32;
15const TAG_LEN: usize = 16;
16const PBKDF2_ITERATIONS: u32 = 600_000;
17
18const SALT_KEY: &[u8] = b"_crypto/salt";
19const CHECK_KEY: &[u8] = b"_crypto/check";
20const CHECK_PLAINTEXT: &[u8] = b"mqdb";
21
22pub struct EncryptedBackend {
23 inner: Arc<dyn StorageBackend>,
24 key: Arc<LessSafeKey>,
25}
26
27impl EncryptedBackend {
28 pub fn open(inner: Arc<dyn StorageBackend>, passphrase: &str) -> Result<Self> {
31 let existing_salt = inner.get(SALT_KEY)?;
32
33 let salt = if let Some(s) = existing_salt {
34 if s.len() != SALT_LEN {
35 return Err(Error::Internal("corrupt encryption salt".into()));
36 }
37 let mut arr = [0u8; SALT_LEN];
38 arr.copy_from_slice(&s);
39 arr
40 } else {
41 let new_salt = generate_salt()?;
42 inner.insert(SALT_KEY, &new_salt)?;
43 new_salt
44 };
45
46 let key = Arc::new(derive_key(passphrase, &salt)?);
47
48 let existing_check = inner.get(CHECK_KEY)?;
49 if let Some(encrypted_check) = existing_check {
50 decrypt(&key, CHECK_KEY, &encrypted_check)
51 .map_err(|_| Error::Internal("invalid passphrase".into()))?;
52 } else {
53 let encrypted = encrypt(&key, CHECK_KEY, CHECK_PLAINTEXT)?;
54 inner.insert(CHECK_KEY, &encrypted)?;
55 }
56
57 Ok(Self { inner, key })
58 }
59}
60
61impl StorageBackend for EncryptedBackend {
62 fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
63 match self.inner.get(key)? {
64 Some(encrypted) => Ok(Some(decrypt(&self.key, key, &encrypted)?)),
65 None => Ok(None),
66 }
67 }
68
69 fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
70 let encrypted = encrypt(&self.key, key, value)?;
71 self.inner.insert(key, &encrypted)
72 }
73
74 fn remove(&self, key: &[u8]) -> Result<()> {
75 self.inner.remove(key)
76 }
77
78 fn prefix_scan(&self, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
79 let raw = self.inner.prefix_scan(prefix)?;
80 decrypt_pairs(&self.key, raw)
81 }
82
83 fn prefix_count(&self, prefix: &[u8]) -> Result<usize> {
84 self.inner.prefix_count(prefix)
85 }
86
87 fn prefix_scan_keys(&self, prefix: &[u8]) -> Result<Vec<Vec<u8>>> {
88 self.inner.prefix_scan_keys(prefix)
89 }
90
91 fn prefix_scan_batch(
92 &self,
93 prefix: &[u8],
94 batch_size: usize,
95 after_key: Option<&[u8]>,
96 ) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
97 let raw = self
98 .inner
99 .prefix_scan_batch(prefix, batch_size, after_key)?;
100 decrypt_pairs(&self.key, raw)
101 }
102
103 fn range_scan(&self, start: &[u8], end: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
104 let raw = self.inner.range_scan(start, end)?;
105 decrypt_pairs(&self.key, raw)
106 }
107
108 fn batch(&self) -> Box<dyn BatchOperations> {
109 Box::new(EncryptedBatch {
110 inner: self.inner.batch(),
111 backend: Arc::clone(&self.inner),
112 key: Arc::clone(&self.key),
113 pending_expects: Vec::new(),
114 })
115 }
116
117 fn flush(&self) -> Result<()> {
118 self.inner.flush()
119 }
120}
121
122struct EncryptedBatch {
123 inner: Box<dyn BatchOperations>,
124 backend: Arc<dyn StorageBackend>,
125 key: Arc<LessSafeKey>,
126 pending_expects: Vec<(Vec<u8>, Vec<u8>)>,
127}
128
129impl BatchOperations for EncryptedBatch {
130 fn insert(&mut self, key: Vec<u8>, value: Vec<u8>) {
131 match encrypt(&self.key, &key, &value) {
132 Ok(encrypted) => self.inner.insert(key, encrypted),
133 Err(_) => self.inner.insert(key, value),
134 }
135 }
136
137 fn remove(&mut self, key: Vec<u8>) {
138 self.inner.remove(key);
139 }
140
141 fn expect_value(&mut self, key: Vec<u8>, expected_value: Vec<u8>) {
142 self.pending_expects.push((key, expected_value));
143 }
144
145 fn commit(mut self: Box<Self>) -> Result<()> {
146 for (key, expected_plaintext) in &self.pending_expects {
147 let stored = self.backend.get(key)?;
148 match stored {
149 Some(encrypted) => {
150 let decrypted = decrypt(&self.key, key, &encrypted)?;
151 if decrypted != *expected_plaintext {
152 return Err(Error::Conflict(
153 "optimistic lock failed: value was modified".into(),
154 ));
155 }
156 self.inner.expect_value(key.clone(), encrypted);
157 }
158 None => {
159 return Err(Error::Conflict(
160 "optimistic lock failed: value was modified".into(),
161 ));
162 }
163 }
164 }
165 self.inner.commit()
166 }
167}
168
169fn generate_salt() -> Result<[u8; SALT_LEN]> {
170 let rng = SystemRandom::new();
171 let mut salt = [0u8; SALT_LEN];
172 rng.fill(&mut salt)
173 .map_err(|_| Error::Internal("random generation failed".into()))?;
174 Ok(salt)
175}
176
177fn derive_key(passphrase: &str, salt: &[u8]) -> Result<LessSafeKey> {
178 let iterations =
179 NonZeroU32::new(PBKDF2_ITERATIONS).expect("PBKDF2_ITERATIONS is non-zero constant");
180 let mut key_bytes = [0u8; KEY_LEN];
181 pbkdf2::derive(
182 pbkdf2::PBKDF2_HMAC_SHA256,
183 iterations,
184 salt,
185 passphrase.as_bytes(),
186 &mut key_bytes,
187 );
188 let unbound = UnboundKey::new(&AES_256_GCM, &key_bytes)
189 .map_err(|_| Error::Internal("key construction failed".into()))?;
190 Ok(LessSafeKey::new(unbound))
191}
192
193fn encrypt(key: &LessSafeKey, storage_key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> {
194 let rng = SystemRandom::new();
195 let mut nonce_bytes = [0u8; NONCE_LEN];
196 rng.fill(&mut nonce_bytes)
197 .map_err(|_| Error::Internal("nonce generation failed".into()))?;
198 let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)
199 .map_err(|_| Error::Internal("nonce construction failed".into()))?;
200 let aad = Aad::from(storage_key);
201
202 let mut in_out = Vec::with_capacity(plaintext.len() + TAG_LEN);
203 in_out.extend_from_slice(plaintext);
204 key.seal_in_place_append_tag(nonce, aad, &mut in_out)
205 .map_err(|_| Error::Internal("encryption failed".into()))?;
206
207 let mut output = Vec::with_capacity(NONCE_LEN + in_out.len());
208 output.extend_from_slice(&nonce_bytes);
209 output.extend_from_slice(&in_out);
210 Ok(output)
211}
212
213fn decrypt(key: &LessSafeKey, storage_key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
214 if data.len() < NONCE_LEN + TAG_LEN + 1 {
215 return Err(Error::Internal("ciphertext too short".into()));
216 }
217 let (nonce_bytes, ciphertext) = data.split_at(NONCE_LEN);
218 let nonce = Nonce::try_assume_unique_for_key(nonce_bytes)
219 .map_err(|_| Error::Internal("nonce construction failed".into()))?;
220 let aad = Aad::from(storage_key);
221
222 let mut in_out = ciphertext.to_vec();
223 let plaintext = key
224 .open_in_place(nonce, aad, &mut in_out)
225 .map_err(|_| Error::Internal("decryption failed".into()))?;
226 Ok(plaintext.to_vec())
227}
228
229fn decrypt_pairs(
230 key: &LessSafeKey,
231 pairs: Vec<(Vec<u8>, Vec<u8>)>,
232) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
233 let mut decrypted = Vec::with_capacity(pairs.len());
234 for (k, v) in pairs {
235 let plaintext = decrypt(key, &k, &v)?;
236 decrypted.push((k, plaintext));
237 }
238 Ok(decrypted)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::storage::MemoryBackend;
245
246 fn make_encrypted(passphrase: &str) -> (Arc<MemoryBackend>, EncryptedBackend) {
247 let memory = Arc::new(MemoryBackend::new());
248 let encrypted =
249 EncryptedBackend::open(Arc::clone(&memory) as Arc<dyn StorageBackend>, passphrase)
250 .unwrap();
251 (memory, encrypted)
252 }
253
254 #[test]
255 fn roundtrip() {
256 let (_mem, enc) = make_encrypted("test-passphrase");
257
258 enc.insert(b"users/1", b"alice").unwrap();
259 let val = enc.get(b"users/1").unwrap().unwrap();
260 assert_eq!(val, b"alice");
261 }
262
263 #[test]
264 fn stored_values_are_encrypted() {
265 let (mem, enc) = make_encrypted("test-passphrase");
266
267 enc.insert(b"users/1", b"alice").unwrap();
268 let raw = mem.get(b"users/1").unwrap().unwrap();
269 assert_ne!(raw, b"alice");
270 assert!(raw.len() > b"alice".len());
271 }
272
273 #[test]
274 fn get_missing_key() {
275 let (_mem, enc) = make_encrypted("test-passphrase");
276 assert_eq!(enc.get(b"nonexistent").unwrap(), None);
277 }
278
279 #[test]
280 fn remove_key() {
281 let (_mem, enc) = make_encrypted("test-passphrase");
282
283 enc.insert(b"key", b"value").unwrap();
284 enc.remove(b"key").unwrap();
285 assert_eq!(enc.get(b"key").unwrap(), None);
286 }
287
288 #[test]
289 fn prefix_scan_decrypts() {
290 let (_mem, enc) = make_encrypted("test-passphrase");
291
292 enc.insert(b"users/1", b"alice").unwrap();
293 enc.insert(b"users/2", b"bob").unwrap();
294 enc.insert(b"posts/1", b"hello").unwrap();
295
296 let results = enc.prefix_scan(b"users/").unwrap();
297 assert_eq!(results.len(), 2);
298 assert_eq!(results[0], (b"users/1".to_vec(), b"alice".to_vec()));
299 assert_eq!(results[1], (b"users/2".to_vec(), b"bob".to_vec()));
300 }
301
302 #[test]
303 fn range_scan_decrypts() {
304 let (_mem, enc) = make_encrypted("test-passphrase");
305
306 enc.insert(b"a", b"1").unwrap();
307 enc.insert(b"b", b"2").unwrap();
308 enc.insert(b"c", b"3").unwrap();
309 enc.insert(b"d", b"4").unwrap();
310
311 let results = enc.range_scan(b"b", b"d").unwrap();
312 assert_eq!(results.len(), 2);
313 assert_eq!(results[0], (b"b".to_vec(), b"2".to_vec()));
314 assert_eq!(results[1], (b"c".to_vec(), b"3".to_vec()));
315 }
316
317 #[test]
318 fn batch_insert_and_get() {
319 let (_mem, enc) = make_encrypted("test-passphrase");
320
321 let mut batch = enc.batch();
322 batch.insert(b"k1".to_vec(), b"v1".to_vec());
323 batch.insert(b"k2".to_vec(), b"v2".to_vec());
324 batch.commit().unwrap();
325
326 assert_eq!(enc.get(b"k1").unwrap(), Some(b"v1".to_vec()));
327 assert_eq!(enc.get(b"k2").unwrap(), Some(b"v2".to_vec()));
328 }
329
330 #[test]
331 fn batch_expect_value_success() {
332 let (_mem, enc) = make_encrypted("test-passphrase");
333
334 enc.insert(b"key", b"original").unwrap();
335
336 let mut batch = enc.batch();
337 batch.expect_value(b"key".to_vec(), b"original".to_vec());
338 batch.insert(b"key".to_vec(), b"updated".to_vec());
339 batch.commit().unwrap();
340
341 assert_eq!(enc.get(b"key").unwrap(), Some(b"updated".to_vec()));
342 }
343
344 #[test]
345 fn batch_expect_value_failure() {
346 let (_mem, enc) = make_encrypted("test-passphrase");
347
348 enc.insert(b"key", b"original").unwrap();
349
350 let mut batch = enc.batch();
351 batch.expect_value(b"key".to_vec(), b"wrong".to_vec());
352 batch.insert(b"key".to_vec(), b"updated".to_vec());
353
354 let result = batch.commit();
355 assert!(result.is_err());
356 assert_eq!(enc.get(b"key").unwrap(), Some(b"original".to_vec()));
357 }
358
359 #[test]
360 fn wrong_passphrase_rejected() {
361 let memory = Arc::new(MemoryBackend::new());
362 let _enc = EncryptedBackend::open(
363 Arc::clone(&memory) as Arc<dyn StorageBackend>,
364 "correct-passphrase",
365 )
366 .unwrap();
367
368 let result = EncryptedBackend::open(memory as Arc<dyn StorageBackend>, "wrong-passphrase");
369 assert!(result.is_err());
370 }
371
372 #[test]
373 fn correct_passphrase_reopens() {
374 let memory = Arc::new(MemoryBackend::new());
375
376 {
377 let enc = EncryptedBackend::open(
378 Arc::clone(&memory) as Arc<dyn StorageBackend>,
379 "my-passphrase",
380 )
381 .unwrap();
382 enc.insert(b"secret", b"data").unwrap();
383 }
384
385 let enc =
386 EncryptedBackend::open(memory as Arc<dyn StorageBackend>, "my-passphrase").unwrap();
387 assert_eq!(enc.get(b"secret").unwrap(), Some(b"data".to_vec()));
388 }
389
390 #[test]
391 fn batch_remove() {
392 let (_mem, enc) = make_encrypted("test-passphrase");
393
394 enc.insert(b"key", b"value").unwrap();
395
396 let mut batch = enc.batch();
397 batch.remove(b"key".to_vec());
398 batch.commit().unwrap();
399
400 assert_eq!(enc.get(b"key").unwrap(), None);
401 }
402}