1#![forbid(unsafe_code)]
65#![warn(missing_docs)]
66#![warn(rustdoc::bare_urls)]
67
68use std::path::Path;
69use std::sync::Arc;
70
71use mdk_storage_traits::{Backend, GroupId, MdkStorageError, MdkStorageProvider};
72use openmls_traits::storage::{StorageProvider, traits};
73use rusqlite::Connection;
74use std::sync::Mutex;
75
76mod db;
77pub mod encryption;
78pub mod error;
79mod groups;
80pub mod keyring;
81mod messages;
82mod migrations;
83mod mls_storage;
84mod permissions;
85#[cfg(test)]
86mod test_utils;
87mod validation;
88mod welcomes;
89
90pub use self::encryption::EncryptionConfig;
91use self::error::Error;
92use self::mls_storage::{GroupDataType, STORAGE_PROVIDER_VERSION};
93pub use self::permissions::verify_permissions;
94use self::permissions::{
95 FileCreationOutcome, precreate_secure_database_file, set_secure_file_permissions,
96};
97
98pub struct MdkSqliteStorage {
129 connection: Arc<Mutex<Connection>>,
131}
132
133impl MdkSqliteStorage {
134 pub fn new<P>(file_path: P, service_id: &str, db_key_id: &str) -> Result<Self, Error>
183 where
184 P: AsRef<Path>,
185 {
186 let file_path = file_path.as_ref();
187
188 let creation_outcome = precreate_secure_database_file(file_path)?;
192
193 let config = match creation_outcome {
194 FileCreationOutcome::Created | FileCreationOutcome::Skipped => {
195 keyring::get_or_create_db_key(service_id, db_key_id)?
198 }
199 FileCreationOutcome::AlreadyExisted => {
200 match keyring::get_db_key(service_id, db_key_id)? {
209 Some(config) => {
210 config
213 }
214 None => {
215 if !encryption::is_database_encrypted(file_path)? {
219 return Err(Error::UnencryptedDatabaseWithEncryption);
220 }
221
222 return Err(Error::KeyringEntryMissingForExistingDatabase {
224 db_path: file_path.display().to_string(),
225 service_id: service_id.to_string(),
226 db_key_id: db_key_id.to_string(),
227 });
228 }
229 }
230 }
231 };
232
233 Self::new_internal_skip_precreate(file_path, Some(config))
234 }
235
236 pub fn new_with_key<P>(file_path: P, config: EncryptionConfig) -> Result<Self, Error>
267 where
268 P: AsRef<Path>,
269 {
270 let file_path = file_path.as_ref();
271
272 if file_path.exists() && !encryption::is_database_encrypted(file_path)? {
275 return Err(Error::UnencryptedDatabaseWithEncryption);
276 }
277
278 Self::new_internal(file_path, Some(config))
279 }
280
281 pub fn new_unencrypted<P>(file_path: P) -> Result<Self, Error>
306 where
307 P: AsRef<Path>,
308 {
309 tracing::warn!(
310 "Creating unencrypted database. Sensitive MLS state will be stored in plaintext. \
311 For production use, use new() or new_with_key() instead."
312 );
313 Self::new_internal(file_path, None)
314 }
315
316 fn new_internal<P>(
320 file_path: P,
321 encryption_config: Option<EncryptionConfig>,
322 ) -> Result<Self, Error>
323 where
324 P: AsRef<Path>,
325 {
326 let file_path = file_path.as_ref();
327
328 precreate_secure_database_file(file_path)?;
330
331 Self::new_internal_skip_precreate(file_path, encryption_config)
332 }
333
334 fn new_internal_skip_precreate(
339 file_path: &Path,
340 encryption_config: Option<EncryptionConfig>,
341 ) -> Result<Self, Error> {
342 let mut connection = Self::open_connection(file_path, encryption_config.as_ref())?;
344
345 migrations::run_migrations(&mut connection)?;
347
348 Self::apply_secure_permissions(file_path)?;
350
351 Ok(Self {
352 connection: Arc::new(Mutex::new(connection)),
353 })
354 }
355
356 fn open_connection(
358 file_path: &Path,
359 encryption_config: Option<&EncryptionConfig>,
360 ) -> Result<Connection, Error> {
361 let conn = Connection::open(file_path)?;
362
363 if let Some(config) = encryption_config {
365 encryption::apply_encryption(&conn, config)?;
366 }
367
368 conn.execute_batch("PRAGMA foreign_keys = ON;")?;
370
371 Ok(conn)
372 }
373
374 fn apply_secure_permissions(db_path: &Path) -> Result<(), Error> {
400 let path_str = db_path.to_string_lossy();
402 if path_str.is_empty() || path_str == ":memory:" || path_str.starts_with(':') {
403 return Ok(());
404 }
405
406 set_secure_file_permissions(db_path)?;
408
409 let parent = db_path.parent();
413 let stem = db_path.file_name().and_then(|n| n.to_str());
414
415 if let (Some(parent), Some(stem)) = (parent, stem) {
416 for suffix in &["-wal", "-shm", "-journal"] {
417 let sidecar = parent.join(format!("{}{}", stem, suffix));
418 if sidecar.exists() {
419 set_secure_file_permissions(&sidecar)?;
420 }
421 }
422 }
423
424 Ok(())
425 }
426
427 #[cfg(test)]
435 pub fn new_in_memory() -> Result<Self, Error> {
436 let mut connection = Connection::open_in_memory()?;
438
439 connection.execute_batch("PRAGMA foreign_keys = ON;")?;
441
442 migrations::run_migrations(&mut connection)?;
444
445 Ok(Self {
446 connection: Arc::new(Mutex::new(connection)),
447 })
448 }
449
450 pub(crate) fn with_connection<F, T>(&self, f: F) -> T
454 where
455 F: FnOnce(&Connection) -> T,
456 {
457 let conn = self.connection.lock().unwrap();
458 f(&conn)
459 }
460
461 fn snapshot_group_state(&self, group_id: &GroupId, name: &str) -> Result<(), Error> {
464 let conn = self.connection.lock().unwrap();
465 let group_id_bytes = group_id.as_slice();
466 let mls_group_id_bytes = mls_storage::MlsCodec::serialize(group_id)
469 .map_err(|e| Error::Database(e.to_string()))?;
470 let now = std::time::SystemTime::now()
471 .duration_since(std::time::UNIX_EPOCH)
472 .map_err(|e| Error::Database(format!("Time error: {}", e)))?
473 .as_secs() as i64;
474
475 conn.execute("BEGIN IMMEDIATE", [])
477 .map_err(|e| Error::Database(e.to_string()))?;
478
479 let result = (|| -> Result<(), Error> {
480 let mut insert_stmt = conn
482 .prepare_cached(
483 "INSERT INTO group_state_snapshots
484 (snapshot_name, group_id, table_name, row_key, row_data, created_at)
485 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
486 )
487 .map_err(|e| Error::Database(e.to_string()))?;
488
489 Self::snapshot_openmls_group_data(
492 &conn,
493 &mut insert_stmt,
494 name,
495 group_id_bytes,
496 &mls_group_id_bytes,
497 now,
498 )?;
499 Self::snapshot_openmls_proposals(
500 &conn,
501 &mut insert_stmt,
502 name,
503 group_id_bytes,
504 &mls_group_id_bytes,
505 now,
506 )?;
507 Self::snapshot_openmls_own_leaf_nodes(
508 &conn,
509 &mut insert_stmt,
510 name,
511 group_id_bytes,
512 &mls_group_id_bytes,
513 now,
514 )?;
515 Self::snapshot_openmls_epoch_key_pairs(
516 &conn,
517 &mut insert_stmt,
518 name,
519 group_id_bytes,
520 &mls_group_id_bytes,
521 now,
522 )?;
523 Self::snapshot_groups_table(&conn, &mut insert_stmt, name, group_id_bytes, now)?;
525 Self::snapshot_group_relays(&conn, &mut insert_stmt, name, group_id_bytes, now)?;
526 Self::snapshot_group_exporter_secrets(
527 &conn,
528 &mut insert_stmt,
529 name,
530 group_id_bytes,
531 now,
532 )?;
533
534 Ok(())
535 })();
536
537 match result {
538 Ok(()) => {
539 conn.execute("COMMIT", [])
540 .map_err(|e| Error::Database(e.to_string()))?;
541 Ok(())
542 }
543 Err(e) => {
544 let _ = conn.execute("ROLLBACK", []);
545 Err(e)
546 }
547 }
548 }
549
550 fn snapshot_openmls_group_data(
552 conn: &rusqlite::Connection,
553 insert_stmt: &mut rusqlite::CachedStatement<'_>,
554 snapshot_name: &str,
555 group_id_bytes: &[u8],
556 mls_group_id_bytes: &[u8],
557 now: i64,
558 ) -> Result<(), Error> {
559 let mut stmt = conn
560 .prepare(
561 "SELECT group_id, data_type, group_data FROM openmls_group_data WHERE group_id = ?",
562 )
563 .map_err(|e| Error::Database(e.to_string()))?;
564 let mut rows = stmt
565 .query([mls_group_id_bytes])
566 .map_err(|e| Error::Database(e.to_string()))?;
567
568 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
569 let gid: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
570 let data_type: String = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
571 let data: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
572 let row_key = serde_json::to_vec(&(&gid, &data_type))
573 .map_err(|e| Error::Database(e.to_string()))?;
574 insert_stmt
575 .execute(rusqlite::params![
576 snapshot_name,
577 group_id_bytes,
578 "openmls_group_data",
579 row_key,
580 data,
581 now
582 ])
583 .map_err(|e| Error::Database(e.to_string()))?;
584 }
585 Ok(())
586 }
587
588 fn snapshot_openmls_proposals(
590 conn: &rusqlite::Connection,
591 insert_stmt: &mut rusqlite::CachedStatement<'_>,
592 snapshot_name: &str,
593 group_id_bytes: &[u8],
594 mls_group_id_bytes: &[u8],
595 now: i64,
596 ) -> Result<(), Error> {
597 let mut stmt = conn
598 .prepare(
599 "SELECT group_id, proposal_ref, proposal FROM openmls_proposals WHERE group_id = ?",
600 )
601 .map_err(|e| Error::Database(e.to_string()))?;
602 let mut rows = stmt
603 .query([mls_group_id_bytes])
604 .map_err(|e| Error::Database(e.to_string()))?;
605
606 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
607 let gid: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
608 let proposal_ref: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
609 let proposal: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
610 let row_key = serde_json::to_vec(&(&gid, &proposal_ref))
611 .map_err(|e| Error::Database(e.to_string()))?;
612 insert_stmt
613 .execute(rusqlite::params![
614 snapshot_name,
615 group_id_bytes,
616 "openmls_proposals",
617 row_key,
618 proposal,
619 now
620 ])
621 .map_err(|e| Error::Database(e.to_string()))?;
622 }
623 Ok(())
624 }
625
626 fn snapshot_openmls_own_leaf_nodes(
628 conn: &rusqlite::Connection,
629 insert_stmt: &mut rusqlite::CachedStatement<'_>,
630 snapshot_name: &str,
631 group_id_bytes: &[u8],
632 mls_group_id_bytes: &[u8],
633 now: i64,
634 ) -> Result<(), Error> {
635 let mut stmt = conn
636 .prepare(
637 "SELECT id, group_id, leaf_node FROM openmls_own_leaf_nodes WHERE group_id = ?",
638 )
639 .map_err(|e| Error::Database(e.to_string()))?;
640 let mut rows = stmt
641 .query([mls_group_id_bytes])
642 .map_err(|e| Error::Database(e.to_string()))?;
643
644 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
645 let id: i64 = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
646 let gid: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
647 let leaf_node: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
648 let row_key = serde_json::to_vec(&id).map_err(|e| Error::Database(e.to_string()))?;
649 let row_data = serde_json::to_vec(&(&gid, &leaf_node))
650 .map_err(|e| Error::Database(e.to_string()))?;
651 insert_stmt
652 .execute(rusqlite::params![
653 snapshot_name,
654 group_id_bytes,
655 "openmls_own_leaf_nodes",
656 row_key,
657 row_data,
658 now
659 ])
660 .map_err(|e| Error::Database(e.to_string()))?;
661 }
662 Ok(())
663 }
664
665 fn snapshot_openmls_epoch_key_pairs(
667 conn: &rusqlite::Connection,
668 insert_stmt: &mut rusqlite::CachedStatement<'_>,
669 snapshot_name: &str,
670 group_id_bytes: &[u8],
671 mls_group_id_bytes: &[u8],
672 now: i64,
673 ) -> Result<(), Error> {
674 let mut stmt = conn
675 .prepare(
676 "SELECT group_id, epoch_id, leaf_index, key_pairs
677 FROM openmls_epoch_key_pairs WHERE group_id = ?",
678 )
679 .map_err(|e| Error::Database(e.to_string()))?;
680 let mut rows = stmt
681 .query([mls_group_id_bytes])
682 .map_err(|e| Error::Database(e.to_string()))?;
683
684 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
685 let gid: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
686 let epoch_id: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
687 let leaf_index: i64 = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
688 let key_pairs: Vec<u8> = row.get(3).map_err(|e| Error::Database(e.to_string()))?;
689 let row_key = serde_json::to_vec(&(&gid, &epoch_id, leaf_index))
690 .map_err(|e| Error::Database(e.to_string()))?;
691 insert_stmt
692 .execute(rusqlite::params![
693 snapshot_name,
694 group_id_bytes,
695 "openmls_epoch_key_pairs",
696 row_key,
697 key_pairs,
698 now
699 ])
700 .map_err(|e| Error::Database(e.to_string()))?;
701 }
702 Ok(())
703 }
704
705 fn snapshot_groups_table(
707 conn: &rusqlite::Connection,
708 insert_stmt: &mut rusqlite::CachedStatement<'_>,
709 snapshot_name: &str,
710 group_id_bytes: &[u8],
711 now: i64,
712 ) -> Result<(), Error> {
713 let mut stmt = conn
714 .prepare(
715 "SELECT mls_group_id, nostr_group_id, name, description, admin_pubkeys,
716 last_message_id, last_message_at, last_message_processed_at, epoch, state,
717 image_hash, image_key, image_nonce, last_self_update_at
718 FROM groups WHERE mls_group_id = ?",
719 )
720 .map_err(|e| Error::Database(e.to_string()))?;
721 let mut rows = stmt
722 .query([group_id_bytes])
723 .map_err(|e| Error::Database(e.to_string()))?;
724
725 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
726 let mls_group_id: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
727 let nostr_group_id: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
728 let name_val: String = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
729 let description: String = row.get(3).map_err(|e| Error::Database(e.to_string()))?;
730 let admin_pubkeys: String = row.get(4).map_err(|e| Error::Database(e.to_string()))?;
731 let last_message_id: Option<Vec<u8>> =
732 row.get(5).map_err(|e| Error::Database(e.to_string()))?;
733 let last_message_at: Option<i64> =
734 row.get(6).map_err(|e| Error::Database(e.to_string()))?;
735 let last_message_processed_at: Option<i64> =
736 row.get(7).map_err(|e| Error::Database(e.to_string()))?;
737 let epoch: i64 = row.get(8).map_err(|e| Error::Database(e.to_string()))?;
738 let state: String = row.get(9).map_err(|e| Error::Database(e.to_string()))?;
739 let image_hash: Option<Vec<u8>> =
740 row.get(10).map_err(|e| Error::Database(e.to_string()))?;
741 let image_key: Option<Vec<u8>> =
742 row.get(11).map_err(|e| Error::Database(e.to_string()))?;
743 let image_nonce: Option<Vec<u8>> =
744 row.get(12).map_err(|e| Error::Database(e.to_string()))?;
745 let last_self_update_at: i64 =
746 row.get(13).map_err(|e| Error::Database(e.to_string()))?;
747
748 let row_key =
749 serde_json::to_vec(&mls_group_id).map_err(|e| Error::Database(e.to_string()))?;
750 let row_data = serde_json::to_vec(&(
751 &nostr_group_id,
752 &name_val,
753 &description,
754 &admin_pubkeys,
755 &last_message_id,
756 &last_message_at,
757 &last_message_processed_at,
758 epoch,
759 &state,
760 &image_hash,
761 &image_key,
762 &image_nonce,
763 &last_self_update_at,
764 ))
765 .map_err(|e| Error::Database(e.to_string()))?;
766
767 insert_stmt
768 .execute(rusqlite::params![
769 snapshot_name,
770 group_id_bytes,
771 "groups",
772 row_key,
773 row_data,
774 now
775 ])
776 .map_err(|e| Error::Database(e.to_string()))?;
777 }
778 Ok(())
779 }
780
781 fn snapshot_group_relays(
783 conn: &rusqlite::Connection,
784 insert_stmt: &mut rusqlite::CachedStatement<'_>,
785 snapshot_name: &str,
786 group_id_bytes: &[u8],
787 now: i64,
788 ) -> Result<(), Error> {
789 let mut stmt = conn
790 .prepare("SELECT id, mls_group_id, relay_url FROM group_relays WHERE mls_group_id = ?")
791 .map_err(|e| Error::Database(e.to_string()))?;
792 let mut rows = stmt
793 .query([group_id_bytes])
794 .map_err(|e| Error::Database(e.to_string()))?;
795
796 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
797 let id: i64 = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
798 let mls_group_id: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
799 let relay_url: String = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
800 let row_key = serde_json::to_vec(&id).map_err(|e| Error::Database(e.to_string()))?;
801 let row_data = serde_json::to_vec(&(&mls_group_id, &relay_url))
802 .map_err(|e| Error::Database(e.to_string()))?;
803 insert_stmt
804 .execute(rusqlite::params![
805 snapshot_name,
806 group_id_bytes,
807 "group_relays",
808 row_key,
809 row_data,
810 now
811 ])
812 .map_err(|e| Error::Database(e.to_string()))?;
813 }
814 Ok(())
815 }
816
817 fn snapshot_group_exporter_secrets(
819 conn: &rusqlite::Connection,
820 insert_stmt: &mut rusqlite::CachedStatement<'_>,
821 snapshot_name: &str,
822 group_id_bytes: &[u8],
823 now: i64,
824 ) -> Result<(), Error> {
825 let mut stmt = conn
826 .prepare(
827 "SELECT mls_group_id, epoch, secret FROM group_exporter_secrets WHERE mls_group_id = ?",
828 )
829 .map_err(|e| Error::Database(e.to_string()))?;
830 let mut rows = stmt
831 .query([group_id_bytes])
832 .map_err(|e| Error::Database(e.to_string()))?;
833
834 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
835 let mls_group_id: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
836 let epoch: i64 = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
837 let secret: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
838 let row_key = serde_json::to_vec(&(&mls_group_id, epoch))
839 .map_err(|e| Error::Database(e.to_string()))?;
840 insert_stmt
841 .execute(rusqlite::params![
842 snapshot_name,
843 group_id_bytes,
844 "group_exporter_secrets",
845 row_key,
846 secret,
847 now
848 ])
849 .map_err(|e| Error::Database(e.to_string()))?;
850 }
851 Ok(())
852 }
853
854 fn restore_group_from_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), Error> {
857 let conn = self.connection.lock().unwrap();
858 let group_id_bytes = group_id.as_slice();
859 let mls_group_id_bytes = mls_storage::MlsCodec::serialize(group_id)
862 .map_err(|e| Error::Database(e.to_string()))?;
863
864 let snapshot_exists: bool = conn
868 .query_row(
869 "SELECT EXISTS(SELECT 1 FROM group_state_snapshots WHERE snapshot_name = ? AND group_id = ?)",
870 rusqlite::params![name, group_id_bytes],
871 |row| row.get(0),
872 )
873 .map_err(|e| Error::Database(e.to_string()))?;
874
875 if !snapshot_exists {
876 return Err(Error::Database("Snapshot not found".to_string()));
877 }
878
879 let snapshot_rows: Vec<(String, Vec<u8>, Vec<u8>)> = {
884 let mut stmt = conn
885 .prepare(
886 "SELECT table_name, row_key, row_data FROM group_state_snapshots
887 WHERE snapshot_name = ? AND group_id = ?",
888 )
889 .map_err(|e| Error::Database(e.to_string()))?;
890
891 let rows = stmt
892 .query_map(rusqlite::params![name, group_id_bytes], |row| {
893 Ok((row.get(0)?, row.get(1)?, row.get(2)?))
894 })
895 .map_err(|e| Error::Database(e.to_string()))?;
896
897 rows.collect::<Result<Vec<_>, _>>()
898 .map_err(|e| Error::Database(e.to_string()))?
899 };
900
901 #[allow(clippy::type_complexity)]
904 let other_snapshots: Vec<(String, String, Vec<u8>, Vec<u8>, i64)> = {
905 let mut stmt = conn
906 .prepare(
907 "SELECT snapshot_name, table_name, row_key, row_data, created_at
908 FROM group_state_snapshots
909 WHERE group_id = ? AND snapshot_name != ?",
910 )
911 .map_err(|e| Error::Database(e.to_string()))?;
912
913 let rows = stmt
914 .query_map(rusqlite::params![group_id_bytes, name], |row| {
915 Ok((
916 row.get(0)?,
917 row.get(1)?,
918 row.get(2)?,
919 row.get(3)?,
920 row.get(4)?,
921 ))
922 })
923 .map_err(|e| Error::Database(e.to_string()))?;
924
925 rows.collect::<Result<Vec<_>, _>>()
926 .map_err(|e| Error::Database(e.to_string()))?
927 };
928
929 conn.execute("BEGIN IMMEDIATE", [])
931 .map_err(|e| Error::Database(e.to_string()))?;
932
933 let result = (|| -> Result<(), Error> {
934 conn.execute(
937 "DELETE FROM openmls_group_data WHERE group_id = ?",
938 [&mls_group_id_bytes],
939 )
940 .map_err(|e| Error::Database(e.to_string()))?;
941
942 conn.execute(
943 "DELETE FROM openmls_proposals WHERE group_id = ?",
944 [&mls_group_id_bytes],
945 )
946 .map_err(|e| Error::Database(e.to_string()))?;
947
948 conn.execute(
949 "DELETE FROM openmls_own_leaf_nodes WHERE group_id = ?",
950 [&mls_group_id_bytes],
951 )
952 .map_err(|e| Error::Database(e.to_string()))?;
953
954 conn.execute(
955 "DELETE FROM openmls_epoch_key_pairs WHERE group_id = ?",
956 [&mls_group_id_bytes],
957 )
958 .map_err(|e| Error::Database(e.to_string()))?;
959
960 conn.execute(
963 "DELETE FROM group_exporter_secrets WHERE mls_group_id = ?",
964 [group_id_bytes],
965 )
966 .map_err(|e| Error::Database(e.to_string()))?;
967
968 conn.execute(
969 "DELETE FROM group_relays WHERE mls_group_id = ?",
970 [group_id_bytes],
971 )
972 .map_err(|e| Error::Database(e.to_string()))?;
973
974 conn.execute(
975 "DELETE FROM groups WHERE mls_group_id = ?",
976 [group_id_bytes],
977 )
978 .map_err(|e| Error::Database(e.to_string()))?;
979
980 for (table_name, row_key, row_data) in &snapshot_rows {
988 if table_name != "groups" {
989 continue;
990 }
991 let mls_group_id: Vec<u8> =
992 serde_json::from_slice(row_key).map_err(|e| Error::Database(e.to_string()))?;
993 #[allow(clippy::type_complexity)]
994 let (
995 nostr_group_id,
996 name_val,
997 description,
998 admin_pubkeys,
999 last_message_id,
1000 last_message_at,
1001 last_message_processed_at,
1002 epoch,
1003 state,
1004 image_hash,
1005 image_key,
1006 image_nonce,
1007 last_self_update_at,
1008 ): (
1009 Vec<u8>,
1010 String,
1011 String,
1012 String,
1013 Option<Vec<u8>>,
1014 Option<i64>,
1015 Option<i64>,
1016 i64,
1017 String,
1018 Option<Vec<u8>>,
1019 Option<Vec<u8>>,
1020 Option<Vec<u8>>,
1021 i64,
1022 ) = serde_json::from_slice(row_data).map_err(|e| Error::Database(e.to_string()))?;
1023 conn.execute(
1024 "INSERT INTO groups (mls_group_id, nostr_group_id, name, description, admin_pubkeys,
1025 last_message_id, last_message_at, last_message_processed_at, epoch, state,
1026 image_hash, image_key, image_nonce, last_self_update_at)
1027 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
1028 rusqlite::params![
1029 mls_group_id,
1030 nostr_group_id,
1031 name_val,
1032 description,
1033 admin_pubkeys,
1034 last_message_id,
1035 last_message_at,
1036 last_message_processed_at,
1037 epoch,
1038 state,
1039 image_hash,
1040 image_key,
1041 image_nonce,
1042 last_self_update_at
1043 ],
1044 )
1045 .map_err(|e| Error::Database(e.to_string()))?;
1046 }
1047
1048 for (table_name, row_key, row_data) in &snapshot_rows {
1050 match table_name.as_str() {
1051 "openmls_group_data" => {
1052 let (gid, data_type): (Vec<u8>, String) =
1053 serde_json::from_slice(row_key)
1054 .map_err(|e| Error::Database(e.to_string()))?;
1055 conn.execute(
1056 "INSERT INTO openmls_group_data (provider_version, group_id, data_type, group_data)
1057 VALUES (1, ?, ?, ?)",
1058 rusqlite::params![gid, data_type, row_data],
1059 )
1060 .map_err(|e| Error::Database(e.to_string()))?;
1061 }
1062 "openmls_proposals" => {
1063 let (gid, proposal_ref): (Vec<u8>, Vec<u8>) =
1064 serde_json::from_slice(row_key)
1065 .map_err(|e| Error::Database(e.to_string()))?;
1066 conn.execute(
1067 "INSERT INTO openmls_proposals (provider_version, group_id, proposal_ref, proposal)
1068 VALUES (1, ?, ?, ?)",
1069 rusqlite::params![gid, proposal_ref, row_data],
1070 )
1071 .map_err(|e| Error::Database(e.to_string()))?;
1072 }
1073 "openmls_own_leaf_nodes" => {
1074 let (gid, leaf_node): (Vec<u8>, Vec<u8>) = serde_json::from_slice(row_data)
1075 .map_err(|e| Error::Database(e.to_string()))?;
1076 conn.execute(
1077 "INSERT INTO openmls_own_leaf_nodes (provider_version, group_id, leaf_node)
1078 VALUES (1, ?, ?)",
1079 rusqlite::params![gid, leaf_node],
1080 )
1081 .map_err(|e| Error::Database(e.to_string()))?;
1082 }
1083 "openmls_epoch_key_pairs" => {
1084 let (gid, epoch_id, leaf_index): (Vec<u8>, Vec<u8>, i64) =
1085 serde_json::from_slice(row_key)
1086 .map_err(|e| Error::Database(e.to_string()))?;
1087 conn.execute(
1088 "INSERT INTO openmls_epoch_key_pairs (provider_version, group_id, epoch_id, leaf_index, key_pairs)
1089 VALUES (1, ?, ?, ?, ?)",
1090 rusqlite::params![gid, epoch_id, leaf_index, row_data],
1091 )
1092 .map_err(|e| Error::Database(e.to_string()))?;
1093 }
1094 "groups" => {
1095 }
1097 "group_relays" => {
1098 let (mls_group_id, relay_url): (Vec<u8>, String) =
1099 serde_json::from_slice(row_data)
1100 .map_err(|e| Error::Database(e.to_string()))?;
1101 conn.execute(
1102 "INSERT INTO group_relays (mls_group_id, relay_url) VALUES (?, ?)",
1103 rusqlite::params![mls_group_id, relay_url],
1104 )
1105 .map_err(|e| Error::Database(e.to_string()))?;
1106 }
1107 "group_exporter_secrets" => {
1108 let (mls_group_id, epoch): (Vec<u8>, i64) = serde_json::from_slice(row_key)
1109 .map_err(|e| Error::Database(e.to_string()))?;
1110 conn.execute(
1111 "INSERT INTO group_exporter_secrets (mls_group_id, epoch, secret) VALUES (?, ?, ?)",
1112 rusqlite::params![mls_group_id, epoch, row_data],
1113 )
1114 .map_err(|e| Error::Database(e.to_string()))?;
1115 }
1116 _ => {
1117 }
1119 }
1120 }
1121
1122 conn.execute(
1124 "DELETE FROM group_state_snapshots WHERE snapshot_name = ? AND group_id = ?",
1125 rusqlite::params![name, group_id_bytes],
1126 )
1127 .map_err(|e| Error::Database(e.to_string()))?;
1128
1129 for (snap_name, table_name, row_key, row_data, created_at) in &other_snapshots {
1132 conn.execute(
1133 "INSERT INTO group_state_snapshots (snapshot_name, group_id, table_name, row_key, row_data, created_at)
1134 VALUES (?, ?, ?, ?, ?, ?)",
1135 rusqlite::params![snap_name, group_id_bytes, table_name, row_key, row_data, created_at],
1136 )
1137 .map_err(|e| Error::Database(e.to_string()))?;
1138 }
1139
1140 Ok(())
1141 })();
1142
1143 match result {
1144 Ok(()) => {
1145 conn.execute("COMMIT", [])
1146 .map_err(|e| Error::Database(e.to_string()))?;
1147 Ok(())
1148 }
1149 Err(e) => {
1150 let _ = conn.execute("ROLLBACK", []);
1151 Err(e)
1152 }
1153 }
1154 }
1155
1156 fn delete_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), Error> {
1158 let conn = self.connection.lock().unwrap();
1159 conn.execute(
1160 "DELETE FROM group_state_snapshots WHERE snapshot_name = ? AND group_id = ?",
1161 rusqlite::params![name, group_id.as_slice()],
1162 )
1163 .map_err(|e| Error::Database(e.to_string()))?;
1164 Ok(())
1165 }
1166}
1167
1168impl MdkStorageProvider for MdkSqliteStorage {
1170 fn backend(&self) -> Backend {
1176 Backend::SQLite
1177 }
1178
1179 fn create_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), MdkStorageError> {
1180 self.snapshot_group_state(group_id, name)
1181 .map_err(|e| MdkStorageError::Database(e.to_string()))
1182 }
1183
1184 fn rollback_group_to_snapshot(
1185 &self,
1186 group_id: &GroupId,
1187 name: &str,
1188 ) -> Result<(), MdkStorageError> {
1189 self.restore_group_from_snapshot(group_id, name)
1190 .map_err(|e| MdkStorageError::Database(e.to_string()))
1191 }
1192
1193 fn release_group_snapshot(
1194 &self,
1195 group_id: &GroupId,
1196 name: &str,
1197 ) -> Result<(), MdkStorageError> {
1198 self.delete_group_snapshot(group_id, name)
1199 .map_err(|e| MdkStorageError::Database(e.to_string()))
1200 }
1201
1202 fn list_group_snapshots(
1203 &self,
1204 group_id: &GroupId,
1205 ) -> Result<Vec<(String, u64)>, MdkStorageError> {
1206 let conn = self.connection.lock().unwrap();
1207 let mut stmt = conn
1208 .prepare_cached(
1209 "SELECT DISTINCT snapshot_name, created_at FROM group_state_snapshots
1210 WHERE group_id = ? ORDER BY created_at ASC",
1211 )
1212 .map_err(|e| MdkStorageError::Database(e.to_string()))?;
1213
1214 let rows = stmt
1215 .query_map(rusqlite::params![group_id.as_slice()], |row| {
1216 let name: String = row.get(0)?;
1217 let created_at: i64 = row.get(1)?;
1218 Ok((name, created_at as u64))
1219 })
1220 .map_err(|e| MdkStorageError::Database(e.to_string()))?;
1221
1222 rows.collect::<Result<Vec<_>, _>>()
1223 .map_err(|e| MdkStorageError::Database(e.to_string()))
1224 }
1225
1226 fn prune_expired_snapshots(&self, min_timestamp: u64) -> Result<usize, MdkStorageError> {
1227 let conn = self.connection.lock().unwrap();
1228 let deleted = conn
1229 .execute(
1230 "DELETE FROM group_state_snapshots WHERE created_at < ?",
1231 rusqlite::params![min_timestamp as i64],
1232 )
1233 .map_err(|e| MdkStorageError::Database(e.to_string()))?;
1234 Ok(deleted)
1235 }
1236}
1237
1238impl StorageProvider<STORAGE_PROVIDER_VERSION> for MdkSqliteStorage {
1243 type Error = MdkStorageError;
1244
1245 fn write_mls_join_config<GroupId, MlsGroupJoinConfig>(
1250 &self,
1251 group_id: &GroupId,
1252 config: &MlsGroupJoinConfig,
1253 ) -> Result<(), Self::Error>
1254 where
1255 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1256 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<STORAGE_PROVIDER_VERSION>,
1257 {
1258 self.with_connection(|conn| {
1259 mls_storage::write_group_data(conn, group_id, GroupDataType::JoinGroupConfig, config)
1260 })
1261 }
1262
1263 fn append_own_leaf_node<GroupId, LeafNode>(
1264 &self,
1265 group_id: &GroupId,
1266 leaf_node: &LeafNode,
1267 ) -> Result<(), Self::Error>
1268 where
1269 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1270 LeafNode: traits::LeafNode<STORAGE_PROVIDER_VERSION>,
1271 {
1272 self.with_connection(|conn| mls_storage::append_own_leaf_node(conn, group_id, leaf_node))
1273 }
1274
1275 fn queue_proposal<GroupId, ProposalRef, QueuedProposal>(
1276 &self,
1277 group_id: &GroupId,
1278 proposal_ref: &ProposalRef,
1279 proposal: &QueuedProposal,
1280 ) -> Result<(), Self::Error>
1281 where
1282 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1283 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1284 QueuedProposal: traits::QueuedProposal<STORAGE_PROVIDER_VERSION>,
1285 {
1286 self.with_connection(|conn| {
1287 mls_storage::queue_proposal(conn, group_id, proposal_ref, proposal)
1288 })
1289 }
1290
1291 fn write_tree<GroupId, TreeSync>(
1292 &self,
1293 group_id: &GroupId,
1294 tree: &TreeSync,
1295 ) -> Result<(), Self::Error>
1296 where
1297 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1298 TreeSync: traits::TreeSync<STORAGE_PROVIDER_VERSION>,
1299 {
1300 self.with_connection(|conn| {
1301 mls_storage::write_group_data(conn, group_id, GroupDataType::Tree, tree)
1302 })
1303 }
1304
1305 fn write_interim_transcript_hash<GroupId, InterimTranscriptHash>(
1306 &self,
1307 group_id: &GroupId,
1308 interim_transcript_hash: &InterimTranscriptHash,
1309 ) -> Result<(), Self::Error>
1310 where
1311 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1312 InterimTranscriptHash: traits::InterimTranscriptHash<STORAGE_PROVIDER_VERSION>,
1313 {
1314 self.with_connection(|conn| {
1315 mls_storage::write_group_data(
1316 conn,
1317 group_id,
1318 GroupDataType::InterimTranscriptHash,
1319 interim_transcript_hash,
1320 )
1321 })
1322 }
1323
1324 fn write_context<GroupId, GroupContext>(
1325 &self,
1326 group_id: &GroupId,
1327 group_context: &GroupContext,
1328 ) -> Result<(), Self::Error>
1329 where
1330 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1331 GroupContext: traits::GroupContext<STORAGE_PROVIDER_VERSION>,
1332 {
1333 self.with_connection(|conn| {
1334 mls_storage::write_group_data(conn, group_id, GroupDataType::Context, group_context)
1335 })
1336 }
1337
1338 fn write_confirmation_tag<GroupId, ConfirmationTag>(
1339 &self,
1340 group_id: &GroupId,
1341 confirmation_tag: &ConfirmationTag,
1342 ) -> Result<(), Self::Error>
1343 where
1344 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1345 ConfirmationTag: traits::ConfirmationTag<STORAGE_PROVIDER_VERSION>,
1346 {
1347 self.with_connection(|conn| {
1348 mls_storage::write_group_data(
1349 conn,
1350 group_id,
1351 GroupDataType::ConfirmationTag,
1352 confirmation_tag,
1353 )
1354 })
1355 }
1356
1357 fn write_group_state<GroupState, GroupId>(
1358 &self,
1359 group_id: &GroupId,
1360 group_state: &GroupState,
1361 ) -> Result<(), Self::Error>
1362 where
1363 GroupState: traits::GroupState<STORAGE_PROVIDER_VERSION>,
1364 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1365 {
1366 self.with_connection(|conn| {
1367 mls_storage::write_group_data(conn, group_id, GroupDataType::GroupState, group_state)
1368 })
1369 }
1370
1371 fn write_message_secrets<GroupId, MessageSecrets>(
1372 &self,
1373 group_id: &GroupId,
1374 message_secrets: &MessageSecrets,
1375 ) -> Result<(), Self::Error>
1376 where
1377 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1378 MessageSecrets: traits::MessageSecrets<STORAGE_PROVIDER_VERSION>,
1379 {
1380 self.with_connection(|conn| {
1381 mls_storage::write_group_data(
1382 conn,
1383 group_id,
1384 GroupDataType::MessageSecrets,
1385 message_secrets,
1386 )
1387 })
1388 }
1389
1390 fn write_resumption_psk_store<GroupId, ResumptionPskStore>(
1391 &self,
1392 group_id: &GroupId,
1393 resumption_psk_store: &ResumptionPskStore,
1394 ) -> Result<(), Self::Error>
1395 where
1396 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1397 ResumptionPskStore: traits::ResumptionPskStore<STORAGE_PROVIDER_VERSION>,
1398 {
1399 self.with_connection(|conn| {
1400 mls_storage::write_group_data(
1401 conn,
1402 group_id,
1403 GroupDataType::ResumptionPskStore,
1404 resumption_psk_store,
1405 )
1406 })
1407 }
1408
1409 fn write_own_leaf_index<GroupId, LeafNodeIndex>(
1410 &self,
1411 group_id: &GroupId,
1412 own_leaf_index: &LeafNodeIndex,
1413 ) -> Result<(), Self::Error>
1414 where
1415 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1416 LeafNodeIndex: traits::LeafNodeIndex<STORAGE_PROVIDER_VERSION>,
1417 {
1418 self.with_connection(|conn| {
1419 mls_storage::write_group_data(
1420 conn,
1421 group_id,
1422 GroupDataType::OwnLeafIndex,
1423 own_leaf_index,
1424 )
1425 })
1426 }
1427
1428 fn write_group_epoch_secrets<GroupId, GroupEpochSecrets>(
1429 &self,
1430 group_id: &GroupId,
1431 group_epoch_secrets: &GroupEpochSecrets,
1432 ) -> Result<(), Self::Error>
1433 where
1434 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1435 GroupEpochSecrets: traits::GroupEpochSecrets<STORAGE_PROVIDER_VERSION>,
1436 {
1437 self.with_connection(|conn| {
1438 mls_storage::write_group_data(
1439 conn,
1440 group_id,
1441 GroupDataType::GroupEpochSecrets,
1442 group_epoch_secrets,
1443 )
1444 })
1445 }
1446
1447 fn write_signature_key_pair<SignaturePublicKey, SignatureKeyPair>(
1448 &self,
1449 public_key: &SignaturePublicKey,
1450 signature_key_pair: &SignatureKeyPair,
1451 ) -> Result<(), Self::Error>
1452 where
1453 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1454 SignatureKeyPair: traits::SignatureKeyPair<STORAGE_PROVIDER_VERSION>,
1455 {
1456 self.with_connection(|conn| {
1457 mls_storage::write_signature_key_pair(conn, public_key, signature_key_pair)
1458 })
1459 }
1460
1461 fn write_encryption_key_pair<EncryptionKey, HpkeKeyPair>(
1462 &self,
1463 public_key: &EncryptionKey,
1464 key_pair: &HpkeKeyPair,
1465 ) -> Result<(), Self::Error>
1466 where
1467 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1468 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1469 {
1470 self.with_connection(|conn| {
1471 mls_storage::write_encryption_key_pair(conn, public_key, key_pair)
1472 })
1473 }
1474
1475 fn write_encryption_epoch_key_pairs<GroupId, EpochKey, HpkeKeyPair>(
1476 &self,
1477 group_id: &GroupId,
1478 epoch: &EpochKey,
1479 leaf_index: u32,
1480 key_pairs: &[HpkeKeyPair],
1481 ) -> Result<(), Self::Error>
1482 where
1483 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1484 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1485 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1486 {
1487 self.with_connection(|conn| {
1488 mls_storage::write_encryption_epoch_key_pairs(
1489 conn, group_id, epoch, leaf_index, key_pairs,
1490 )
1491 })
1492 }
1493
1494 fn write_key_package<HashReference, KeyPackage>(
1495 &self,
1496 hash_ref: &HashReference,
1497 key_package: &KeyPackage,
1498 ) -> Result<(), Self::Error>
1499 where
1500 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1501 KeyPackage: traits::KeyPackage<STORAGE_PROVIDER_VERSION>,
1502 {
1503 self.with_connection(|conn| mls_storage::write_key_package(conn, hash_ref, key_package))
1504 }
1505
1506 fn write_psk<PskId, PskBundle>(
1507 &self,
1508 psk_id: &PskId,
1509 psk: &PskBundle,
1510 ) -> Result<(), Self::Error>
1511 where
1512 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1513 PskBundle: traits::PskBundle<STORAGE_PROVIDER_VERSION>,
1514 {
1515 self.with_connection(|conn| mls_storage::write_psk(conn, psk_id, psk))
1516 }
1517
1518 fn mls_group_join_config<GroupId, MlsGroupJoinConfig>(
1523 &self,
1524 group_id: &GroupId,
1525 ) -> Result<Option<MlsGroupJoinConfig>, Self::Error>
1526 where
1527 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1528 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<STORAGE_PROVIDER_VERSION>,
1529 {
1530 self.with_connection(|conn| {
1531 mls_storage::read_group_data(conn, group_id, GroupDataType::JoinGroupConfig)
1532 })
1533 }
1534
1535 fn own_leaf_nodes<GroupId, LeafNode>(
1536 &self,
1537 group_id: &GroupId,
1538 ) -> Result<Vec<LeafNode>, Self::Error>
1539 where
1540 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1541 LeafNode: traits::LeafNode<STORAGE_PROVIDER_VERSION>,
1542 {
1543 self.with_connection(|conn| mls_storage::read_own_leaf_nodes(conn, group_id))
1544 }
1545
1546 fn queued_proposal_refs<GroupId, ProposalRef>(
1547 &self,
1548 group_id: &GroupId,
1549 ) -> Result<Vec<ProposalRef>, Self::Error>
1550 where
1551 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1552 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1553 {
1554 self.with_connection(|conn| mls_storage::read_queued_proposal_refs(conn, group_id))
1555 }
1556
1557 fn queued_proposals<GroupId, ProposalRef, QueuedProposal>(
1558 &self,
1559 group_id: &GroupId,
1560 ) -> Result<Vec<(ProposalRef, QueuedProposal)>, Self::Error>
1561 where
1562 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1563 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1564 QueuedProposal: traits::QueuedProposal<STORAGE_PROVIDER_VERSION>,
1565 {
1566 self.with_connection(|conn| mls_storage::read_queued_proposals(conn, group_id))
1567 }
1568
1569 fn tree<GroupId, TreeSync>(&self, group_id: &GroupId) -> Result<Option<TreeSync>, Self::Error>
1570 where
1571 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1572 TreeSync: traits::TreeSync<STORAGE_PROVIDER_VERSION>,
1573 {
1574 self.with_connection(|conn| {
1575 mls_storage::read_group_data(conn, group_id, GroupDataType::Tree)
1576 })
1577 }
1578
1579 fn group_context<GroupId, GroupContext>(
1580 &self,
1581 group_id: &GroupId,
1582 ) -> Result<Option<GroupContext>, Self::Error>
1583 where
1584 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1585 GroupContext: traits::GroupContext<STORAGE_PROVIDER_VERSION>,
1586 {
1587 self.with_connection(|conn| {
1588 mls_storage::read_group_data(conn, group_id, GroupDataType::Context)
1589 })
1590 }
1591
1592 fn interim_transcript_hash<GroupId, InterimTranscriptHash>(
1593 &self,
1594 group_id: &GroupId,
1595 ) -> Result<Option<InterimTranscriptHash>, Self::Error>
1596 where
1597 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1598 InterimTranscriptHash: traits::InterimTranscriptHash<STORAGE_PROVIDER_VERSION>,
1599 {
1600 self.with_connection(|conn| {
1601 mls_storage::read_group_data(conn, group_id, GroupDataType::InterimTranscriptHash)
1602 })
1603 }
1604
1605 fn confirmation_tag<GroupId, ConfirmationTag>(
1606 &self,
1607 group_id: &GroupId,
1608 ) -> Result<Option<ConfirmationTag>, Self::Error>
1609 where
1610 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1611 ConfirmationTag: traits::ConfirmationTag<STORAGE_PROVIDER_VERSION>,
1612 {
1613 self.with_connection(|conn| {
1614 mls_storage::read_group_data(conn, group_id, GroupDataType::ConfirmationTag)
1615 })
1616 }
1617
1618 fn group_state<GroupState, GroupId>(
1619 &self,
1620 group_id: &GroupId,
1621 ) -> Result<Option<GroupState>, Self::Error>
1622 where
1623 GroupState: traits::GroupState<STORAGE_PROVIDER_VERSION>,
1624 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1625 {
1626 self.with_connection(|conn| {
1627 mls_storage::read_group_data(conn, group_id, GroupDataType::GroupState)
1628 })
1629 }
1630
1631 fn message_secrets<GroupId, MessageSecrets>(
1632 &self,
1633 group_id: &GroupId,
1634 ) -> Result<Option<MessageSecrets>, Self::Error>
1635 where
1636 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1637 MessageSecrets: traits::MessageSecrets<STORAGE_PROVIDER_VERSION>,
1638 {
1639 self.with_connection(|conn| {
1640 mls_storage::read_group_data(conn, group_id, GroupDataType::MessageSecrets)
1641 })
1642 }
1643
1644 fn resumption_psk_store<GroupId, ResumptionPskStore>(
1645 &self,
1646 group_id: &GroupId,
1647 ) -> Result<Option<ResumptionPskStore>, Self::Error>
1648 where
1649 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1650 ResumptionPskStore: traits::ResumptionPskStore<STORAGE_PROVIDER_VERSION>,
1651 {
1652 self.with_connection(|conn| {
1653 mls_storage::read_group_data(conn, group_id, GroupDataType::ResumptionPskStore)
1654 })
1655 }
1656
1657 fn own_leaf_index<GroupId, LeafNodeIndex>(
1658 &self,
1659 group_id: &GroupId,
1660 ) -> Result<Option<LeafNodeIndex>, Self::Error>
1661 where
1662 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1663 LeafNodeIndex: traits::LeafNodeIndex<STORAGE_PROVIDER_VERSION>,
1664 {
1665 self.with_connection(|conn| {
1666 mls_storage::read_group_data(conn, group_id, GroupDataType::OwnLeafIndex)
1667 })
1668 }
1669
1670 fn group_epoch_secrets<GroupId, GroupEpochSecrets>(
1671 &self,
1672 group_id: &GroupId,
1673 ) -> Result<Option<GroupEpochSecrets>, Self::Error>
1674 where
1675 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1676 GroupEpochSecrets: traits::GroupEpochSecrets<STORAGE_PROVIDER_VERSION>,
1677 {
1678 self.with_connection(|conn| {
1679 mls_storage::read_group_data(conn, group_id, GroupDataType::GroupEpochSecrets)
1680 })
1681 }
1682
1683 fn signature_key_pair<SignaturePublicKey, SignatureKeyPair>(
1684 &self,
1685 public_key: &SignaturePublicKey,
1686 ) -> Result<Option<SignatureKeyPair>, Self::Error>
1687 where
1688 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1689 SignatureKeyPair: traits::SignatureKeyPair<STORAGE_PROVIDER_VERSION>,
1690 {
1691 self.with_connection(|conn| mls_storage::read_signature_key_pair(conn, public_key))
1692 }
1693
1694 fn encryption_key_pair<HpkeKeyPair, EncryptionKey>(
1695 &self,
1696 public_key: &EncryptionKey,
1697 ) -> Result<Option<HpkeKeyPair>, Self::Error>
1698 where
1699 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1700 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1701 {
1702 self.with_connection(|conn| mls_storage::read_encryption_key_pair(conn, public_key))
1703 }
1704
1705 fn encryption_epoch_key_pairs<GroupId, EpochKey, HpkeKeyPair>(
1706 &self,
1707 group_id: &GroupId,
1708 epoch: &EpochKey,
1709 leaf_index: u32,
1710 ) -> Result<Vec<HpkeKeyPair>, Self::Error>
1711 where
1712 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1713 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1714 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1715 {
1716 self.with_connection(|conn| {
1717 mls_storage::read_encryption_epoch_key_pairs(conn, group_id, epoch, leaf_index)
1718 })
1719 }
1720
1721 fn key_package<HashReference, KeyPackage>(
1722 &self,
1723 hash_ref: &HashReference,
1724 ) -> Result<Option<KeyPackage>, Self::Error>
1725 where
1726 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1727 KeyPackage: traits::KeyPackage<STORAGE_PROVIDER_VERSION>,
1728 {
1729 self.with_connection(|conn| mls_storage::read_key_package(conn, hash_ref))
1730 }
1731
1732 fn psk<PskBundle, PskId>(&self, psk_id: &PskId) -> Result<Option<PskBundle>, Self::Error>
1733 where
1734 PskBundle: traits::PskBundle<STORAGE_PROVIDER_VERSION>,
1735 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1736 {
1737 self.with_connection(|conn| mls_storage::read_psk(conn, psk_id))
1738 }
1739
1740 fn remove_proposal<GroupId, ProposalRef>(
1745 &self,
1746 group_id: &GroupId,
1747 proposal_ref: &ProposalRef,
1748 ) -> Result<(), Self::Error>
1749 where
1750 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1751 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1752 {
1753 self.with_connection(|conn| mls_storage::remove_proposal(conn, group_id, proposal_ref))
1754 }
1755
1756 fn delete_own_leaf_nodes<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1757 where
1758 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1759 {
1760 self.with_connection(|conn| mls_storage::delete_own_leaf_nodes(conn, group_id))
1761 }
1762
1763 fn delete_group_config<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1764 where
1765 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1766 {
1767 self.with_connection(|conn| {
1768 mls_storage::delete_group_data(conn, group_id, GroupDataType::JoinGroupConfig)
1769 })
1770 }
1771
1772 fn delete_tree<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1773 where
1774 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1775 {
1776 self.with_connection(|conn| {
1777 mls_storage::delete_group_data(conn, group_id, GroupDataType::Tree)
1778 })
1779 }
1780
1781 fn delete_confirmation_tag<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1782 where
1783 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1784 {
1785 self.with_connection(|conn| {
1786 mls_storage::delete_group_data(conn, group_id, GroupDataType::ConfirmationTag)
1787 })
1788 }
1789
1790 fn delete_group_state<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1791 where
1792 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1793 {
1794 self.with_connection(|conn| {
1795 mls_storage::delete_group_data(conn, group_id, GroupDataType::GroupState)
1796 })
1797 }
1798
1799 fn delete_context<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1800 where
1801 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1802 {
1803 self.with_connection(|conn| {
1804 mls_storage::delete_group_data(conn, group_id, GroupDataType::Context)
1805 })
1806 }
1807
1808 fn delete_interim_transcript_hash<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1809 where
1810 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1811 {
1812 self.with_connection(|conn| {
1813 mls_storage::delete_group_data(conn, group_id, GroupDataType::InterimTranscriptHash)
1814 })
1815 }
1816
1817 fn delete_message_secrets<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1818 where
1819 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1820 {
1821 self.with_connection(|conn| {
1822 mls_storage::delete_group_data(conn, group_id, GroupDataType::MessageSecrets)
1823 })
1824 }
1825
1826 fn delete_all_resumption_psk_secrets<GroupId>(
1827 &self,
1828 group_id: &GroupId,
1829 ) -> Result<(), Self::Error>
1830 where
1831 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1832 {
1833 self.with_connection(|conn| {
1834 mls_storage::delete_group_data(conn, group_id, GroupDataType::ResumptionPskStore)
1835 })
1836 }
1837
1838 fn delete_own_leaf_index<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1839 where
1840 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1841 {
1842 self.with_connection(|conn| {
1843 mls_storage::delete_group_data(conn, group_id, GroupDataType::OwnLeafIndex)
1844 })
1845 }
1846
1847 fn delete_group_epoch_secrets<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1848 where
1849 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1850 {
1851 self.with_connection(|conn| {
1852 mls_storage::delete_group_data(conn, group_id, GroupDataType::GroupEpochSecrets)
1853 })
1854 }
1855
1856 fn clear_proposal_queue<GroupId, ProposalRef>(
1857 &self,
1858 group_id: &GroupId,
1859 ) -> Result<(), Self::Error>
1860 where
1861 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1862 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1863 {
1864 self.with_connection(|conn| mls_storage::clear_proposal_queue(conn, group_id))
1865 }
1866
1867 fn delete_signature_key_pair<SignaturePublicKey>(
1868 &self,
1869 public_key: &SignaturePublicKey,
1870 ) -> Result<(), Self::Error>
1871 where
1872 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1873 {
1874 self.with_connection(|conn| mls_storage::delete_signature_key_pair(conn, public_key))
1875 }
1876
1877 fn delete_encryption_key_pair<EncryptionKey>(
1878 &self,
1879 public_key: &EncryptionKey,
1880 ) -> Result<(), Self::Error>
1881 where
1882 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1883 {
1884 self.with_connection(|conn| mls_storage::delete_encryption_key_pair(conn, public_key))
1885 }
1886
1887 fn delete_encryption_epoch_key_pairs<GroupId, EpochKey>(
1888 &self,
1889 group_id: &GroupId,
1890 epoch: &EpochKey,
1891 leaf_index: u32,
1892 ) -> Result<(), Self::Error>
1893 where
1894 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1895 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1896 {
1897 self.with_connection(|conn| {
1898 mls_storage::delete_encryption_epoch_key_pairs(conn, group_id, epoch, leaf_index)
1899 })
1900 }
1901
1902 fn delete_key_package<HashReference>(&self, hash_ref: &HashReference) -> Result<(), Self::Error>
1903 where
1904 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1905 {
1906 self.with_connection(|conn| mls_storage::delete_key_package(conn, hash_ref))
1907 }
1908
1909 fn delete_psk<PskId>(&self, psk_id: &PskId) -> Result<(), Self::Error>
1910 where
1911 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1912 {
1913 self.with_connection(|conn| mls_storage::delete_psk(conn, psk_id))
1914 }
1915}
1916
1917#[cfg(test)]
1918mod tests {
1919 use std::collections::BTreeSet;
1920
1921 use mdk_storage_traits::GroupId;
1922 use mdk_storage_traits::Secret;
1923 use mdk_storage_traits::groups::GroupStorage;
1924 use mdk_storage_traits::groups::types::{
1925 Group, GroupExporterSecret, GroupState, SelfUpdateState,
1926 };
1927 use tempfile::tempdir;
1928
1929 use super::*;
1930
1931 #[test]
1932 fn test_new_in_memory() {
1933 let storage = MdkSqliteStorage::new_in_memory();
1934 assert!(storage.is_ok());
1935 let storage = storage.unwrap();
1936 assert_eq!(storage.backend(), Backend::SQLite);
1937 }
1938
1939 #[test]
1940 fn test_backend_type() {
1941 let storage = MdkSqliteStorage::new_in_memory().unwrap();
1942 assert_eq!(storage.backend(), Backend::SQLite);
1943 assert!(storage.backend().is_persistent());
1944 }
1945
1946 #[test]
1947 fn test_file_based_storage() {
1948 let temp_dir = tempdir().unwrap();
1949 let db_path = temp_dir.path().join("test_db.sqlite");
1950
1951 let storage = MdkSqliteStorage::new_unencrypted(&db_path);
1953 assert!(storage.is_ok());
1954
1955 assert!(db_path.exists());
1957
1958 let storage2 = MdkSqliteStorage::new_unencrypted(&db_path);
1960 assert!(storage2.is_ok());
1961
1962 drop(storage);
1964 drop(storage2);
1965 temp_dir.close().unwrap();
1966 }
1967
1968 #[test]
1969 fn test_database_tables() {
1970 let temp_dir = tempdir().unwrap();
1971 let db_path = temp_dir.path().join("migration_test.sqlite");
1972
1973 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
1975
1976 storage.with_connection(|conn| {
1978 let mut stmt = conn
1980 .prepare("SELECT name FROM sqlite_master WHERE type='table'")
1981 .unwrap();
1982 let table_names: Vec<String> = stmt
1983 .query_map([], |row| row.get(0))
1984 .unwrap()
1985 .map(|r| r.unwrap())
1986 .collect();
1987
1988 assert!(table_names.contains(&"groups".to_string()));
1990 assert!(table_names.contains(&"messages".to_string()));
1991 assert!(table_names.contains(&"welcomes".to_string()));
1992 assert!(table_names.contains(&"processed_messages".to_string()));
1993 assert!(table_names.contains(&"processed_welcomes".to_string()));
1994 assert!(table_names.contains(&"group_relays".to_string()));
1995 assert!(table_names.contains(&"group_exporter_secrets".to_string()));
1996
1997 assert!(table_names.contains(&"openmls_group_data".to_string()));
1999 assert!(table_names.contains(&"openmls_proposals".to_string()));
2000 assert!(table_names.contains(&"openmls_own_leaf_nodes".to_string()));
2001 assert!(table_names.contains(&"openmls_key_packages".to_string()));
2002 assert!(table_names.contains(&"openmls_psks".to_string()));
2003 assert!(table_names.contains(&"openmls_signature_keys".to_string()));
2004 assert!(table_names.contains(&"openmls_encryption_keys".to_string()));
2005 assert!(table_names.contains(&"openmls_epoch_key_pairs".to_string()));
2006 });
2007
2008 drop(storage);
2010 temp_dir.close().unwrap();
2011 }
2012
2013 #[test]
2014 fn test_group_exporter_secrets() {
2015 let storage = MdkSqliteStorage::new_in_memory().unwrap();
2017
2018 let mls_group_id = GroupId::from_slice(vec![1, 2, 3, 4].as_slice());
2020 let group = Group {
2021 mls_group_id: mls_group_id.clone(),
2022 nostr_group_id: [0u8; 32],
2023 name: "Test Group".to_string(),
2024 description: "A test group for exporter secrets".to_string(),
2025 admin_pubkeys: BTreeSet::new(),
2026 last_message_id: None,
2027 last_message_at: None,
2028 last_message_processed_at: None,
2029 epoch: 0,
2030 state: GroupState::Active,
2031 image_hash: None,
2032 image_key: None,
2033 image_nonce: None,
2034 self_update_state: SelfUpdateState::Required,
2035 };
2036
2037 storage.save_group(group.clone()).unwrap();
2039
2040 let secret_epoch_0 = GroupExporterSecret {
2042 mls_group_id: mls_group_id.clone(),
2043 epoch: 0,
2044 secret: Secret::new([0u8; 32]),
2045 };
2046
2047 let secret_epoch_1 = GroupExporterSecret {
2048 mls_group_id: mls_group_id.clone(),
2049 epoch: 1,
2050 secret: Secret::new([0u8; 32]),
2051 };
2052
2053 storage
2055 .save_group_exporter_secret(secret_epoch_0.clone())
2056 .unwrap();
2057 storage
2058 .save_group_exporter_secret(secret_epoch_1.clone())
2059 .unwrap();
2060
2061 let retrieved_secret_0 = storage.get_group_exporter_secret(&mls_group_id, 0).unwrap();
2063 assert!(retrieved_secret_0.is_some());
2064 let retrieved_secret_0 = retrieved_secret_0.unwrap();
2065 assert_eq!(retrieved_secret_0, secret_epoch_0);
2066
2067 let retrieved_secret_1 = storage.get_group_exporter_secret(&mls_group_id, 1).unwrap();
2068 assert!(retrieved_secret_1.is_some());
2069 let retrieved_secret_1 = retrieved_secret_1.unwrap();
2070 assert_eq!(retrieved_secret_1, secret_epoch_1);
2071
2072 let non_existent_epoch = storage
2074 .get_group_exporter_secret(&mls_group_id, 999)
2075 .unwrap();
2076 assert!(non_existent_epoch.is_none());
2077
2078 let non_existent_group_id = GroupId::from_slice(&[9, 9, 9, 9]);
2080 let result = storage.get_group_exporter_secret(&non_existent_group_id, 0);
2081 assert!(result.is_err());
2082
2083 let updated_secret_0 = GroupExporterSecret {
2085 mls_group_id: mls_group_id.clone(),
2086 epoch: 0,
2087 secret: Secret::new([0u8; 32]),
2088 };
2089 storage
2090 .save_group_exporter_secret(updated_secret_0.clone())
2091 .unwrap();
2092
2093 let retrieved_updated_secret = storage
2094 .get_group_exporter_secret(&mls_group_id, 0)
2095 .unwrap()
2096 .unwrap();
2097 assert_eq!(retrieved_updated_secret, updated_secret_0);
2098
2099 let invalid_secret = GroupExporterSecret {
2101 mls_group_id: non_existent_group_id.clone(),
2102 epoch: 0,
2103 secret: Secret::new([0u8; 32]),
2104 };
2105 let result = storage.save_group_exporter_secret(invalid_secret);
2106 assert!(result.is_err());
2107 }
2108
2109 mod encryption_tests {
2114 #[cfg(unix)]
2115 use std::os::unix::fs::PermissionsExt;
2116 use std::thread;
2117
2118 use mdk_storage_traits::Secret;
2119 use mdk_storage_traits::groups::GroupStorage;
2120 use mdk_storage_traits::groups::types::{Group, GroupExporterSecret, GroupState};
2121 use mdk_storage_traits::messages::MessageStorage;
2122 use mdk_storage_traits::test_utils::cross_storage::{
2123 create_test_group, create_test_message, create_test_welcome,
2124 };
2125 use mdk_storage_traits::welcomes::WelcomeStorage;
2126 use nostr::EventId;
2127
2128 use super::*;
2129 use crate::test_utils::ensure_mock_store;
2130
2131 #[test]
2132 fn test_encrypted_storage_creation() {
2133 let temp_dir = tempdir().unwrap();
2134 let db_path = temp_dir.path().join("encrypted.db");
2135
2136 let config = EncryptionConfig::generate().unwrap();
2137 let storage = MdkSqliteStorage::new_with_key(&db_path, config);
2138 assert!(storage.is_ok());
2139
2140 assert!(db_path.exists());
2142
2143 assert!(
2145 encryption::is_database_encrypted(&db_path).unwrap(),
2146 "Database should be encrypted"
2147 );
2148 }
2149
2150 #[test]
2151 fn test_encrypted_storage_reopen_with_correct_key() {
2152 let temp_dir = tempdir().unwrap();
2153 let db_path = temp_dir.path().join("encrypted_reopen.db");
2154
2155 let config = EncryptionConfig::generate().unwrap();
2157 let key = *config.key();
2158
2159 {
2160 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2161 let _ = storage.backend();
2163 }
2164
2165 let config2 = EncryptionConfig::new(key);
2167 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2);
2168 assert!(
2169 storage2.is_ok(),
2170 "Should be able to reopen with correct key"
2171 );
2172 }
2173
2174 #[test]
2175 fn test_encrypted_storage_wrong_key_fails() {
2176 let temp_dir = tempdir().unwrap();
2177 let db_path = temp_dir.path().join("encrypted_wrong_key.db");
2178
2179 let config1 = EncryptionConfig::generate().unwrap();
2181 {
2182 let storage = MdkSqliteStorage::new_with_key(&db_path, config1).unwrap();
2183 drop(storage);
2184 }
2185
2186 let config2 = EncryptionConfig::generate().unwrap();
2188 let result = MdkSqliteStorage::new_with_key(&db_path, config2);
2189
2190 assert!(result.is_err(), "Opening with wrong key should fail");
2191
2192 match result {
2194 Err(error::Error::WrongEncryptionKey) => {}
2195 Err(e) => panic!("Expected WrongEncryptionKey error, got: {:?}", e),
2196 Ok(_) => panic!("Expected error but got success"),
2197 }
2198 }
2199
2200 #[test]
2201 fn test_unencrypted_cannot_read_encrypted() {
2202 let temp_dir = tempdir().unwrap();
2203 let db_path = temp_dir.path().join("encrypted_only.db");
2204
2205 let config = EncryptionConfig::generate().unwrap();
2207 {
2208 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2209 drop(storage);
2210 }
2211
2212 let result = MdkSqliteStorage::new_unencrypted(&db_path);
2214
2215 assert!(
2217 result.is_err(),
2218 "Opening encrypted database without key should fail"
2219 );
2220 }
2221
2222 #[test]
2223 fn test_encrypted_storage_data_persistence() {
2224 let temp_dir = tempdir().unwrap();
2225 let db_path = temp_dir.path().join("encrypted_persist.db");
2226
2227 let config = EncryptionConfig::generate().unwrap();
2228 let key = *config.key();
2229
2230 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
2232 {
2233 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2234
2235 let group = Group {
2236 mls_group_id: mls_group_id.clone(),
2237 nostr_group_id: [0u8; 32],
2238 name: "Encrypted Group".to_string(),
2239 description: "Testing encrypted persistence".to_string(),
2240 admin_pubkeys: BTreeSet::new(),
2241 last_message_id: None,
2242 last_message_at: None,
2243 last_message_processed_at: None,
2244 epoch: 0,
2245 state: GroupState::Active,
2246 image_hash: None,
2247 image_key: None,
2248 image_nonce: None,
2249 self_update_state: SelfUpdateState::Required,
2250 };
2251
2252 storage.save_group(group).unwrap();
2253 }
2254
2255 let config2 = EncryptionConfig::new(key);
2257 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2258
2259 let found_group = storage2.find_group_by_mls_group_id(&mls_group_id).unwrap();
2260 assert!(found_group.is_some());
2261 assert_eq!(found_group.unwrap().name, "Encrypted Group");
2262 }
2263
2264 #[test]
2265 fn test_file_permissions_are_secure() {
2266 let temp_dir = tempdir().unwrap();
2267 let db_path = temp_dir.path().join("secure_perms.db");
2268
2269 let config = EncryptionConfig::generate().unwrap();
2270 let _storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2271
2272 #[cfg(unix)]
2274 {
2275 let metadata = std::fs::metadata(&db_path).unwrap();
2276 let mode = metadata.permissions().mode();
2277
2278 assert_eq!(
2280 mode & 0o077,
2281 0,
2282 "Database file should have owner-only permissions, got {:o}",
2283 mode & 0o777
2284 );
2285 }
2286 }
2287
2288 #[test]
2289 fn test_encrypted_storage_multiple_groups() {
2290 let temp_dir = tempdir().unwrap();
2291 let db_path = temp_dir.path().join("multi_groups.db");
2292
2293 let config = EncryptionConfig::generate().unwrap();
2294 let key = *config.key();
2295
2296 {
2298 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2299
2300 for i in 0..5 {
2301 let mls_group_id = GroupId::from_slice(&[i; 8]);
2302 let mut group = create_test_group(mls_group_id);
2303 group.name = format!("Group {}", i);
2304 group.description = format!("Description {}", i);
2305 storage.save_group(group).unwrap();
2306 }
2307 }
2308
2309 let config2 = EncryptionConfig::new(key);
2311 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2312
2313 let groups = storage2.all_groups().unwrap();
2314 assert_eq!(groups.len(), 5);
2315
2316 for i in 0..5u8 {
2317 let mls_group_id = GroupId::from_slice(&[i; 8]);
2318 let group = storage2
2319 .find_group_by_mls_group_id(&mls_group_id)
2320 .unwrap()
2321 .unwrap();
2322 assert_eq!(group.name, format!("Group {}", i));
2323 }
2324 }
2325
2326 #[test]
2327 fn test_encrypted_storage_messages() {
2328 let temp_dir = tempdir().unwrap();
2329 let db_path = temp_dir.path().join("messages.db");
2330
2331 let config = EncryptionConfig::generate().unwrap();
2332 let key = *config.key();
2333
2334 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
2335
2336 {
2338 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2339
2340 let group = create_test_group(mls_group_id.clone());
2341 storage.save_group(group).unwrap();
2342
2343 let event_id = EventId::all_zeros();
2345 let mut message = create_test_message(mls_group_id.clone(), event_id);
2346 message.content = "Test message content".to_string();
2347 storage.save_message(message).unwrap();
2348 }
2349
2350 let config2 = EncryptionConfig::new(key);
2352 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2353
2354 let messages = storage2.messages(&mls_group_id, None).unwrap();
2355 assert_eq!(messages.len(), 1);
2356 assert_eq!(messages[0].content, "Test message content");
2357 }
2358
2359 #[test]
2360 fn test_encrypted_storage_welcomes() {
2361 let temp_dir = tempdir().unwrap();
2362 let db_path = temp_dir.path().join("welcomes.db");
2363
2364 let config = EncryptionConfig::generate().unwrap();
2365 let key = *config.key();
2366
2367 let mls_group_id = GroupId::from_slice(&[5, 6, 7, 8]);
2368
2369 {
2371 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2372
2373 let group = create_test_group(mls_group_id.clone());
2374 storage.save_group(group).unwrap();
2375
2376 let event_id = EventId::all_zeros();
2377 let welcome = create_test_welcome(mls_group_id.clone(), event_id);
2378 storage.save_welcome(welcome).unwrap();
2379 }
2380
2381 let config2 = EncryptionConfig::new(key);
2383 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2384
2385 let welcomes = storage2.pending_welcomes(None).unwrap();
2386 assert_eq!(welcomes.len(), 1);
2387 }
2388
2389 #[test]
2390 fn test_encrypted_storage_exporter_secrets() {
2391 let temp_dir = tempdir().unwrap();
2392 let db_path = temp_dir.path().join("exporter_secrets.db");
2393
2394 let config = EncryptionConfig::generate().unwrap();
2395 let key = *config.key();
2396
2397 let mls_group_id = GroupId::from_slice(&[10, 20, 30, 40]);
2398
2399 {
2401 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2402
2403 let group = Group {
2404 mls_group_id: mls_group_id.clone(),
2405 nostr_group_id: [0u8; 32],
2406 name: "Exporter Secret Test".to_string(),
2407 description: "Testing exporter secrets".to_string(),
2408 admin_pubkeys: BTreeSet::new(),
2409 last_message_id: None,
2410 last_message_at: None,
2411 last_message_processed_at: None,
2412 epoch: 5,
2413 state: GroupState::Active,
2414 image_hash: None,
2415 image_key: None,
2416 image_nonce: None,
2417 self_update_state: SelfUpdateState::Required,
2418 };
2419 storage.save_group(group).unwrap();
2420
2421 for epoch in 0..=5u64 {
2423 let secret = GroupExporterSecret {
2424 mls_group_id: mls_group_id.clone(),
2425 epoch,
2426 secret: Secret::new([epoch as u8; 32]),
2427 };
2428 storage.save_group_exporter_secret(secret).unwrap();
2429 }
2430 }
2431
2432 let config2 = EncryptionConfig::new(key);
2434 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2435
2436 for epoch in 0..=5u64 {
2437 let secret = storage2
2438 .get_group_exporter_secret(&mls_group_id, epoch)
2439 .unwrap()
2440 .unwrap();
2441 assert_eq!(secret.epoch, epoch);
2442 assert_eq!(secret.secret[0], epoch as u8);
2443 }
2444
2445 let missing = storage2
2447 .get_group_exporter_secret(&mls_group_id, 999)
2448 .unwrap();
2449 assert!(missing.is_none());
2450 }
2451
2452 #[test]
2453 fn test_encrypted_storage_with_nested_directory() {
2454 let temp_dir = tempdir().unwrap();
2455 let db_path = temp_dir
2456 .path()
2457 .join("deep")
2458 .join("nested")
2459 .join("path")
2460 .join("db.sqlite");
2461
2462 let config = EncryptionConfig::generate().unwrap();
2463 let storage = MdkSqliteStorage::new_with_key(&db_path, config);
2464 assert!(storage.is_ok());
2465
2466 assert!(db_path.parent().unwrap().exists());
2468 assert!(db_path.exists());
2469
2470 assert!(encryption::is_database_encrypted(&db_path).unwrap());
2472 }
2473
2474 #[test]
2475 fn test_encrypted_unencrypted_incompatibility() {
2476 let temp_dir = tempdir().unwrap();
2477 let db_path = temp_dir.path().join("compat_test.db");
2478
2479 {
2481 let _storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
2482 }
2483
2484 assert!(!encryption::is_database_encrypted(&db_path).unwrap());
2486
2487 let encrypted_path = temp_dir.path().join("compat_encrypted.db");
2489 {
2490 let config = EncryptionConfig::generate().unwrap();
2491 let _storage = MdkSqliteStorage::new_with_key(&encrypted_path, config).unwrap();
2492 }
2493
2494 assert!(encryption::is_database_encrypted(&encrypted_path).unwrap());
2496 }
2497
2498 #[test]
2499 fn test_new_on_unencrypted_database_returns_correct_error() {
2500 ensure_mock_store();
2507
2508 let temp_dir = tempdir().unwrap();
2509 let db_path = temp_dir.path().join("unencrypted_then_new.db");
2510
2511 {
2513 let _storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
2514 }
2515
2516 assert!(!encryption::is_database_encrypted(&db_path).unwrap());
2518
2519 let result = MdkSqliteStorage::new(&db_path, "com.test.app", "test.key.id");
2521
2522 assert!(result.is_err());
2523 match result {
2524 Err(Error::UnencryptedDatabaseWithEncryption) => {
2525 }
2528 Err(Error::KeyringEntryMissingForExistingDatabase { .. }) => {
2529 panic!(
2530 "Got KeyringEntryMissingForExistingDatabase but should have gotten \
2531 UnencryptedDatabaseWithEncryption. The database is unencrypted, not \
2532 encrypted with a missing key."
2533 );
2534 }
2535 Err(other) => {
2536 panic!("Unexpected error: {:?}", other);
2537 }
2538 Ok(_) => {
2539 panic!("Expected an error but got Ok");
2540 }
2541 }
2542 }
2543
2544 #[test]
2545 fn test_new_with_key_on_unencrypted_database_returns_correct_error() {
2546 let temp_dir = tempdir().unwrap();
2551 let db_path = temp_dir.path().join("unencrypted_then_new_with_key.db");
2552
2553 {
2555 let _storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
2556 }
2557
2558 assert!(!encryption::is_database_encrypted(&db_path).unwrap());
2560
2561 let config = EncryptionConfig::generate().unwrap();
2564 let result = MdkSqliteStorage::new_with_key(&db_path, config);
2565
2566 assert!(result.is_err());
2567 match result {
2568 Err(Error::UnencryptedDatabaseWithEncryption) => {
2569 }
2572 Err(Error::WrongEncryptionKey) => {
2573 panic!(
2574 "Got WrongEncryptionKey but should have gotten \
2575 UnencryptedDatabaseWithEncryption. The database is unencrypted, not \
2576 encrypted with a different key."
2577 );
2578 }
2579 Err(other) => {
2580 panic!("Unexpected error: {:?}", other);
2581 }
2582 Ok(_) => {
2583 panic!("Expected an error but got Ok");
2584 }
2585 }
2586 }
2587
2588 #[test]
2589 fn test_encrypted_storage_large_data() {
2590 let temp_dir = tempdir().unwrap();
2591 let db_path = temp_dir.path().join("large_data.db");
2592
2593 let config = EncryptionConfig::generate().unwrap();
2594 let key = *config.key();
2595
2596 let mls_group_id = GroupId::from_slice(&[99; 8]);
2597
2598 let large_content = "x".repeat(10_000);
2600 {
2601 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2602
2603 let mut group = create_test_group(mls_group_id.clone());
2604 group.name = "Large Data Test".to_string();
2605 group.description = "Testing large data".to_string();
2606 storage.save_group(group).unwrap();
2607
2608 let event_id = EventId::all_zeros();
2609 let mut message = create_test_message(mls_group_id.clone(), event_id);
2610 message.content = large_content.clone();
2611 storage.save_message(message).unwrap();
2612 }
2613
2614 let config2 = EncryptionConfig::new(key);
2616 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2617
2618 let messages = storage2.messages(&mls_group_id, None).unwrap();
2619 assert_eq!(messages.len(), 1);
2620 assert_eq!(messages[0].content, large_content);
2621 }
2622
2623 #[test]
2624 fn test_encrypted_storage_concurrent_reads() {
2625 let temp_dir = tempdir().unwrap();
2626 let db_path = temp_dir.path().join("concurrent.db");
2627
2628 let config = EncryptionConfig::generate().unwrap();
2629 let key = *config.key();
2630
2631 let mls_group_id = GroupId::from_slice(&[77; 8]);
2632
2633 {
2635 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2636
2637 let mut group = create_test_group(mls_group_id.clone());
2638 group.name = "Concurrent Test".to_string();
2639 group.description = "Testing concurrent access".to_string();
2640 storage.save_group(group).unwrap();
2641 }
2642
2643 let config1 = EncryptionConfig::new(key);
2645 let config2 = EncryptionConfig::new(key);
2646
2647 let storage1 = MdkSqliteStorage::new_with_key(&db_path, config1).unwrap();
2648 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2649
2650 let group1 = storage1
2652 .find_group_by_mls_group_id(&mls_group_id)
2653 .unwrap()
2654 .unwrap();
2655 let group2 = storage2
2656 .find_group_by_mls_group_id(&mls_group_id)
2657 .unwrap()
2658 .unwrap();
2659
2660 assert_eq!(group1.name, group2.name);
2661 }
2662
2663 #[cfg(unix)]
2664 #[test]
2665 fn test_encrypted_storage_sidecar_file_permissions() {
2666 let temp_dir = tempdir().unwrap();
2667 let db_path = temp_dir.path().join("sidecar_test.db");
2668
2669 let config = EncryptionConfig::generate().unwrap();
2670 let key = *config.key();
2671
2672 {
2674 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2675
2676 for i in 0..10 {
2678 let mls_group_id = GroupId::from_slice(&[i; 8]);
2679 let mut group = create_test_group(mls_group_id);
2680 group.name = format!("Group {}", i);
2681 group.description = format!("Description {}", i);
2682 storage.save_group(group).unwrap();
2683 }
2684 }
2685
2686 let config2 = EncryptionConfig::new(key);
2688 let _storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2689
2690 let db_metadata = std::fs::metadata(&db_path).unwrap();
2692 let db_mode = db_metadata.permissions().mode();
2693 assert_eq!(
2694 db_mode & 0o077,
2695 0,
2696 "Database file should have owner-only permissions, got {:o}",
2697 db_mode & 0o777
2698 );
2699
2700 let sidecar_suffixes = ["-wal", "-shm", "-journal"];
2702 for suffix in &sidecar_suffixes {
2703 let sidecar_path = temp_dir.path().join(format!("sidecar_test.db{}", suffix));
2704 if sidecar_path.exists() {
2705 let metadata = std::fs::metadata(&sidecar_path).unwrap();
2706 let mode = metadata.permissions().mode();
2707 assert_eq!(
2708 mode & 0o077,
2709 0,
2710 "Sidecar file {} should have owner-only permissions, got {:o}",
2711 suffix,
2712 mode & 0o777
2713 );
2714 }
2715 }
2716 }
2717
2718 #[test]
2719 fn test_encryption_config_key_is_accessible() {
2720 let key = [0xDE; 32];
2721 let config = EncryptionConfig::new(key);
2722
2723 assert_eq!(config.key().len(), 32);
2725 assert_eq!(config.key()[0], 0xDE);
2726 assert_eq!(config.key()[31], 0xDE);
2727 }
2728
2729 #[test]
2730 fn test_encrypted_storage_empty_group_name() {
2731 let temp_dir = tempdir().unwrap();
2732 let db_path = temp_dir.path().join("empty_name.db");
2733
2734 let config = EncryptionConfig::generate().unwrap();
2735 let key = *config.key();
2736
2737 let mls_group_id = GroupId::from_slice(&[0xAB; 8]);
2738
2739 {
2741 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2742
2743 let mut group = create_test_group(mls_group_id.clone());
2744 group.name = String::new();
2745 group.description = String::new();
2746 storage.save_group(group).unwrap();
2747 }
2748
2749 let config2 = EncryptionConfig::new(key);
2751 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2752
2753 let group = storage2
2754 .find_group_by_mls_group_id(&mls_group_id)
2755 .unwrap()
2756 .unwrap();
2757 assert!(group.name.is_empty());
2758 assert!(group.description.is_empty());
2759 }
2760
2761 #[test]
2762 fn test_encrypted_storage_unicode_content() {
2763 let temp_dir = tempdir().unwrap();
2764 let db_path = temp_dir.path().join("unicode.db");
2765
2766 let config = EncryptionConfig::generate().unwrap();
2767 let key = *config.key();
2768
2769 let mls_group_id = GroupId::from_slice(&[0xCD; 8]);
2770 let unicode_content = "Hello 世界! 🎉 Ñoño مرحبا Привет 日本語 한국어 ελληνικά";
2771
2772 {
2774 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2775
2776 let mut group = create_test_group(mls_group_id.clone());
2777 group.name = "Тест группа 测试组".to_string();
2778 group.description = "描述 описание".to_string();
2779 storage.save_group(group).unwrap();
2780
2781 let event_id = EventId::all_zeros();
2782 let mut message = create_test_message(mls_group_id.clone(), event_id);
2783 message.content = unicode_content.to_string();
2784 storage.save_message(message).unwrap();
2785 }
2786
2787 let config2 = EncryptionConfig::new(key);
2789 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2790
2791 let group = storage2
2792 .find_group_by_mls_group_id(&mls_group_id)
2793 .unwrap()
2794 .unwrap();
2795 assert_eq!(group.name, "Тест группа 测试组");
2796 assert_eq!(group.description, "描述 описание");
2797
2798 let messages = storage2.messages(&mls_group_id, None).unwrap();
2799 assert_eq!(messages[0].content, unicode_content);
2800 }
2801
2802 #[test]
2807 fn test_existing_db_with_missing_keyring_entry_fails() {
2808 ensure_mock_store();
2809
2810 let temp_dir = tempdir().unwrap();
2811 let db_path = temp_dir.path().join("missing_key_test.db");
2812
2813 let service_id = "test.mdk.storage.missingkey";
2814 let db_key_id = "test.key.missingkeytest";
2815
2816 let _ = keyring::delete_db_key(service_id, db_key_id);
2818
2819 {
2821 let storage = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2822 assert!(storage.is_ok(), "Should create database successfully");
2823 }
2824
2825 assert!(db_path.exists(), "Database file should exist");
2827
2828 keyring::delete_db_key(service_id, db_key_id).unwrap();
2830
2831 let key_check = keyring::get_db_key(service_id, db_key_id).unwrap();
2833 assert!(key_check.is_none(), "Key should be deleted");
2834
2835 let result = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2838
2839 assert!(result.is_err(), "Should fail when keyring entry is missing");
2840
2841 match result {
2842 Err(error::Error::KeyringEntryMissingForExistingDatabase {
2843 db_path: err_path,
2844 service_id: err_service,
2845 db_key_id: err_key,
2846 }) => {
2847 assert!(
2848 err_path.contains("missing_key_test.db"),
2849 "Error should contain database path"
2850 );
2851 assert_eq!(err_service, service_id);
2852 assert_eq!(err_key, db_key_id);
2853 }
2854 Err(e) => panic!(
2855 "Expected KeyringEntryMissingForExistingDatabase error, got: {:?}",
2856 e
2857 ),
2858 Ok(_) => panic!("Expected error but got success"),
2859 }
2860
2861 let key_after = keyring::get_db_key(service_id, db_key_id).unwrap();
2863 assert!(
2864 key_after.is_none(),
2865 "No new key should have been stored in keyring"
2866 );
2867 }
2868
2869 #[test]
2871 fn test_new_db_with_keyring_creates_key() {
2872 ensure_mock_store();
2873
2874 let temp_dir = tempdir().unwrap();
2875 let db_path = temp_dir.path().join("new_db_keyring.db");
2876
2877 let service_id = "test.mdk.storage.newdb";
2878 let db_key_id = "test.key.newdbtest";
2879
2880 let _ = keyring::delete_db_key(service_id, db_key_id);
2882
2883 assert!(!db_path.exists(), "Database should not exist yet");
2885
2886 let storage = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2888 assert!(storage.is_ok(), "Should create database successfully");
2889
2890 assert!(db_path.exists(), "Database file should exist");
2892
2893 let key = keyring::get_db_key(service_id, db_key_id).unwrap();
2895 assert!(key.is_some(), "Key should be stored in keyring");
2896
2897 assert!(
2899 encryption::is_database_encrypted(&db_path).unwrap(),
2900 "Database should be encrypted"
2901 );
2902
2903 drop(storage);
2905 keyring::delete_db_key(service_id, db_key_id).unwrap();
2906 }
2907
2908 #[test]
2910 fn test_reopen_db_with_keyring_succeeds() {
2911 ensure_mock_store();
2912
2913 let temp_dir = tempdir().unwrap();
2914 let db_path = temp_dir.path().join("reopen_keyring.db");
2915
2916 let service_id = "test.mdk.storage.reopen";
2917 let db_key_id = "test.key.reopentest";
2918
2919 let _ = keyring::delete_db_key(service_id, db_key_id);
2921
2922 let mls_group_id = GroupId::from_slice(&[0xAA; 8]);
2923
2924 {
2926 let storage = MdkSqliteStorage::new(&db_path, service_id, db_key_id).unwrap();
2927
2928 let mut group = create_test_group(mls_group_id.clone());
2929 group.name = "Keyring Reopen Test".to_string();
2930 storage.save_group(group).unwrap();
2931 }
2932
2933 let storage2 = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2935 assert!(storage2.is_ok(), "Should reopen database successfully");
2936
2937 let storage2 = storage2.unwrap();
2939 let group = storage2
2940 .find_group_by_mls_group_id(&mls_group_id)
2941 .unwrap()
2942 .unwrap();
2943 assert_eq!(group.name, "Keyring Reopen Test");
2944
2945 drop(storage2);
2947 keyring::delete_db_key(service_id, db_key_id).unwrap();
2948 }
2949
2950 #[test]
2952 fn test_concurrent_encrypted_access_same_key() {
2953 let temp_dir = tempdir().unwrap();
2954 let db_path = temp_dir.path().join("concurrent_encrypted.db");
2955
2956 let config = EncryptionConfig::generate().unwrap();
2957 let key = *config.key();
2958
2959 {
2961 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2962 let group = create_test_group(GroupId::from_slice(&[1, 2, 3, 4]));
2963 storage.save_group(group).unwrap();
2964 }
2965
2966 let num_threads = 5;
2968 let handles: Vec<_> = (0..num_threads)
2969 .map(|_| {
2970 let db_path = db_path.clone();
2971 thread::spawn(move || {
2972 let config = EncryptionConfig::new(key);
2973 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2974 let groups = storage.all_groups().unwrap();
2975 assert_eq!(groups.len(), 1);
2976 groups
2977 })
2978 })
2979 .collect();
2980
2981 for handle in handles {
2983 let groups = handle.join().unwrap();
2984 assert_eq!(groups.len(), 1);
2985 }
2986 }
2987
2988 #[test]
2990 fn test_multiple_encrypted_databases_different_keys() {
2991 let temp_dir = tempdir().unwrap();
2992
2993 let db1_path = temp_dir.path().join("db1.db");
2995 let db2_path = temp_dir.path().join("db2.db");
2996 let db3_path = temp_dir.path().join("db3.db");
2997
2998 let config1 = EncryptionConfig::generate().unwrap();
2999 let config2 = EncryptionConfig::generate().unwrap();
3000 let config3 = EncryptionConfig::generate().unwrap();
3001
3002 let key1 = *config1.key();
3003 let key2 = *config2.key();
3004 let key3 = *config3.key();
3005
3006 {
3008 let storage1 = MdkSqliteStorage::new_with_key(&db1_path, config1).unwrap();
3009 let mut group1 = create_test_group(GroupId::from_slice(&[1]));
3010 group1.name = "Database 1".to_string();
3011 storage1.save_group(group1).unwrap();
3012
3013 let storage2 = MdkSqliteStorage::new_with_key(&db2_path, config2).unwrap();
3014 let mut group2 = create_test_group(GroupId::from_slice(&[2]));
3015 group2.name = "Database 2".to_string();
3016 storage2.save_group(group2).unwrap();
3017
3018 let storage3 = MdkSqliteStorage::new_with_key(&db3_path, config3).unwrap();
3019 let mut group3 = create_test_group(GroupId::from_slice(&[3]));
3020 group3.name = "Database 3".to_string();
3021 storage3.save_group(group3).unwrap();
3022 }
3023
3024 let config1_reopen = EncryptionConfig::new(key1);
3026 let config2_reopen = EncryptionConfig::new(key2);
3027 let config3_reopen = EncryptionConfig::new(key3);
3028
3029 let storage1 = MdkSqliteStorage::new_with_key(&db1_path, config1_reopen).unwrap();
3030 let storage2 = MdkSqliteStorage::new_with_key(&db2_path, config2_reopen).unwrap();
3031 let storage3 = MdkSqliteStorage::new_with_key(&db3_path, config3_reopen).unwrap();
3032
3033 let group1 = storage1
3035 .find_group_by_mls_group_id(&GroupId::from_slice(&[1]))
3036 .unwrap()
3037 .unwrap();
3038 assert_eq!(group1.name, "Database 1");
3039
3040 let group2 = storage2
3041 .find_group_by_mls_group_id(&GroupId::from_slice(&[2]))
3042 .unwrap()
3043 .unwrap();
3044 assert_eq!(group2.name, "Database 2");
3045
3046 let group3 = storage3
3047 .find_group_by_mls_group_id(&GroupId::from_slice(&[3]))
3048 .unwrap()
3049 .unwrap();
3050 assert_eq!(group3.name, "Database 3");
3051
3052 let wrong_config = EncryptionConfig::new(key1);
3054 let result = MdkSqliteStorage::new_with_key(&db2_path, wrong_config);
3055 assert!(result.is_err());
3056 }
3057 }
3058
3059 mod migration_tests {
3064 use super::*;
3065
3066 #[test]
3067 fn test_fresh_database_has_all_tables() {
3068 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3069
3070 let expected_mdk_tables = [
3072 "groups",
3073 "group_relays",
3074 "group_exporter_secrets",
3075 "messages",
3076 "processed_messages",
3077 "welcomes",
3078 "processed_welcomes",
3079 ];
3080
3081 let expected_openmls_tables = [
3083 "openmls_group_data",
3084 "openmls_proposals",
3085 "openmls_own_leaf_nodes",
3086 "openmls_key_packages",
3087 "openmls_psks",
3088 "openmls_signature_keys",
3089 "openmls_encryption_keys",
3090 "openmls_epoch_key_pairs",
3091 ];
3092
3093 storage.with_connection(|conn| {
3094 let mut stmt = conn
3096 .prepare(
3097 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
3098 )
3099 .unwrap();
3100 let table_names: Vec<String> = stmt
3101 .query_map([], |row| row.get(0))
3102 .unwrap()
3103 .map(|r| r.unwrap())
3104 .collect();
3105
3106 for table in &expected_mdk_tables {
3108 assert!(
3109 table_names.contains(&table.to_string()),
3110 "Missing MDK table: {}",
3111 table
3112 );
3113 }
3114
3115 for table in &expected_openmls_tables {
3117 assert!(
3118 table_names.contains(&table.to_string()),
3119 "Missing OpenMLS table: {}",
3120 table
3121 );
3122 }
3123 });
3124 }
3125
3126 #[test]
3127 fn test_all_indexes_exist() {
3128 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3129
3130 let expected_indexes = [
3132 "idx_groups_nostr_group_id",
3133 "idx_group_relays_mls_group_id",
3134 "idx_group_exporter_secrets_mls_group_id",
3135 "idx_messages_mls_group_id",
3136 "idx_messages_wrapper_event_id",
3137 "idx_messages_created_at",
3138 "idx_messages_pubkey",
3139 "idx_messages_kind",
3140 "idx_messages_state",
3141 "idx_processed_messages_message_event_id",
3142 "idx_processed_messages_state",
3143 "idx_processed_messages_processed_at",
3144 "idx_welcomes_mls_group_id",
3145 "idx_welcomes_wrapper_event_id",
3146 "idx_welcomes_state",
3147 "idx_welcomes_nostr_group_id",
3148 "idx_processed_welcomes_welcome_event_id",
3149 "idx_processed_welcomes_state",
3150 "idx_processed_welcomes_processed_at",
3151 ];
3152
3153 storage.with_connection(|conn| {
3154 let mut stmt = conn
3155 .prepare("SELECT name FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%'")
3156 .unwrap();
3157 let index_names: Vec<String> = stmt
3158 .query_map([], |row| row.get(0))
3159 .unwrap()
3160 .map(|r| r.unwrap())
3161 .collect();
3162
3163 for idx in &expected_indexes {
3164 assert!(
3165 index_names.contains(&idx.to_string()),
3166 "Missing index: {}. Found indexes: {:?}",
3167 idx,
3168 index_names
3169 );
3170 }
3171 });
3172 }
3173
3174 #[test]
3175 fn test_foreign_key_constraints_work() {
3176 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3177
3178 storage.with_connection(|conn| {
3179 let fk_enabled: i32 = conn
3181 .query_row("PRAGMA foreign_keys", [], |row| row.get(0))
3182 .unwrap();
3183 assert_eq!(fk_enabled, 1, "Foreign keys should be enabled");
3184
3185 let result = conn.execute(
3187 "INSERT INTO group_relays (mls_group_id, relay_url) VALUES (?, ?)",
3188 rusqlite::params![vec![1u8, 2u8, 3u8, 4u8], "wss://relay.example.com"],
3189 );
3190 assert!(result.is_err(), "Should fail due to foreign key constraint");
3191 });
3192 }
3193
3194 #[test]
3195 fn test_openmls_group_data_check_constraint() {
3196 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3197
3198 storage.with_connection(|conn| {
3199 let valid_result = conn.execute(
3201 "INSERT INTO openmls_group_data (provider_version, group_id, data_type, group_data) VALUES (?, ?, ?, ?)",
3202 rusqlite::params![1, vec![1u8, 2u8, 3u8], "tree", vec![4u8, 5u8, 6u8]],
3203 );
3204 assert!(valid_result.is_ok(), "Valid data_type should succeed");
3205
3206 let invalid_result = conn.execute(
3208 "INSERT INTO openmls_group_data (provider_version, group_id, data_type, group_data) VALUES (?, ?, ?, ?)",
3209 rusqlite::params![1, vec![7u8, 8u8, 9u8], "invalid_type", vec![10u8, 11u8]],
3210 );
3211 assert!(
3212 invalid_result.is_err(),
3213 "Invalid data_type should fail CHECK constraint"
3214 );
3215 });
3216 }
3217
3218 #[test]
3219 fn test_schema_matches_plan_specification() {
3220 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3221
3222 storage.with_connection(|conn| {
3223 let groups_info: Vec<(String, String)> = conn
3225 .prepare("PRAGMA table_info(groups)")
3226 .unwrap()
3227 .query_map([], |row| Ok((row.get(1)?, row.get(2)?)))
3228 .unwrap()
3229 .map(|r| r.unwrap())
3230 .collect();
3231
3232 let groups_columns: Vec<&str> =
3233 groups_info.iter().map(|(n, _)| n.as_str()).collect();
3234 assert!(groups_columns.contains(&"mls_group_id"));
3235 assert!(groups_columns.contains(&"nostr_group_id"));
3236 assert!(groups_columns.contains(&"name"));
3237 assert!(groups_columns.contains(&"description"));
3238 assert!(groups_columns.contains(&"admin_pubkeys"));
3239 assert!(groups_columns.contains(&"epoch"));
3240 assert!(groups_columns.contains(&"state"));
3241
3242 let messages_info: Vec<String> = conn
3244 .prepare("PRAGMA table_info(messages)")
3245 .unwrap()
3246 .query_map([], |row| row.get(1))
3247 .unwrap()
3248 .map(|r| r.unwrap())
3249 .collect();
3250
3251 assert!(messages_info.contains(&"mls_group_id".to_string()));
3252 assert!(messages_info.contains(&"id".to_string()));
3253 assert!(messages_info.contains(&"pubkey".to_string()));
3254 assert!(messages_info.contains(&"kind".to_string()));
3255 assert!(messages_info.contains(&"created_at".to_string()));
3256 assert!(messages_info.contains(&"content".to_string()));
3257 assert!(messages_info.contains(&"wrapper_event_id".to_string()));
3258 });
3259 }
3260 }
3261
3262 mod snapshot_tests {
3267 use std::collections::BTreeSet;
3268
3269 use mdk_storage_traits::groups::GroupStorage;
3270 use mdk_storage_traits::groups::types::{
3271 Group, GroupExporterSecret, GroupState, SelfUpdateState,
3272 };
3273 use mdk_storage_traits::{GroupId, MdkStorageProvider, Secret};
3274
3275 use super::*;
3276
3277 fn create_test_group(id: u8) -> Group {
3278 Group {
3279 mls_group_id: GroupId::from_slice(&[id; 32]),
3280 nostr_group_id: [id; 32],
3281 name: format!("Test Group {}", id),
3282 description: format!("Description {}", id),
3283 admin_pubkeys: BTreeSet::new(),
3284 last_message_id: None,
3285 last_message_at: None,
3286 last_message_processed_at: None,
3287 epoch: 0,
3288 state: GroupState::Active,
3289 image_hash: None,
3290 image_key: None,
3291 image_nonce: None,
3292 self_update_state: SelfUpdateState::Required,
3293 }
3294 }
3295
3296 #[test]
3297 fn test_snapshot_and_rollback_group_state() {
3298 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3299
3300 let group = create_test_group(1);
3302 let group_id = group.mls_group_id.clone();
3303 storage.save_group(group).unwrap();
3304
3305 let initial_group = storage
3307 .find_group_by_mls_group_id(&group_id)
3308 .unwrap()
3309 .unwrap();
3310 assert_eq!(initial_group.name, "Test Group 1");
3311 assert_eq!(initial_group.epoch, 0);
3312
3313 storage
3315 .create_group_snapshot(&group_id, "snap_epoch_0")
3316 .unwrap();
3317
3318 let mut modified_group = initial_group.clone();
3320 modified_group.name = "Modified Group".to_string();
3321 modified_group.epoch = 1;
3322 storage.save_group(modified_group).unwrap();
3323
3324 let after_mod = storage
3326 .find_group_by_mls_group_id(&group_id)
3327 .unwrap()
3328 .unwrap();
3329 assert_eq!(after_mod.name, "Modified Group");
3330 assert_eq!(after_mod.epoch, 1);
3331
3332 storage
3334 .rollback_group_to_snapshot(&group_id, "snap_epoch_0")
3335 .unwrap();
3336
3337 let after_rollback = storage
3339 .find_group_by_mls_group_id(&group_id)
3340 .unwrap()
3341 .unwrap();
3342 assert_eq!(after_rollback.name, "Test Group 1");
3343 assert_eq!(after_rollback.epoch, 0);
3344 }
3345
3346 #[test]
3347 fn test_snapshot_release_without_rollback() {
3348 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3349
3350 let group = create_test_group(2);
3352 let group_id = group.mls_group_id.clone();
3353 storage.save_group(group).unwrap();
3354
3355 storage
3357 .create_group_snapshot(&group_id, "snap_to_release")
3358 .unwrap();
3359
3360 let mut modified = storage
3362 .find_group_by_mls_group_id(&group_id)
3363 .unwrap()
3364 .unwrap();
3365 modified.name = "Modified Name".to_string();
3366 storage.save_group(modified).unwrap();
3367
3368 storage
3370 .release_group_snapshot(&group_id, "snap_to_release")
3371 .unwrap();
3372
3373 let final_state = storage
3375 .find_group_by_mls_group_id(&group_id)
3376 .unwrap()
3377 .unwrap();
3378 assert_eq!(final_state.name, "Modified Name");
3379 }
3380
3381 #[test]
3382 fn test_snapshot_with_exporter_secrets_rollback() {
3383 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3384
3385 let group = create_test_group(3);
3387 let group_id = group.mls_group_id.clone();
3388 storage.save_group(group).unwrap();
3389
3390 let secret_0 = GroupExporterSecret {
3392 mls_group_id: group_id.clone(),
3393 epoch: 0,
3394 secret: Secret::new([0u8; 32]),
3395 };
3396 storage.save_group_exporter_secret(secret_0).unwrap();
3397
3398 storage
3400 .create_group_snapshot(&group_id, "snap_secrets")
3401 .unwrap();
3402
3403 let secret_1 = GroupExporterSecret {
3405 mls_group_id: group_id.clone(),
3406 epoch: 1,
3407 secret: Secret::new([1u8; 32]),
3408 };
3409 storage.save_group_exporter_secret(secret_1).unwrap();
3410
3411 let secret_check = storage.get_group_exporter_secret(&group_id, 1).unwrap();
3413 assert!(secret_check.is_some());
3414
3415 storage
3417 .rollback_group_to_snapshot(&group_id, "snap_secrets")
3418 .unwrap();
3419
3420 let after_rollback = storage.get_group_exporter_secret(&group_id, 1).unwrap();
3422 assert!(after_rollback.is_none());
3423
3424 let epoch_0 = storage.get_group_exporter_secret(&group_id, 0).unwrap();
3426 assert!(epoch_0.is_some());
3427 }
3428
3429 #[test]
3430 fn test_snapshot_isolation_between_groups() {
3431 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3432
3433 let group1 = create_test_group(10);
3435 let group2 = create_test_group(20);
3436 let group1_id = group1.mls_group_id.clone();
3437 let group2_id = group2.mls_group_id.clone();
3438
3439 storage.save_group(group1).unwrap();
3440 storage.save_group(group2).unwrap();
3441
3442 storage
3444 .create_group_snapshot(&group1_id, "snap_group1")
3445 .unwrap();
3446
3447 let mut mod1 = storage
3449 .find_group_by_mls_group_id(&group1_id)
3450 .unwrap()
3451 .unwrap();
3452 let mut mod2 = storage
3453 .find_group_by_mls_group_id(&group2_id)
3454 .unwrap()
3455 .unwrap();
3456 mod1.name = "Modified Group 1".to_string();
3457 mod2.name = "Modified Group 2".to_string();
3458 storage.save_group(mod1).unwrap();
3459 storage.save_group(mod2).unwrap();
3460
3461 storage
3463 .rollback_group_to_snapshot(&group1_id, "snap_group1")
3464 .unwrap();
3465
3466 let final1 = storage
3468 .find_group_by_mls_group_id(&group1_id)
3469 .unwrap()
3470 .unwrap();
3471 assert_eq!(final1.name, "Test Group 10");
3472
3473 let final2 = storage
3475 .find_group_by_mls_group_id(&group2_id)
3476 .unwrap()
3477 .unwrap();
3478 assert_eq!(final2.name, "Modified Group 2");
3479 }
3480
3481 #[test]
3482 fn test_rollback_nonexistent_snapshot_returns_error() {
3483 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3484
3485 let group = create_test_group(5);
3486 let group_id = group.mls_group_id.clone();
3487 storage.save_group(group.clone()).unwrap();
3488
3489 let result = storage.rollback_group_to_snapshot(&group_id, "nonexistent_snap");
3494 assert!(
3495 result.is_err(),
3496 "Rollback to nonexistent snapshot should return an error"
3497 );
3498
3499 let after_rollback = storage.find_group_by_mls_group_id(&group_id).unwrap();
3501 assert!(
3502 after_rollback.is_some(),
3503 "Group should NOT be deleted when rolling back to nonexistent snapshot"
3504 );
3505 assert_eq!(
3506 after_rollback.unwrap().epoch,
3507 group.epoch,
3508 "Group data should be unchanged"
3509 );
3510 }
3511
3512 #[test]
3513 fn test_release_nonexistent_snapshot_succeeds() {
3514 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3515
3516 let group = create_test_group(6);
3517 let group_id = group.mls_group_id.clone();
3518 storage.save_group(group).unwrap();
3519
3520 let result = storage.release_group_snapshot(&group_id, "nonexistent_snap");
3522 assert!(result.is_ok());
3524 }
3525
3526 #[test]
3527 fn test_multiple_snapshots_same_group() {
3528 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3529
3530 let group = create_test_group(7);
3531 let group_id = group.mls_group_id.clone();
3532 storage.save_group(group).unwrap();
3533
3534 storage
3536 .create_group_snapshot(&group_id, "snap_epoch_0")
3537 .unwrap();
3538
3539 let mut mod1 = storage
3541 .find_group_by_mls_group_id(&group_id)
3542 .unwrap()
3543 .unwrap();
3544 mod1.epoch = 1;
3545 mod1.name = "Epoch 1".to_string();
3546 storage.save_group(mod1).unwrap();
3547
3548 storage
3550 .create_group_snapshot(&group_id, "snap_epoch_1")
3551 .unwrap();
3552
3553 let mut mod2 = storage
3555 .find_group_by_mls_group_id(&group_id)
3556 .unwrap()
3557 .unwrap();
3558 mod2.epoch = 2;
3559 mod2.name = "Epoch 2".to_string();
3560 storage.save_group(mod2).unwrap();
3561
3562 storage
3564 .rollback_group_to_snapshot(&group_id, "snap_epoch_1")
3565 .unwrap();
3566
3567 let after_rollback = storage
3568 .find_group_by_mls_group_id(&group_id)
3569 .unwrap()
3570 .unwrap();
3571 assert_eq!(after_rollback.epoch, 1);
3572 assert_eq!(after_rollback.name, "Epoch 1");
3573
3574 storage
3576 .rollback_group_to_snapshot(&group_id, "snap_epoch_0")
3577 .unwrap();
3578
3579 let final_state = storage
3580 .find_group_by_mls_group_id(&group_id)
3581 .unwrap()
3582 .unwrap();
3583 assert_eq!(final_state.epoch, 0);
3584 assert_eq!(final_state.name, "Test Group 7");
3585 }
3586
3587 #[test]
3588 fn test_list_group_snapshots_empty() {
3589 use mdk_storage_traits::MdkStorageProvider;
3590
3591 let temp_dir = tempdir().unwrap();
3592 let db_path = temp_dir.path().join("list_snapshots_empty.db");
3593 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3594
3595 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
3596
3597 let snapshots = storage.list_group_snapshots(&group_id).unwrap();
3598 assert!(
3599 snapshots.is_empty(),
3600 "Should return empty list for no snapshots"
3601 );
3602 }
3603
3604 #[test]
3605 fn test_list_group_snapshots_returns_snapshots_sorted_by_created_at() {
3606 use mdk_storage_traits::MdkStorageProvider;
3607
3608 let temp_dir = tempdir().unwrap();
3609 let db_path = temp_dir.path().join("list_snapshots_sorted.db");
3610 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3611
3612 let group_id = GroupId::from_slice(&[8; 32]);
3613 let nostr_group_id: [u8; 32] = [9; 32];
3614
3615 let group = Group {
3617 mls_group_id: group_id.clone(),
3618 nostr_group_id,
3619 name: "Test Group".to_string(),
3620 description: "".to_string(),
3621 admin_pubkeys: BTreeSet::new(),
3622 last_message_id: None,
3623 last_message_at: None,
3624 last_message_processed_at: None,
3625 epoch: 1,
3626 state: GroupState::Active,
3627 image_hash: None,
3628 image_key: None,
3629 image_nonce: None,
3630 self_update_state: SelfUpdateState::Required,
3631 };
3632 storage.save_group(group).unwrap();
3633
3634 storage
3636 .create_group_snapshot(&group_id, "snap_first")
3637 .unwrap();
3638 std::thread::sleep(std::time::Duration::from_millis(10));
3639 storage
3640 .create_group_snapshot(&group_id, "snap_second")
3641 .unwrap();
3642 std::thread::sleep(std::time::Duration::from_millis(10));
3643 storage
3644 .create_group_snapshot(&group_id, "snap_third")
3645 .unwrap();
3646
3647 let result = storage.list_group_snapshots(&group_id).unwrap();
3648
3649 assert_eq!(result.len(), 3);
3650 assert_eq!(result[0].0, "snap_first");
3652 assert_eq!(result[1].0, "snap_second");
3653 assert_eq!(result[2].0, "snap_third");
3654 assert!(result[0].1 <= result[1].1);
3656 assert!(result[1].1 <= result[2].1);
3657 }
3658
3659 #[test]
3660 fn test_list_group_snapshots_only_returns_matching_group() {
3661 use mdk_storage_traits::MdkStorageProvider;
3662
3663 let temp_dir = tempdir().unwrap();
3664 let db_path = temp_dir.path().join("list_snapshots_filtered.db");
3665 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3666
3667 let group1 = GroupId::from_slice(&[1; 32]);
3668 let group2 = GroupId::from_slice(&[2; 32]);
3669
3670 let g1 = Group {
3672 mls_group_id: group1.clone(),
3673 nostr_group_id: [11; 32],
3674 name: "Group 1".to_string(),
3675 description: "".to_string(),
3676 admin_pubkeys: BTreeSet::new(),
3677 last_message_id: None,
3678 last_message_at: None,
3679 last_message_processed_at: None,
3680 epoch: 1,
3681 state: GroupState::Active,
3682 image_hash: None,
3683 image_key: None,
3684 image_nonce: None,
3685 self_update_state: SelfUpdateState::Required,
3686 };
3687 let g2 = Group {
3688 mls_group_id: group2.clone(),
3689 nostr_group_id: [22; 32],
3690 name: "Group 2".to_string(),
3691 ..g1.clone()
3692 };
3693 storage.save_group(g1).unwrap();
3694 storage.save_group(g2).unwrap();
3695
3696 storage.create_group_snapshot(&group1, "snap_g1").unwrap();
3698 storage.create_group_snapshot(&group2, "snap_g2").unwrap();
3699
3700 let result1 = storage.list_group_snapshots(&group1).unwrap();
3701 let result2 = storage.list_group_snapshots(&group2).unwrap();
3702
3703 assert_eq!(result1.len(), 1);
3704 assert_eq!(result1[0].0, "snap_g1");
3705
3706 assert_eq!(result2.len(), 1);
3707 assert_eq!(result2[0].0, "snap_g2");
3708 }
3709
3710 #[test]
3711 fn test_prune_expired_snapshots_removes_old_snapshots() {
3712 use mdk_storage_traits::MdkStorageProvider;
3713
3714 let temp_dir = tempdir().unwrap();
3715 let db_path = temp_dir.path().join("prune_expired.db");
3716 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3717
3718 let group_id = GroupId::from_slice(&[3; 32]);
3719
3720 let group = Group {
3721 mls_group_id: group_id.clone(),
3722 nostr_group_id: [33; 32],
3723 name: "Test Group".to_string(),
3724 description: "".to_string(),
3725 admin_pubkeys: BTreeSet::new(),
3726 last_message_id: None,
3727 last_message_at: None,
3728 last_message_processed_at: None,
3729 epoch: 1,
3730 state: GroupState::Active,
3731 image_hash: None,
3732 image_key: None,
3733 image_nonce: None,
3734 self_update_state: SelfUpdateState::Required,
3735 };
3736 storage.save_group(group).unwrap();
3737
3738 storage
3740 .create_group_snapshot(&group_id, "old_snap")
3741 .unwrap();
3742
3743 let snapshots_before = storage.list_group_snapshots(&group_id).unwrap();
3745 assert_eq!(snapshots_before.len(), 1);
3746 let old_ts = snapshots_before[0].1;
3747
3748 let pruned = storage.prune_expired_snapshots(old_ts + 1).unwrap();
3750 assert_eq!(pruned, 1, "Should have pruned 1 snapshot");
3751
3752 let remaining = storage.list_group_snapshots(&group_id).unwrap();
3753 assert!(remaining.is_empty());
3754 }
3755
3756 #[test]
3757 fn test_prune_expired_snapshots_keeps_recent_snapshots() {
3758 use mdk_storage_traits::MdkStorageProvider;
3759
3760 let temp_dir = tempdir().unwrap();
3761 let db_path = temp_dir.path().join("prune_keeps_recent.db");
3762 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3763
3764 let group_id = GroupId::from_slice(&[4; 32]);
3765
3766 let group = Group {
3767 mls_group_id: group_id.clone(),
3768 nostr_group_id: [44; 32],
3769 name: "Test Group".to_string(),
3770 description: "".to_string(),
3771 admin_pubkeys: BTreeSet::new(),
3772 last_message_id: None,
3773 last_message_at: None,
3774 last_message_processed_at: None,
3775 epoch: 1,
3776 state: GroupState::Active,
3777 image_hash: None,
3778 image_key: None,
3779 image_nonce: None,
3780 self_update_state: SelfUpdateState::Required,
3781 };
3782 storage.save_group(group).unwrap();
3783
3784 storage
3786 .create_group_snapshot(&group_id, "recent_snap")
3787 .unwrap();
3788
3789 let pruned = storage.prune_expired_snapshots(0).unwrap();
3791 assert_eq!(pruned, 0, "Should have pruned 0 snapshots");
3792
3793 let remaining = storage.list_group_snapshots(&group_id).unwrap();
3794 assert_eq!(remaining.len(), 1);
3795 assert_eq!(remaining[0].0, "recent_snap");
3796 }
3797
3798 #[test]
3799 fn test_prune_expired_snapshots_with_cascade_delete() {
3800 use mdk_storage_traits::MdkStorageProvider;
3803
3804 let temp_dir = tempdir().unwrap();
3805 let db_path = temp_dir.path().join("prune_cascade.db");
3806 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3807
3808 let group_id = GroupId::from_slice(&[5; 32]);
3809
3810 let group = Group {
3811 mls_group_id: group_id.clone(),
3812 nostr_group_id: [55; 32],
3813 name: "Test Group".to_string(),
3814 description: "".to_string(),
3815 admin_pubkeys: BTreeSet::new(),
3816 last_message_id: None,
3817 last_message_at: None,
3818 last_message_processed_at: None,
3819 epoch: 1,
3820 state: GroupState::Active,
3821 image_hash: None,
3822 image_key: None,
3823 image_nonce: None,
3824 self_update_state: SelfUpdateState::Required,
3825 };
3826 storage.save_group(group).unwrap();
3827
3828 storage
3830 .create_group_snapshot(&group_id, "to_prune")
3831 .unwrap();
3832
3833 let before = storage.list_group_snapshots(&group_id).unwrap();
3835 assert_eq!(before.len(), 1);
3836
3837 let ts = before[0].1;
3839 let pruned = storage.prune_expired_snapshots(ts + 1).unwrap();
3840 assert_eq!(pruned, 1);
3841
3842 let after = storage.list_group_snapshots(&group_id).unwrap();
3844 assert!(after.is_empty());
3845
3846 let rollback_result = storage.rollback_group_to_snapshot(&group_id, "to_prune");
3848 assert!(rollback_result.is_err());
3849 }
3850 }
3851
3852 mod snapshot_openmls_tests {
3863 use std::collections::BTreeSet;
3864
3865 use mdk_storage_traits::groups::GroupStorage;
3866 use mdk_storage_traits::groups::types::{Group, GroupState};
3867 use mdk_storage_traits::mls_codec::MlsCodec;
3868 use mdk_storage_traits::{GroupId, MdkStorageProvider};
3869 use rusqlite::params;
3870
3871 use super::*;
3872
3873 fn create_test_group(id: u8) -> Group {
3875 Group {
3876 mls_group_id: GroupId::from_slice(&[id; 32]),
3877 nostr_group_id: [id; 32],
3878 name: format!("Test Group {}", id),
3879 description: format!("Description {}", id),
3880 admin_pubkeys: BTreeSet::new(),
3881 last_message_id: None,
3882 last_message_at: None,
3883 last_message_processed_at: None,
3884 epoch: 0,
3885 state: GroupState::Active,
3886 image_hash: None,
3887 image_key: None,
3888 image_nonce: None,
3889 self_update_state: SelfUpdateState::Required,
3890 }
3891 }
3892
3893 fn count_openmls_rows(storage: &MdkSqliteStorage, table: &str, group_id: &GroupId) -> i64 {
3896 let mls_key = MlsCodec::serialize(group_id).unwrap();
3897 storage.with_connection(|conn| {
3898 conn.query_row(
3899 &format!("SELECT COUNT(*) FROM {} WHERE group_id = ?", table),
3900 params![mls_key],
3901 |row| row.get(0),
3902 )
3903 .unwrap()
3904 })
3905 }
3906
3907 fn count_snapshot_rows_for_table(
3909 storage: &MdkSqliteStorage,
3910 snapshot_name: &str,
3911 group_id: &GroupId,
3912 table_name: &str,
3913 ) -> i64 {
3914 let raw_bytes = group_id.as_slice().to_vec();
3915 storage.with_connection(|conn| {
3916 conn.query_row(
3917 "SELECT COUNT(*) FROM group_state_snapshots
3918 WHERE snapshot_name = ? AND group_id = ? AND table_name = ?",
3919 params![snapshot_name, raw_bytes, table_name],
3920 |row| row.get(0),
3921 )
3922 .unwrap()
3923 })
3924 }
3925
3926 fn seed_openmls_data(storage: &MdkSqliteStorage, group_id: &GroupId) {
3932 let mls_key = MlsCodec::serialize(group_id).unwrap();
3933 storage.with_connection(|conn| {
3934 conn.execute(
3936 "INSERT OR REPLACE INTO openmls_group_data
3937 (group_id, data_type, group_data, provider_version)
3938 VALUES (?, ?, ?, ?)",
3939 params![mls_key, "group_state", b"test_crypto_state" as &[u8], 1i32],
3940 )
3941 .unwrap();
3942
3943 conn.execute(
3944 "INSERT OR REPLACE INTO openmls_group_data
3945 (group_id, data_type, group_data, provider_version)
3946 VALUES (?, ?, ?, ?)",
3947 params![mls_key, "tree", b"test_tree_data" as &[u8], 1i32],
3948 )
3949 .unwrap();
3950
3951 conn.execute(
3953 "INSERT INTO openmls_own_leaf_nodes
3954 (group_id, leaf_node, provider_version)
3955 VALUES (?, ?, ?)",
3956 params![mls_key, b"test_leaf_node" as &[u8], 1i32],
3957 )
3958 .unwrap();
3959
3960 let proposal_ref = MlsCodec::serialize(&vec![10u8, 20, 30]).unwrap();
3962 conn.execute(
3963 "INSERT OR REPLACE INTO openmls_proposals
3964 (group_id, proposal_ref, proposal, provider_version)
3965 VALUES (?, ?, ?, ?)",
3966 params![mls_key, proposal_ref, b"test_proposal" as &[u8], 1i32],
3967 )
3968 .unwrap();
3969
3970 let epoch_key = MlsCodec::serialize(&5u64).unwrap();
3972 conn.execute(
3973 "INSERT OR REPLACE INTO openmls_epoch_key_pairs
3974 (group_id, epoch_id, leaf_index, key_pairs, provider_version)
3975 VALUES (?, ?, ?, ?, ?)",
3976 params![mls_key, epoch_key, 0i32, b"test_key_pairs" as &[u8], 1i32],
3977 )
3978 .unwrap();
3979 });
3980 }
3981
3982 #[test]
3989 fn test_snapshot_captures_openmls_group_data() {
3990 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3991
3992 let group = create_test_group(1);
3993 let group_id = group.mls_group_id.clone();
3994 storage.save_group(group).unwrap();
3995
3996 seed_openmls_data(&storage, &group_id);
3998
3999 assert_eq!(
4001 count_openmls_rows(&storage, "openmls_group_data", &group_id),
4002 2,
4003 "openmls_group_data should have 2 rows (group_state + tree)"
4004 );
4005
4006 storage
4008 .create_group_snapshot(&group_id, "snap_mls")
4009 .unwrap();
4010
4011 let snap_count = count_snapshot_rows_for_table(
4013 &storage,
4014 "snap_mls",
4015 &group_id,
4016 "openmls_group_data",
4017 );
4018 assert_eq!(
4019 snap_count, 2,
4020 "Snapshot must capture openmls_group_data rows written via StorageProvider \
4021 (MlsCodec-serialized group_id keys)"
4022 );
4023 }
4024
4025 #[test]
4027 fn test_snapshot_captures_openmls_own_leaf_nodes() {
4028 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4029
4030 let group = create_test_group(2);
4031 let group_id = group.mls_group_id.clone();
4032 storage.save_group(group).unwrap();
4033
4034 seed_openmls_data(&storage, &group_id);
4035
4036 assert_eq!(
4037 count_openmls_rows(&storage, "openmls_own_leaf_nodes", &group_id),
4038 1,
4039 "openmls_own_leaf_nodes should have 1 row"
4040 );
4041
4042 storage
4043 .create_group_snapshot(&group_id, "snap_leaf")
4044 .unwrap();
4045
4046 let snap_count = count_snapshot_rows_for_table(
4047 &storage,
4048 "snap_leaf",
4049 &group_id,
4050 "openmls_own_leaf_nodes",
4051 );
4052 assert_eq!(
4053 snap_count, 1,
4054 "Snapshot must capture openmls_own_leaf_nodes rows written via \
4055 StorageProvider"
4056 );
4057 }
4058
4059 #[test]
4061 fn test_snapshot_captures_openmls_proposals() {
4062 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4063
4064 let group = create_test_group(3);
4065 let group_id = group.mls_group_id.clone();
4066 storage.save_group(group).unwrap();
4067
4068 seed_openmls_data(&storage, &group_id);
4069
4070 assert_eq!(
4071 count_openmls_rows(&storage, "openmls_proposals", &group_id),
4072 1,
4073 "openmls_proposals should have 1 row"
4074 );
4075
4076 storage
4077 .create_group_snapshot(&group_id, "snap_prop")
4078 .unwrap();
4079
4080 let snap_count = count_snapshot_rows_for_table(
4081 &storage,
4082 "snap_prop",
4083 &group_id,
4084 "openmls_proposals",
4085 );
4086 assert_eq!(
4087 snap_count, 1,
4088 "Snapshot must capture openmls_proposals rows written via \
4089 StorageProvider"
4090 );
4091 }
4092
4093 #[test]
4095 fn test_snapshot_captures_openmls_epoch_key_pairs() {
4096 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4097
4098 let group = create_test_group(4);
4099 let group_id = group.mls_group_id.clone();
4100 storage.save_group(group).unwrap();
4101
4102 seed_openmls_data(&storage, &group_id);
4103
4104 assert_eq!(
4105 count_openmls_rows(&storage, "openmls_epoch_key_pairs", &group_id),
4106 1,
4107 "openmls_epoch_key_pairs should have 1 row"
4108 );
4109
4110 storage
4111 .create_group_snapshot(&group_id, "snap_epoch")
4112 .unwrap();
4113
4114 let snap_count = count_snapshot_rows_for_table(
4115 &storage,
4116 "snap_epoch",
4117 &group_id,
4118 "openmls_epoch_key_pairs",
4119 );
4120 assert_eq!(
4121 snap_count, 1,
4122 "Snapshot must capture openmls_epoch_key_pairs rows written via \
4123 StorageProvider"
4124 );
4125 }
4126
4127 #[test]
4141 fn test_rollback_restores_openmls_group_data() {
4142 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4143
4144 let group = create_test_group(5);
4145 let group_id = group.mls_group_id.clone();
4146 storage.save_group(group).unwrap();
4147
4148 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4149
4150 storage.with_connection(|conn| {
4152 conn.execute(
4153 "INSERT OR REPLACE INTO openmls_group_data
4154 (group_id, data_type, group_data, provider_version)
4155 VALUES (?, ?, ?, ?)",
4156 params![mls_key, "group_state", b"epoch5_crypto" as &[u8], 1i32],
4157 )
4158 .unwrap();
4159 });
4160
4161 storage.create_group_snapshot(&group_id, "snap_e5").unwrap();
4163
4164 storage.with_connection(|conn| {
4166 conn.execute(
4167 "UPDATE openmls_group_data SET group_data = ?
4168 WHERE group_id = ? AND data_type = ?",
4169 params![b"epoch6_crypto" as &[u8], mls_key, "group_state"],
4170 )
4171 .unwrap();
4172 });
4173
4174 let crypto_before_rollback: Vec<u8> = storage.with_connection(|conn| {
4176 conn.query_row(
4177 "SELECT group_data FROM openmls_group_data
4178 WHERE group_id = ? AND data_type = ?",
4179 params![mls_key, "group_state"],
4180 |row| row.get(0),
4181 )
4182 .unwrap()
4183 });
4184 assert_eq!(crypto_before_rollback, b"epoch6_crypto");
4185
4186 storage
4188 .rollback_group_to_snapshot(&group_id, "snap_e5")
4189 .unwrap();
4190
4191 let crypto_after_rollback: Vec<u8> = storage.with_connection(|conn| {
4193 conn.query_row(
4194 "SELECT group_data FROM openmls_group_data
4195 WHERE group_id = ? AND data_type = ?",
4196 params![mls_key, "group_state"],
4197 |row| row.get(0),
4198 )
4199 .unwrap()
4200 });
4201 assert_eq!(
4202 crypto_after_rollback, b"epoch5_crypto",
4203 "Rollback must restore openmls_group_data to the snapshot state. \
4204 If this is epoch6_crypto, the snapshot failed to capture the \
4205 OpenMLS rows due to group_id encoding mismatch \
4206 (as_slice vs MlsCodec::serialize)."
4207 );
4208 }
4209
4210 #[test]
4212 fn test_rollback_restores_openmls_epoch_key_pairs() {
4213 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4214
4215 let group = create_test_group(6);
4216 let group_id = group.mls_group_id.clone();
4217 storage.save_group(group).unwrap();
4218
4219 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4220 let epoch_key = MlsCodec::serialize(&5u64).unwrap();
4221
4222 storage.with_connection(|conn| {
4224 conn.execute(
4225 "INSERT OR REPLACE INTO openmls_epoch_key_pairs
4226 (group_id, epoch_id, leaf_index, key_pairs, provider_version)
4227 VALUES (?, ?, ?, ?, ?)",
4228 params![mls_key, epoch_key, 0i32, b"epoch5_keys" as &[u8], 1i32],
4229 )
4230 .unwrap();
4231 });
4232
4233 storage
4235 .create_group_snapshot(&group_id, "snap_keys")
4236 .unwrap();
4237
4238 storage.with_connection(|conn| {
4240 conn.execute(
4241 "UPDATE openmls_epoch_key_pairs SET key_pairs = ?
4242 WHERE group_id = ? AND epoch_id = ? AND leaf_index = ?",
4243 params![b"epoch6_keys" as &[u8], mls_key, epoch_key, 0i32],
4244 )
4245 .unwrap();
4246 });
4247
4248 storage
4250 .rollback_group_to_snapshot(&group_id, "snap_keys")
4251 .unwrap();
4252
4253 let keys_after: Vec<u8> = storage.with_connection(|conn| {
4255 conn.query_row(
4256 "SELECT key_pairs FROM openmls_epoch_key_pairs
4257 WHERE group_id = ? AND epoch_id = ? AND leaf_index = ?",
4258 params![mls_key, epoch_key, 0i32],
4259 |row| row.get(0),
4260 )
4261 .unwrap()
4262 });
4263 assert_eq!(
4264 keys_after, b"epoch5_keys",
4265 "Rollback must restore openmls_epoch_key_pairs to snapshot state"
4266 );
4267 }
4268
4269 #[test]
4271 fn test_rollback_restores_openmls_own_leaf_nodes() {
4272 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4273
4274 let group = create_test_group(7);
4275 let group_id = group.mls_group_id.clone();
4276 storage.save_group(group).unwrap();
4277
4278 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4279
4280 storage.with_connection(|conn| {
4282 conn.execute(
4283 "INSERT INTO openmls_own_leaf_nodes
4284 (group_id, leaf_node, provider_version)
4285 VALUES (?, ?, ?)",
4286 params![mls_key, b"original_leaf" as &[u8], 1i32],
4287 )
4288 .unwrap();
4289 });
4290
4291 storage
4293 .create_group_snapshot(&group_id, "snap_leaf")
4294 .unwrap();
4295
4296 storage.with_connection(|conn| {
4298 conn.execute(
4299 "INSERT INTO openmls_own_leaf_nodes
4300 (group_id, leaf_node, provider_version)
4301 VALUES (?, ?, ?)",
4302 params![mls_key, b"added_after_snapshot" as &[u8], 1i32],
4303 )
4304 .unwrap();
4305 });
4306
4307 assert_eq!(
4309 count_openmls_rows(&storage, "openmls_own_leaf_nodes", &group_id),
4310 2
4311 );
4312
4313 storage
4315 .rollback_group_to_snapshot(&group_id, "snap_leaf")
4316 .unwrap();
4317
4318 let leaf_count = count_openmls_rows(&storage, "openmls_own_leaf_nodes", &group_id);
4320 assert_eq!(
4321 leaf_count, 1,
4322 "Rollback must restore openmls_own_leaf_nodes to snapshot state \
4323 (1 leaf, not 2)"
4324 );
4325
4326 let leaf_data: Vec<u8> = storage.with_connection(|conn| {
4327 conn.query_row(
4328 "SELECT leaf_node FROM openmls_own_leaf_nodes
4329 WHERE group_id = ? AND provider_version = ?",
4330 params![mls_key, 1i32],
4331 |row| row.get(0),
4332 )
4333 .unwrap()
4334 });
4335 assert_eq!(
4336 leaf_data, b"original_leaf",
4337 "Rollback must restore the original leaf node data"
4338 );
4339 }
4340
4341 #[test]
4348 fn test_rollback_metadata_crypto_consistency() {
4349 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4350
4351 let group = create_test_group(8);
4352 let group_id = group.mls_group_id.clone();
4353 storage.save_group(group).unwrap();
4354
4355 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4356
4357 {
4359 let mut g = storage
4360 .find_group_by_mls_group_id(&group_id)
4361 .unwrap()
4362 .unwrap();
4363 g.epoch = 5;
4364 storage.save_group(g).unwrap();
4365 }
4366 storage.with_connection(|conn| {
4367 conn.execute(
4368 "INSERT OR REPLACE INTO openmls_group_data
4369 (group_id, data_type, group_data, provider_version)
4370 VALUES (?, ?, ?, ?)",
4371 params![mls_key, "group_state", b"epoch5_state" as &[u8], 1i32],
4372 )
4373 .unwrap();
4374 });
4375
4376 storage
4378 .create_group_snapshot(&group_id, "snap_epoch5")
4379 .unwrap();
4380
4381 {
4383 let mut g = storage
4384 .find_group_by_mls_group_id(&group_id)
4385 .unwrap()
4386 .unwrap();
4387 g.epoch = 6;
4388 storage.save_group(g).unwrap();
4389 }
4390 storage.with_connection(|conn| {
4391 conn.execute(
4392 "UPDATE openmls_group_data SET group_data = ?
4393 WHERE group_id = ? AND data_type = ?",
4394 params![b"epoch6_state" as &[u8], mls_key, "group_state"],
4395 )
4396 .unwrap();
4397 });
4398
4399 storage
4401 .rollback_group_to_snapshot(&group_id, "snap_epoch5")
4402 .unwrap();
4403
4404 let group_after = storage
4406 .find_group_by_mls_group_id(&group_id)
4407 .unwrap()
4408 .unwrap();
4409 assert_eq!(
4410 group_after.epoch, 5,
4411 "MDK groups.epoch should be 5 after rollback"
4412 );
4413
4414 let crypto_after: Vec<u8> = storage.with_connection(|conn| {
4416 conn.query_row(
4417 "SELECT group_data FROM openmls_group_data
4418 WHERE group_id = ? AND data_type = ?",
4419 params![mls_key, "group_state"],
4420 |row| row.get(0),
4421 )
4422 .unwrap()
4423 });
4424 assert_eq!(
4425 crypto_after, b"epoch5_state",
4426 "OpenMLS crypto state must match MDK metadata epoch after rollback. \
4427 groups.epoch=5 but crypto state is still epoch6 data means \
4428 split-brain: MDK thinks epoch 5, MLS engine has epoch 6 keys. \
4429 Every subsequent message in this group will fail to decrypt."
4430 );
4431 }
4432
4433 #[test]
4441 fn test_restore_deletes_openmls_data_before_reinserting() {
4442 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4443
4444 let group = create_test_group(9);
4445 let group_id = group.mls_group_id.clone();
4446 storage.save_group(group).unwrap();
4447
4448 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4449
4450 storage.with_connection(|conn| {
4452 conn.execute(
4453 "INSERT OR REPLACE INTO openmls_group_data
4454 (group_id, data_type, group_data, provider_version)
4455 VALUES (?, ?, ?, ?)",
4456 params![mls_key, "group_state", b"initial_state" as &[u8], 1i32],
4457 )
4458 .unwrap();
4459 });
4460
4461 storage
4463 .create_group_snapshot(&group_id, "snap_initial")
4464 .unwrap();
4465
4466 storage.with_connection(|conn| {
4468 conn.execute(
4469 "UPDATE openmls_group_data SET group_data = ?
4470 WHERE group_id = ? AND data_type = ?",
4471 params![b"modified_state" as &[u8], mls_key, "group_state"],
4472 )
4473 .unwrap();
4474 });
4475
4476 storage
4478 .rollback_group_to_snapshot(&group_id, "snap_initial")
4479 .unwrap();
4480
4481 let row_count = count_openmls_rows(&storage, "openmls_group_data", &group_id);
4486 assert_eq!(
4487 row_count, 1,
4488 "After rollback, there should be exactly 1 openmls_group_data \
4489 row. More than 1 means the DELETE used the wrong key format and \
4490 failed to remove the stale OpenMLS row before re-inserting from \
4491 snapshot."
4492 );
4493 }
4494 }
4495}