1#![forbid(unsafe_code)]
65#![warn(missing_docs)]
66#![warn(rustdoc::bare_urls)]
67
68use std::path::Path;
69use std::sync::Arc;
70
71use mdk_storage_traits::{Backend, GroupId, MdkStorageError, MdkStorageProvider};
72use openmls_traits::storage::{StorageProvider, traits};
73use rusqlite::Connection;
74use std::sync::Mutex;
75
76mod db;
77pub mod encryption;
78pub mod error;
79mod groups;
80pub mod keyring;
81mod messages;
82mod migrations;
83mod mls_storage;
84mod permissions;
85#[cfg(test)]
86mod test_utils;
87mod validation;
88mod welcomes;
89
90pub use self::encryption::EncryptionConfig;
91use self::error::Error;
92use self::mls_storage::{GroupDataType, STORAGE_PROVIDER_VERSION};
93pub use self::permissions::verify_permissions;
94use self::permissions::{
95 FileCreationOutcome, precreate_secure_database_file, set_secure_file_permissions,
96};
97
98pub struct MdkSqliteStorage {
129 connection: Arc<Mutex<Connection>>,
131}
132
133impl MdkSqliteStorage {
134 pub fn new<P>(file_path: P, service_id: &str, db_key_id: &str) -> Result<Self, Error>
183 where
184 P: AsRef<Path>,
185 {
186 let file_path = file_path.as_ref();
187
188 let creation_outcome = precreate_secure_database_file(file_path)?;
192
193 let config = match creation_outcome {
194 FileCreationOutcome::Created | FileCreationOutcome::Skipped => {
195 keyring::get_or_create_db_key(service_id, db_key_id)?
198 }
199 FileCreationOutcome::AlreadyExisted => {
200 match keyring::get_db_key(service_id, db_key_id)? {
209 Some(config) => {
210 config
213 }
214 None => {
215 if !encryption::is_database_encrypted(file_path)? {
219 return Err(Error::UnencryptedDatabaseWithEncryption);
220 }
221
222 return Err(Error::KeyringEntryMissingForExistingDatabase {
224 db_path: file_path.display().to_string(),
225 service_id: service_id.to_string(),
226 db_key_id: db_key_id.to_string(),
227 });
228 }
229 }
230 }
231 };
232
233 Self::new_internal_skip_precreate(file_path, Some(config))
234 }
235
236 pub fn new_with_key<P>(file_path: P, config: EncryptionConfig) -> Result<Self, Error>
267 where
268 P: AsRef<Path>,
269 {
270 let file_path = file_path.as_ref();
271
272 if file_path.exists() && !encryption::is_database_encrypted(file_path)? {
275 return Err(Error::UnencryptedDatabaseWithEncryption);
276 }
277
278 Self::new_internal(file_path, Some(config))
279 }
280
281 pub fn new_unencrypted<P>(file_path: P) -> Result<Self, Error>
306 where
307 P: AsRef<Path>,
308 {
309 tracing::warn!(
310 "Creating unencrypted database. Sensitive MLS state will be stored in plaintext. \
311 For production use, use new() or new_with_key() instead."
312 );
313 Self::new_internal(file_path, None)
314 }
315
316 fn new_internal<P>(
320 file_path: P,
321 encryption_config: Option<EncryptionConfig>,
322 ) -> Result<Self, Error>
323 where
324 P: AsRef<Path>,
325 {
326 let file_path = file_path.as_ref();
327
328 precreate_secure_database_file(file_path)?;
330
331 Self::new_internal_skip_precreate(file_path, encryption_config)
332 }
333
334 fn new_internal_skip_precreate(
339 file_path: &Path,
340 encryption_config: Option<EncryptionConfig>,
341 ) -> Result<Self, Error> {
342 let mut connection = Self::open_connection(file_path, encryption_config.as_ref())?;
344
345 migrations::run_migrations(&mut connection)?;
347
348 Self::apply_secure_permissions(file_path)?;
350
351 Ok(Self {
352 connection: Arc::new(Mutex::new(connection)),
353 })
354 }
355
356 fn open_connection(
358 file_path: &Path,
359 encryption_config: Option<&EncryptionConfig>,
360 ) -> Result<Connection, Error> {
361 let conn = Connection::open(file_path)?;
362
363 if let Some(config) = encryption_config {
365 encryption::apply_encryption(&conn, config)?;
366 }
367
368 conn.execute_batch("PRAGMA foreign_keys = ON;")?;
370
371 Ok(conn)
372 }
373
374 fn apply_secure_permissions(db_path: &Path) -> Result<(), Error> {
400 let path_str = db_path.to_string_lossy();
402 if path_str.is_empty() || path_str == ":memory:" || path_str.starts_with(':') {
403 return Ok(());
404 }
405
406 set_secure_file_permissions(db_path)?;
408
409 let parent = db_path.parent();
413 let stem = db_path.file_name().and_then(|n| n.to_str());
414
415 if let (Some(parent), Some(stem)) = (parent, stem) {
416 for suffix in &["-wal", "-shm", "-journal"] {
417 let sidecar = parent.join(format!("{}{}", stem, suffix));
418 if sidecar.exists() {
419 set_secure_file_permissions(&sidecar)?;
420 }
421 }
422 }
423
424 Ok(())
425 }
426
427 #[cfg(test)]
435 pub fn new_in_memory() -> Result<Self, Error> {
436 let mut connection = Connection::open_in_memory()?;
438
439 connection.execute_batch("PRAGMA foreign_keys = ON;")?;
441
442 migrations::run_migrations(&mut connection)?;
444
445 Ok(Self {
446 connection: Arc::new(Mutex::new(connection)),
447 })
448 }
449
450 pub(crate) fn with_connection<F, T>(&self, f: F) -> T
454 where
455 F: FnOnce(&Connection) -> T,
456 {
457 let conn = self.connection.lock().unwrap();
458 f(&conn)
459 }
460
461 fn snapshot_group_state(&self, group_id: &GroupId, name: &str) -> Result<(), Error> {
464 let conn = self.connection.lock().unwrap();
465 let group_id_bytes = group_id.as_slice();
466 let mls_group_id_bytes = mls_storage::MlsCodec::serialize(group_id)
469 .map_err(|e| Error::Database(e.to_string()))?;
470 let now = std::time::SystemTime::now()
471 .duration_since(std::time::UNIX_EPOCH)
472 .map_err(|e| Error::Database(format!("Time error: {}", e)))?
473 .as_secs() as i64;
474
475 conn.execute("BEGIN IMMEDIATE", [])
477 .map_err(|e| Error::Database(e.to_string()))?;
478
479 let result = (|| -> Result<(), Error> {
480 let mut insert_stmt = conn
482 .prepare_cached(
483 "INSERT INTO group_state_snapshots
484 (snapshot_name, group_id, table_name, row_key, row_data, created_at)
485 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
486 )
487 .map_err(|e| Error::Database(e.to_string()))?;
488
489 Self::snapshot_openmls_group_data(
492 &conn,
493 &mut insert_stmt,
494 name,
495 group_id_bytes,
496 &mls_group_id_bytes,
497 now,
498 )?;
499 Self::snapshot_openmls_proposals(
500 &conn,
501 &mut insert_stmt,
502 name,
503 group_id_bytes,
504 &mls_group_id_bytes,
505 now,
506 )?;
507 Self::snapshot_openmls_own_leaf_nodes(
508 &conn,
509 &mut insert_stmt,
510 name,
511 group_id_bytes,
512 &mls_group_id_bytes,
513 now,
514 )?;
515 Self::snapshot_openmls_epoch_key_pairs(
516 &conn,
517 &mut insert_stmt,
518 name,
519 group_id_bytes,
520 &mls_group_id_bytes,
521 now,
522 )?;
523 Self::snapshot_groups_table(&conn, &mut insert_stmt, name, group_id_bytes, now)?;
525 Self::snapshot_group_relays(&conn, &mut insert_stmt, name, group_id_bytes, now)?;
526 Self::snapshot_group_exporter_secrets(
527 &conn,
528 &mut insert_stmt,
529 name,
530 group_id_bytes,
531 now,
532 )?;
533
534 Ok(())
535 })();
536
537 match result {
538 Ok(()) => {
539 conn.execute("COMMIT", [])
540 .map_err(|e| Error::Database(e.to_string()))?;
541 Ok(())
542 }
543 Err(e) => {
544 let _ = conn.execute("ROLLBACK", []);
545 Err(e)
546 }
547 }
548 }
549
550 fn snapshot_openmls_group_data(
552 conn: &rusqlite::Connection,
553 insert_stmt: &mut rusqlite::CachedStatement<'_>,
554 snapshot_name: &str,
555 group_id_bytes: &[u8],
556 mls_group_id_bytes: &[u8],
557 now: i64,
558 ) -> Result<(), Error> {
559 let mut stmt = conn
560 .prepare(
561 "SELECT group_id, data_type, group_data FROM openmls_group_data WHERE group_id = ?",
562 )
563 .map_err(|e| Error::Database(e.to_string()))?;
564 let mut rows = stmt
565 .query([mls_group_id_bytes])
566 .map_err(|e| Error::Database(e.to_string()))?;
567
568 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
569 let gid: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
570 let data_type: String = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
571 let data: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
572 let row_key = serde_json::to_vec(&(&gid, &data_type))
573 .map_err(|e| Error::Database(e.to_string()))?;
574 insert_stmt
575 .execute(rusqlite::params![
576 snapshot_name,
577 group_id_bytes,
578 "openmls_group_data",
579 row_key,
580 data,
581 now
582 ])
583 .map_err(|e| Error::Database(e.to_string()))?;
584 }
585 Ok(())
586 }
587
588 fn snapshot_openmls_proposals(
590 conn: &rusqlite::Connection,
591 insert_stmt: &mut rusqlite::CachedStatement<'_>,
592 snapshot_name: &str,
593 group_id_bytes: &[u8],
594 mls_group_id_bytes: &[u8],
595 now: i64,
596 ) -> Result<(), Error> {
597 let mut stmt = conn
598 .prepare(
599 "SELECT group_id, proposal_ref, proposal FROM openmls_proposals WHERE group_id = ?",
600 )
601 .map_err(|e| Error::Database(e.to_string()))?;
602 let mut rows = stmt
603 .query([mls_group_id_bytes])
604 .map_err(|e| Error::Database(e.to_string()))?;
605
606 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
607 let gid: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
608 let proposal_ref: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
609 let proposal: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
610 let row_key = serde_json::to_vec(&(&gid, &proposal_ref))
611 .map_err(|e| Error::Database(e.to_string()))?;
612 insert_stmt
613 .execute(rusqlite::params![
614 snapshot_name,
615 group_id_bytes,
616 "openmls_proposals",
617 row_key,
618 proposal,
619 now
620 ])
621 .map_err(|e| Error::Database(e.to_string()))?;
622 }
623 Ok(())
624 }
625
626 fn snapshot_openmls_own_leaf_nodes(
628 conn: &rusqlite::Connection,
629 insert_stmt: &mut rusqlite::CachedStatement<'_>,
630 snapshot_name: &str,
631 group_id_bytes: &[u8],
632 mls_group_id_bytes: &[u8],
633 now: i64,
634 ) -> Result<(), Error> {
635 let mut stmt = conn
636 .prepare(
637 "SELECT id, group_id, leaf_node FROM openmls_own_leaf_nodes WHERE group_id = ?",
638 )
639 .map_err(|e| Error::Database(e.to_string()))?;
640 let mut rows = stmt
641 .query([mls_group_id_bytes])
642 .map_err(|e| Error::Database(e.to_string()))?;
643
644 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
645 let id: i64 = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
646 let gid: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
647 let leaf_node: Vec<u8> = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
648 let row_key = serde_json::to_vec(&id).map_err(|e| Error::Database(e.to_string()))?;
649 let row_data = serde_json::to_vec(&(&gid, &leaf_node))
650 .map_err(|e| Error::Database(e.to_string()))?;
651 insert_stmt
652 .execute(rusqlite::params![
653 snapshot_name,
654 group_id_bytes,
655 "openmls_own_leaf_nodes",
656 row_key,
657 row_data,
658 now
659 ])
660 .map_err(|e| Error::Database(e.to_string()))?;
661 }
662 Ok(())
663 }
664
665 fn snapshot_openmls_epoch_key_pairs(
667 conn: &rusqlite::Connection,
668 insert_stmt: &mut rusqlite::CachedStatement<'_>,
669 snapshot_name: &str,
670 group_id_bytes: &[u8],
671 mls_group_id_bytes: &[u8],
672 now: i64,
673 ) -> Result<(), Error> {
674 let mut stmt = conn
675 .prepare(
676 "SELECT group_id, epoch_id, leaf_index, key_pairs
677 FROM openmls_epoch_key_pairs WHERE group_id = ?",
678 )
679 .map_err(|e| Error::Database(e.to_string()))?;
680 let mut rows = stmt
681 .query([mls_group_id_bytes])
682 .map_err(|e| Error::Database(e.to_string()))?;
683
684 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
685 let gid: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
686 let epoch_id: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
687 let leaf_index: i64 = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
688 let key_pairs: Vec<u8> = row.get(3).map_err(|e| Error::Database(e.to_string()))?;
689 let row_key = serde_json::to_vec(&(&gid, &epoch_id, leaf_index))
690 .map_err(|e| Error::Database(e.to_string()))?;
691 insert_stmt
692 .execute(rusqlite::params![
693 snapshot_name,
694 group_id_bytes,
695 "openmls_epoch_key_pairs",
696 row_key,
697 key_pairs,
698 now
699 ])
700 .map_err(|e| Error::Database(e.to_string()))?;
701 }
702 Ok(())
703 }
704
705 fn snapshot_groups_table(
707 conn: &rusqlite::Connection,
708 insert_stmt: &mut rusqlite::CachedStatement<'_>,
709 snapshot_name: &str,
710 group_id_bytes: &[u8],
711 now: i64,
712 ) -> Result<(), Error> {
713 let mut stmt = conn
714 .prepare(
715 "SELECT mls_group_id, nostr_group_id, name, description, admin_pubkeys,
716 last_message_id, last_message_at, last_message_processed_at, epoch, state,
717 image_hash, image_key, image_nonce, last_self_update_at
718 FROM groups WHERE mls_group_id = ?",
719 )
720 .map_err(|e| Error::Database(e.to_string()))?;
721 let mut rows = stmt
722 .query([group_id_bytes])
723 .map_err(|e| Error::Database(e.to_string()))?;
724
725 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
726 let mls_group_id: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
727 let nostr_group_id: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
728 let name_val: String = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
729 let description: String = row.get(3).map_err(|e| Error::Database(e.to_string()))?;
730 let admin_pubkeys: String = row.get(4).map_err(|e| Error::Database(e.to_string()))?;
731 let last_message_id: Option<Vec<u8>> =
732 row.get(5).map_err(|e| Error::Database(e.to_string()))?;
733 let last_message_at: Option<i64> =
734 row.get(6).map_err(|e| Error::Database(e.to_string()))?;
735 let last_message_processed_at: Option<i64> =
736 row.get(7).map_err(|e| Error::Database(e.to_string()))?;
737 let epoch: i64 = row.get(8).map_err(|e| Error::Database(e.to_string()))?;
738 let state: String = row.get(9).map_err(|e| Error::Database(e.to_string()))?;
739 let image_hash: Option<Vec<u8>> =
740 row.get(10).map_err(|e| Error::Database(e.to_string()))?;
741 let image_key: Option<Vec<u8>> =
742 row.get(11).map_err(|e| Error::Database(e.to_string()))?;
743 let image_nonce: Option<Vec<u8>> =
744 row.get(12).map_err(|e| Error::Database(e.to_string()))?;
745 let last_self_update_at: i64 =
746 row.get(13).map_err(|e| Error::Database(e.to_string()))?;
747
748 let row_key =
749 serde_json::to_vec(&mls_group_id).map_err(|e| Error::Database(e.to_string()))?;
750 let row_data = serde_json::to_vec(&(
751 &nostr_group_id,
752 &name_val,
753 &description,
754 &admin_pubkeys,
755 &last_message_id,
756 &last_message_at,
757 &last_message_processed_at,
758 epoch,
759 &state,
760 &image_hash,
761 &image_key,
762 &image_nonce,
763 &last_self_update_at,
764 ))
765 .map_err(|e| Error::Database(e.to_string()))?;
766
767 insert_stmt
768 .execute(rusqlite::params![
769 snapshot_name,
770 group_id_bytes,
771 "groups",
772 row_key,
773 row_data,
774 now
775 ])
776 .map_err(|e| Error::Database(e.to_string()))?;
777 }
778 Ok(())
779 }
780
781 fn snapshot_group_relays(
783 conn: &rusqlite::Connection,
784 insert_stmt: &mut rusqlite::CachedStatement<'_>,
785 snapshot_name: &str,
786 group_id_bytes: &[u8],
787 now: i64,
788 ) -> Result<(), Error> {
789 let mut stmt = conn
790 .prepare("SELECT id, mls_group_id, relay_url FROM group_relays WHERE mls_group_id = ?")
791 .map_err(|e| Error::Database(e.to_string()))?;
792 let mut rows = stmt
793 .query([group_id_bytes])
794 .map_err(|e| Error::Database(e.to_string()))?;
795
796 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
797 let id: i64 = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
798 let mls_group_id: Vec<u8> = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
799 let relay_url: String = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
800 let row_key = serde_json::to_vec(&id).map_err(|e| Error::Database(e.to_string()))?;
801 let row_data = serde_json::to_vec(&(&mls_group_id, &relay_url))
802 .map_err(|e| Error::Database(e.to_string()))?;
803 insert_stmt
804 .execute(rusqlite::params![
805 snapshot_name,
806 group_id_bytes,
807 "group_relays",
808 row_key,
809 row_data,
810 now
811 ])
812 .map_err(|e| Error::Database(e.to_string()))?;
813 }
814 Ok(())
815 }
816
817 fn snapshot_group_exporter_secrets(
819 conn: &rusqlite::Connection,
820 insert_stmt: &mut rusqlite::CachedStatement<'_>,
821 snapshot_name: &str,
822 group_id_bytes: &[u8],
823 now: i64,
824 ) -> Result<(), Error> {
825 let mut stmt = conn
826 .prepare(
827 "SELECT mls_group_id, epoch, label, secret FROM group_exporter_secrets WHERE mls_group_id = ?",
828 )
829 .map_err(|e| Error::Database(e.to_string()))?;
830 let mut rows = stmt
831 .query([group_id_bytes])
832 .map_err(|e| Error::Database(e.to_string()))?;
833
834 while let Some(row) = rows.next().map_err(|e| Error::Database(e.to_string()))? {
835 let mls_group_id: Vec<u8> = row.get(0).map_err(|e| Error::Database(e.to_string()))?;
836 let epoch: i64 = row.get(1).map_err(|e| Error::Database(e.to_string()))?;
837 let label: String = row.get(2).map_err(|e| Error::Database(e.to_string()))?;
838 let secret: Vec<u8> = row.get(3).map_err(|e| Error::Database(e.to_string()))?;
839 let row_key = serde_json::to_vec(&(&mls_group_id, epoch, &label))
840 .map_err(|e| Error::Database(e.to_string()))?;
841 insert_stmt
842 .execute(rusqlite::params![
843 snapshot_name,
844 group_id_bytes,
845 "group_exporter_secrets",
846 row_key,
847 secret,
848 now
849 ])
850 .map_err(|e| Error::Database(e.to_string()))?;
851 }
852 Ok(())
853 }
854
855 fn restore_group_from_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), Error> {
858 let conn = self.connection.lock().unwrap();
859 let group_id_bytes = group_id.as_slice();
860 let mls_group_id_bytes = mls_storage::MlsCodec::serialize(group_id)
863 .map_err(|e| Error::Database(e.to_string()))?;
864
865 let snapshot_exists: bool = conn
869 .query_row(
870 "SELECT EXISTS(SELECT 1 FROM group_state_snapshots WHERE snapshot_name = ? AND group_id = ?)",
871 rusqlite::params![name, group_id_bytes],
872 |row| row.get(0),
873 )
874 .map_err(|e| Error::Database(e.to_string()))?;
875
876 if !snapshot_exists {
877 return Err(Error::Database("Snapshot not found".to_string()));
878 }
879
880 let snapshot_rows: Vec<(String, Vec<u8>, Vec<u8>)> = {
885 let mut stmt = conn
886 .prepare(
887 "SELECT table_name, row_key, row_data FROM group_state_snapshots
888 WHERE snapshot_name = ? AND group_id = ?",
889 )
890 .map_err(|e| Error::Database(e.to_string()))?;
891
892 let rows = stmt
893 .query_map(rusqlite::params![name, group_id_bytes], |row| {
894 Ok((row.get(0)?, row.get(1)?, row.get(2)?))
895 })
896 .map_err(|e| Error::Database(e.to_string()))?;
897
898 rows.collect::<Result<Vec<_>, _>>()
899 .map_err(|e| Error::Database(e.to_string()))?
900 };
901
902 #[allow(clippy::type_complexity)]
905 let other_snapshots: Vec<(String, String, Vec<u8>, Vec<u8>, i64)> = {
906 let mut stmt = conn
907 .prepare(
908 "SELECT snapshot_name, table_name, row_key, row_data, created_at
909 FROM group_state_snapshots
910 WHERE group_id = ? AND snapshot_name != ?",
911 )
912 .map_err(|e| Error::Database(e.to_string()))?;
913
914 let rows = stmt
915 .query_map(rusqlite::params![group_id_bytes, name], |row| {
916 Ok((
917 row.get(0)?,
918 row.get(1)?,
919 row.get(2)?,
920 row.get(3)?,
921 row.get(4)?,
922 ))
923 })
924 .map_err(|e| Error::Database(e.to_string()))?;
925
926 rows.collect::<Result<Vec<_>, _>>()
927 .map_err(|e| Error::Database(e.to_string()))?
928 };
929
930 conn.execute("BEGIN IMMEDIATE", [])
932 .map_err(|e| Error::Database(e.to_string()))?;
933
934 let result = (|| -> Result<(), Error> {
935 conn.execute(
938 "DELETE FROM openmls_group_data WHERE group_id = ?",
939 [&mls_group_id_bytes],
940 )
941 .map_err(|e| Error::Database(e.to_string()))?;
942
943 conn.execute(
944 "DELETE FROM openmls_proposals WHERE group_id = ?",
945 [&mls_group_id_bytes],
946 )
947 .map_err(|e| Error::Database(e.to_string()))?;
948
949 conn.execute(
950 "DELETE FROM openmls_own_leaf_nodes WHERE group_id = ?",
951 [&mls_group_id_bytes],
952 )
953 .map_err(|e| Error::Database(e.to_string()))?;
954
955 conn.execute(
956 "DELETE FROM openmls_epoch_key_pairs WHERE group_id = ?",
957 [&mls_group_id_bytes],
958 )
959 .map_err(|e| Error::Database(e.to_string()))?;
960
961 conn.execute(
964 "DELETE FROM group_exporter_secrets WHERE mls_group_id = ?",
965 [group_id_bytes],
966 )
967 .map_err(|e| Error::Database(e.to_string()))?;
968
969 conn.execute(
970 "DELETE FROM group_relays WHERE mls_group_id = ?",
971 [group_id_bytes],
972 )
973 .map_err(|e| Error::Database(e.to_string()))?;
974
975 conn.execute(
976 "DELETE FROM groups WHERE mls_group_id = ?",
977 [group_id_bytes],
978 )
979 .map_err(|e| Error::Database(e.to_string()))?;
980
981 for (table_name, row_key, row_data) in &snapshot_rows {
989 if table_name != "groups" {
990 continue;
991 }
992 let mls_group_id: Vec<u8> =
993 serde_json::from_slice(row_key).map_err(|e| Error::Database(e.to_string()))?;
994 #[allow(clippy::type_complexity)]
995 let (
996 nostr_group_id,
997 name_val,
998 description,
999 admin_pubkeys,
1000 last_message_id,
1001 last_message_at,
1002 last_message_processed_at,
1003 epoch,
1004 state,
1005 image_hash,
1006 image_key,
1007 image_nonce,
1008 last_self_update_at,
1009 ): (
1010 Vec<u8>,
1011 String,
1012 String,
1013 String,
1014 Option<Vec<u8>>,
1015 Option<i64>,
1016 Option<i64>,
1017 i64,
1018 String,
1019 Option<Vec<u8>>,
1020 Option<Vec<u8>>,
1021 Option<Vec<u8>>,
1022 i64,
1023 ) = serde_json::from_slice(row_data).map_err(|e| Error::Database(e.to_string()))?;
1024 conn.execute(
1025 "INSERT INTO groups (mls_group_id, nostr_group_id, name, description, admin_pubkeys,
1026 last_message_id, last_message_at, last_message_processed_at, epoch, state,
1027 image_hash, image_key, image_nonce, last_self_update_at)
1028 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
1029 rusqlite::params![
1030 mls_group_id,
1031 nostr_group_id,
1032 name_val,
1033 description,
1034 admin_pubkeys,
1035 last_message_id,
1036 last_message_at,
1037 last_message_processed_at,
1038 epoch,
1039 state,
1040 image_hash,
1041 image_key,
1042 image_nonce,
1043 last_self_update_at
1044 ],
1045 )
1046 .map_err(|e| Error::Database(e.to_string()))?;
1047 }
1048
1049 for (table_name, row_key, row_data) in &snapshot_rows {
1051 match table_name.as_str() {
1052 "openmls_group_data" => {
1053 let (gid, data_type): (Vec<u8>, String) =
1054 serde_json::from_slice(row_key)
1055 .map_err(|e| Error::Database(e.to_string()))?;
1056 conn.execute(
1057 "INSERT INTO openmls_group_data (provider_version, group_id, data_type, group_data)
1058 VALUES (1, ?, ?, ?)",
1059 rusqlite::params![gid, data_type, row_data],
1060 )
1061 .map_err(|e| Error::Database(e.to_string()))?;
1062 }
1063 "openmls_proposals" => {
1064 let (gid, proposal_ref): (Vec<u8>, Vec<u8>) =
1065 serde_json::from_slice(row_key)
1066 .map_err(|e| Error::Database(e.to_string()))?;
1067 conn.execute(
1068 "INSERT INTO openmls_proposals (provider_version, group_id, proposal_ref, proposal)
1069 VALUES (1, ?, ?, ?)",
1070 rusqlite::params![gid, proposal_ref, row_data],
1071 )
1072 .map_err(|e| Error::Database(e.to_string()))?;
1073 }
1074 "openmls_own_leaf_nodes" => {
1075 let (gid, leaf_node): (Vec<u8>, Vec<u8>) = serde_json::from_slice(row_data)
1076 .map_err(|e| Error::Database(e.to_string()))?;
1077 conn.execute(
1078 "INSERT INTO openmls_own_leaf_nodes (provider_version, group_id, leaf_node)
1079 VALUES (1, ?, ?)",
1080 rusqlite::params![gid, leaf_node],
1081 )
1082 .map_err(|e| Error::Database(e.to_string()))?;
1083 }
1084 "openmls_epoch_key_pairs" => {
1085 let (gid, epoch_id, leaf_index): (Vec<u8>, Vec<u8>, i64) =
1086 serde_json::from_slice(row_key)
1087 .map_err(|e| Error::Database(e.to_string()))?;
1088 conn.execute(
1089 "INSERT INTO openmls_epoch_key_pairs (provider_version, group_id, epoch_id, leaf_index, key_pairs)
1090 VALUES (1, ?, ?, ?, ?)",
1091 rusqlite::params![gid, epoch_id, leaf_index, row_data],
1092 )
1093 .map_err(|e| Error::Database(e.to_string()))?;
1094 }
1095 "groups" => {
1096 }
1098 "group_relays" => {
1099 let (mls_group_id, relay_url): (Vec<u8>, String) =
1100 serde_json::from_slice(row_data)
1101 .map_err(|e| Error::Database(e.to_string()))?;
1102 conn.execute(
1103 "INSERT INTO group_relays (mls_group_id, relay_url) VALUES (?, ?)",
1104 rusqlite::params![mls_group_id, relay_url],
1105 )
1106 .map_err(|e| Error::Database(e.to_string()))?;
1107 }
1108 "group_exporter_secrets" => {
1109 let (mls_group_id, epoch, label): (Vec<u8>, i64, String) =
1110 match serde_json::from_slice(row_key) {
1111 Ok(v) => v,
1112 Err(_) => {
1113 let (mls_group_id, epoch): (Vec<u8>, i64) =
1114 serde_json::from_slice(row_key)
1115 .map_err(|e| Error::Database(e.to_string()))?;
1116 (mls_group_id, epoch, "group-event".to_string())
1117 }
1118 };
1119 conn.execute(
1120 "INSERT INTO group_exporter_secrets (mls_group_id, epoch, label, secret) VALUES (?, ?, ?, ?)",
1121 rusqlite::params![mls_group_id, epoch, label, row_data],
1122 )
1123 .map_err(|e| Error::Database(e.to_string()))?;
1124 }
1125 _ => {
1126 }
1128 }
1129 }
1130
1131 conn.execute(
1133 "DELETE FROM group_state_snapshots WHERE snapshot_name = ? AND group_id = ?",
1134 rusqlite::params![name, group_id_bytes],
1135 )
1136 .map_err(|e| Error::Database(e.to_string()))?;
1137
1138 for (snap_name, table_name, row_key, row_data, created_at) in &other_snapshots {
1141 conn.execute(
1142 "INSERT INTO group_state_snapshots (snapshot_name, group_id, table_name, row_key, row_data, created_at)
1143 VALUES (?, ?, ?, ?, ?, ?)",
1144 rusqlite::params![snap_name, group_id_bytes, table_name, row_key, row_data, created_at],
1145 )
1146 .map_err(|e| Error::Database(e.to_string()))?;
1147 }
1148
1149 Ok(())
1150 })();
1151
1152 match result {
1153 Ok(()) => {
1154 conn.execute("COMMIT", [])
1155 .map_err(|e| Error::Database(e.to_string()))?;
1156 Ok(())
1157 }
1158 Err(e) => {
1159 let _ = conn.execute("ROLLBACK", []);
1160 Err(e)
1161 }
1162 }
1163 }
1164
1165 fn delete_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), Error> {
1167 let conn = self.connection.lock().unwrap();
1168 conn.execute(
1169 "DELETE FROM group_state_snapshots WHERE snapshot_name = ? AND group_id = ?",
1170 rusqlite::params![name, group_id.as_slice()],
1171 )
1172 .map_err(|e| Error::Database(e.to_string()))?;
1173 Ok(())
1174 }
1175}
1176
1177impl MdkStorageProvider for MdkSqliteStorage {
1179 fn backend(&self) -> Backend {
1185 Backend::SQLite
1186 }
1187
1188 fn create_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), MdkStorageError> {
1189 self.snapshot_group_state(group_id, name)
1190 .map_err(|e| MdkStorageError::Database(e.to_string()))
1191 }
1192
1193 fn rollback_group_to_snapshot(
1194 &self,
1195 group_id: &GroupId,
1196 name: &str,
1197 ) -> Result<(), MdkStorageError> {
1198 self.restore_group_from_snapshot(group_id, name)
1199 .map_err(|e| MdkStorageError::Database(e.to_string()))
1200 }
1201
1202 fn release_group_snapshot(
1203 &self,
1204 group_id: &GroupId,
1205 name: &str,
1206 ) -> Result<(), MdkStorageError> {
1207 self.delete_group_snapshot(group_id, name)
1208 .map_err(|e| MdkStorageError::Database(e.to_string()))
1209 }
1210
1211 fn list_group_snapshots(
1212 &self,
1213 group_id: &GroupId,
1214 ) -> Result<Vec<(String, u64)>, MdkStorageError> {
1215 let conn = self.connection.lock().unwrap();
1216 let mut stmt = conn
1217 .prepare_cached(
1218 "SELECT DISTINCT snapshot_name, created_at FROM group_state_snapshots
1219 WHERE group_id = ? ORDER BY created_at ASC",
1220 )
1221 .map_err(|e| MdkStorageError::Database(e.to_string()))?;
1222
1223 let rows = stmt
1224 .query_map(rusqlite::params![group_id.as_slice()], |row| {
1225 let name: String = row.get(0)?;
1226 let created_at: i64 = row.get(1)?;
1227 Ok((name, created_at as u64))
1228 })
1229 .map_err(|e| MdkStorageError::Database(e.to_string()))?;
1230
1231 rows.collect::<Result<Vec<_>, _>>()
1232 .map_err(|e| MdkStorageError::Database(e.to_string()))
1233 }
1234
1235 fn prune_expired_snapshots(&self, min_timestamp: u64) -> Result<usize, MdkStorageError> {
1236 let conn = self.connection.lock().unwrap();
1237 let deleted = conn
1238 .execute(
1239 "DELETE FROM group_state_snapshots WHERE created_at < ?",
1240 rusqlite::params![min_timestamp as i64],
1241 )
1242 .map_err(|e| MdkStorageError::Database(e.to_string()))?;
1243 Ok(deleted)
1244 }
1245}
1246
1247impl StorageProvider<STORAGE_PROVIDER_VERSION> for MdkSqliteStorage {
1252 type Error = MdkStorageError;
1253
1254 fn write_mls_join_config<GroupId, MlsGroupJoinConfig>(
1259 &self,
1260 group_id: &GroupId,
1261 config: &MlsGroupJoinConfig,
1262 ) -> Result<(), Self::Error>
1263 where
1264 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1265 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<STORAGE_PROVIDER_VERSION>,
1266 {
1267 self.with_connection(|conn| {
1268 mls_storage::write_group_data(conn, group_id, GroupDataType::JoinGroupConfig, config)
1269 })
1270 }
1271
1272 fn append_own_leaf_node<GroupId, LeafNode>(
1273 &self,
1274 group_id: &GroupId,
1275 leaf_node: &LeafNode,
1276 ) -> Result<(), Self::Error>
1277 where
1278 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1279 LeafNode: traits::LeafNode<STORAGE_PROVIDER_VERSION>,
1280 {
1281 self.with_connection(|conn| mls_storage::append_own_leaf_node(conn, group_id, leaf_node))
1282 }
1283
1284 fn queue_proposal<GroupId, ProposalRef, QueuedProposal>(
1285 &self,
1286 group_id: &GroupId,
1287 proposal_ref: &ProposalRef,
1288 proposal: &QueuedProposal,
1289 ) -> Result<(), Self::Error>
1290 where
1291 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1292 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1293 QueuedProposal: traits::QueuedProposal<STORAGE_PROVIDER_VERSION>,
1294 {
1295 self.with_connection(|conn| {
1296 mls_storage::queue_proposal(conn, group_id, proposal_ref, proposal)
1297 })
1298 }
1299
1300 fn write_tree<GroupId, TreeSync>(
1301 &self,
1302 group_id: &GroupId,
1303 tree: &TreeSync,
1304 ) -> Result<(), Self::Error>
1305 where
1306 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1307 TreeSync: traits::TreeSync<STORAGE_PROVIDER_VERSION>,
1308 {
1309 self.with_connection(|conn| {
1310 mls_storage::write_group_data(conn, group_id, GroupDataType::Tree, tree)
1311 })
1312 }
1313
1314 fn write_interim_transcript_hash<GroupId, InterimTranscriptHash>(
1315 &self,
1316 group_id: &GroupId,
1317 interim_transcript_hash: &InterimTranscriptHash,
1318 ) -> Result<(), Self::Error>
1319 where
1320 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1321 InterimTranscriptHash: traits::InterimTranscriptHash<STORAGE_PROVIDER_VERSION>,
1322 {
1323 self.with_connection(|conn| {
1324 mls_storage::write_group_data(
1325 conn,
1326 group_id,
1327 GroupDataType::InterimTranscriptHash,
1328 interim_transcript_hash,
1329 )
1330 })
1331 }
1332
1333 fn write_context<GroupId, GroupContext>(
1334 &self,
1335 group_id: &GroupId,
1336 group_context: &GroupContext,
1337 ) -> Result<(), Self::Error>
1338 where
1339 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1340 GroupContext: traits::GroupContext<STORAGE_PROVIDER_VERSION>,
1341 {
1342 self.with_connection(|conn| {
1343 mls_storage::write_group_data(conn, group_id, GroupDataType::Context, group_context)
1344 })
1345 }
1346
1347 fn write_confirmation_tag<GroupId, ConfirmationTag>(
1348 &self,
1349 group_id: &GroupId,
1350 confirmation_tag: &ConfirmationTag,
1351 ) -> Result<(), Self::Error>
1352 where
1353 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1354 ConfirmationTag: traits::ConfirmationTag<STORAGE_PROVIDER_VERSION>,
1355 {
1356 self.with_connection(|conn| {
1357 mls_storage::write_group_data(
1358 conn,
1359 group_id,
1360 GroupDataType::ConfirmationTag,
1361 confirmation_tag,
1362 )
1363 })
1364 }
1365
1366 fn write_group_state<GroupState, GroupId>(
1367 &self,
1368 group_id: &GroupId,
1369 group_state: &GroupState,
1370 ) -> Result<(), Self::Error>
1371 where
1372 GroupState: traits::GroupState<STORAGE_PROVIDER_VERSION>,
1373 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1374 {
1375 self.with_connection(|conn| {
1376 mls_storage::write_group_data(conn, group_id, GroupDataType::GroupState, group_state)
1377 })
1378 }
1379
1380 fn write_message_secrets<GroupId, MessageSecrets>(
1381 &self,
1382 group_id: &GroupId,
1383 message_secrets: &MessageSecrets,
1384 ) -> Result<(), Self::Error>
1385 where
1386 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1387 MessageSecrets: traits::MessageSecrets<STORAGE_PROVIDER_VERSION>,
1388 {
1389 self.with_connection(|conn| {
1390 mls_storage::write_group_data(
1391 conn,
1392 group_id,
1393 GroupDataType::MessageSecrets,
1394 message_secrets,
1395 )
1396 })
1397 }
1398
1399 fn write_resumption_psk_store<GroupId, ResumptionPskStore>(
1400 &self,
1401 group_id: &GroupId,
1402 resumption_psk_store: &ResumptionPskStore,
1403 ) -> Result<(), Self::Error>
1404 where
1405 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1406 ResumptionPskStore: traits::ResumptionPskStore<STORAGE_PROVIDER_VERSION>,
1407 {
1408 self.with_connection(|conn| {
1409 mls_storage::write_group_data(
1410 conn,
1411 group_id,
1412 GroupDataType::ResumptionPskStore,
1413 resumption_psk_store,
1414 )
1415 })
1416 }
1417
1418 fn write_own_leaf_index<GroupId, LeafNodeIndex>(
1419 &self,
1420 group_id: &GroupId,
1421 own_leaf_index: &LeafNodeIndex,
1422 ) -> Result<(), Self::Error>
1423 where
1424 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1425 LeafNodeIndex: traits::LeafNodeIndex<STORAGE_PROVIDER_VERSION>,
1426 {
1427 self.with_connection(|conn| {
1428 mls_storage::write_group_data(
1429 conn,
1430 group_id,
1431 GroupDataType::OwnLeafIndex,
1432 own_leaf_index,
1433 )
1434 })
1435 }
1436
1437 fn write_group_epoch_secrets<GroupId, GroupEpochSecrets>(
1438 &self,
1439 group_id: &GroupId,
1440 group_epoch_secrets: &GroupEpochSecrets,
1441 ) -> Result<(), Self::Error>
1442 where
1443 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1444 GroupEpochSecrets: traits::GroupEpochSecrets<STORAGE_PROVIDER_VERSION>,
1445 {
1446 self.with_connection(|conn| {
1447 mls_storage::write_group_data(
1448 conn,
1449 group_id,
1450 GroupDataType::GroupEpochSecrets,
1451 group_epoch_secrets,
1452 )
1453 })
1454 }
1455
1456 fn write_signature_key_pair<SignaturePublicKey, SignatureKeyPair>(
1457 &self,
1458 public_key: &SignaturePublicKey,
1459 signature_key_pair: &SignatureKeyPair,
1460 ) -> Result<(), Self::Error>
1461 where
1462 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1463 SignatureKeyPair: traits::SignatureKeyPair<STORAGE_PROVIDER_VERSION>,
1464 {
1465 self.with_connection(|conn| {
1466 mls_storage::write_signature_key_pair(conn, public_key, signature_key_pair)
1467 })
1468 }
1469
1470 fn write_encryption_key_pair<EncryptionKey, HpkeKeyPair>(
1471 &self,
1472 public_key: &EncryptionKey,
1473 key_pair: &HpkeKeyPair,
1474 ) -> Result<(), Self::Error>
1475 where
1476 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1477 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1478 {
1479 self.with_connection(|conn| {
1480 mls_storage::write_encryption_key_pair(conn, public_key, key_pair)
1481 })
1482 }
1483
1484 fn write_encryption_epoch_key_pairs<GroupId, EpochKey, HpkeKeyPair>(
1485 &self,
1486 group_id: &GroupId,
1487 epoch: &EpochKey,
1488 leaf_index: u32,
1489 key_pairs: &[HpkeKeyPair],
1490 ) -> Result<(), Self::Error>
1491 where
1492 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1493 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1494 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1495 {
1496 self.with_connection(|conn| {
1497 mls_storage::write_encryption_epoch_key_pairs(
1498 conn, group_id, epoch, leaf_index, key_pairs,
1499 )
1500 })
1501 }
1502
1503 fn write_key_package<HashReference, KeyPackage>(
1504 &self,
1505 hash_ref: &HashReference,
1506 key_package: &KeyPackage,
1507 ) -> Result<(), Self::Error>
1508 where
1509 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1510 KeyPackage: traits::KeyPackage<STORAGE_PROVIDER_VERSION>,
1511 {
1512 self.with_connection(|conn| mls_storage::write_key_package(conn, hash_ref, key_package))
1513 }
1514
1515 fn write_psk<PskId, PskBundle>(
1516 &self,
1517 psk_id: &PskId,
1518 psk: &PskBundle,
1519 ) -> Result<(), Self::Error>
1520 where
1521 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1522 PskBundle: traits::PskBundle<STORAGE_PROVIDER_VERSION>,
1523 {
1524 self.with_connection(|conn| mls_storage::write_psk(conn, psk_id, psk))
1525 }
1526
1527 fn mls_group_join_config<GroupId, MlsGroupJoinConfig>(
1532 &self,
1533 group_id: &GroupId,
1534 ) -> Result<Option<MlsGroupJoinConfig>, Self::Error>
1535 where
1536 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1537 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<STORAGE_PROVIDER_VERSION>,
1538 {
1539 self.with_connection(|conn| {
1540 mls_storage::read_group_data(conn, group_id, GroupDataType::JoinGroupConfig)
1541 })
1542 }
1543
1544 fn own_leaf_nodes<GroupId, LeafNode>(
1545 &self,
1546 group_id: &GroupId,
1547 ) -> Result<Vec<LeafNode>, Self::Error>
1548 where
1549 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1550 LeafNode: traits::LeafNode<STORAGE_PROVIDER_VERSION>,
1551 {
1552 self.with_connection(|conn| mls_storage::read_own_leaf_nodes(conn, group_id))
1553 }
1554
1555 fn queued_proposal_refs<GroupId, ProposalRef>(
1556 &self,
1557 group_id: &GroupId,
1558 ) -> Result<Vec<ProposalRef>, Self::Error>
1559 where
1560 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1561 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1562 {
1563 self.with_connection(|conn| mls_storage::read_queued_proposal_refs(conn, group_id))
1564 }
1565
1566 fn queued_proposals<GroupId, ProposalRef, QueuedProposal>(
1567 &self,
1568 group_id: &GroupId,
1569 ) -> Result<Vec<(ProposalRef, QueuedProposal)>, Self::Error>
1570 where
1571 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1572 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1573 QueuedProposal: traits::QueuedProposal<STORAGE_PROVIDER_VERSION>,
1574 {
1575 self.with_connection(|conn| mls_storage::read_queued_proposals(conn, group_id))
1576 }
1577
1578 fn tree<GroupId, TreeSync>(&self, group_id: &GroupId) -> Result<Option<TreeSync>, Self::Error>
1579 where
1580 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1581 TreeSync: traits::TreeSync<STORAGE_PROVIDER_VERSION>,
1582 {
1583 self.with_connection(|conn| {
1584 mls_storage::read_group_data(conn, group_id, GroupDataType::Tree)
1585 })
1586 }
1587
1588 fn group_context<GroupId, GroupContext>(
1589 &self,
1590 group_id: &GroupId,
1591 ) -> Result<Option<GroupContext>, Self::Error>
1592 where
1593 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1594 GroupContext: traits::GroupContext<STORAGE_PROVIDER_VERSION>,
1595 {
1596 self.with_connection(|conn| {
1597 mls_storage::read_group_data(conn, group_id, GroupDataType::Context)
1598 })
1599 }
1600
1601 fn interim_transcript_hash<GroupId, InterimTranscriptHash>(
1602 &self,
1603 group_id: &GroupId,
1604 ) -> Result<Option<InterimTranscriptHash>, Self::Error>
1605 where
1606 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1607 InterimTranscriptHash: traits::InterimTranscriptHash<STORAGE_PROVIDER_VERSION>,
1608 {
1609 self.with_connection(|conn| {
1610 mls_storage::read_group_data(conn, group_id, GroupDataType::InterimTranscriptHash)
1611 })
1612 }
1613
1614 fn confirmation_tag<GroupId, ConfirmationTag>(
1615 &self,
1616 group_id: &GroupId,
1617 ) -> Result<Option<ConfirmationTag>, Self::Error>
1618 where
1619 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1620 ConfirmationTag: traits::ConfirmationTag<STORAGE_PROVIDER_VERSION>,
1621 {
1622 self.with_connection(|conn| {
1623 mls_storage::read_group_data(conn, group_id, GroupDataType::ConfirmationTag)
1624 })
1625 }
1626
1627 fn group_state<GroupState, GroupId>(
1628 &self,
1629 group_id: &GroupId,
1630 ) -> Result<Option<GroupState>, Self::Error>
1631 where
1632 GroupState: traits::GroupState<STORAGE_PROVIDER_VERSION>,
1633 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1634 {
1635 self.with_connection(|conn| {
1636 mls_storage::read_group_data(conn, group_id, GroupDataType::GroupState)
1637 })
1638 }
1639
1640 fn message_secrets<GroupId, MessageSecrets>(
1641 &self,
1642 group_id: &GroupId,
1643 ) -> Result<Option<MessageSecrets>, Self::Error>
1644 where
1645 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1646 MessageSecrets: traits::MessageSecrets<STORAGE_PROVIDER_VERSION>,
1647 {
1648 self.with_connection(|conn| {
1649 mls_storage::read_group_data(conn, group_id, GroupDataType::MessageSecrets)
1650 })
1651 }
1652
1653 fn resumption_psk_store<GroupId, ResumptionPskStore>(
1654 &self,
1655 group_id: &GroupId,
1656 ) -> Result<Option<ResumptionPskStore>, Self::Error>
1657 where
1658 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1659 ResumptionPskStore: traits::ResumptionPskStore<STORAGE_PROVIDER_VERSION>,
1660 {
1661 self.with_connection(|conn| {
1662 mls_storage::read_group_data(conn, group_id, GroupDataType::ResumptionPskStore)
1663 })
1664 }
1665
1666 fn own_leaf_index<GroupId, LeafNodeIndex>(
1667 &self,
1668 group_id: &GroupId,
1669 ) -> Result<Option<LeafNodeIndex>, Self::Error>
1670 where
1671 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1672 LeafNodeIndex: traits::LeafNodeIndex<STORAGE_PROVIDER_VERSION>,
1673 {
1674 self.with_connection(|conn| {
1675 mls_storage::read_group_data(conn, group_id, GroupDataType::OwnLeafIndex)
1676 })
1677 }
1678
1679 fn group_epoch_secrets<GroupId, GroupEpochSecrets>(
1680 &self,
1681 group_id: &GroupId,
1682 ) -> Result<Option<GroupEpochSecrets>, Self::Error>
1683 where
1684 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1685 GroupEpochSecrets: traits::GroupEpochSecrets<STORAGE_PROVIDER_VERSION>,
1686 {
1687 self.with_connection(|conn| {
1688 mls_storage::read_group_data(conn, group_id, GroupDataType::GroupEpochSecrets)
1689 })
1690 }
1691
1692 fn signature_key_pair<SignaturePublicKey, SignatureKeyPair>(
1693 &self,
1694 public_key: &SignaturePublicKey,
1695 ) -> Result<Option<SignatureKeyPair>, Self::Error>
1696 where
1697 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1698 SignatureKeyPair: traits::SignatureKeyPair<STORAGE_PROVIDER_VERSION>,
1699 {
1700 self.with_connection(|conn| mls_storage::read_signature_key_pair(conn, public_key))
1701 }
1702
1703 fn encryption_key_pair<HpkeKeyPair, EncryptionKey>(
1704 &self,
1705 public_key: &EncryptionKey,
1706 ) -> Result<Option<HpkeKeyPair>, Self::Error>
1707 where
1708 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1709 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1710 {
1711 self.with_connection(|conn| mls_storage::read_encryption_key_pair(conn, public_key))
1712 }
1713
1714 fn encryption_epoch_key_pairs<GroupId, EpochKey, HpkeKeyPair>(
1715 &self,
1716 group_id: &GroupId,
1717 epoch: &EpochKey,
1718 leaf_index: u32,
1719 ) -> Result<Vec<HpkeKeyPair>, Self::Error>
1720 where
1721 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1722 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1723 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1724 {
1725 self.with_connection(|conn| {
1726 mls_storage::read_encryption_epoch_key_pairs(conn, group_id, epoch, leaf_index)
1727 })
1728 }
1729
1730 fn key_package<HashReference, KeyPackage>(
1731 &self,
1732 hash_ref: &HashReference,
1733 ) -> Result<Option<KeyPackage>, Self::Error>
1734 where
1735 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1736 KeyPackage: traits::KeyPackage<STORAGE_PROVIDER_VERSION>,
1737 {
1738 self.with_connection(|conn| mls_storage::read_key_package(conn, hash_ref))
1739 }
1740
1741 fn psk<PskBundle, PskId>(&self, psk_id: &PskId) -> Result<Option<PskBundle>, Self::Error>
1742 where
1743 PskBundle: traits::PskBundle<STORAGE_PROVIDER_VERSION>,
1744 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1745 {
1746 self.with_connection(|conn| mls_storage::read_psk(conn, psk_id))
1747 }
1748
1749 fn remove_proposal<GroupId, ProposalRef>(
1754 &self,
1755 group_id: &GroupId,
1756 proposal_ref: &ProposalRef,
1757 ) -> Result<(), Self::Error>
1758 where
1759 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1760 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1761 {
1762 self.with_connection(|conn| mls_storage::remove_proposal(conn, group_id, proposal_ref))
1763 }
1764
1765 fn delete_own_leaf_nodes<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1766 where
1767 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1768 {
1769 self.with_connection(|conn| mls_storage::delete_own_leaf_nodes(conn, group_id))
1770 }
1771
1772 fn delete_group_config<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1773 where
1774 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1775 {
1776 self.with_connection(|conn| {
1777 mls_storage::delete_group_data(conn, group_id, GroupDataType::JoinGroupConfig)
1778 })
1779 }
1780
1781 fn delete_tree<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1782 where
1783 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1784 {
1785 self.with_connection(|conn| {
1786 mls_storage::delete_group_data(conn, group_id, GroupDataType::Tree)
1787 })
1788 }
1789
1790 fn delete_confirmation_tag<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1791 where
1792 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1793 {
1794 self.with_connection(|conn| {
1795 mls_storage::delete_group_data(conn, group_id, GroupDataType::ConfirmationTag)
1796 })
1797 }
1798
1799 fn delete_group_state<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1800 where
1801 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1802 {
1803 self.with_connection(|conn| {
1804 mls_storage::delete_group_data(conn, group_id, GroupDataType::GroupState)
1805 })
1806 }
1807
1808 fn delete_context<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1809 where
1810 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1811 {
1812 self.with_connection(|conn| {
1813 mls_storage::delete_group_data(conn, group_id, GroupDataType::Context)
1814 })
1815 }
1816
1817 fn delete_interim_transcript_hash<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1818 where
1819 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1820 {
1821 self.with_connection(|conn| {
1822 mls_storage::delete_group_data(conn, group_id, GroupDataType::InterimTranscriptHash)
1823 })
1824 }
1825
1826 fn delete_message_secrets<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1827 where
1828 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1829 {
1830 self.with_connection(|conn| {
1831 mls_storage::delete_group_data(conn, group_id, GroupDataType::MessageSecrets)
1832 })
1833 }
1834
1835 fn delete_all_resumption_psk_secrets<GroupId>(
1836 &self,
1837 group_id: &GroupId,
1838 ) -> Result<(), Self::Error>
1839 where
1840 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1841 {
1842 self.with_connection(|conn| {
1843 mls_storage::delete_group_data(conn, group_id, GroupDataType::ResumptionPskStore)
1844 })
1845 }
1846
1847 fn delete_own_leaf_index<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1848 where
1849 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1850 {
1851 self.with_connection(|conn| {
1852 mls_storage::delete_group_data(conn, group_id, GroupDataType::OwnLeafIndex)
1853 })
1854 }
1855
1856 fn delete_group_epoch_secrets<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1857 where
1858 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1859 {
1860 self.with_connection(|conn| {
1861 mls_storage::delete_group_data(conn, group_id, GroupDataType::GroupEpochSecrets)
1862 })
1863 }
1864
1865 fn clear_proposal_queue<GroupId, ProposalRef>(
1866 &self,
1867 group_id: &GroupId,
1868 ) -> Result<(), Self::Error>
1869 where
1870 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1871 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1872 {
1873 self.with_connection(|conn| mls_storage::clear_proposal_queue(conn, group_id))
1874 }
1875
1876 fn delete_signature_key_pair<SignaturePublicKey>(
1877 &self,
1878 public_key: &SignaturePublicKey,
1879 ) -> Result<(), Self::Error>
1880 where
1881 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1882 {
1883 self.with_connection(|conn| mls_storage::delete_signature_key_pair(conn, public_key))
1884 }
1885
1886 fn delete_encryption_key_pair<EncryptionKey>(
1887 &self,
1888 public_key: &EncryptionKey,
1889 ) -> Result<(), Self::Error>
1890 where
1891 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1892 {
1893 self.with_connection(|conn| mls_storage::delete_encryption_key_pair(conn, public_key))
1894 }
1895
1896 fn delete_encryption_epoch_key_pairs<GroupId, EpochKey>(
1897 &self,
1898 group_id: &GroupId,
1899 epoch: &EpochKey,
1900 leaf_index: u32,
1901 ) -> Result<(), Self::Error>
1902 where
1903 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1904 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1905 {
1906 self.with_connection(|conn| {
1907 mls_storage::delete_encryption_epoch_key_pairs(conn, group_id, epoch, leaf_index)
1908 })
1909 }
1910
1911 fn delete_key_package<HashReference>(&self, hash_ref: &HashReference) -> Result<(), Self::Error>
1912 where
1913 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1914 {
1915 self.with_connection(|conn| mls_storage::delete_key_package(conn, hash_ref))
1916 }
1917
1918 fn delete_psk<PskId>(&self, psk_id: &PskId) -> Result<(), Self::Error>
1919 where
1920 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1921 {
1922 self.with_connection(|conn| mls_storage::delete_psk(conn, psk_id))
1923 }
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928 use std::collections::BTreeSet;
1929
1930 use mdk_storage_traits::GroupId;
1931 use mdk_storage_traits::Secret;
1932 use mdk_storage_traits::groups::GroupStorage;
1933 use mdk_storage_traits::groups::types::{
1934 Group, GroupExporterSecret, GroupState, SelfUpdateState,
1935 };
1936 use tempfile::tempdir;
1937
1938 use super::*;
1939
1940 #[test]
1941 fn test_new_in_memory() {
1942 let storage = MdkSqliteStorage::new_in_memory();
1943 assert!(storage.is_ok());
1944 let storage = storage.unwrap();
1945 assert_eq!(storage.backend(), Backend::SQLite);
1946 }
1947
1948 #[test]
1949 fn test_backend_type() {
1950 let storage = MdkSqliteStorage::new_in_memory().unwrap();
1951 assert_eq!(storage.backend(), Backend::SQLite);
1952 assert!(storage.backend().is_persistent());
1953 }
1954
1955 #[test]
1956 fn test_file_based_storage() {
1957 let temp_dir = tempdir().unwrap();
1958 let db_path = temp_dir.path().join("test_db.sqlite");
1959
1960 let storage = MdkSqliteStorage::new_unencrypted(&db_path);
1962 assert!(storage.is_ok());
1963
1964 assert!(db_path.exists());
1966
1967 let storage2 = MdkSqliteStorage::new_unencrypted(&db_path);
1969 assert!(storage2.is_ok());
1970
1971 drop(storage);
1973 drop(storage2);
1974 temp_dir.close().unwrap();
1975 }
1976
1977 #[test]
1978 fn test_database_tables() {
1979 let temp_dir = tempdir().unwrap();
1980 let db_path = temp_dir.path().join("migration_test.sqlite");
1981
1982 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
1984
1985 storage.with_connection(|conn| {
1987 let mut stmt = conn
1989 .prepare("SELECT name FROM sqlite_master WHERE type='table'")
1990 .unwrap();
1991 let table_names: Vec<String> = stmt
1992 .query_map([], |row| row.get(0))
1993 .unwrap()
1994 .map(|r| r.unwrap())
1995 .collect();
1996
1997 assert!(table_names.contains(&"groups".to_string()));
1999 assert!(table_names.contains(&"messages".to_string()));
2000 assert!(table_names.contains(&"welcomes".to_string()));
2001 assert!(table_names.contains(&"processed_messages".to_string()));
2002 assert!(table_names.contains(&"processed_welcomes".to_string()));
2003 assert!(table_names.contains(&"group_relays".to_string()));
2004 assert!(table_names.contains(&"group_exporter_secrets".to_string()));
2005
2006 assert!(table_names.contains(&"openmls_group_data".to_string()));
2008 assert!(table_names.contains(&"openmls_proposals".to_string()));
2009 assert!(table_names.contains(&"openmls_own_leaf_nodes".to_string()));
2010 assert!(table_names.contains(&"openmls_key_packages".to_string()));
2011 assert!(table_names.contains(&"openmls_psks".to_string()));
2012 assert!(table_names.contains(&"openmls_signature_keys".to_string()));
2013 assert!(table_names.contains(&"openmls_encryption_keys".to_string()));
2014 assert!(table_names.contains(&"openmls_epoch_key_pairs".to_string()));
2015
2016 let mut pk_stmt = conn
2018 .prepare("PRAGMA table_info(group_exporter_secrets)")
2019 .unwrap();
2020 let pk_columns: Vec<(String, i64)> = pk_stmt
2021 .query_map([], |row| Ok((row.get(1)?, row.get(5)?)))
2022 .unwrap()
2023 .map(|r| r.unwrap())
2024 .filter(|(_, pk_pos)| *pk_pos > 0)
2025 .collect();
2026
2027 assert_eq!(
2028 pk_columns,
2029 vec![
2030 ("mls_group_id".to_string(), 1),
2031 ("epoch".to_string(), 2),
2032 ("label".to_string(), 3),
2033 ]
2034 );
2035 });
2036
2037 drop(storage);
2039 temp_dir.close().unwrap();
2040 }
2041
2042 #[test]
2043 fn test_group_exporter_secrets() {
2044 let storage = MdkSqliteStorage::new_in_memory().unwrap();
2046
2047 let mls_group_id = GroupId::from_slice(vec![1, 2, 3, 4].as_slice());
2049 let group = Group {
2050 mls_group_id: mls_group_id.clone(),
2051 nostr_group_id: [0u8; 32],
2052 name: "Test Group".to_string(),
2053 description: "A test group for exporter secrets".to_string(),
2054 admin_pubkeys: BTreeSet::new(),
2055 last_message_id: None,
2056 last_message_at: None,
2057 last_message_processed_at: None,
2058 epoch: 0,
2059 state: GroupState::Active,
2060 image_hash: None,
2061 image_key: None,
2062 image_nonce: None,
2063 self_update_state: SelfUpdateState::Required,
2064 };
2065
2066 storage.save_group(group.clone()).unwrap();
2068
2069 let secret_epoch_0 = GroupExporterSecret {
2071 mls_group_id: mls_group_id.clone(),
2072 epoch: 0,
2073 secret: Secret::new([0u8; 32]),
2074 };
2075
2076 let secret_epoch_1 = GroupExporterSecret {
2077 mls_group_id: mls_group_id.clone(),
2078 epoch: 1,
2079 secret: Secret::new([0u8; 32]),
2080 };
2081
2082 let mip04_secret_epoch_1 = GroupExporterSecret {
2083 mls_group_id: mls_group_id.clone(),
2084 epoch: 1,
2085 secret: Secret::new([7u8; 32]),
2086 };
2087
2088 storage
2090 .save_group_exporter_secret(secret_epoch_0.clone())
2091 .unwrap();
2092 storage
2093 .save_group_exporter_secret(secret_epoch_1.clone())
2094 .unwrap();
2095 storage
2096 .save_group_mip04_exporter_secret(mip04_secret_epoch_1.clone())
2097 .unwrap();
2098
2099 let retrieved_secret_0 = storage.get_group_exporter_secret(&mls_group_id, 0).unwrap();
2101 assert!(retrieved_secret_0.is_some());
2102 let retrieved_secret_0 = retrieved_secret_0.unwrap();
2103 assert_eq!(retrieved_secret_0, secret_epoch_0);
2104
2105 let retrieved_secret_1 = storage.get_group_exporter_secret(&mls_group_id, 1).unwrap();
2106 assert!(retrieved_secret_1.is_some());
2107 let retrieved_secret_1 = retrieved_secret_1.unwrap();
2108 assert_eq!(retrieved_secret_1, secret_epoch_1);
2109
2110 let retrieved_mip04_secret_1 = storage
2111 .get_group_mip04_exporter_secret(&mls_group_id, 1)
2112 .unwrap();
2113 assert!(retrieved_mip04_secret_1.is_some());
2114 let retrieved_mip04_secret_1 = retrieved_mip04_secret_1.unwrap();
2115 assert_eq!(retrieved_mip04_secret_1, mip04_secret_epoch_1);
2116
2117 let non_existent_epoch = storage
2119 .get_group_exporter_secret(&mls_group_id, 999)
2120 .unwrap();
2121 assert!(non_existent_epoch.is_none());
2122
2123 let non_existent_group_id = GroupId::from_slice(&[9, 9, 9, 9]);
2125 let result = storage.get_group_exporter_secret(&non_existent_group_id, 0);
2126 assert!(result.is_err());
2127
2128 let updated_secret_0 = GroupExporterSecret {
2130 mls_group_id: mls_group_id.clone(),
2131 epoch: 0,
2132 secret: Secret::new([0u8; 32]),
2133 };
2134 storage
2135 .save_group_exporter_secret(updated_secret_0.clone())
2136 .unwrap();
2137
2138 let retrieved_updated_secret = storage
2139 .get_group_exporter_secret(&mls_group_id, 0)
2140 .unwrap()
2141 .unwrap();
2142 assert_eq!(retrieved_updated_secret, updated_secret_0);
2143
2144 let updated_mip04_secret_1 = GroupExporterSecret {
2146 mls_group_id: mls_group_id.clone(),
2147 epoch: 1,
2148 secret: Secret::new([9u8; 32]),
2149 };
2150 storage
2151 .save_group_mip04_exporter_secret(updated_mip04_secret_1.clone())
2152 .unwrap();
2153 let still_group_event_1 = storage
2154 .get_group_exporter_secret(&mls_group_id, 1)
2155 .unwrap()
2156 .unwrap();
2157 assert_eq!(still_group_event_1, secret_epoch_1);
2158 let now_mip04_1 = storage
2159 .get_group_mip04_exporter_secret(&mls_group_id, 1)
2160 .unwrap()
2161 .unwrap();
2162 assert_eq!(now_mip04_1, updated_mip04_secret_1);
2163
2164 let invalid_secret = GroupExporterSecret {
2166 mls_group_id: non_existent_group_id.clone(),
2167 epoch: 0,
2168 secret: Secret::new([0u8; 32]),
2169 };
2170 let result = storage.save_group_exporter_secret(invalid_secret);
2171 assert!(result.is_err());
2172 }
2173
2174 mod encryption_tests {
2179 #[cfg(unix)]
2180 use std::os::unix::fs::PermissionsExt;
2181 use std::thread;
2182
2183 use mdk_storage_traits::Secret;
2184 use mdk_storage_traits::groups::GroupStorage;
2185 use mdk_storage_traits::groups::types::{Group, GroupExporterSecret, GroupState};
2186 use mdk_storage_traits::messages::MessageStorage;
2187 use mdk_storage_traits::test_utils::cross_storage::{
2188 create_test_group, create_test_message, create_test_welcome,
2189 };
2190 use mdk_storage_traits::welcomes::WelcomeStorage;
2191 use nostr::EventId;
2192
2193 use super::*;
2194 use crate::test_utils::ensure_mock_store;
2195
2196 #[test]
2197 fn test_encrypted_storage_creation() {
2198 let temp_dir = tempdir().unwrap();
2199 let db_path = temp_dir.path().join("encrypted.db");
2200
2201 let config = EncryptionConfig::generate().unwrap();
2202 let storage = MdkSqliteStorage::new_with_key(&db_path, config);
2203 assert!(storage.is_ok());
2204
2205 assert!(db_path.exists());
2207
2208 assert!(
2210 encryption::is_database_encrypted(&db_path).unwrap(),
2211 "Database should be encrypted"
2212 );
2213 }
2214
2215 #[test]
2216 fn test_encrypted_storage_reopen_with_correct_key() {
2217 let temp_dir = tempdir().unwrap();
2218 let db_path = temp_dir.path().join("encrypted_reopen.db");
2219
2220 let config = EncryptionConfig::generate().unwrap();
2222 let key = *config.key();
2223
2224 {
2225 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2226 let _ = storage.backend();
2228 }
2229
2230 let config2 = EncryptionConfig::new(key);
2232 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2);
2233 assert!(
2234 storage2.is_ok(),
2235 "Should be able to reopen with correct key"
2236 );
2237 }
2238
2239 #[test]
2240 fn test_encrypted_storage_wrong_key_fails() {
2241 let temp_dir = tempdir().unwrap();
2242 let db_path = temp_dir.path().join("encrypted_wrong_key.db");
2243
2244 let config1 = EncryptionConfig::generate().unwrap();
2246 {
2247 let storage = MdkSqliteStorage::new_with_key(&db_path, config1).unwrap();
2248 drop(storage);
2249 }
2250
2251 let config2 = EncryptionConfig::generate().unwrap();
2253 let result = MdkSqliteStorage::new_with_key(&db_path, config2);
2254
2255 assert!(result.is_err(), "Opening with wrong key should fail");
2256
2257 match result {
2259 Err(error::Error::WrongEncryptionKey) => {}
2260 Err(e) => panic!("Expected WrongEncryptionKey error, got: {:?}", e),
2261 Ok(_) => panic!("Expected error but got success"),
2262 }
2263 }
2264
2265 #[test]
2266 fn test_unencrypted_cannot_read_encrypted() {
2267 let temp_dir = tempdir().unwrap();
2268 let db_path = temp_dir.path().join("encrypted_only.db");
2269
2270 let config = EncryptionConfig::generate().unwrap();
2272 {
2273 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2274 drop(storage);
2275 }
2276
2277 let result = MdkSqliteStorage::new_unencrypted(&db_path);
2279
2280 assert!(
2282 result.is_err(),
2283 "Opening encrypted database without key should fail"
2284 );
2285 }
2286
2287 #[test]
2288 fn test_encrypted_storage_data_persistence() {
2289 let temp_dir = tempdir().unwrap();
2290 let db_path = temp_dir.path().join("encrypted_persist.db");
2291
2292 let config = EncryptionConfig::generate().unwrap();
2293 let key = *config.key();
2294
2295 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
2297 {
2298 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2299
2300 let group = Group {
2301 mls_group_id: mls_group_id.clone(),
2302 nostr_group_id: [0u8; 32],
2303 name: "Encrypted Group".to_string(),
2304 description: "Testing encrypted persistence".to_string(),
2305 admin_pubkeys: BTreeSet::new(),
2306 last_message_id: None,
2307 last_message_at: None,
2308 last_message_processed_at: None,
2309 epoch: 0,
2310 state: GroupState::Active,
2311 image_hash: None,
2312 image_key: None,
2313 image_nonce: None,
2314 self_update_state: SelfUpdateState::Required,
2315 };
2316
2317 storage.save_group(group).unwrap();
2318 }
2319
2320 let config2 = EncryptionConfig::new(key);
2322 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2323
2324 let found_group = storage2.find_group_by_mls_group_id(&mls_group_id).unwrap();
2325 assert!(found_group.is_some());
2326 assert_eq!(found_group.unwrap().name, "Encrypted Group");
2327 }
2328
2329 #[test]
2330 fn test_file_permissions_are_secure() {
2331 let temp_dir = tempdir().unwrap();
2332 let db_path = temp_dir.path().join("secure_perms.db");
2333
2334 let config = EncryptionConfig::generate().unwrap();
2335 let _storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2336
2337 #[cfg(unix)]
2339 {
2340 let metadata = std::fs::metadata(&db_path).unwrap();
2341 let mode = metadata.permissions().mode();
2342
2343 assert_eq!(
2345 mode & 0o077,
2346 0,
2347 "Database file should have owner-only permissions, got {:o}",
2348 mode & 0o777
2349 );
2350 }
2351 }
2352
2353 #[test]
2354 fn test_encrypted_storage_multiple_groups() {
2355 let temp_dir = tempdir().unwrap();
2356 let db_path = temp_dir.path().join("multi_groups.db");
2357
2358 let config = EncryptionConfig::generate().unwrap();
2359 let key = *config.key();
2360
2361 {
2363 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2364
2365 for i in 0..5 {
2366 let mls_group_id = GroupId::from_slice(&[i; 8]);
2367 let mut group = create_test_group(mls_group_id);
2368 group.name = format!("Group {}", i);
2369 group.description = format!("Description {}", i);
2370 storage.save_group(group).unwrap();
2371 }
2372 }
2373
2374 let config2 = EncryptionConfig::new(key);
2376 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2377
2378 let groups = storage2.all_groups().unwrap();
2379 assert_eq!(groups.len(), 5);
2380
2381 for i in 0..5u8 {
2382 let mls_group_id = GroupId::from_slice(&[i; 8]);
2383 let group = storage2
2384 .find_group_by_mls_group_id(&mls_group_id)
2385 .unwrap()
2386 .unwrap();
2387 assert_eq!(group.name, format!("Group {}", i));
2388 }
2389 }
2390
2391 #[test]
2392 fn test_encrypted_storage_messages() {
2393 let temp_dir = tempdir().unwrap();
2394 let db_path = temp_dir.path().join("messages.db");
2395
2396 let config = EncryptionConfig::generate().unwrap();
2397 let key = *config.key();
2398
2399 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
2400
2401 {
2403 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2404
2405 let group = create_test_group(mls_group_id.clone());
2406 storage.save_group(group).unwrap();
2407
2408 let event_id = EventId::all_zeros();
2410 let mut message = create_test_message(mls_group_id.clone(), event_id);
2411 message.content = "Test message content".to_string();
2412 storage.save_message(message).unwrap();
2413 }
2414
2415 let config2 = EncryptionConfig::new(key);
2417 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2418
2419 let messages = storage2.messages(&mls_group_id, None).unwrap();
2420 assert_eq!(messages.len(), 1);
2421 assert_eq!(messages[0].content, "Test message content");
2422 }
2423
2424 #[test]
2425 fn test_encrypted_storage_welcomes() {
2426 let temp_dir = tempdir().unwrap();
2427 let db_path = temp_dir.path().join("welcomes.db");
2428
2429 let config = EncryptionConfig::generate().unwrap();
2430 let key = *config.key();
2431
2432 let mls_group_id = GroupId::from_slice(&[5, 6, 7, 8]);
2433
2434 {
2436 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2437
2438 let group = create_test_group(mls_group_id.clone());
2439 storage.save_group(group).unwrap();
2440
2441 let event_id = EventId::all_zeros();
2442 let welcome = create_test_welcome(mls_group_id.clone(), event_id);
2443 storage.save_welcome(welcome).unwrap();
2444 }
2445
2446 let config2 = EncryptionConfig::new(key);
2448 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2449
2450 let welcomes = storage2.pending_welcomes(None).unwrap();
2451 assert_eq!(welcomes.len(), 1);
2452 }
2453
2454 #[test]
2455 fn test_encrypted_storage_exporter_secrets() {
2456 let temp_dir = tempdir().unwrap();
2457 let db_path = temp_dir.path().join("exporter_secrets.db");
2458
2459 let config = EncryptionConfig::generate().unwrap();
2460 let key = *config.key();
2461
2462 let mls_group_id = GroupId::from_slice(&[10, 20, 30, 40]);
2463
2464 {
2466 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2467
2468 let group = Group {
2469 mls_group_id: mls_group_id.clone(),
2470 nostr_group_id: [0u8; 32],
2471 name: "Exporter Secret Test".to_string(),
2472 description: "Testing exporter secrets".to_string(),
2473 admin_pubkeys: BTreeSet::new(),
2474 last_message_id: None,
2475 last_message_at: None,
2476 last_message_processed_at: None,
2477 epoch: 5,
2478 state: GroupState::Active,
2479 image_hash: None,
2480 image_key: None,
2481 image_nonce: None,
2482 self_update_state: SelfUpdateState::Required,
2483 };
2484 storage.save_group(group).unwrap();
2485
2486 for epoch in 0..=5u64 {
2488 let secret = GroupExporterSecret {
2489 mls_group_id: mls_group_id.clone(),
2490 epoch,
2491 secret: Secret::new([epoch as u8; 32]),
2492 };
2493 storage.save_group_exporter_secret(secret).unwrap();
2494 }
2495 }
2496
2497 let config2 = EncryptionConfig::new(key);
2499 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2500
2501 for epoch in 0..=5u64 {
2502 let secret = storage2
2503 .get_group_exporter_secret(&mls_group_id, epoch)
2504 .unwrap()
2505 .unwrap();
2506 assert_eq!(secret.epoch, epoch);
2507 assert_eq!(secret.secret[0], epoch as u8);
2508 }
2509
2510 let missing = storage2
2512 .get_group_exporter_secret(&mls_group_id, 999)
2513 .unwrap();
2514 assert!(missing.is_none());
2515 }
2516
2517 #[test]
2518 fn test_encrypted_storage_with_nested_directory() {
2519 let temp_dir = tempdir().unwrap();
2520 let db_path = temp_dir
2521 .path()
2522 .join("deep")
2523 .join("nested")
2524 .join("path")
2525 .join("db.sqlite");
2526
2527 let config = EncryptionConfig::generate().unwrap();
2528 let storage = MdkSqliteStorage::new_with_key(&db_path, config);
2529 assert!(storage.is_ok());
2530
2531 assert!(db_path.parent().unwrap().exists());
2533 assert!(db_path.exists());
2534
2535 assert!(encryption::is_database_encrypted(&db_path).unwrap());
2537 }
2538
2539 #[test]
2540 fn test_encrypted_unencrypted_incompatibility() {
2541 let temp_dir = tempdir().unwrap();
2542 let db_path = temp_dir.path().join("compat_test.db");
2543
2544 {
2546 let _storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
2547 }
2548
2549 assert!(!encryption::is_database_encrypted(&db_path).unwrap());
2551
2552 let encrypted_path = temp_dir.path().join("compat_encrypted.db");
2554 {
2555 let config = EncryptionConfig::generate().unwrap();
2556 let _storage = MdkSqliteStorage::new_with_key(&encrypted_path, config).unwrap();
2557 }
2558
2559 assert!(encryption::is_database_encrypted(&encrypted_path).unwrap());
2561 }
2562
2563 #[test]
2564 fn test_new_on_unencrypted_database_returns_correct_error() {
2565 ensure_mock_store();
2572
2573 let temp_dir = tempdir().unwrap();
2574 let db_path = temp_dir.path().join("unencrypted_then_new.db");
2575
2576 {
2578 let _storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
2579 }
2580
2581 assert!(!encryption::is_database_encrypted(&db_path).unwrap());
2583
2584 let result = MdkSqliteStorage::new(&db_path, "com.test.app", "test.key.id");
2586
2587 assert!(result.is_err());
2588 match result {
2589 Err(Error::UnencryptedDatabaseWithEncryption) => {
2590 }
2593 Err(Error::KeyringEntryMissingForExistingDatabase { .. }) => {
2594 panic!(
2595 "Got KeyringEntryMissingForExistingDatabase but should have gotten \
2596 UnencryptedDatabaseWithEncryption. The database is unencrypted, not \
2597 encrypted with a missing key."
2598 );
2599 }
2600 Err(other) => {
2601 panic!("Unexpected error: {:?}", other);
2602 }
2603 Ok(_) => {
2604 panic!("Expected an error but got Ok");
2605 }
2606 }
2607 }
2608
2609 #[test]
2610 fn test_new_with_key_on_unencrypted_database_returns_correct_error() {
2611 let temp_dir = tempdir().unwrap();
2616 let db_path = temp_dir.path().join("unencrypted_then_new_with_key.db");
2617
2618 {
2620 let _storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
2621 }
2622
2623 assert!(!encryption::is_database_encrypted(&db_path).unwrap());
2625
2626 let config = EncryptionConfig::generate().unwrap();
2629 let result = MdkSqliteStorage::new_with_key(&db_path, config);
2630
2631 assert!(result.is_err());
2632 match result {
2633 Err(Error::UnencryptedDatabaseWithEncryption) => {
2634 }
2637 Err(Error::WrongEncryptionKey) => {
2638 panic!(
2639 "Got WrongEncryptionKey but should have gotten \
2640 UnencryptedDatabaseWithEncryption. The database is unencrypted, not \
2641 encrypted with a different key."
2642 );
2643 }
2644 Err(other) => {
2645 panic!("Unexpected error: {:?}", other);
2646 }
2647 Ok(_) => {
2648 panic!("Expected an error but got Ok");
2649 }
2650 }
2651 }
2652
2653 #[test]
2654 fn test_encrypted_storage_large_data() {
2655 let temp_dir = tempdir().unwrap();
2656 let db_path = temp_dir.path().join("large_data.db");
2657
2658 let config = EncryptionConfig::generate().unwrap();
2659 let key = *config.key();
2660
2661 let mls_group_id = GroupId::from_slice(&[99; 8]);
2662
2663 let large_content = "x".repeat(10_000);
2665 {
2666 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2667
2668 let mut group = create_test_group(mls_group_id.clone());
2669 group.name = "Large Data Test".to_string();
2670 group.description = "Testing large data".to_string();
2671 storage.save_group(group).unwrap();
2672
2673 let event_id = EventId::all_zeros();
2674 let mut message = create_test_message(mls_group_id.clone(), event_id);
2675 message.content = large_content.clone();
2676 storage.save_message(message).unwrap();
2677 }
2678
2679 let config2 = EncryptionConfig::new(key);
2681 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2682
2683 let messages = storage2.messages(&mls_group_id, None).unwrap();
2684 assert_eq!(messages.len(), 1);
2685 assert_eq!(messages[0].content, large_content);
2686 }
2687
2688 #[test]
2689 fn test_encrypted_storage_concurrent_reads() {
2690 let temp_dir = tempdir().unwrap();
2691 let db_path = temp_dir.path().join("concurrent.db");
2692
2693 let config = EncryptionConfig::generate().unwrap();
2694 let key = *config.key();
2695
2696 let mls_group_id = GroupId::from_slice(&[77; 8]);
2697
2698 {
2700 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2701
2702 let mut group = create_test_group(mls_group_id.clone());
2703 group.name = "Concurrent Test".to_string();
2704 group.description = "Testing concurrent access".to_string();
2705 storage.save_group(group).unwrap();
2706 }
2707
2708 let config1 = EncryptionConfig::new(key);
2710 let config2 = EncryptionConfig::new(key);
2711
2712 let storage1 = MdkSqliteStorage::new_with_key(&db_path, config1).unwrap();
2713 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2714
2715 let group1 = storage1
2717 .find_group_by_mls_group_id(&mls_group_id)
2718 .unwrap()
2719 .unwrap();
2720 let group2 = storage2
2721 .find_group_by_mls_group_id(&mls_group_id)
2722 .unwrap()
2723 .unwrap();
2724
2725 assert_eq!(group1.name, group2.name);
2726 }
2727
2728 #[cfg(unix)]
2729 #[test]
2730 fn test_encrypted_storage_sidecar_file_permissions() {
2731 let temp_dir = tempdir().unwrap();
2732 let db_path = temp_dir.path().join("sidecar_test.db");
2733
2734 let config = EncryptionConfig::generate().unwrap();
2735 let key = *config.key();
2736
2737 {
2739 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2740
2741 for i in 0..10 {
2743 let mls_group_id = GroupId::from_slice(&[i; 8]);
2744 let mut group = create_test_group(mls_group_id);
2745 group.name = format!("Group {}", i);
2746 group.description = format!("Description {}", i);
2747 storage.save_group(group).unwrap();
2748 }
2749 }
2750
2751 let config2 = EncryptionConfig::new(key);
2753 let _storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2754
2755 let db_metadata = std::fs::metadata(&db_path).unwrap();
2757 let db_mode = db_metadata.permissions().mode();
2758 assert_eq!(
2759 db_mode & 0o077,
2760 0,
2761 "Database file should have owner-only permissions, got {:o}",
2762 db_mode & 0o777
2763 );
2764
2765 let sidecar_suffixes = ["-wal", "-shm", "-journal"];
2767 for suffix in &sidecar_suffixes {
2768 let sidecar_path = temp_dir.path().join(format!("sidecar_test.db{}", suffix));
2769 if sidecar_path.exists() {
2770 let metadata = std::fs::metadata(&sidecar_path).unwrap();
2771 let mode = metadata.permissions().mode();
2772 assert_eq!(
2773 mode & 0o077,
2774 0,
2775 "Sidecar file {} should have owner-only permissions, got {:o}",
2776 suffix,
2777 mode & 0o777
2778 );
2779 }
2780 }
2781 }
2782
2783 #[test]
2784 fn test_encryption_config_key_is_accessible() {
2785 let key = [0xDE; 32];
2786 let config = EncryptionConfig::new(key);
2787
2788 assert_eq!(config.key().len(), 32);
2790 assert_eq!(config.key()[0], 0xDE);
2791 assert_eq!(config.key()[31], 0xDE);
2792 }
2793
2794 #[test]
2795 fn test_encrypted_storage_empty_group_name() {
2796 let temp_dir = tempdir().unwrap();
2797 let db_path = temp_dir.path().join("empty_name.db");
2798
2799 let config = EncryptionConfig::generate().unwrap();
2800 let key = *config.key();
2801
2802 let mls_group_id = GroupId::from_slice(&[0xAB; 8]);
2803
2804 {
2806 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2807
2808 let mut group = create_test_group(mls_group_id.clone());
2809 group.name = String::new();
2810 group.description = String::new();
2811 storage.save_group(group).unwrap();
2812 }
2813
2814 let config2 = EncryptionConfig::new(key);
2816 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2817
2818 let group = storage2
2819 .find_group_by_mls_group_id(&mls_group_id)
2820 .unwrap()
2821 .unwrap();
2822 assert!(group.name.is_empty());
2823 assert!(group.description.is_empty());
2824 }
2825
2826 #[test]
2827 fn test_encrypted_storage_unicode_content() {
2828 let temp_dir = tempdir().unwrap();
2829 let db_path = temp_dir.path().join("unicode.db");
2830
2831 let config = EncryptionConfig::generate().unwrap();
2832 let key = *config.key();
2833
2834 let mls_group_id = GroupId::from_slice(&[0xCD; 8]);
2835 let unicode_content = "Hello 世界! 🎉 Ñoño مرحبا Привет 日本語 한국어 ελληνικά";
2836
2837 {
2839 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
2840
2841 let mut group = create_test_group(mls_group_id.clone());
2842 group.name = "Тест группа 测试组".to_string();
2843 group.description = "描述 описание".to_string();
2844 storage.save_group(group).unwrap();
2845
2846 let event_id = EventId::all_zeros();
2847 let mut message = create_test_message(mls_group_id.clone(), event_id);
2848 message.content = unicode_content.to_string();
2849 storage.save_message(message).unwrap();
2850 }
2851
2852 let config2 = EncryptionConfig::new(key);
2854 let storage2 = MdkSqliteStorage::new_with_key(&db_path, config2).unwrap();
2855
2856 let group = storage2
2857 .find_group_by_mls_group_id(&mls_group_id)
2858 .unwrap()
2859 .unwrap();
2860 assert_eq!(group.name, "Тест группа 测试组");
2861 assert_eq!(group.description, "描述 описание");
2862
2863 let messages = storage2.messages(&mls_group_id, None).unwrap();
2864 assert_eq!(messages[0].content, unicode_content);
2865 }
2866
2867 #[test]
2872 fn test_existing_db_with_missing_keyring_entry_fails() {
2873 ensure_mock_store();
2874
2875 let temp_dir = tempdir().unwrap();
2876 let db_path = temp_dir.path().join("missing_key_test.db");
2877
2878 let service_id = "test.mdk.storage.missingkey";
2879 let db_key_id = "test.key.missingkeytest";
2880
2881 let _ = keyring::delete_db_key(service_id, db_key_id);
2883
2884 {
2886 let storage = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2887 assert!(storage.is_ok(), "Should create database successfully");
2888 }
2889
2890 assert!(db_path.exists(), "Database file should exist");
2892
2893 keyring::delete_db_key(service_id, db_key_id).unwrap();
2895
2896 let key_check = keyring::get_db_key(service_id, db_key_id).unwrap();
2898 assert!(key_check.is_none(), "Key should be deleted");
2899
2900 let result = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2903
2904 assert!(result.is_err(), "Should fail when keyring entry is missing");
2905
2906 match result {
2907 Err(error::Error::KeyringEntryMissingForExistingDatabase {
2908 db_path: err_path,
2909 service_id: err_service,
2910 db_key_id: err_key,
2911 }) => {
2912 assert!(
2913 err_path.contains("missing_key_test.db"),
2914 "Error should contain database path"
2915 );
2916 assert_eq!(err_service, service_id);
2917 assert_eq!(err_key, db_key_id);
2918 }
2919 Err(e) => panic!(
2920 "Expected KeyringEntryMissingForExistingDatabase error, got: {:?}",
2921 e
2922 ),
2923 Ok(_) => panic!("Expected error but got success"),
2924 }
2925
2926 let key_after = keyring::get_db_key(service_id, db_key_id).unwrap();
2928 assert!(
2929 key_after.is_none(),
2930 "No new key should have been stored in keyring"
2931 );
2932 }
2933
2934 #[test]
2936 fn test_new_db_with_keyring_creates_key() {
2937 ensure_mock_store();
2938
2939 let temp_dir = tempdir().unwrap();
2940 let db_path = temp_dir.path().join("new_db_keyring.db");
2941
2942 let service_id = "test.mdk.storage.newdb";
2943 let db_key_id = "test.key.newdbtest";
2944
2945 let _ = keyring::delete_db_key(service_id, db_key_id);
2947
2948 assert!(!db_path.exists(), "Database should not exist yet");
2950
2951 let storage = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
2953 assert!(storage.is_ok(), "Should create database successfully");
2954
2955 assert!(db_path.exists(), "Database file should exist");
2957
2958 let key = keyring::get_db_key(service_id, db_key_id).unwrap();
2960 assert!(key.is_some(), "Key should be stored in keyring");
2961
2962 assert!(
2964 encryption::is_database_encrypted(&db_path).unwrap(),
2965 "Database should be encrypted"
2966 );
2967
2968 drop(storage);
2970 keyring::delete_db_key(service_id, db_key_id).unwrap();
2971 }
2972
2973 #[test]
2975 fn test_reopen_db_with_keyring_succeeds() {
2976 ensure_mock_store();
2977
2978 let temp_dir = tempdir().unwrap();
2979 let db_path = temp_dir.path().join("reopen_keyring.db");
2980
2981 let service_id = "test.mdk.storage.reopen";
2982 let db_key_id = "test.key.reopentest";
2983
2984 let _ = keyring::delete_db_key(service_id, db_key_id);
2986
2987 let mls_group_id = GroupId::from_slice(&[0xAA; 8]);
2988
2989 {
2991 let storage = MdkSqliteStorage::new(&db_path, service_id, db_key_id).unwrap();
2992
2993 let mut group = create_test_group(mls_group_id.clone());
2994 group.name = "Keyring Reopen Test".to_string();
2995 storage.save_group(group).unwrap();
2996 }
2997
2998 let storage2 = MdkSqliteStorage::new(&db_path, service_id, db_key_id);
3000 assert!(storage2.is_ok(), "Should reopen database successfully");
3001
3002 let storage2 = storage2.unwrap();
3004 let group = storage2
3005 .find_group_by_mls_group_id(&mls_group_id)
3006 .unwrap()
3007 .unwrap();
3008 assert_eq!(group.name, "Keyring Reopen Test");
3009
3010 drop(storage2);
3012 keyring::delete_db_key(service_id, db_key_id).unwrap();
3013 }
3014
3015 #[test]
3017 fn test_concurrent_encrypted_access_same_key() {
3018 let temp_dir = tempdir().unwrap();
3019 let db_path = temp_dir.path().join("concurrent_encrypted.db");
3020
3021 let config = EncryptionConfig::generate().unwrap();
3022 let key = *config.key();
3023
3024 {
3026 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
3027 let group = create_test_group(GroupId::from_slice(&[1, 2, 3, 4]));
3028 storage.save_group(group).unwrap();
3029 }
3030
3031 let num_threads = 5;
3033 let handles: Vec<_> = (0..num_threads)
3034 .map(|_| {
3035 let db_path = db_path.clone();
3036 thread::spawn(move || {
3037 let config = EncryptionConfig::new(key);
3038 let storage = MdkSqliteStorage::new_with_key(&db_path, config).unwrap();
3039 let groups = storage.all_groups().unwrap();
3040 assert_eq!(groups.len(), 1);
3041 groups
3042 })
3043 })
3044 .collect();
3045
3046 for handle in handles {
3048 let groups = handle.join().unwrap();
3049 assert_eq!(groups.len(), 1);
3050 }
3051 }
3052
3053 #[test]
3055 fn test_multiple_encrypted_databases_different_keys() {
3056 let temp_dir = tempdir().unwrap();
3057
3058 let db1_path = temp_dir.path().join("db1.db");
3060 let db2_path = temp_dir.path().join("db2.db");
3061 let db3_path = temp_dir.path().join("db3.db");
3062
3063 let config1 = EncryptionConfig::generate().unwrap();
3064 let config2 = EncryptionConfig::generate().unwrap();
3065 let config3 = EncryptionConfig::generate().unwrap();
3066
3067 let key1 = *config1.key();
3068 let key2 = *config2.key();
3069 let key3 = *config3.key();
3070
3071 {
3073 let storage1 = MdkSqliteStorage::new_with_key(&db1_path, config1).unwrap();
3074 let mut group1 = create_test_group(GroupId::from_slice(&[1]));
3075 group1.name = "Database 1".to_string();
3076 storage1.save_group(group1).unwrap();
3077
3078 let storage2 = MdkSqliteStorage::new_with_key(&db2_path, config2).unwrap();
3079 let mut group2 = create_test_group(GroupId::from_slice(&[2]));
3080 group2.name = "Database 2".to_string();
3081 storage2.save_group(group2).unwrap();
3082
3083 let storage3 = MdkSqliteStorage::new_with_key(&db3_path, config3).unwrap();
3084 let mut group3 = create_test_group(GroupId::from_slice(&[3]));
3085 group3.name = "Database 3".to_string();
3086 storage3.save_group(group3).unwrap();
3087 }
3088
3089 let config1_reopen = EncryptionConfig::new(key1);
3091 let config2_reopen = EncryptionConfig::new(key2);
3092 let config3_reopen = EncryptionConfig::new(key3);
3093
3094 let storage1 = MdkSqliteStorage::new_with_key(&db1_path, config1_reopen).unwrap();
3095 let storage2 = MdkSqliteStorage::new_with_key(&db2_path, config2_reopen).unwrap();
3096 let storage3 = MdkSqliteStorage::new_with_key(&db3_path, config3_reopen).unwrap();
3097
3098 let group1 = storage1
3100 .find_group_by_mls_group_id(&GroupId::from_slice(&[1]))
3101 .unwrap()
3102 .unwrap();
3103 assert_eq!(group1.name, "Database 1");
3104
3105 let group2 = storage2
3106 .find_group_by_mls_group_id(&GroupId::from_slice(&[2]))
3107 .unwrap()
3108 .unwrap();
3109 assert_eq!(group2.name, "Database 2");
3110
3111 let group3 = storage3
3112 .find_group_by_mls_group_id(&GroupId::from_slice(&[3]))
3113 .unwrap()
3114 .unwrap();
3115 assert_eq!(group3.name, "Database 3");
3116
3117 let wrong_config = EncryptionConfig::new(key1);
3119 let result = MdkSqliteStorage::new_with_key(&db2_path, wrong_config);
3120 assert!(result.is_err());
3121 }
3122 }
3123
3124 mod migration_tests {
3129 use super::*;
3130
3131 #[test]
3132 fn test_fresh_database_has_all_tables() {
3133 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3134
3135 let expected_mdk_tables = [
3137 "groups",
3138 "group_relays",
3139 "group_exporter_secrets",
3140 "messages",
3141 "processed_messages",
3142 "welcomes",
3143 "processed_welcomes",
3144 ];
3145
3146 let expected_openmls_tables = [
3148 "openmls_group_data",
3149 "openmls_proposals",
3150 "openmls_own_leaf_nodes",
3151 "openmls_key_packages",
3152 "openmls_psks",
3153 "openmls_signature_keys",
3154 "openmls_encryption_keys",
3155 "openmls_epoch_key_pairs",
3156 ];
3157
3158 storage.with_connection(|conn| {
3159 let mut stmt = conn
3161 .prepare(
3162 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
3163 )
3164 .unwrap();
3165 let table_names: Vec<String> = stmt
3166 .query_map([], |row| row.get(0))
3167 .unwrap()
3168 .map(|r| r.unwrap())
3169 .collect();
3170
3171 for table in &expected_mdk_tables {
3173 assert!(
3174 table_names.contains(&table.to_string()),
3175 "Missing MDK table: {}",
3176 table
3177 );
3178 }
3179
3180 for table in &expected_openmls_tables {
3182 assert!(
3183 table_names.contains(&table.to_string()),
3184 "Missing OpenMLS table: {}",
3185 table
3186 );
3187 }
3188 });
3189 }
3190
3191 #[test]
3192 fn test_all_indexes_exist() {
3193 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3194
3195 let expected_indexes = [
3197 "idx_groups_nostr_group_id",
3198 "idx_group_relays_mls_group_id",
3199 "idx_group_exporter_secrets_mls_group_id",
3200 "idx_messages_mls_group_id",
3201 "idx_messages_wrapper_event_id",
3202 "idx_messages_created_at",
3203 "idx_messages_pubkey",
3204 "idx_messages_kind",
3205 "idx_messages_state",
3206 "idx_processed_messages_message_event_id",
3207 "idx_processed_messages_state",
3208 "idx_processed_messages_processed_at",
3209 "idx_welcomes_mls_group_id",
3210 "idx_welcomes_wrapper_event_id",
3211 "idx_welcomes_state",
3212 "idx_welcomes_nostr_group_id",
3213 "idx_processed_welcomes_welcome_event_id",
3214 "idx_processed_welcomes_state",
3215 "idx_processed_welcomes_processed_at",
3216 ];
3217
3218 storage.with_connection(|conn| {
3219 let mut stmt = conn
3220 .prepare("SELECT name FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%'")
3221 .unwrap();
3222 let index_names: Vec<String> = stmt
3223 .query_map([], |row| row.get(0))
3224 .unwrap()
3225 .map(|r| r.unwrap())
3226 .collect();
3227
3228 for idx in &expected_indexes {
3229 assert!(
3230 index_names.contains(&idx.to_string()),
3231 "Missing index: {}. Found indexes: {:?}",
3232 idx,
3233 index_names
3234 );
3235 }
3236 });
3237 }
3238
3239 #[test]
3240 fn test_foreign_key_constraints_work() {
3241 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3242
3243 storage.with_connection(|conn| {
3244 let fk_enabled: i32 = conn
3246 .query_row("PRAGMA foreign_keys", [], |row| row.get(0))
3247 .unwrap();
3248 assert_eq!(fk_enabled, 1, "Foreign keys should be enabled");
3249
3250 let result = conn.execute(
3252 "INSERT INTO group_relays (mls_group_id, relay_url) VALUES (?, ?)",
3253 rusqlite::params![vec![1u8, 2u8, 3u8, 4u8], "wss://relay.example.com"],
3254 );
3255 assert!(result.is_err(), "Should fail due to foreign key constraint");
3256 });
3257 }
3258
3259 #[test]
3260 fn test_openmls_group_data_check_constraint() {
3261 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3262
3263 storage.with_connection(|conn| {
3264 let valid_result = conn.execute(
3266 "INSERT INTO openmls_group_data (provider_version, group_id, data_type, group_data) VALUES (?, ?, ?, ?)",
3267 rusqlite::params![1, vec![1u8, 2u8, 3u8], "tree", vec![4u8, 5u8, 6u8]],
3268 );
3269 assert!(valid_result.is_ok(), "Valid data_type should succeed");
3270
3271 let invalid_result = conn.execute(
3273 "INSERT INTO openmls_group_data (provider_version, group_id, data_type, group_data) VALUES (?, ?, ?, ?)",
3274 rusqlite::params![1, vec![7u8, 8u8, 9u8], "invalid_type", vec![10u8, 11u8]],
3275 );
3276 assert!(
3277 invalid_result.is_err(),
3278 "Invalid data_type should fail CHECK constraint"
3279 );
3280 });
3281 }
3282
3283 #[test]
3284 fn test_schema_matches_plan_specification() {
3285 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3286
3287 storage.with_connection(|conn| {
3288 let groups_info: Vec<(String, String)> = conn
3290 .prepare("PRAGMA table_info(groups)")
3291 .unwrap()
3292 .query_map([], |row| Ok((row.get(1)?, row.get(2)?)))
3293 .unwrap()
3294 .map(|r| r.unwrap())
3295 .collect();
3296
3297 let groups_columns: Vec<&str> =
3298 groups_info.iter().map(|(n, _)| n.as_str()).collect();
3299 assert!(groups_columns.contains(&"mls_group_id"));
3300 assert!(groups_columns.contains(&"nostr_group_id"));
3301 assert!(groups_columns.contains(&"name"));
3302 assert!(groups_columns.contains(&"description"));
3303 assert!(groups_columns.contains(&"admin_pubkeys"));
3304 assert!(groups_columns.contains(&"epoch"));
3305 assert!(groups_columns.contains(&"state"));
3306
3307 let messages_info: Vec<String> = conn
3309 .prepare("PRAGMA table_info(messages)")
3310 .unwrap()
3311 .query_map([], |row| row.get(1))
3312 .unwrap()
3313 .map(|r| r.unwrap())
3314 .collect();
3315
3316 assert!(messages_info.contains(&"mls_group_id".to_string()));
3317 assert!(messages_info.contains(&"id".to_string()));
3318 assert!(messages_info.contains(&"pubkey".to_string()));
3319 assert!(messages_info.contains(&"kind".to_string()));
3320 assert!(messages_info.contains(&"created_at".to_string()));
3321 assert!(messages_info.contains(&"content".to_string()));
3322 assert!(messages_info.contains(&"wrapper_event_id".to_string()));
3323 });
3324 }
3325 }
3326
3327 mod snapshot_tests {
3332 use std::collections::BTreeSet;
3333
3334 use mdk_storage_traits::groups::GroupStorage;
3335 use mdk_storage_traits::groups::types::{
3336 Group, GroupExporterSecret, GroupState, SelfUpdateState,
3337 };
3338 use mdk_storage_traits::{GroupId, MdkStorageProvider, Secret};
3339
3340 use super::*;
3341
3342 fn create_test_group(id: u8) -> Group {
3343 Group {
3344 mls_group_id: GroupId::from_slice(&[id; 32]),
3345 nostr_group_id: [id; 32],
3346 name: format!("Test Group {}", id),
3347 description: format!("Description {}", id),
3348 admin_pubkeys: BTreeSet::new(),
3349 last_message_id: None,
3350 last_message_at: None,
3351 last_message_processed_at: None,
3352 epoch: 0,
3353 state: GroupState::Active,
3354 image_hash: None,
3355 image_key: None,
3356 image_nonce: None,
3357 self_update_state: SelfUpdateState::Required,
3358 }
3359 }
3360
3361 #[test]
3362 fn test_snapshot_and_rollback_group_state() {
3363 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3364
3365 let group = create_test_group(1);
3367 let group_id = group.mls_group_id.clone();
3368 storage.save_group(group).unwrap();
3369
3370 let initial_group = storage
3372 .find_group_by_mls_group_id(&group_id)
3373 .unwrap()
3374 .unwrap();
3375 assert_eq!(initial_group.name, "Test Group 1");
3376 assert_eq!(initial_group.epoch, 0);
3377
3378 storage
3380 .create_group_snapshot(&group_id, "snap_epoch_0")
3381 .unwrap();
3382
3383 let mut modified_group = initial_group.clone();
3385 modified_group.name = "Modified Group".to_string();
3386 modified_group.epoch = 1;
3387 storage.save_group(modified_group).unwrap();
3388
3389 let after_mod = storage
3391 .find_group_by_mls_group_id(&group_id)
3392 .unwrap()
3393 .unwrap();
3394 assert_eq!(after_mod.name, "Modified Group");
3395 assert_eq!(after_mod.epoch, 1);
3396
3397 storage
3399 .rollback_group_to_snapshot(&group_id, "snap_epoch_0")
3400 .unwrap();
3401
3402 let after_rollback = storage
3404 .find_group_by_mls_group_id(&group_id)
3405 .unwrap()
3406 .unwrap();
3407 assert_eq!(after_rollback.name, "Test Group 1");
3408 assert_eq!(after_rollback.epoch, 0);
3409 }
3410
3411 #[test]
3412 fn test_snapshot_release_without_rollback() {
3413 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3414
3415 let group = create_test_group(2);
3417 let group_id = group.mls_group_id.clone();
3418 storage.save_group(group).unwrap();
3419
3420 storage
3422 .create_group_snapshot(&group_id, "snap_to_release")
3423 .unwrap();
3424
3425 let mut modified = storage
3427 .find_group_by_mls_group_id(&group_id)
3428 .unwrap()
3429 .unwrap();
3430 modified.name = "Modified Name".to_string();
3431 storage.save_group(modified).unwrap();
3432
3433 storage
3435 .release_group_snapshot(&group_id, "snap_to_release")
3436 .unwrap();
3437
3438 let final_state = storage
3440 .find_group_by_mls_group_id(&group_id)
3441 .unwrap()
3442 .unwrap();
3443 assert_eq!(final_state.name, "Modified Name");
3444 }
3445
3446 #[test]
3447 fn test_snapshot_with_exporter_secrets_rollback() {
3448 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3449
3450 let group = create_test_group(3);
3452 let group_id = group.mls_group_id.clone();
3453 storage.save_group(group).unwrap();
3454
3455 let secret_0 = GroupExporterSecret {
3457 mls_group_id: group_id.clone(),
3458 epoch: 0,
3459 secret: Secret::new([0u8; 32]),
3460 };
3461 storage.save_group_exporter_secret(secret_0).unwrap();
3462
3463 storage
3465 .create_group_snapshot(&group_id, "snap_secrets")
3466 .unwrap();
3467
3468 let secret_1 = GroupExporterSecret {
3470 mls_group_id: group_id.clone(),
3471 epoch: 1,
3472 secret: Secret::new([1u8; 32]),
3473 };
3474 storage.save_group_exporter_secret(secret_1).unwrap();
3475
3476 let secret_check = storage.get_group_exporter_secret(&group_id, 1).unwrap();
3478 assert!(secret_check.is_some());
3479
3480 storage
3482 .rollback_group_to_snapshot(&group_id, "snap_secrets")
3483 .unwrap();
3484
3485 let after_rollback = storage.get_group_exporter_secret(&group_id, 1).unwrap();
3487 assert!(after_rollback.is_none());
3488
3489 let epoch_0 = storage.get_group_exporter_secret(&group_id, 0).unwrap();
3491 assert!(epoch_0.is_some());
3492 }
3493
3494 #[test]
3495 fn test_snapshot_rollback_preserves_exporter_secret_labels() {
3496 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3497
3498 let group = create_test_group(33);
3499 let group_id = group.mls_group_id.clone();
3500 storage.save_group(group).unwrap();
3501
3502 let initial_group_event = GroupExporterSecret {
3503 mls_group_id: group_id.clone(),
3504 epoch: 0,
3505 secret: Secret::new([10u8; 32]),
3506 };
3507 let initial_mip04 = GroupExporterSecret {
3508 mls_group_id: group_id.clone(),
3509 epoch: 0,
3510 secret: Secret::new([20u8; 32]),
3511 };
3512
3513 storage
3514 .save_group_exporter_secret(initial_group_event.clone())
3515 .unwrap();
3516 storage
3517 .save_group_mip04_exporter_secret(initial_mip04.clone())
3518 .unwrap();
3519
3520 storage
3521 .create_group_snapshot(&group_id, "snap_labels")
3522 .unwrap();
3523
3524 let mutated_group_event = GroupExporterSecret {
3525 mls_group_id: group_id.clone(),
3526 epoch: 0,
3527 secret: Secret::new([30u8; 32]),
3528 };
3529 let mutated_mip04 = GroupExporterSecret {
3530 mls_group_id: group_id.clone(),
3531 epoch: 0,
3532 secret: Secret::new([40u8; 32]),
3533 };
3534
3535 storage
3536 .save_group_exporter_secret(mutated_group_event)
3537 .unwrap();
3538 storage
3539 .save_group_mip04_exporter_secret(mutated_mip04)
3540 .unwrap();
3541
3542 storage
3543 .rollback_group_to_snapshot(&group_id, "snap_labels")
3544 .unwrap();
3545
3546 let restored_group_event = storage
3547 .get_group_exporter_secret(&group_id, 0)
3548 .unwrap()
3549 .unwrap();
3550 let restored_mip04 = storage
3551 .get_group_mip04_exporter_secret(&group_id, 0)
3552 .unwrap()
3553 .unwrap();
3554
3555 assert_eq!(restored_group_event, initial_group_event);
3556 assert_eq!(restored_mip04, initial_mip04);
3557 }
3558
3559 #[test]
3560 fn test_snapshot_rollback_legacy_unlabeled_exporter_secret_row_key_regression() {
3561 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3562
3563 let group = create_test_group(34);
3564 let group_id = group.mls_group_id.clone();
3565 storage.save_group(group).unwrap();
3566
3567 let initial_group_event = GroupExporterSecret {
3568 mls_group_id: group_id.clone(),
3569 epoch: 7,
3570 secret: Secret::new([55u8; 32]),
3571 };
3572 storage
3573 .save_group_exporter_secret(initial_group_event.clone())
3574 .unwrap();
3575
3576 storage
3577 .create_group_snapshot(&group_id, "snap_legacy_unlabeled_key")
3578 .unwrap();
3579
3580 let legacy_row_key =
3584 serde_json::to_vec(&(group_id.as_slice().to_vec(), 7_i64)).unwrap();
3585 storage
3586 .with_connection(|conn| {
3587 conn.execute(
3588 "UPDATE group_state_snapshots SET row_key = ? WHERE snapshot_name = ? AND group_id = ? AND table_name = 'group_exporter_secrets'",
3589 rusqlite::params![
3590 legacy_row_key,
3591 "snap_legacy_unlabeled_key",
3592 group_id.as_slice()
3593 ],
3594 )
3595 .map_err(|e| Error::Database(e.to_string()))?;
3596 Ok::<(), Error>(())
3597 })
3598 .unwrap();
3599
3600 storage
3602 .save_group_exporter_secret(GroupExporterSecret {
3603 mls_group_id: group_id.clone(),
3604 epoch: 7,
3605 secret: Secret::new([99u8; 32]),
3606 })
3607 .unwrap();
3608
3609 storage
3610 .rollback_group_to_snapshot(&group_id, "snap_legacy_unlabeled_key")
3611 .unwrap();
3612
3613 let restored_group_event = storage
3614 .get_group_exporter_secret(&group_id, 7)
3615 .unwrap()
3616 .unwrap();
3617 assert_eq!(restored_group_event.mls_group_id, group_id);
3618 assert_eq!(restored_group_event.epoch, 7);
3619 assert_eq!(restored_group_event, initial_group_event);
3620
3621 let restored_mip04 = storage
3622 .get_group_mip04_exporter_secret(&group_id, 7)
3623 .unwrap();
3624 assert!(
3625 restored_mip04.is_none(),
3626 "Legacy unlabeled row_key should restore as label='group-event', not MIP-04"
3627 );
3628 }
3629
3630 #[test]
3631 fn test_snapshot_isolation_between_groups() {
3632 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3633
3634 let group1 = create_test_group(10);
3636 let group2 = create_test_group(20);
3637 let group1_id = group1.mls_group_id.clone();
3638 let group2_id = group2.mls_group_id.clone();
3639
3640 storage.save_group(group1).unwrap();
3641 storage.save_group(group2).unwrap();
3642
3643 storage
3645 .create_group_snapshot(&group1_id, "snap_group1")
3646 .unwrap();
3647
3648 let mut mod1 = storage
3650 .find_group_by_mls_group_id(&group1_id)
3651 .unwrap()
3652 .unwrap();
3653 let mut mod2 = storage
3654 .find_group_by_mls_group_id(&group2_id)
3655 .unwrap()
3656 .unwrap();
3657 mod1.name = "Modified Group 1".to_string();
3658 mod2.name = "Modified Group 2".to_string();
3659 storage.save_group(mod1).unwrap();
3660 storage.save_group(mod2).unwrap();
3661
3662 storage
3664 .rollback_group_to_snapshot(&group1_id, "snap_group1")
3665 .unwrap();
3666
3667 let final1 = storage
3669 .find_group_by_mls_group_id(&group1_id)
3670 .unwrap()
3671 .unwrap();
3672 assert_eq!(final1.name, "Test Group 10");
3673
3674 let final2 = storage
3676 .find_group_by_mls_group_id(&group2_id)
3677 .unwrap()
3678 .unwrap();
3679 assert_eq!(final2.name, "Modified Group 2");
3680 }
3681
3682 #[test]
3683 fn test_rollback_nonexistent_snapshot_returns_error() {
3684 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3685
3686 let group = create_test_group(5);
3687 let group_id = group.mls_group_id.clone();
3688 storage.save_group(group.clone()).unwrap();
3689
3690 let result = storage.rollback_group_to_snapshot(&group_id, "nonexistent_snap");
3695 assert!(
3696 result.is_err(),
3697 "Rollback to nonexistent snapshot should return an error"
3698 );
3699
3700 let after_rollback = storage.find_group_by_mls_group_id(&group_id).unwrap();
3702 assert!(
3703 after_rollback.is_some(),
3704 "Group should NOT be deleted when rolling back to nonexistent snapshot"
3705 );
3706 assert_eq!(
3707 after_rollback.unwrap().epoch,
3708 group.epoch,
3709 "Group data should be unchanged"
3710 );
3711 }
3712
3713 #[test]
3714 fn test_release_nonexistent_snapshot_succeeds() {
3715 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3716
3717 let group = create_test_group(6);
3718 let group_id = group.mls_group_id.clone();
3719 storage.save_group(group).unwrap();
3720
3721 let result = storage.release_group_snapshot(&group_id, "nonexistent_snap");
3723 assert!(result.is_ok());
3725 }
3726
3727 #[test]
3728 fn test_multiple_snapshots_same_group() {
3729 let storage = MdkSqliteStorage::new_in_memory().unwrap();
3730
3731 let group = create_test_group(7);
3732 let group_id = group.mls_group_id.clone();
3733 storage.save_group(group).unwrap();
3734
3735 storage
3737 .create_group_snapshot(&group_id, "snap_epoch_0")
3738 .unwrap();
3739
3740 let mut mod1 = storage
3742 .find_group_by_mls_group_id(&group_id)
3743 .unwrap()
3744 .unwrap();
3745 mod1.epoch = 1;
3746 mod1.name = "Epoch 1".to_string();
3747 storage.save_group(mod1).unwrap();
3748
3749 storage
3751 .create_group_snapshot(&group_id, "snap_epoch_1")
3752 .unwrap();
3753
3754 let mut mod2 = storage
3756 .find_group_by_mls_group_id(&group_id)
3757 .unwrap()
3758 .unwrap();
3759 mod2.epoch = 2;
3760 mod2.name = "Epoch 2".to_string();
3761 storage.save_group(mod2).unwrap();
3762
3763 storage
3765 .rollback_group_to_snapshot(&group_id, "snap_epoch_1")
3766 .unwrap();
3767
3768 let after_rollback = storage
3769 .find_group_by_mls_group_id(&group_id)
3770 .unwrap()
3771 .unwrap();
3772 assert_eq!(after_rollback.epoch, 1);
3773 assert_eq!(after_rollback.name, "Epoch 1");
3774
3775 storage
3777 .rollback_group_to_snapshot(&group_id, "snap_epoch_0")
3778 .unwrap();
3779
3780 let final_state = storage
3781 .find_group_by_mls_group_id(&group_id)
3782 .unwrap()
3783 .unwrap();
3784 assert_eq!(final_state.epoch, 0);
3785 assert_eq!(final_state.name, "Test Group 7");
3786 }
3787
3788 #[test]
3789 fn test_list_group_snapshots_empty() {
3790 use mdk_storage_traits::MdkStorageProvider;
3791
3792 let temp_dir = tempdir().unwrap();
3793 let db_path = temp_dir.path().join("list_snapshots_empty.db");
3794 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3795
3796 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
3797
3798 let snapshots = storage.list_group_snapshots(&group_id).unwrap();
3799 assert!(
3800 snapshots.is_empty(),
3801 "Should return empty list for no snapshots"
3802 );
3803 }
3804
3805 #[test]
3806 fn test_list_group_snapshots_returns_snapshots_sorted_by_created_at() {
3807 use mdk_storage_traits::MdkStorageProvider;
3808
3809 let temp_dir = tempdir().unwrap();
3810 let db_path = temp_dir.path().join("list_snapshots_sorted.db");
3811 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3812
3813 let group_id = GroupId::from_slice(&[8; 32]);
3814 let nostr_group_id: [u8; 32] = [9; 32];
3815
3816 let group = Group {
3818 mls_group_id: group_id.clone(),
3819 nostr_group_id,
3820 name: "Test Group".to_string(),
3821 description: "".to_string(),
3822 admin_pubkeys: BTreeSet::new(),
3823 last_message_id: None,
3824 last_message_at: None,
3825 last_message_processed_at: None,
3826 epoch: 1,
3827 state: GroupState::Active,
3828 image_hash: None,
3829 image_key: None,
3830 image_nonce: None,
3831 self_update_state: SelfUpdateState::Required,
3832 };
3833 storage.save_group(group).unwrap();
3834
3835 storage
3837 .create_group_snapshot(&group_id, "snap_first")
3838 .unwrap();
3839 std::thread::sleep(std::time::Duration::from_millis(10));
3840 storage
3841 .create_group_snapshot(&group_id, "snap_second")
3842 .unwrap();
3843 std::thread::sleep(std::time::Duration::from_millis(10));
3844 storage
3845 .create_group_snapshot(&group_id, "snap_third")
3846 .unwrap();
3847
3848 let result = storage.list_group_snapshots(&group_id).unwrap();
3849
3850 assert_eq!(result.len(), 3);
3851 assert_eq!(result[0].0, "snap_first");
3853 assert_eq!(result[1].0, "snap_second");
3854 assert_eq!(result[2].0, "snap_third");
3855 assert!(result[0].1 <= result[1].1);
3857 assert!(result[1].1 <= result[2].1);
3858 }
3859
3860 #[test]
3861 fn test_list_group_snapshots_only_returns_matching_group() {
3862 use mdk_storage_traits::MdkStorageProvider;
3863
3864 let temp_dir = tempdir().unwrap();
3865 let db_path = temp_dir.path().join("list_snapshots_filtered.db");
3866 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3867
3868 let group1 = GroupId::from_slice(&[1; 32]);
3869 let group2 = GroupId::from_slice(&[2; 32]);
3870
3871 let g1 = Group {
3873 mls_group_id: group1.clone(),
3874 nostr_group_id: [11; 32],
3875 name: "Group 1".to_string(),
3876 description: "".to_string(),
3877 admin_pubkeys: BTreeSet::new(),
3878 last_message_id: None,
3879 last_message_at: None,
3880 last_message_processed_at: None,
3881 epoch: 1,
3882 state: GroupState::Active,
3883 image_hash: None,
3884 image_key: None,
3885 image_nonce: None,
3886 self_update_state: SelfUpdateState::Required,
3887 };
3888 let g2 = Group {
3889 mls_group_id: group2.clone(),
3890 nostr_group_id: [22; 32],
3891 name: "Group 2".to_string(),
3892 ..g1.clone()
3893 };
3894 storage.save_group(g1).unwrap();
3895 storage.save_group(g2).unwrap();
3896
3897 storage.create_group_snapshot(&group1, "snap_g1").unwrap();
3899 storage.create_group_snapshot(&group2, "snap_g2").unwrap();
3900
3901 let result1 = storage.list_group_snapshots(&group1).unwrap();
3902 let result2 = storage.list_group_snapshots(&group2).unwrap();
3903
3904 assert_eq!(result1.len(), 1);
3905 assert_eq!(result1[0].0, "snap_g1");
3906
3907 assert_eq!(result2.len(), 1);
3908 assert_eq!(result2[0].0, "snap_g2");
3909 }
3910
3911 #[test]
3912 fn test_prune_expired_snapshots_removes_old_snapshots() {
3913 use mdk_storage_traits::MdkStorageProvider;
3914
3915 let temp_dir = tempdir().unwrap();
3916 let db_path = temp_dir.path().join("prune_expired.db");
3917 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3918
3919 let group_id = GroupId::from_slice(&[3; 32]);
3920
3921 let group = Group {
3922 mls_group_id: group_id.clone(),
3923 nostr_group_id: [33; 32],
3924 name: "Test Group".to_string(),
3925 description: "".to_string(),
3926 admin_pubkeys: BTreeSet::new(),
3927 last_message_id: None,
3928 last_message_at: None,
3929 last_message_processed_at: None,
3930 epoch: 1,
3931 state: GroupState::Active,
3932 image_hash: None,
3933 image_key: None,
3934 image_nonce: None,
3935 self_update_state: SelfUpdateState::Required,
3936 };
3937 storage.save_group(group).unwrap();
3938
3939 storage
3941 .create_group_snapshot(&group_id, "old_snap")
3942 .unwrap();
3943
3944 let snapshots_before = storage.list_group_snapshots(&group_id).unwrap();
3946 assert_eq!(snapshots_before.len(), 1);
3947 let old_ts = snapshots_before[0].1;
3948
3949 let pruned = storage.prune_expired_snapshots(old_ts + 1).unwrap();
3951 assert_eq!(pruned, 1, "Should have pruned 1 snapshot");
3952
3953 let remaining = storage.list_group_snapshots(&group_id).unwrap();
3954 assert!(remaining.is_empty());
3955 }
3956
3957 #[test]
3958 fn test_prune_expired_snapshots_keeps_recent_snapshots() {
3959 use mdk_storage_traits::MdkStorageProvider;
3960
3961 let temp_dir = tempdir().unwrap();
3962 let db_path = temp_dir.path().join("prune_keeps_recent.db");
3963 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
3964
3965 let group_id = GroupId::from_slice(&[4; 32]);
3966
3967 let group = Group {
3968 mls_group_id: group_id.clone(),
3969 nostr_group_id: [44; 32],
3970 name: "Test Group".to_string(),
3971 description: "".to_string(),
3972 admin_pubkeys: BTreeSet::new(),
3973 last_message_id: None,
3974 last_message_at: None,
3975 last_message_processed_at: None,
3976 epoch: 1,
3977 state: GroupState::Active,
3978 image_hash: None,
3979 image_key: None,
3980 image_nonce: None,
3981 self_update_state: SelfUpdateState::Required,
3982 };
3983 storage.save_group(group).unwrap();
3984
3985 storage
3987 .create_group_snapshot(&group_id, "recent_snap")
3988 .unwrap();
3989
3990 let pruned = storage.prune_expired_snapshots(0).unwrap();
3992 assert_eq!(pruned, 0, "Should have pruned 0 snapshots");
3993
3994 let remaining = storage.list_group_snapshots(&group_id).unwrap();
3995 assert_eq!(remaining.len(), 1);
3996 assert_eq!(remaining[0].0, "recent_snap");
3997 }
3998
3999 #[test]
4000 fn test_prune_expired_snapshots_with_cascade_delete() {
4001 use mdk_storage_traits::MdkStorageProvider;
4004
4005 let temp_dir = tempdir().unwrap();
4006 let db_path = temp_dir.path().join("prune_cascade.db");
4007 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
4008
4009 let group_id = GroupId::from_slice(&[5; 32]);
4010
4011 let group = Group {
4012 mls_group_id: group_id.clone(),
4013 nostr_group_id: [55; 32],
4014 name: "Test Group".to_string(),
4015 description: "".to_string(),
4016 admin_pubkeys: BTreeSet::new(),
4017 last_message_id: None,
4018 last_message_at: None,
4019 last_message_processed_at: None,
4020 epoch: 1,
4021 state: GroupState::Active,
4022 image_hash: None,
4023 image_key: None,
4024 image_nonce: None,
4025 self_update_state: SelfUpdateState::Required,
4026 };
4027 storage.save_group(group).unwrap();
4028
4029 storage
4031 .create_group_snapshot(&group_id, "to_prune")
4032 .unwrap();
4033
4034 let before = storage.list_group_snapshots(&group_id).unwrap();
4036 assert_eq!(before.len(), 1);
4037
4038 let ts = before[0].1;
4040 let pruned = storage.prune_expired_snapshots(ts + 1).unwrap();
4041 assert_eq!(pruned, 1);
4042
4043 let after = storage.list_group_snapshots(&group_id).unwrap();
4045 assert!(after.is_empty());
4046
4047 let rollback_result = storage.rollback_group_to_snapshot(&group_id, "to_prune");
4049 assert!(rollback_result.is_err());
4050 }
4051 }
4052
4053 mod snapshot_openmls_tests {
4064 use std::collections::BTreeSet;
4065
4066 use mdk_storage_traits::groups::GroupStorage;
4067 use mdk_storage_traits::groups::types::{Group, GroupState};
4068 use mdk_storage_traits::mls_codec::MlsCodec;
4069 use mdk_storage_traits::{GroupId, MdkStorageProvider};
4070 use rusqlite::params;
4071
4072 use super::*;
4073
4074 fn create_test_group(id: u8) -> Group {
4076 Group {
4077 mls_group_id: GroupId::from_slice(&[id; 32]),
4078 nostr_group_id: [id; 32],
4079 name: format!("Test Group {}", id),
4080 description: format!("Description {}", id),
4081 admin_pubkeys: BTreeSet::new(),
4082 last_message_id: None,
4083 last_message_at: None,
4084 last_message_processed_at: None,
4085 epoch: 0,
4086 state: GroupState::Active,
4087 image_hash: None,
4088 image_key: None,
4089 image_nonce: None,
4090 self_update_state: SelfUpdateState::Required,
4091 }
4092 }
4093
4094 fn count_openmls_rows(storage: &MdkSqliteStorage, table: &str, group_id: &GroupId) -> i64 {
4097 let mls_key = MlsCodec::serialize(group_id).unwrap();
4098 storage.with_connection(|conn| {
4099 conn.query_row(
4100 &format!("SELECT COUNT(*) FROM {} WHERE group_id = ?", table),
4101 params![mls_key],
4102 |row| row.get(0),
4103 )
4104 .unwrap()
4105 })
4106 }
4107
4108 fn count_snapshot_rows_for_table(
4110 storage: &MdkSqliteStorage,
4111 snapshot_name: &str,
4112 group_id: &GroupId,
4113 table_name: &str,
4114 ) -> i64 {
4115 let raw_bytes = group_id.as_slice().to_vec();
4116 storage.with_connection(|conn| {
4117 conn.query_row(
4118 "SELECT COUNT(*) FROM group_state_snapshots
4119 WHERE snapshot_name = ? AND group_id = ? AND table_name = ?",
4120 params![snapshot_name, raw_bytes, table_name],
4121 |row| row.get(0),
4122 )
4123 .unwrap()
4124 })
4125 }
4126
4127 fn seed_openmls_data(storage: &MdkSqliteStorage, group_id: &GroupId) {
4133 let mls_key = MlsCodec::serialize(group_id).unwrap();
4134 storage.with_connection(|conn| {
4135 conn.execute(
4137 "INSERT OR REPLACE INTO openmls_group_data
4138 (group_id, data_type, group_data, provider_version)
4139 VALUES (?, ?, ?, ?)",
4140 params![mls_key, "group_state", b"test_crypto_state" as &[u8], 1i32],
4141 )
4142 .unwrap();
4143
4144 conn.execute(
4145 "INSERT OR REPLACE INTO openmls_group_data
4146 (group_id, data_type, group_data, provider_version)
4147 VALUES (?, ?, ?, ?)",
4148 params![mls_key, "tree", b"test_tree_data" as &[u8], 1i32],
4149 )
4150 .unwrap();
4151
4152 conn.execute(
4154 "INSERT INTO openmls_own_leaf_nodes
4155 (group_id, leaf_node, provider_version)
4156 VALUES (?, ?, ?)",
4157 params![mls_key, b"test_leaf_node" as &[u8], 1i32],
4158 )
4159 .unwrap();
4160
4161 let proposal_ref = MlsCodec::serialize(&vec![10u8, 20, 30]).unwrap();
4163 conn.execute(
4164 "INSERT OR REPLACE INTO openmls_proposals
4165 (group_id, proposal_ref, proposal, provider_version)
4166 VALUES (?, ?, ?, ?)",
4167 params![mls_key, proposal_ref, b"test_proposal" as &[u8], 1i32],
4168 )
4169 .unwrap();
4170
4171 let epoch_key = MlsCodec::serialize(&5u64).unwrap();
4173 conn.execute(
4174 "INSERT OR REPLACE INTO openmls_epoch_key_pairs
4175 (group_id, epoch_id, leaf_index, key_pairs, provider_version)
4176 VALUES (?, ?, ?, ?, ?)",
4177 params![mls_key, epoch_key, 0i32, b"test_key_pairs" as &[u8], 1i32],
4178 )
4179 .unwrap();
4180 });
4181 }
4182
4183 #[test]
4190 fn test_snapshot_captures_openmls_group_data() {
4191 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4192
4193 let group = create_test_group(1);
4194 let group_id = group.mls_group_id.clone();
4195 storage.save_group(group).unwrap();
4196
4197 seed_openmls_data(&storage, &group_id);
4199
4200 assert_eq!(
4202 count_openmls_rows(&storage, "openmls_group_data", &group_id),
4203 2,
4204 "openmls_group_data should have 2 rows (group_state + tree)"
4205 );
4206
4207 storage
4209 .create_group_snapshot(&group_id, "snap_mls")
4210 .unwrap();
4211
4212 let snap_count = count_snapshot_rows_for_table(
4214 &storage,
4215 "snap_mls",
4216 &group_id,
4217 "openmls_group_data",
4218 );
4219 assert_eq!(
4220 snap_count, 2,
4221 "Snapshot must capture openmls_group_data rows written via StorageProvider \
4222 (MlsCodec-serialized group_id keys)"
4223 );
4224 }
4225
4226 #[test]
4228 fn test_snapshot_captures_openmls_own_leaf_nodes() {
4229 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4230
4231 let group = create_test_group(2);
4232 let group_id = group.mls_group_id.clone();
4233 storage.save_group(group).unwrap();
4234
4235 seed_openmls_data(&storage, &group_id);
4236
4237 assert_eq!(
4238 count_openmls_rows(&storage, "openmls_own_leaf_nodes", &group_id),
4239 1,
4240 "openmls_own_leaf_nodes should have 1 row"
4241 );
4242
4243 storage
4244 .create_group_snapshot(&group_id, "snap_leaf")
4245 .unwrap();
4246
4247 let snap_count = count_snapshot_rows_for_table(
4248 &storage,
4249 "snap_leaf",
4250 &group_id,
4251 "openmls_own_leaf_nodes",
4252 );
4253 assert_eq!(
4254 snap_count, 1,
4255 "Snapshot must capture openmls_own_leaf_nodes rows written via \
4256 StorageProvider"
4257 );
4258 }
4259
4260 #[test]
4262 fn test_snapshot_captures_openmls_proposals() {
4263 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4264
4265 let group = create_test_group(3);
4266 let group_id = group.mls_group_id.clone();
4267 storage.save_group(group).unwrap();
4268
4269 seed_openmls_data(&storage, &group_id);
4270
4271 assert_eq!(
4272 count_openmls_rows(&storage, "openmls_proposals", &group_id),
4273 1,
4274 "openmls_proposals should have 1 row"
4275 );
4276
4277 storage
4278 .create_group_snapshot(&group_id, "snap_prop")
4279 .unwrap();
4280
4281 let snap_count = count_snapshot_rows_for_table(
4282 &storage,
4283 "snap_prop",
4284 &group_id,
4285 "openmls_proposals",
4286 );
4287 assert_eq!(
4288 snap_count, 1,
4289 "Snapshot must capture openmls_proposals rows written via \
4290 StorageProvider"
4291 );
4292 }
4293
4294 #[test]
4296 fn test_snapshot_captures_openmls_epoch_key_pairs() {
4297 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4298
4299 let group = create_test_group(4);
4300 let group_id = group.mls_group_id.clone();
4301 storage.save_group(group).unwrap();
4302
4303 seed_openmls_data(&storage, &group_id);
4304
4305 assert_eq!(
4306 count_openmls_rows(&storage, "openmls_epoch_key_pairs", &group_id),
4307 1,
4308 "openmls_epoch_key_pairs should have 1 row"
4309 );
4310
4311 storage
4312 .create_group_snapshot(&group_id, "snap_epoch")
4313 .unwrap();
4314
4315 let snap_count = count_snapshot_rows_for_table(
4316 &storage,
4317 "snap_epoch",
4318 &group_id,
4319 "openmls_epoch_key_pairs",
4320 );
4321 assert_eq!(
4322 snap_count, 1,
4323 "Snapshot must capture openmls_epoch_key_pairs rows written via \
4324 StorageProvider"
4325 );
4326 }
4327
4328 #[test]
4342 fn test_rollback_restores_openmls_group_data() {
4343 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4344
4345 let group = create_test_group(5);
4346 let group_id = group.mls_group_id.clone();
4347 storage.save_group(group).unwrap();
4348
4349 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4350
4351 storage.with_connection(|conn| {
4353 conn.execute(
4354 "INSERT OR REPLACE INTO openmls_group_data
4355 (group_id, data_type, group_data, provider_version)
4356 VALUES (?, ?, ?, ?)",
4357 params![mls_key, "group_state", b"epoch5_crypto" as &[u8], 1i32],
4358 )
4359 .unwrap();
4360 });
4361
4362 storage.create_group_snapshot(&group_id, "snap_e5").unwrap();
4364
4365 storage.with_connection(|conn| {
4367 conn.execute(
4368 "UPDATE openmls_group_data SET group_data = ?
4369 WHERE group_id = ? AND data_type = ?",
4370 params![b"epoch6_crypto" as &[u8], mls_key, "group_state"],
4371 )
4372 .unwrap();
4373 });
4374
4375 let crypto_before_rollback: Vec<u8> = storage.with_connection(|conn| {
4377 conn.query_row(
4378 "SELECT group_data FROM openmls_group_data
4379 WHERE group_id = ? AND data_type = ?",
4380 params![mls_key, "group_state"],
4381 |row| row.get(0),
4382 )
4383 .unwrap()
4384 });
4385 assert_eq!(crypto_before_rollback, b"epoch6_crypto");
4386
4387 storage
4389 .rollback_group_to_snapshot(&group_id, "snap_e5")
4390 .unwrap();
4391
4392 let crypto_after_rollback: Vec<u8> = storage.with_connection(|conn| {
4394 conn.query_row(
4395 "SELECT group_data FROM openmls_group_data
4396 WHERE group_id = ? AND data_type = ?",
4397 params![mls_key, "group_state"],
4398 |row| row.get(0),
4399 )
4400 .unwrap()
4401 });
4402 assert_eq!(
4403 crypto_after_rollback, b"epoch5_crypto",
4404 "Rollback must restore openmls_group_data to the snapshot state. \
4405 If this is epoch6_crypto, the snapshot failed to capture the \
4406 OpenMLS rows due to group_id encoding mismatch \
4407 (as_slice vs MlsCodec::serialize)."
4408 );
4409 }
4410
4411 #[test]
4413 fn test_rollback_restores_openmls_epoch_key_pairs() {
4414 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4415
4416 let group = create_test_group(6);
4417 let group_id = group.mls_group_id.clone();
4418 storage.save_group(group).unwrap();
4419
4420 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4421 let epoch_key = MlsCodec::serialize(&5u64).unwrap();
4422
4423 storage.with_connection(|conn| {
4425 conn.execute(
4426 "INSERT OR REPLACE INTO openmls_epoch_key_pairs
4427 (group_id, epoch_id, leaf_index, key_pairs, provider_version)
4428 VALUES (?, ?, ?, ?, ?)",
4429 params![mls_key, epoch_key, 0i32, b"epoch5_keys" as &[u8], 1i32],
4430 )
4431 .unwrap();
4432 });
4433
4434 storage
4436 .create_group_snapshot(&group_id, "snap_keys")
4437 .unwrap();
4438
4439 storage.with_connection(|conn| {
4441 conn.execute(
4442 "UPDATE openmls_epoch_key_pairs SET key_pairs = ?
4443 WHERE group_id = ? AND epoch_id = ? AND leaf_index = ?",
4444 params![b"epoch6_keys" as &[u8], mls_key, epoch_key, 0i32],
4445 )
4446 .unwrap();
4447 });
4448
4449 storage
4451 .rollback_group_to_snapshot(&group_id, "snap_keys")
4452 .unwrap();
4453
4454 let keys_after: Vec<u8> = storage.with_connection(|conn| {
4456 conn.query_row(
4457 "SELECT key_pairs FROM openmls_epoch_key_pairs
4458 WHERE group_id = ? AND epoch_id = ? AND leaf_index = ?",
4459 params![mls_key, epoch_key, 0i32],
4460 |row| row.get(0),
4461 )
4462 .unwrap()
4463 });
4464 assert_eq!(
4465 keys_after, b"epoch5_keys",
4466 "Rollback must restore openmls_epoch_key_pairs to snapshot state"
4467 );
4468 }
4469
4470 #[test]
4472 fn test_rollback_restores_openmls_own_leaf_nodes() {
4473 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4474
4475 let group = create_test_group(7);
4476 let group_id = group.mls_group_id.clone();
4477 storage.save_group(group).unwrap();
4478
4479 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4480
4481 storage.with_connection(|conn| {
4483 conn.execute(
4484 "INSERT INTO openmls_own_leaf_nodes
4485 (group_id, leaf_node, provider_version)
4486 VALUES (?, ?, ?)",
4487 params![mls_key, b"original_leaf" as &[u8], 1i32],
4488 )
4489 .unwrap();
4490 });
4491
4492 storage
4494 .create_group_snapshot(&group_id, "snap_leaf")
4495 .unwrap();
4496
4497 storage.with_connection(|conn| {
4499 conn.execute(
4500 "INSERT INTO openmls_own_leaf_nodes
4501 (group_id, leaf_node, provider_version)
4502 VALUES (?, ?, ?)",
4503 params![mls_key, b"added_after_snapshot" as &[u8], 1i32],
4504 )
4505 .unwrap();
4506 });
4507
4508 assert_eq!(
4510 count_openmls_rows(&storage, "openmls_own_leaf_nodes", &group_id),
4511 2
4512 );
4513
4514 storage
4516 .rollback_group_to_snapshot(&group_id, "snap_leaf")
4517 .unwrap();
4518
4519 let leaf_count = count_openmls_rows(&storage, "openmls_own_leaf_nodes", &group_id);
4521 assert_eq!(
4522 leaf_count, 1,
4523 "Rollback must restore openmls_own_leaf_nodes to snapshot state \
4524 (1 leaf, not 2)"
4525 );
4526
4527 let leaf_data: Vec<u8> = storage.with_connection(|conn| {
4528 conn.query_row(
4529 "SELECT leaf_node FROM openmls_own_leaf_nodes
4530 WHERE group_id = ? AND provider_version = ?",
4531 params![mls_key, 1i32],
4532 |row| row.get(0),
4533 )
4534 .unwrap()
4535 });
4536 assert_eq!(
4537 leaf_data, b"original_leaf",
4538 "Rollback must restore the original leaf node data"
4539 );
4540 }
4541
4542 #[test]
4549 fn test_rollback_metadata_crypto_consistency() {
4550 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4551
4552 let group = create_test_group(8);
4553 let group_id = group.mls_group_id.clone();
4554 storage.save_group(group).unwrap();
4555
4556 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4557
4558 {
4560 let mut g = storage
4561 .find_group_by_mls_group_id(&group_id)
4562 .unwrap()
4563 .unwrap();
4564 g.epoch = 5;
4565 storage.save_group(g).unwrap();
4566 }
4567 storage.with_connection(|conn| {
4568 conn.execute(
4569 "INSERT OR REPLACE INTO openmls_group_data
4570 (group_id, data_type, group_data, provider_version)
4571 VALUES (?, ?, ?, ?)",
4572 params![mls_key, "group_state", b"epoch5_state" as &[u8], 1i32],
4573 )
4574 .unwrap();
4575 });
4576
4577 storage
4579 .create_group_snapshot(&group_id, "snap_epoch5")
4580 .unwrap();
4581
4582 {
4584 let mut g = storage
4585 .find_group_by_mls_group_id(&group_id)
4586 .unwrap()
4587 .unwrap();
4588 g.epoch = 6;
4589 storage.save_group(g).unwrap();
4590 }
4591 storage.with_connection(|conn| {
4592 conn.execute(
4593 "UPDATE openmls_group_data SET group_data = ?
4594 WHERE group_id = ? AND data_type = ?",
4595 params![b"epoch6_state" as &[u8], mls_key, "group_state"],
4596 )
4597 .unwrap();
4598 });
4599
4600 storage
4602 .rollback_group_to_snapshot(&group_id, "snap_epoch5")
4603 .unwrap();
4604
4605 let group_after = storage
4607 .find_group_by_mls_group_id(&group_id)
4608 .unwrap()
4609 .unwrap();
4610 assert_eq!(
4611 group_after.epoch, 5,
4612 "MDK groups.epoch should be 5 after rollback"
4613 );
4614
4615 let crypto_after: Vec<u8> = storage.with_connection(|conn| {
4617 conn.query_row(
4618 "SELECT group_data FROM openmls_group_data
4619 WHERE group_id = ? AND data_type = ?",
4620 params![mls_key, "group_state"],
4621 |row| row.get(0),
4622 )
4623 .unwrap()
4624 });
4625 assert_eq!(
4626 crypto_after, b"epoch5_state",
4627 "OpenMLS crypto state must match MDK metadata epoch after rollback. \
4628 groups.epoch=5 but crypto state is still epoch6 data means \
4629 split-brain: MDK thinks epoch 5, MLS engine has epoch 6 keys. \
4630 Every subsequent message in this group will fail to decrypt."
4631 );
4632 }
4633
4634 #[test]
4642 fn test_restore_deletes_openmls_data_before_reinserting() {
4643 let storage = MdkSqliteStorage::new_in_memory().unwrap();
4644
4645 let group = create_test_group(9);
4646 let group_id = group.mls_group_id.clone();
4647 storage.save_group(group).unwrap();
4648
4649 let mls_key = MlsCodec::serialize(&group_id).unwrap();
4650
4651 storage.with_connection(|conn| {
4653 conn.execute(
4654 "INSERT OR REPLACE INTO openmls_group_data
4655 (group_id, data_type, group_data, provider_version)
4656 VALUES (?, ?, ?, ?)",
4657 params![mls_key, "group_state", b"initial_state" as &[u8], 1i32],
4658 )
4659 .unwrap();
4660 });
4661
4662 storage
4664 .create_group_snapshot(&group_id, "snap_initial")
4665 .unwrap();
4666
4667 storage.with_connection(|conn| {
4669 conn.execute(
4670 "UPDATE openmls_group_data SET group_data = ?
4671 WHERE group_id = ? AND data_type = ?",
4672 params![b"modified_state" as &[u8], mls_key, "group_state"],
4673 )
4674 .unwrap();
4675 });
4676
4677 storage
4679 .rollback_group_to_snapshot(&group_id, "snap_initial")
4680 .unwrap();
4681
4682 let row_count = count_openmls_rows(&storage, "openmls_group_data", &group_id);
4687 assert_eq!(
4688 row_count, 1,
4689 "After rollback, there should be exactly 1 openmls_group_data \
4690 row. More than 1 means the DELETE used the wrong key format and \
4691 failed to remove the stale OpenMLS row before re-inserting from \
4692 snapshot."
4693 );
4694 }
4695 }
4696}