mdk_sqlite_storage/
encryption.rs1use std::fmt;
8use std::fs::File;
9use std::io::{ErrorKind, Read};
10use std::path::Path;
11
12use mdk_storage_traits::Secret;
13use rusqlite::Connection;
14
15use crate::error::Error;
16
17#[derive(Clone)]
30pub struct EncryptionConfig {
31 key: Secret<[u8; 32]>,
33}
34
35impl EncryptionConfig {
36 #[must_use]
51 pub fn new(key: [u8; 32]) -> Self {
52 Self {
53 key: Secret::new(key),
54 }
55 }
56
57 pub fn from_slice(key: &[u8]) -> Result<Self, Error> {
67 let key: [u8; 32] = key
68 .try_into()
69 .map_err(|_| Error::InvalidKeyLength(key.len()))?;
70 Ok(Self {
71 key: Secret::new(key),
72 })
73 }
74
75 pub fn generate() -> Result<Self, Error> {
84 let mut key = [0u8; 32];
85 getrandom::fill(&mut key).map_err(|e| Error::KeyGeneration(e.to_string()))?;
86 Ok(Self {
87 key: Secret::new(key),
88 })
89 }
90
91 #[must_use]
93 pub fn key(&self) -> &[u8; 32] {
94 &self.key
95 }
96
97 fn to_sqlcipher_key(&self) -> String {
101 format!("x'{}'", hex::encode(self.key.as_ref()))
102 }
103}
104
105impl fmt::Debug for EncryptionConfig {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("EncryptionConfig")
109 .field("key", &"[REDACTED]")
110 .finish()
111 }
112}
113
114pub fn apply_encryption(conn: &Connection, config: &EncryptionConfig) -> Result<(), Error> {
134 let key = config.to_sqlcipher_key();
135
136 conn.execute_batch(&format!("PRAGMA key = \"{key}\";"))?;
138
139 conn.execute_batch("PRAGMA cipher_compatibility = 4;")?;
141
142 conn.execute_batch("PRAGMA temp_store = MEMORY;")?;
144
145 validate_encryption_key(conn)?;
148
149 Ok(())
150}
151
152fn validate_encryption_key(conn: &Connection) -> Result<(), Error> {
157 match conn.query_row("SELECT count(*) FROM sqlite_master;", [], |row| {
158 row.get::<_, i64>(0)
159 }) {
160 Ok(_) => Ok(()),
161 Err(rusqlite::Error::SqliteFailure(err, _))
162 if err.code == rusqlite::ffi::ErrorCode::NotADatabase =>
163 {
164 Err(Error::WrongEncryptionKey)
166 }
167 Err(e) => Err(e.into()),
168 }
169}
170
171pub fn is_database_encrypted<P>(path: P) -> Result<bool, Error>
188where
189 P: AsRef<Path>,
190{
191 let path = path.as_ref();
192
193 if !path.exists() {
194 return Ok(false);
196 }
197
198 let mut file = File::open(path)?;
199 let mut header = [0u8; 16];
200
201 match file.read_exact(&mut header) {
202 Ok(()) => {
203 const SQLITE_HEADER: &[u8; 16] = b"SQLite format 3\0";
205 Ok(header != *SQLITE_HEADER)
206 }
207 Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
208 Ok(false)
210 }
211 Err(e) => Err(e.into()),
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_encryption_config_new() {
221 let key = [0x42u8; 32];
222 let config = EncryptionConfig::new(key);
223 assert_eq!(config.key(), &key);
224 }
225
226 #[test]
227 fn test_encryption_config_from_slice() {
228 let key = vec![0x42u8; 32];
229 let config = EncryptionConfig::from_slice(&key).unwrap();
230 assert_eq!(config.key(), key.as_slice());
231 }
232
233 #[test]
234 fn test_encryption_config_from_slice_invalid_length() {
235 let short_key = vec![0x42u8; 16];
236 let result = EncryptionConfig::from_slice(&short_key);
237 assert!(matches!(result, Err(Error::InvalidKeyLength(16))));
238
239 let long_key = vec![0x42u8; 64];
240 let result = EncryptionConfig::from_slice(&long_key);
241 assert!(matches!(result, Err(Error::InvalidKeyLength(64))));
242 }
243
244 #[test]
245 fn test_encryption_config_from_slice_empty() {
246 let empty_key: Vec<u8> = vec![];
247 let result = EncryptionConfig::from_slice(&empty_key);
248 assert!(matches!(result, Err(Error::InvalidKeyLength(0))));
249 }
250
251 #[test]
252 fn test_encryption_config_debug_redacts_key() {
253 let key = [0x42u8; 32];
254 let config = EncryptionConfig::new(key);
255 let debug_str = format!("{:?}", config);
256 assert!(debug_str.contains("REDACTED"));
257 assert!(!debug_str.contains("42"));
258 }
259
260 #[test]
261 fn test_encryption_config_generate() {
262 let config1 = EncryptionConfig::generate().unwrap();
263 let config2 = EncryptionConfig::generate().unwrap();
264
265 assert_ne!(config1.key(), config2.key());
267
268 assert_eq!(config1.key().len(), 32);
270 }
271
272 #[test]
273 fn test_encryption_config_clone() {
274 let key = [0x42u8; 32];
275 let config1 = EncryptionConfig::new(key);
276 let config2 = config1.clone();
277 assert_eq!(config1.key(), config2.key());
278 }
279
280 #[test]
281 fn test_to_sqlcipher_key_format() {
282 let key = [0x00u8; 32];
283 let config = EncryptionConfig::new(key);
284 let sqlcipher_key = config.to_sqlcipher_key();
285
286 assert!(sqlcipher_key.starts_with("x'"));
288 assert!(sqlcipher_key.ends_with('\''));
289 assert_eq!(sqlcipher_key.len(), 2 + 64 + 1); }
291
292 #[test]
293 fn test_to_sqlcipher_key_format_nonzero() {
294 let mut key = [0u8; 32];
296 key[0] = 0xAB;
297 key[31] = 0xCD;
298 let config = EncryptionConfig::new(key);
299 let sqlcipher_key = config.to_sqlcipher_key();
300
301 assert!(sqlcipher_key.starts_with("x'ab"));
303 assert!(sqlcipher_key.ends_with("cd'"));
304 }
305
306 #[test]
307 fn test_is_database_encrypted_nonexistent() {
308 let result = is_database_encrypted("/nonexistent/path/db.sqlite");
309 assert!(matches!(result, Ok(false)));
310 }
311
312 #[test]
313 fn test_is_database_encrypted_empty_file() {
314 let temp_dir = tempfile::tempdir().unwrap();
315 let db_path = temp_dir.path().join("empty.db");
316
317 std::fs::File::create(&db_path).unwrap();
319
320 let result = is_database_encrypted(&db_path);
322 assert!(matches!(result, Ok(false)));
323 }
324
325 #[test]
326 fn test_is_database_encrypted_small_file() {
327 let temp_dir = tempfile::tempdir().unwrap();
328 let db_path = temp_dir.path().join("small.db");
329
330 std::fs::write(&db_path, b"too small").unwrap();
332
333 let result = is_database_encrypted(&db_path);
335 assert!(matches!(result, Ok(false)));
336 }
337
338 #[test]
339 fn test_is_database_encrypted_unencrypted_sqlite() {
340 let temp_dir = tempfile::tempdir().unwrap();
341 let db_path = temp_dir.path().join("plain.db");
342
343 let conn = Connection::open(&db_path).unwrap();
345 conn.execute_batch("CREATE TABLE test (id INTEGER);")
346 .unwrap();
347 drop(conn);
348
349 let result = is_database_encrypted(&db_path);
351 assert!(matches!(result, Ok(false)));
352 }
353
354 #[test]
355 fn test_is_database_encrypted_encrypted_sqlite() {
356 let temp_dir = tempfile::tempdir().unwrap();
357 let db_path = temp_dir.path().join("encrypted.db");
358
359 let config = EncryptionConfig::generate().unwrap();
361 let conn = Connection::open(&db_path).unwrap();
362 apply_encryption(&conn, &config).unwrap();
363 conn.execute_batch("CREATE TABLE test (id INTEGER);")
364 .unwrap();
365 drop(conn);
366
367 let result = is_database_encrypted(&db_path);
369 assert!(matches!(result, Ok(true)));
370 }
371
372 #[test]
373 fn test_apply_encryption_new_database() {
374 let temp_dir = tempfile::tempdir().unwrap();
375 let db_path = temp_dir.path().join("new_encrypted.db");
376
377 let config = EncryptionConfig::generate().unwrap();
378 let conn = Connection::open(&db_path).unwrap();
379
380 let result = apply_encryption(&conn, &config);
382 assert!(result.is_ok());
383
384 conn.execute_batch("CREATE TABLE test (id INTEGER);")
386 .unwrap();
387 conn.execute("INSERT INTO test VALUES (42)", []).unwrap();
388
389 let count: i64 = conn
390 .query_row("SELECT COUNT(*) FROM test", [], |row| row.get(0))
391 .unwrap();
392 assert_eq!(count, 1);
393 }
394
395 #[test]
396 fn test_apply_encryption_reopen_correct_key() {
397 let temp_dir = tempfile::tempdir().unwrap();
398 let db_path = temp_dir.path().join("reopen.db");
399
400 let config = EncryptionConfig::generate().unwrap();
401 let key = *config.key();
402
403 {
405 let conn = Connection::open(&db_path).unwrap();
406 apply_encryption(&conn, &config).unwrap();
407 conn.execute_batch("CREATE TABLE test (id INTEGER);")
408 .unwrap();
409 conn.execute("INSERT INTO test VALUES (123)", []).unwrap();
410 }
411
412 let config2 = EncryptionConfig::new(key);
414 let conn2 = Connection::open(&db_path).unwrap();
415 let result = apply_encryption(&conn2, &config2);
416 assert!(result.is_ok());
417
418 let value: i64 = conn2
420 .query_row("SELECT id FROM test", [], |row| row.get(0))
421 .unwrap();
422 assert_eq!(value, 123);
423 }
424
425 #[test]
426 fn test_apply_encryption_wrong_key() {
427 let temp_dir = tempfile::tempdir().unwrap();
428 let db_path = temp_dir.path().join("wrong_key.db");
429
430 let config1 = EncryptionConfig::generate().unwrap();
432 {
433 let conn = Connection::open(&db_path).unwrap();
434 apply_encryption(&conn, &config1).unwrap();
435 conn.execute_batch("CREATE TABLE test (id INTEGER);")
436 .unwrap();
437 }
438
439 let config2 = EncryptionConfig::generate().unwrap();
441 let conn2 = Connection::open(&db_path).unwrap();
442 let result = apply_encryption(&conn2, &config2);
443
444 assert!(result.is_err());
445 assert!(matches!(result, Err(Error::WrongEncryptionKey)));
446 }
447
448 #[test]
449 fn test_apply_encryption_on_plain_database_fails() {
450 let temp_dir = tempfile::tempdir().unwrap();
451 let db_path = temp_dir.path().join("plain_then_encrypt.db");
452
453 {
455 let conn = Connection::open(&db_path).unwrap();
456 conn.execute_batch("CREATE TABLE test (id INTEGER);")
457 .unwrap();
458 conn.execute("INSERT INTO test VALUES (1)", []).unwrap();
459 }
460
461 let config = EncryptionConfig::generate().unwrap();
463 let conn2 = Connection::open(&db_path).unwrap();
464 let result = apply_encryption(&conn2, &config);
465
466 assert!(result.is_err());
468 }
469
470 #[test]
471 fn test_validate_encryption_key_success() {
472 let temp_dir = tempfile::tempdir().unwrap();
473 let db_path = temp_dir.path().join("validate.db");
474
475 let config = EncryptionConfig::generate().unwrap();
476 let conn = Connection::open(&db_path).unwrap();
477 apply_encryption(&conn, &config).unwrap();
478
479 conn.execute_batch("CREATE TABLE test (id INTEGER);")
481 .unwrap();
482
483 let result = validate_encryption_key(&conn);
485 assert!(result.is_ok());
486 }
487
488 #[test]
489 fn test_encryption_persists_across_connections() {
490 let temp_dir = tempfile::tempdir().unwrap();
491 let db_path = temp_dir.path().join("persist.db");
492
493 let config = EncryptionConfig::generate().unwrap();
494 let key = *config.key();
495
496 {
498 let conn = Connection::open(&db_path).unwrap();
499 apply_encryption(&conn, &config).unwrap();
500 conn.execute_batch("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT);")
501 .unwrap();
502 conn.execute("INSERT INTO users (name) VALUES ('Alice')", [])
503 .unwrap();
504 }
505
506 {
508 let config2 = EncryptionConfig::new(key);
509 let conn = Connection::open(&db_path).unwrap();
510 apply_encryption(&conn, &config2).unwrap();
511 conn.execute("INSERT INTO users (name) VALUES ('Bob')", [])
512 .unwrap();
513 }
514
515 let config3 = EncryptionConfig::new(key);
517 let conn = Connection::open(&db_path).unwrap();
518 apply_encryption(&conn, &config3).unwrap();
519
520 let count: i64 = conn
521 .query_row("SELECT COUNT(*) FROM users", [], |row| row.get(0))
522 .unwrap();
523 assert_eq!(count, 2);
524
525 let names: Vec<String> = conn
527 .prepare("SELECT name FROM users ORDER BY id")
528 .unwrap()
529 .query_map([], |row| row.get(0))
530 .unwrap()
531 .collect::<Result<Vec<_>, _>>()
532 .unwrap();
533 assert_eq!(names, vec!["Alice", "Bob"]);
534 }
535
536 #[test]
537 fn test_encrypted_database_binary_data() {
538 let temp_dir = tempfile::tempdir().unwrap();
539 let db_path = temp_dir.path().join("binary.db");
540
541 let config = EncryptionConfig::generate().unwrap();
542 let key = *config.key();
543
544 let binary_data: Vec<u8> = (0..=255).collect();
546
547 {
548 let conn = Connection::open(&db_path).unwrap();
549 apply_encryption(&conn, &config).unwrap();
550 conn.execute_batch("CREATE TABLE blobs (data BLOB);")
551 .unwrap();
552 conn.execute("INSERT INTO blobs VALUES (?)", [&binary_data])
553 .unwrap();
554 }
555
556 let config2 = EncryptionConfig::new(key);
558 let conn = Connection::open(&db_path).unwrap();
559 apply_encryption(&conn, &config2).unwrap();
560
561 let retrieved: Vec<u8> = conn
562 .query_row("SELECT data FROM blobs", [], |row| row.get(0))
563 .unwrap();
564 assert_eq!(retrieved, binary_data);
565 }
566
567 #[test]
568 fn test_apply_encryption_on_corrupted_database() {
569 let temp_dir = tempfile::tempdir().unwrap();
570 let db_path = temp_dir.path().join("corrupted.db");
571
572 std::fs::write(&db_path, b"corrupted database content").unwrap();
574
575 let config = EncryptionConfig::generate().unwrap();
576 let conn = Connection::open(&db_path).unwrap();
577 let result = apply_encryption(&conn, &config);
578
579 assert!(result.is_err());
581 }
582
583 #[test]
584 fn test_is_database_encrypted_with_partial_write() {
585 let temp_dir = tempfile::tempdir().unwrap();
586 let db_path = temp_dir.path().join("partial.db");
587
588 std::fs::write(&db_path, b"partial").unwrap();
590
591 let result = is_database_encrypted(&db_path).unwrap();
593 assert!(!result);
594 }
595
596 #[test]
597 fn test_encryption_config_generate_produces_unique_keys() {
598 let keys: Vec<_> = (0..100)
600 .map(|_| *EncryptionConfig::generate().unwrap().key())
601 .collect();
602
603 for i in 0..keys.len() {
605 for j in (i + 1)..keys.len() {
606 assert_ne!(
607 keys[i], keys[j],
608 "Generated keys should be unique (with overwhelming probability)"
609 );
610 }
611 }
612 }
613}