1#![forbid(unsafe_code)]
49#![warn(missing_docs)]
50#![warn(rustdoc::bare_urls)]
51
52use std::collections::{BTreeSet, HashMap};
53use std::fmt;
54use std::num::NonZeroUsize;
55
56use lru::LruCache;
57use mdk_storage_traits::GroupId;
58use mdk_storage_traits::groups::types::{Group, GroupExporterSecret, GroupRelay};
59use mdk_storage_traits::messages::types::{Message, ProcessedMessage};
60use mdk_storage_traits::welcomes::types::{ProcessedWelcome, Welcome};
61use mdk_storage_traits::{Backend, MdkStorageError, MdkStorageProvider};
62use nostr::EventId;
63use openmls_traits::storage::{StorageProvider, traits};
64use parking_lot::RwLock;
65
66mod groups;
67mod messages;
68mod mls_storage;
69mod snapshot;
70mod welcomes;
71
72use self::mls_storage::{
73 GroupDataType, MlsEncryptionKeys, MlsEpochKeyPairs, MlsGroupData, MlsKeyPackages,
74 MlsOwnLeafNodes, MlsProposals, MlsPsks, MlsSignatureKeys, STORAGE_PROVIDER_VERSION,
75};
76pub use self::snapshot::{GroupScopedSnapshot, MemoryStorageSnapshot};
77use self::snapshot::{HashMapToLruExt, LruCacheExt};
78
79const DEFAULT_CACHE_SIZE: NonZeroUsize = match NonZeroUsize::new(1000) {
81 Some(v) => v,
82 None => panic!("cache size must be non-zero"),
83};
84
85pub const DEFAULT_MAX_RELAYS_PER_GROUP: usize = 100;
88
89pub const DEFAULT_MAX_MESSAGES_PER_GROUP: usize = 10000;
93
94pub const DEFAULT_MAX_GROUP_NAME_LENGTH: usize = 256;
98
99pub const DEFAULT_MAX_GROUP_DESCRIPTION_LENGTH: usize = 4096;
103
104pub const DEFAULT_MAX_ADMINS_PER_GROUP: usize = 100;
107
108pub const DEFAULT_MAX_RELAYS_PER_WELCOME: usize = 100;
111
112pub const DEFAULT_MAX_ADMINS_PER_WELCOME: usize = 100;
115
116pub const DEFAULT_MAX_RELAY_URL_LENGTH: usize = 512;
119
120#[derive(Debug, Clone, Copy)]
137pub struct ValidationLimits {
138 pub cache_size: usize,
140 pub max_relays_per_group: usize,
142 pub max_messages_per_group: usize,
144 pub max_group_name_length: usize,
146 pub max_group_description_length: usize,
148 pub max_admins_per_group: usize,
150 pub max_relays_per_welcome: usize,
152 pub max_admins_per_welcome: usize,
154 pub max_relay_url_length: usize,
156}
157
158impl Default for ValidationLimits {
159 fn default() -> Self {
160 Self {
161 cache_size: DEFAULT_CACHE_SIZE.get(),
162 max_relays_per_group: DEFAULT_MAX_RELAYS_PER_GROUP,
163 max_messages_per_group: DEFAULT_MAX_MESSAGES_PER_GROUP,
164 max_group_name_length: DEFAULT_MAX_GROUP_NAME_LENGTH,
165 max_group_description_length: DEFAULT_MAX_GROUP_DESCRIPTION_LENGTH,
166 max_admins_per_group: DEFAULT_MAX_ADMINS_PER_GROUP,
167 max_relays_per_welcome: DEFAULT_MAX_RELAYS_PER_WELCOME,
168 max_admins_per_welcome: DEFAULT_MAX_ADMINS_PER_WELCOME,
169 max_relay_url_length: DEFAULT_MAX_RELAY_URL_LENGTH,
170 }
171 }
172}
173
174impl ValidationLimits {
175 pub fn new() -> Self {
177 Self::default()
178 }
179
180 pub fn with_cache_size(mut self, size: usize) -> Self {
186 assert!(size > 0, "cache_size must be greater than 0");
187 self.cache_size = size;
188 self
189 }
190
191 pub fn with_max_relays_per_group(mut self, limit: usize) -> Self {
197 assert!(limit > 0, "max_relays_per_group must be greater than 0");
198 self.max_relays_per_group = limit;
199 self
200 }
201
202 pub fn with_max_messages_per_group(mut self, limit: usize) -> Self {
208 assert!(limit > 0, "max_messages_per_group must be greater than 0");
209 self.max_messages_per_group = limit;
210 self
211 }
212
213 pub fn with_max_group_name_length(mut self, limit: usize) -> Self {
219 assert!(limit > 0, "max_group_name_length must be greater than 0");
220 self.max_group_name_length = limit;
221 self
222 }
223
224 pub fn with_max_group_description_length(mut self, limit: usize) -> Self {
230 assert!(
231 limit > 0,
232 "max_group_description_length must be greater than 0"
233 );
234 self.max_group_description_length = limit;
235 self
236 }
237
238 pub fn with_max_admins_per_group(mut self, limit: usize) -> Self {
244 assert!(limit > 0, "max_admins_per_group must be greater than 0");
245 self.max_admins_per_group = limit;
246 self
247 }
248
249 pub fn with_max_relays_per_welcome(mut self, limit: usize) -> Self {
255 assert!(limit > 0, "max_relays_per_welcome must be greater than 0");
256 self.max_relays_per_welcome = limit;
257 self
258 }
259
260 pub fn with_max_admins_per_welcome(mut self, limit: usize) -> Self {
266 assert!(limit > 0, "max_admins_per_welcome must be greater than 0");
267 self.max_admins_per_welcome = limit;
268 self
269 }
270
271 pub fn with_max_relay_url_length(mut self, limit: usize) -> Self {
277 assert!(limit > 0, "max_relay_url_length must be greater than 0");
278 self.max_relay_url_length = limit;
279 self
280 }
281}
282
283pub struct MdkMemoryStorage {
331 limits: ValidationLimits,
333 inner: RwLock<MdkMemoryStorageInner>,
335 group_snapshots: RwLock<HashMap<(GroupId, String), GroupScopedSnapshot>>,
339}
340
341struct MdkMemoryStorageInner {
343 mls_group_data: MlsGroupData,
347 mls_own_leaf_nodes: MlsOwnLeafNodes,
348 mls_proposals: MlsProposals,
349 mls_key_packages: MlsKeyPackages,
350 mls_psks: MlsPsks,
351 mls_signature_keys: MlsSignatureKeys,
352 mls_encryption_keys: MlsEncryptionKeys,
353 mls_epoch_key_pairs: MlsEpochKeyPairs,
354
355 groups_cache: LruCache<GroupId, Group>,
359 groups_by_nostr_id_cache: LruCache<[u8; 32], Group>,
360 group_relays_cache: LruCache<GroupId, BTreeSet<GroupRelay>>,
361 welcomes_cache: LruCache<EventId, Welcome>,
362 processed_welcomes_cache: LruCache<EventId, ProcessedWelcome>,
363 messages_cache: LruCache<EventId, Message>,
364 messages_by_group_cache: LruCache<GroupId, HashMap<EventId, Message>>,
365 processed_messages_cache: LruCache<EventId, ProcessedMessage>,
366 group_exporter_secrets_cache: LruCache<(GroupId, u64), GroupExporterSecret>,
367 group_mip04_exporter_secrets_cache: LruCache<(GroupId, u64), GroupExporterSecret>,
368}
369
370impl fmt::Debug for MdkMemoryStorage {
371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372 f.debug_struct("MdkMemoryStorage")
373 .field("limits", &self.limits)
374 .field("inner", &"RwLock<MdkMemoryStorageInner>")
375 .finish()
376 }
377}
378
379impl Default for MdkMemoryStorage {
380 fn default() -> Self {
386 Self::new()
387 }
388}
389
390impl MdkMemoryStorage {
391 pub fn new() -> Self {
397 Self::with_cache_size(DEFAULT_CACHE_SIZE)
398 }
399
400 pub fn with_cache_size(cache_size: NonZeroUsize) -> Self {
410 Self::with_limits(ValidationLimits::default().with_cache_size(cache_size.get()))
411 }
412
413 pub fn with_limits(limits: ValidationLimits) -> Self {
423 let cache_size =
424 NonZeroUsize::new(limits.cache_size).expect("cache_size must be greater than 0");
425
426 let inner = MdkMemoryStorageInner {
427 mls_group_data: MlsGroupData::new(),
429 mls_own_leaf_nodes: MlsOwnLeafNodes::new(),
430 mls_proposals: MlsProposals::new(),
431 mls_key_packages: MlsKeyPackages::new(),
432 mls_psks: MlsPsks::new(),
433 mls_signature_keys: MlsSignatureKeys::new(),
434 mls_encryption_keys: MlsEncryptionKeys::new(),
435 mls_epoch_key_pairs: MlsEpochKeyPairs::new(),
436 groups_cache: LruCache::new(cache_size),
438 groups_by_nostr_id_cache: LruCache::new(cache_size),
439 group_relays_cache: LruCache::new(cache_size),
440 welcomes_cache: LruCache::new(cache_size),
441 processed_welcomes_cache: LruCache::new(cache_size),
442 messages_cache: LruCache::new(cache_size),
443 messages_by_group_cache: LruCache::new(cache_size),
444 processed_messages_cache: LruCache::new(cache_size),
445 group_exporter_secrets_cache: LruCache::new(cache_size),
446 group_mip04_exporter_secrets_cache: LruCache::new(cache_size),
447 };
448
449 MdkMemoryStorage {
450 limits,
451 inner: RwLock::new(inner),
452 group_snapshots: RwLock::new(HashMap::new()),
453 }
454 }
455
456 pub fn create_snapshot(&self) -> MemoryStorageSnapshot {
473 let inner = self.inner.read();
474 MemoryStorageSnapshot {
475 mls_group_data: inner.mls_group_data.clone_data(),
477 mls_own_leaf_nodes: inner.mls_own_leaf_nodes.clone_data(),
478 mls_proposals: inner.mls_proposals.clone_data(),
479 mls_key_packages: inner.mls_key_packages.clone_data(),
480 mls_psks: inner.mls_psks.clone_data(),
481 mls_signature_keys: inner.mls_signature_keys.clone_data(),
482 mls_encryption_keys: inner.mls_encryption_keys.clone_data(),
483 mls_epoch_key_pairs: inner.mls_epoch_key_pairs.clone_data(),
484 groups: inner.groups_cache.clone_to_hashmap(),
486 groups_by_nostr_id: inner.groups_by_nostr_id_cache.clone_to_hashmap(),
487 group_relays: inner.group_relays_cache.clone_to_hashmap(),
488 group_exporter_secrets: inner.group_exporter_secrets_cache.clone_to_hashmap(),
489 group_mip04_exporter_secrets: inner
490 .group_mip04_exporter_secrets_cache
491 .clone_to_hashmap(),
492 welcomes: inner.welcomes_cache.clone_to_hashmap(),
493 processed_welcomes: inner.processed_welcomes_cache.clone_to_hashmap(),
494 messages: inner.messages_cache.clone_to_hashmap(),
495 messages_by_group: inner.messages_by_group_cache.clone_to_hashmap(),
496 processed_messages: inner.processed_messages_cache.clone_to_hashmap(),
497 }
498 }
499
500 pub fn restore_snapshot(&self, snapshot: MemoryStorageSnapshot) {
513 let mut inner = self.inner.write();
514
515 inner.mls_group_data.restore_data(snapshot.mls_group_data);
517 inner
518 .mls_own_leaf_nodes
519 .restore_data(snapshot.mls_own_leaf_nodes);
520 inner.mls_proposals.restore_data(snapshot.mls_proposals);
521 inner
522 .mls_key_packages
523 .restore_data(snapshot.mls_key_packages);
524 inner.mls_psks.restore_data(snapshot.mls_psks);
525 inner
526 .mls_signature_keys
527 .restore_data(snapshot.mls_signature_keys);
528 inner
529 .mls_encryption_keys
530 .restore_data(snapshot.mls_encryption_keys);
531 inner
532 .mls_epoch_key_pairs
533 .restore_data(snapshot.mls_epoch_key_pairs);
534
535 snapshot.groups.restore_to_lru(&mut inner.groups_cache);
537 snapshot
538 .groups_by_nostr_id
539 .restore_to_lru(&mut inner.groups_by_nostr_id_cache);
540 snapshot
541 .group_relays
542 .restore_to_lru(&mut inner.group_relays_cache);
543 snapshot
544 .group_exporter_secrets
545 .restore_to_lru(&mut inner.group_exporter_secrets_cache);
546 snapshot
547 .group_mip04_exporter_secrets
548 .restore_to_lru(&mut inner.group_mip04_exporter_secrets_cache);
549 snapshot.welcomes.restore_to_lru(&mut inner.welcomes_cache);
550 snapshot
551 .processed_welcomes
552 .restore_to_lru(&mut inner.processed_welcomes_cache);
553 snapshot.messages.restore_to_lru(&mut inner.messages_cache);
554 snapshot
555 .messages_by_group
556 .restore_to_lru(&mut inner.messages_by_group_cache);
557 snapshot
558 .processed_messages
559 .restore_to_lru(&mut inner.processed_messages_cache);
560 }
561
562 pub fn create_group_scoped_snapshot(&self, group_id: &GroupId) -> GroupScopedSnapshot {
585 let inner = self.inner.read();
586
587 let mls_group_id_bytes = mls_storage::MlsCodec::serialize(group_id.inner())
590 .expect("Failed to serialize group_id for MLS lookup");
591
592 let mls_group_data: HashMap<(Vec<u8>, GroupDataType), Vec<u8>> = inner
594 .mls_group_data
595 .data
596 .iter()
597 .filter(|((gid, _), _)| *gid == mls_group_id_bytes)
598 .map(|(k, v)| (k.clone(), v.clone()))
599 .collect();
600
601 let mls_own_leaf_nodes = inner
603 .mls_own_leaf_nodes
604 .data
605 .get(&mls_group_id_bytes)
606 .cloned()
607 .unwrap_or_default();
608
609 let mls_proposals: HashMap<Vec<u8>, Vec<u8>> = inner
611 .mls_proposals
612 .data
613 .iter()
614 .filter(|((gid, _), _)| *gid == mls_group_id_bytes)
615 .map(|((_, prop_ref), prop)| (prop_ref.clone(), prop.clone()))
616 .collect();
617
618 let mls_epoch_key_pairs: HashMap<(Vec<u8>, u32), Vec<u8>> = inner
620 .mls_epoch_key_pairs
621 .data
622 .iter()
623 .filter(|((gid, _, _), _)| *gid == mls_group_id_bytes)
624 .map(|((_, epoch_id, leaf_idx), kp)| ((epoch_id.clone(), *leaf_idx), kp.clone()))
625 .collect();
626
627 let group = inner.groups_cache.peek(group_id).cloned();
629
630 let group_relays = inner
631 .group_relays_cache
632 .peek(group_id)
633 .cloned()
634 .unwrap_or_default();
635
636 let group_exporter_secrets: HashMap<u64, GroupExporterSecret> = inner
637 .group_exporter_secrets_cache
638 .iter()
639 .filter(|((gid, _), _)| gid == group_id)
640 .map(|((_, epoch), secret)| (*epoch, secret.clone()))
641 .collect();
642
643 let group_mip04_exporter_secrets: HashMap<u64, GroupExporterSecret> = inner
644 .group_mip04_exporter_secrets_cache
645 .iter()
646 .filter(|((gid, _), _)| gid == group_id)
647 .map(|((_, epoch), secret)| (*epoch, secret.clone()))
648 .collect();
649
650 let created_at = std::time::SystemTime::now()
652 .duration_since(std::time::UNIX_EPOCH)
653 .expect("System time before Unix epoch")
654 .as_secs();
655
656 GroupScopedSnapshot {
657 group_id: group_id.clone(),
658 created_at,
659 mls_group_data,
660 mls_own_leaf_nodes,
661 mls_proposals,
662 mls_epoch_key_pairs,
663 group,
664 group_relays,
665 group_exporter_secrets,
666 group_mip04_exporter_secrets,
667 }
668 }
669
670 pub fn restore_group_scoped_snapshot(&self, snapshot: GroupScopedSnapshot) {
684 let mut inner = self.inner.write();
685 let group_id = &snapshot.group_id;
686
687 let mls_group_id_bytes = mls_storage::MlsCodec::serialize(group_id.inner())
690 .expect("Failed to serialize group_id for MLS lookup");
691
692 inner
696 .mls_group_data
697 .data
698 .retain(|(gid, _), _| *gid != mls_group_id_bytes);
699
700 inner.mls_own_leaf_nodes.data.remove(&mls_group_id_bytes);
702
703 inner
705 .mls_proposals
706 .data
707 .retain(|(gid, _), _| *gid != mls_group_id_bytes);
708
709 inner
711 .mls_epoch_key_pairs
712 .data
713 .retain(|(gid, _, _), _| *gid != mls_group_id_bytes);
714
715 let nostr_group_id = inner.groups_cache.peek(group_id).map(|g| g.nostr_group_id);
718 inner.groups_cache.pop(group_id);
719 if let Some(nostr_id) = nostr_group_id {
720 inner.groups_by_nostr_id_cache.pop(&nostr_id);
721 }
722
723 inner.group_relays_cache.pop(group_id);
724
725 let keys_to_remove: Vec<_> = inner
727 .group_exporter_secrets_cache
728 .iter()
729 .filter(|((gid, _), _)| gid == group_id)
730 .map(|(k, _)| k.clone())
731 .collect();
732 for key in keys_to_remove {
733 inner.group_exporter_secrets_cache.pop(&key);
734 }
735
736 let mip04_keys_to_remove: Vec<_> = inner
738 .group_mip04_exporter_secrets_cache
739 .iter()
740 .filter(|((gid, _), _)| gid == group_id)
741 .map(|(k, _)| k.clone())
742 .collect();
743 for key in mip04_keys_to_remove {
744 inner.group_mip04_exporter_secrets_cache.pop(&key);
745 }
746
747 for (key, value) in snapshot.mls_group_data {
751 inner.mls_group_data.data.insert(key, value);
752 }
753
754 if !snapshot.mls_own_leaf_nodes.is_empty() {
756 inner
757 .mls_own_leaf_nodes
758 .data
759 .insert(mls_group_id_bytes.clone(), snapshot.mls_own_leaf_nodes);
760 }
761
762 for (prop_ref, prop) in snapshot.mls_proposals {
764 inner
765 .mls_proposals
766 .data
767 .insert((mls_group_id_bytes.clone(), prop_ref), prop);
768 }
769
770 for ((epoch_id, leaf_idx), kp) in snapshot.mls_epoch_key_pairs {
772 inner
773 .mls_epoch_key_pairs
774 .data
775 .insert((mls_group_id_bytes.clone(), epoch_id, leaf_idx), kp);
776 }
777
778 if let Some(group) = snapshot.group {
780 let nostr_id = group.nostr_group_id;
781 inner.groups_cache.put(group_id.clone(), group.clone());
782 inner.groups_by_nostr_id_cache.put(nostr_id, group);
783 }
784
785 if !snapshot.group_relays.is_empty() {
786 inner
787 .group_relays_cache
788 .put(group_id.clone(), snapshot.group_relays);
789 }
790
791 for (epoch, secret) in snapshot.group_exporter_secrets {
792 inner
793 .group_exporter_secrets_cache
794 .put((group_id.clone(), epoch), secret);
795 }
796
797 for (epoch, secret) in snapshot.group_mip04_exporter_secrets {
798 inner
799 .group_mip04_exporter_secrets_cache
800 .put((group_id.clone(), epoch), secret);
801 }
802 }
803
804 pub fn limits(&self) -> &ValidationLimits {
806 &self.limits
807 }
808}
809
810impl MdkStorageProvider for MdkMemoryStorage {
812 fn backend(&self) -> Backend {
818 Backend::Memory
819 }
820
821 fn create_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), MdkStorageError> {
822 let snapshot = self.create_group_scoped_snapshot(group_id);
825 self.group_snapshots
826 .write()
827 .insert((group_id.clone(), name.to_string()), snapshot);
828 Ok(())
829 }
830
831 fn rollback_group_to_snapshot(
832 &self,
833 group_id: &GroupId,
834 name: &str,
835 ) -> Result<(), MdkStorageError> {
836 let key = (group_id.clone(), name.to_string());
837 let snapshot = self
839 .group_snapshots
840 .write()
841 .remove(&key)
842 .ok_or_else(|| MdkStorageError::NotFound("Snapshot not found".to_string()))?;
843 self.restore_group_scoped_snapshot(snapshot);
844 Ok(())
845 }
846
847 fn release_group_snapshot(
848 &self,
849 group_id: &GroupId,
850 name: &str,
851 ) -> Result<(), MdkStorageError> {
852 let key = (group_id.clone(), name.to_string());
853 self.group_snapshots.write().remove(&key);
854 Ok(())
855 }
856
857 fn list_group_snapshots(
858 &self,
859 group_id: &GroupId,
860 ) -> Result<Vec<(String, u64)>, MdkStorageError> {
861 let snapshots = self.group_snapshots.read();
862 let mut result: Vec<(String, u64)> = snapshots
863 .iter()
864 .filter(|((gid, _), _)| gid == group_id)
865 .map(|((_, name), snap)| (name.clone(), snap.created_at))
866 .collect();
867 result.sort_by_key(|(_, created_at)| *created_at);
869 Ok(result)
870 }
871
872 fn prune_expired_snapshots(&self, min_timestamp: u64) -> Result<usize, MdkStorageError> {
873 let mut snapshots = self.group_snapshots.write();
874 let initial_count = snapshots.len();
875 snapshots.retain(|_, snap| snap.created_at >= min_timestamp);
876 let pruned_count = initial_count - snapshots.len();
877 Ok(pruned_count)
878 }
879}
880
881impl StorageProvider<STORAGE_PROVIDER_VERSION> for MdkMemoryStorage {
886 type Error = MdkStorageError;
887
888 fn write_mls_join_config<GroupId, MlsGroupJoinConfig>(
893 &self,
894 group_id: &GroupId,
895 config: &MlsGroupJoinConfig,
896 ) -> Result<(), Self::Error>
897 where
898 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
899 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<STORAGE_PROVIDER_VERSION>,
900 {
901 self.inner
902 .write()
903 .mls_group_data
904 .write(group_id, GroupDataType::JoinGroupConfig, config)
905 }
906
907 fn append_own_leaf_node<GroupId, LeafNode>(
908 &self,
909 group_id: &GroupId,
910 leaf_node: &LeafNode,
911 ) -> Result<(), Self::Error>
912 where
913 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
914 LeafNode: traits::LeafNode<STORAGE_PROVIDER_VERSION>,
915 {
916 self.inner
917 .write()
918 .mls_own_leaf_nodes
919 .append(group_id, leaf_node)
920 }
921
922 fn queue_proposal<GroupId, ProposalRef, QueuedProposal>(
923 &self,
924 group_id: &GroupId,
925 proposal_ref: &ProposalRef,
926 proposal: &QueuedProposal,
927 ) -> Result<(), Self::Error>
928 where
929 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
930 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
931 QueuedProposal: traits::QueuedProposal<STORAGE_PROVIDER_VERSION>,
932 {
933 self.inner
934 .write()
935 .mls_proposals
936 .queue(group_id, proposal_ref, proposal)
937 }
938
939 fn write_tree<GroupId, TreeSync>(
940 &self,
941 group_id: &GroupId,
942 tree: &TreeSync,
943 ) -> Result<(), Self::Error>
944 where
945 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
946 TreeSync: traits::TreeSync<STORAGE_PROVIDER_VERSION>,
947 {
948 self.inner
949 .write()
950 .mls_group_data
951 .write(group_id, GroupDataType::Tree, tree)
952 }
953
954 fn write_interim_transcript_hash<GroupId, InterimTranscriptHash>(
955 &self,
956 group_id: &GroupId,
957 interim_transcript_hash: &InterimTranscriptHash,
958 ) -> Result<(), Self::Error>
959 where
960 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
961 InterimTranscriptHash: traits::InterimTranscriptHash<STORAGE_PROVIDER_VERSION>,
962 {
963 self.inner.write().mls_group_data.write(
964 group_id,
965 GroupDataType::InterimTranscriptHash,
966 interim_transcript_hash,
967 )
968 }
969
970 fn write_context<GroupId, GroupContext>(
971 &self,
972 group_id: &GroupId,
973 group_context: &GroupContext,
974 ) -> Result<(), Self::Error>
975 where
976 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
977 GroupContext: traits::GroupContext<STORAGE_PROVIDER_VERSION>,
978 {
979 self.inner
980 .write()
981 .mls_group_data
982 .write(group_id, GroupDataType::Context, group_context)
983 }
984
985 fn write_confirmation_tag<GroupId, ConfirmationTag>(
986 &self,
987 group_id: &GroupId,
988 confirmation_tag: &ConfirmationTag,
989 ) -> Result<(), Self::Error>
990 where
991 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
992 ConfirmationTag: traits::ConfirmationTag<STORAGE_PROVIDER_VERSION>,
993 {
994 self.inner.write().mls_group_data.write(
995 group_id,
996 GroupDataType::ConfirmationTag,
997 confirmation_tag,
998 )
999 }
1000
1001 fn write_group_state<GroupState, GroupId>(
1002 &self,
1003 group_id: &GroupId,
1004 group_state: &GroupState,
1005 ) -> Result<(), Self::Error>
1006 where
1007 GroupState: traits::GroupState<STORAGE_PROVIDER_VERSION>,
1008 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1009 {
1010 self.inner
1011 .write()
1012 .mls_group_data
1013 .write(group_id, GroupDataType::GroupState, group_state)
1014 }
1015
1016 fn write_message_secrets<GroupId, MessageSecrets>(
1017 &self,
1018 group_id: &GroupId,
1019 message_secrets: &MessageSecrets,
1020 ) -> Result<(), Self::Error>
1021 where
1022 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1023 MessageSecrets: traits::MessageSecrets<STORAGE_PROVIDER_VERSION>,
1024 {
1025 self.inner.write().mls_group_data.write(
1026 group_id,
1027 GroupDataType::MessageSecrets,
1028 message_secrets,
1029 )
1030 }
1031
1032 fn write_resumption_psk_store<GroupId, ResumptionPskStore>(
1033 &self,
1034 group_id: &GroupId,
1035 resumption_psk_store: &ResumptionPskStore,
1036 ) -> Result<(), Self::Error>
1037 where
1038 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1039 ResumptionPskStore: traits::ResumptionPskStore<STORAGE_PROVIDER_VERSION>,
1040 {
1041 self.inner.write().mls_group_data.write(
1042 group_id,
1043 GroupDataType::ResumptionPskStore,
1044 resumption_psk_store,
1045 )
1046 }
1047
1048 fn write_own_leaf_index<GroupId, LeafNodeIndex>(
1049 &self,
1050 group_id: &GroupId,
1051 own_leaf_index: &LeafNodeIndex,
1052 ) -> Result<(), Self::Error>
1053 where
1054 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1055 LeafNodeIndex: traits::LeafNodeIndex<STORAGE_PROVIDER_VERSION>,
1056 {
1057 self.inner.write().mls_group_data.write(
1058 group_id,
1059 GroupDataType::OwnLeafIndex,
1060 own_leaf_index,
1061 )
1062 }
1063
1064 fn write_group_epoch_secrets<GroupId, GroupEpochSecrets>(
1065 &self,
1066 group_id: &GroupId,
1067 group_epoch_secrets: &GroupEpochSecrets,
1068 ) -> Result<(), Self::Error>
1069 where
1070 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1071 GroupEpochSecrets: traits::GroupEpochSecrets<STORAGE_PROVIDER_VERSION>,
1072 {
1073 self.inner.write().mls_group_data.write(
1074 group_id,
1075 GroupDataType::GroupEpochSecrets,
1076 group_epoch_secrets,
1077 )
1078 }
1079
1080 fn write_signature_key_pair<SignaturePublicKey, SignatureKeyPair>(
1081 &self,
1082 public_key: &SignaturePublicKey,
1083 signature_key_pair: &SignatureKeyPair,
1084 ) -> Result<(), Self::Error>
1085 where
1086 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1087 SignatureKeyPair: traits::SignatureKeyPair<STORAGE_PROVIDER_VERSION>,
1088 {
1089 self.inner
1090 .write()
1091 .mls_signature_keys
1092 .write(public_key, signature_key_pair)
1093 }
1094
1095 fn write_encryption_key_pair<EncryptionKey, HpkeKeyPair>(
1096 &self,
1097 public_key: &EncryptionKey,
1098 key_pair: &HpkeKeyPair,
1099 ) -> Result<(), Self::Error>
1100 where
1101 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1102 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1103 {
1104 self.inner
1105 .write()
1106 .mls_encryption_keys
1107 .write(public_key, key_pair)
1108 }
1109
1110 fn write_encryption_epoch_key_pairs<GroupId, EpochKey, HpkeKeyPair>(
1111 &self,
1112 group_id: &GroupId,
1113 epoch: &EpochKey,
1114 leaf_index: u32,
1115 key_pairs: &[HpkeKeyPair],
1116 ) -> Result<(), Self::Error>
1117 where
1118 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1119 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1120 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1121 {
1122 self.inner
1123 .write()
1124 .mls_epoch_key_pairs
1125 .write(group_id, epoch, leaf_index, key_pairs)
1126 }
1127
1128 fn write_key_package<HashReference, KeyPackage>(
1129 &self,
1130 hash_ref: &HashReference,
1131 key_package: &KeyPackage,
1132 ) -> Result<(), Self::Error>
1133 where
1134 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1135 KeyPackage: traits::KeyPackage<STORAGE_PROVIDER_VERSION>,
1136 {
1137 self.inner
1138 .write()
1139 .mls_key_packages
1140 .write(hash_ref, key_package)
1141 }
1142
1143 fn write_psk<PskId, PskBundle>(
1144 &self,
1145 psk_id: &PskId,
1146 psk: &PskBundle,
1147 ) -> Result<(), Self::Error>
1148 where
1149 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1150 PskBundle: traits::PskBundle<STORAGE_PROVIDER_VERSION>,
1151 {
1152 self.inner.write().mls_psks.write(psk_id, psk)
1153 }
1154
1155 fn mls_group_join_config<GroupId, MlsGroupJoinConfig>(
1160 &self,
1161 group_id: &GroupId,
1162 ) -> Result<Option<MlsGroupJoinConfig>, Self::Error>
1163 where
1164 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1165 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<STORAGE_PROVIDER_VERSION>,
1166 {
1167 self.inner
1168 .read()
1169 .mls_group_data
1170 .read(group_id, GroupDataType::JoinGroupConfig)
1171 }
1172
1173 fn own_leaf_nodes<GroupId, LeafNode>(
1174 &self,
1175 group_id: &GroupId,
1176 ) -> Result<Vec<LeafNode>, Self::Error>
1177 where
1178 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1179 LeafNode: traits::LeafNode<STORAGE_PROVIDER_VERSION>,
1180 {
1181 self.inner.read().mls_own_leaf_nodes.read(group_id)
1182 }
1183
1184 fn queued_proposal_refs<GroupId, ProposalRef>(
1185 &self,
1186 group_id: &GroupId,
1187 ) -> Result<Vec<ProposalRef>, Self::Error>
1188 where
1189 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1190 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1191 {
1192 self.inner.read().mls_proposals.read_refs(group_id)
1193 }
1194
1195 fn queued_proposals<GroupId, ProposalRef, QueuedProposal>(
1196 &self,
1197 group_id: &GroupId,
1198 ) -> Result<Vec<(ProposalRef, QueuedProposal)>, Self::Error>
1199 where
1200 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1201 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1202 QueuedProposal: traits::QueuedProposal<STORAGE_PROVIDER_VERSION>,
1203 {
1204 self.inner.read().mls_proposals.read_proposals(group_id)
1205 }
1206
1207 fn tree<GroupId, TreeSync>(&self, group_id: &GroupId) -> Result<Option<TreeSync>, Self::Error>
1208 where
1209 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1210 TreeSync: traits::TreeSync<STORAGE_PROVIDER_VERSION>,
1211 {
1212 self.inner
1213 .read()
1214 .mls_group_data
1215 .read(group_id, GroupDataType::Tree)
1216 }
1217
1218 fn group_context<GroupId, GroupContext>(
1219 &self,
1220 group_id: &GroupId,
1221 ) -> Result<Option<GroupContext>, Self::Error>
1222 where
1223 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1224 GroupContext: traits::GroupContext<STORAGE_PROVIDER_VERSION>,
1225 {
1226 self.inner
1227 .read()
1228 .mls_group_data
1229 .read(group_id, GroupDataType::Context)
1230 }
1231
1232 fn interim_transcript_hash<GroupId, InterimTranscriptHash>(
1233 &self,
1234 group_id: &GroupId,
1235 ) -> Result<Option<InterimTranscriptHash>, Self::Error>
1236 where
1237 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1238 InterimTranscriptHash: traits::InterimTranscriptHash<STORAGE_PROVIDER_VERSION>,
1239 {
1240 self.inner
1241 .read()
1242 .mls_group_data
1243 .read(group_id, GroupDataType::InterimTranscriptHash)
1244 }
1245
1246 fn confirmation_tag<GroupId, ConfirmationTag>(
1247 &self,
1248 group_id: &GroupId,
1249 ) -> Result<Option<ConfirmationTag>, Self::Error>
1250 where
1251 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1252 ConfirmationTag: traits::ConfirmationTag<STORAGE_PROVIDER_VERSION>,
1253 {
1254 self.inner
1255 .read()
1256 .mls_group_data
1257 .read(group_id, GroupDataType::ConfirmationTag)
1258 }
1259
1260 fn group_state<GroupState, GroupId>(
1261 &self,
1262 group_id: &GroupId,
1263 ) -> Result<Option<GroupState>, Self::Error>
1264 where
1265 GroupState: traits::GroupState<STORAGE_PROVIDER_VERSION>,
1266 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1267 {
1268 self.inner
1269 .read()
1270 .mls_group_data
1271 .read(group_id, GroupDataType::GroupState)
1272 }
1273
1274 fn message_secrets<GroupId, MessageSecrets>(
1275 &self,
1276 group_id: &GroupId,
1277 ) -> Result<Option<MessageSecrets>, Self::Error>
1278 where
1279 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1280 MessageSecrets: traits::MessageSecrets<STORAGE_PROVIDER_VERSION>,
1281 {
1282 self.inner
1283 .read()
1284 .mls_group_data
1285 .read(group_id, GroupDataType::MessageSecrets)
1286 }
1287
1288 fn resumption_psk_store<GroupId, ResumptionPskStore>(
1289 &self,
1290 group_id: &GroupId,
1291 ) -> Result<Option<ResumptionPskStore>, Self::Error>
1292 where
1293 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1294 ResumptionPskStore: traits::ResumptionPskStore<STORAGE_PROVIDER_VERSION>,
1295 {
1296 self.inner
1297 .read()
1298 .mls_group_data
1299 .read(group_id, GroupDataType::ResumptionPskStore)
1300 }
1301
1302 fn own_leaf_index<GroupId, LeafNodeIndex>(
1303 &self,
1304 group_id: &GroupId,
1305 ) -> Result<Option<LeafNodeIndex>, Self::Error>
1306 where
1307 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1308 LeafNodeIndex: traits::LeafNodeIndex<STORAGE_PROVIDER_VERSION>,
1309 {
1310 self.inner
1311 .read()
1312 .mls_group_data
1313 .read(group_id, GroupDataType::OwnLeafIndex)
1314 }
1315
1316 fn group_epoch_secrets<GroupId, GroupEpochSecrets>(
1317 &self,
1318 group_id: &GroupId,
1319 ) -> Result<Option<GroupEpochSecrets>, Self::Error>
1320 where
1321 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1322 GroupEpochSecrets: traits::GroupEpochSecrets<STORAGE_PROVIDER_VERSION>,
1323 {
1324 self.inner
1325 .read()
1326 .mls_group_data
1327 .read(group_id, GroupDataType::GroupEpochSecrets)
1328 }
1329
1330 fn signature_key_pair<SignaturePublicKey, SignatureKeyPair>(
1331 &self,
1332 public_key: &SignaturePublicKey,
1333 ) -> Result<Option<SignatureKeyPair>, Self::Error>
1334 where
1335 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1336 SignatureKeyPair: traits::SignatureKeyPair<STORAGE_PROVIDER_VERSION>,
1337 {
1338 self.inner.read().mls_signature_keys.read(public_key)
1339 }
1340
1341 fn encryption_key_pair<HpkeKeyPair, EncryptionKey>(
1342 &self,
1343 public_key: &EncryptionKey,
1344 ) -> Result<Option<HpkeKeyPair>, Self::Error>
1345 where
1346 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1347 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1348 {
1349 self.inner.read().mls_encryption_keys.read(public_key)
1350 }
1351
1352 fn encryption_epoch_key_pairs<GroupId, EpochKey, HpkeKeyPair>(
1353 &self,
1354 group_id: &GroupId,
1355 epoch: &EpochKey,
1356 leaf_index: u32,
1357 ) -> Result<Vec<HpkeKeyPair>, Self::Error>
1358 where
1359 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1360 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1361 HpkeKeyPair: traits::HpkeKeyPair<STORAGE_PROVIDER_VERSION>,
1362 {
1363 self.inner
1364 .read()
1365 .mls_epoch_key_pairs
1366 .read(group_id, epoch, leaf_index)
1367 }
1368
1369 fn key_package<HashReference, KeyPackage>(
1370 &self,
1371 hash_ref: &HashReference,
1372 ) -> Result<Option<KeyPackage>, Self::Error>
1373 where
1374 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1375 KeyPackage: traits::KeyPackage<STORAGE_PROVIDER_VERSION>,
1376 {
1377 self.inner.read().mls_key_packages.read(hash_ref)
1378 }
1379
1380 fn psk<PskBundle, PskId>(&self, psk_id: &PskId) -> Result<Option<PskBundle>, Self::Error>
1381 where
1382 PskBundle: traits::PskBundle<STORAGE_PROVIDER_VERSION>,
1383 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1384 {
1385 self.inner.read().mls_psks.read(psk_id)
1386 }
1387
1388 fn remove_proposal<GroupId, ProposalRef>(
1393 &self,
1394 group_id: &GroupId,
1395 proposal_ref: &ProposalRef,
1396 ) -> Result<(), Self::Error>
1397 where
1398 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1399 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1400 {
1401 self.inner
1402 .write()
1403 .mls_proposals
1404 .remove(group_id, proposal_ref)
1405 }
1406
1407 fn delete_own_leaf_nodes<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1408 where
1409 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1410 {
1411 self.inner.write().mls_own_leaf_nodes.delete(group_id)
1412 }
1413
1414 fn delete_group_config<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1415 where
1416 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1417 {
1418 self.inner
1419 .write()
1420 .mls_group_data
1421 .delete(group_id, GroupDataType::JoinGroupConfig)
1422 }
1423
1424 fn delete_tree<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1425 where
1426 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1427 {
1428 self.inner
1429 .write()
1430 .mls_group_data
1431 .delete(group_id, GroupDataType::Tree)
1432 }
1433
1434 fn delete_confirmation_tag<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1435 where
1436 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1437 {
1438 self.inner
1439 .write()
1440 .mls_group_data
1441 .delete(group_id, GroupDataType::ConfirmationTag)
1442 }
1443
1444 fn delete_group_state<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1445 where
1446 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1447 {
1448 self.inner
1449 .write()
1450 .mls_group_data
1451 .delete(group_id, GroupDataType::GroupState)
1452 }
1453
1454 fn delete_context<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1455 where
1456 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1457 {
1458 self.inner
1459 .write()
1460 .mls_group_data
1461 .delete(group_id, GroupDataType::Context)
1462 }
1463
1464 fn delete_interim_transcript_hash<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1465 where
1466 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1467 {
1468 self.inner
1469 .write()
1470 .mls_group_data
1471 .delete(group_id, GroupDataType::InterimTranscriptHash)
1472 }
1473
1474 fn delete_message_secrets<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1475 where
1476 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1477 {
1478 self.inner
1479 .write()
1480 .mls_group_data
1481 .delete(group_id, GroupDataType::MessageSecrets)
1482 }
1483
1484 fn delete_all_resumption_psk_secrets<GroupId>(
1485 &self,
1486 group_id: &GroupId,
1487 ) -> Result<(), Self::Error>
1488 where
1489 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1490 {
1491 self.inner
1492 .write()
1493 .mls_group_data
1494 .delete(group_id, GroupDataType::ResumptionPskStore)
1495 }
1496
1497 fn delete_own_leaf_index<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1498 where
1499 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1500 {
1501 self.inner
1502 .write()
1503 .mls_group_data
1504 .delete(group_id, GroupDataType::OwnLeafIndex)
1505 }
1506
1507 fn delete_group_epoch_secrets<GroupId>(&self, group_id: &GroupId) -> Result<(), Self::Error>
1508 where
1509 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1510 {
1511 self.inner
1512 .write()
1513 .mls_group_data
1514 .delete(group_id, GroupDataType::GroupEpochSecrets)
1515 }
1516
1517 fn clear_proposal_queue<GroupId, ProposalRef>(
1518 &self,
1519 group_id: &GroupId,
1520 ) -> Result<(), Self::Error>
1521 where
1522 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1523 ProposalRef: traits::ProposalRef<STORAGE_PROVIDER_VERSION>,
1524 {
1525 self.inner.write().mls_proposals.clear(group_id)
1526 }
1527
1528 fn delete_signature_key_pair<SignaturePublicKey>(
1529 &self,
1530 public_key: &SignaturePublicKey,
1531 ) -> Result<(), Self::Error>
1532 where
1533 SignaturePublicKey: traits::SignaturePublicKey<STORAGE_PROVIDER_VERSION>,
1534 {
1535 self.inner.write().mls_signature_keys.delete(public_key)
1536 }
1537
1538 fn delete_encryption_key_pair<EncryptionKey>(
1539 &self,
1540 public_key: &EncryptionKey,
1541 ) -> Result<(), Self::Error>
1542 where
1543 EncryptionKey: traits::EncryptionKey<STORAGE_PROVIDER_VERSION>,
1544 {
1545 self.inner.write().mls_encryption_keys.delete(public_key)
1546 }
1547
1548 fn delete_encryption_epoch_key_pairs<GroupId, EpochKey>(
1549 &self,
1550 group_id: &GroupId,
1551 epoch: &EpochKey,
1552 leaf_index: u32,
1553 ) -> Result<(), Self::Error>
1554 where
1555 GroupId: traits::GroupId<STORAGE_PROVIDER_VERSION>,
1556 EpochKey: traits::EpochKey<STORAGE_PROVIDER_VERSION>,
1557 {
1558 self.inner
1559 .write()
1560 .mls_epoch_key_pairs
1561 .delete(group_id, epoch, leaf_index)
1562 }
1563
1564 fn delete_key_package<HashReference>(&self, hash_ref: &HashReference) -> Result<(), Self::Error>
1565 where
1566 HashReference: traits::HashReference<STORAGE_PROVIDER_VERSION>,
1567 {
1568 self.inner.write().mls_key_packages.delete(hash_ref)
1569 }
1570
1571 fn delete_psk<PskId>(&self, psk_id: &PskId) -> Result<(), Self::Error>
1572 where
1573 PskId: traits::PskId<STORAGE_PROVIDER_VERSION>,
1574 {
1575 self.inner.write().mls_psks.delete(psk_id)
1576 }
1577}
1578
1579#[cfg(test)]
1580mod tests {
1581 use std::collections::BTreeSet;
1582
1583 use mdk_storage_traits::groups::GroupStorage;
1584 use mdk_storage_traits::groups::types::{
1585 Group, GroupExporterSecret, GroupState, SelfUpdateState,
1586 };
1587 use mdk_storage_traits::messages::MessageStorage;
1588 use mdk_storage_traits::messages::error::MessageError;
1589 use mdk_storage_traits::messages::types::{Message, MessageState, ProcessedMessageState};
1590 use mdk_storage_traits::test_utils::crypto_utils::generate_random_bytes;
1591 use mdk_storage_traits::welcomes::WelcomeStorage;
1592 use mdk_storage_traits::welcomes::types::{ProcessedWelcomeState, Welcome, WelcomeState};
1593 use mdk_storage_traits::{GroupId, MdkStorageProvider, Secret};
1594 use nostr::{EventId, Kind, PublicKey, RelayUrl, Tags, Timestamp, UnsignedEvent};
1595
1596 use super::*;
1597
1598 fn create_test_group_id() -> GroupId {
1599 GroupId::from_slice(&[1, 2, 3, 4])
1600 }
1601
1602 #[test]
1603 fn test_new() {
1604 let nostr_storage = MdkMemoryStorage::new();
1605 assert_eq!(nostr_storage.backend(), Backend::Memory);
1606 }
1607
1608 #[test]
1609 fn test_default() {
1610 let nostr_storage = MdkMemoryStorage::default();
1611 assert_eq!(nostr_storage.backend(), Backend::Memory);
1612 }
1613
1614 #[test]
1615 fn test_backend_type() {
1616 let nostr_storage = MdkMemoryStorage::default();
1617 assert_eq!(nostr_storage.backend(), Backend::Memory);
1618 assert!(!nostr_storage.backend().is_persistent());
1619 }
1620
1621 #[test]
1622 fn test_storage_is_memory_based() {
1623 let nostr_storage = MdkMemoryStorage::default();
1624 assert!(!nostr_storage.backend().is_persistent());
1625 }
1626
1627 #[test]
1628 fn test_compare_backend_types() {
1629 let nostr_storage = MdkMemoryStorage::default();
1630 let memory_backend = nostr_storage.backend();
1631 assert_eq!(memory_backend, Backend::Memory);
1632 assert_ne!(memory_backend, Backend::SQLite);
1633 }
1634
1635 #[test]
1636 fn test_create_multiple_instances() {
1637 let nostr_storage1 = MdkMemoryStorage::new();
1638 let nostr_storage2 = MdkMemoryStorage::new();
1639
1640 assert_eq!(nostr_storage1.backend(), nostr_storage2.backend());
1641 assert_eq!(nostr_storage1.backend(), Backend::Memory);
1642 assert_eq!(nostr_storage2.backend(), Backend::Memory);
1643 }
1644
1645 #[test]
1646 fn test_group_cache() {
1647 let nostr_storage = MdkMemoryStorage::default();
1648 let mls_group_id = create_test_group_id();
1649 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
1650 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
1651 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
1652 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
1653 let group = Group {
1654 mls_group_id: mls_group_id.clone(),
1655 nostr_group_id,
1656 name: "Test Group".to_string(),
1657 description: "A test group".to_string(),
1658 admin_pubkeys: BTreeSet::new(),
1659 last_message_id: None,
1660 last_message_at: None,
1661 last_message_processed_at: None,
1662 epoch: 0,
1663 state: GroupState::Active,
1664 image_hash,
1665 image_key,
1666 image_nonce,
1667 self_update_state: SelfUpdateState::Required,
1668 };
1669 nostr_storage.save_group(group.clone()).unwrap();
1670 let found_group = nostr_storage
1671 .find_group_by_mls_group_id(&mls_group_id)
1672 .unwrap()
1673 .unwrap();
1674 assert_eq!(found_group.mls_group_id, mls_group_id);
1675 assert_eq!(found_group.nostr_group_id, nostr_group_id);
1676
1677 {
1679 let inner = nostr_storage.inner.read();
1680 let cache = &inner.groups_cache;
1681 assert!(cache.contains(&mls_group_id));
1682 }
1683 {
1684 let inner = nostr_storage.inner.read();
1685 let cache = &inner.groups_by_nostr_id_cache;
1686 assert!(cache.contains(&nostr_group_id));
1687 }
1688 }
1689
1690 #[test]
1691 fn test_group_relays() {
1692 let nostr_storage = MdkMemoryStorage::default();
1693 let mls_group_id = create_test_group_id();
1694 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
1695 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
1696 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
1697 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
1698 let group = Group {
1699 mls_group_id: mls_group_id.clone(),
1700 nostr_group_id,
1701 name: "Another Test Group".to_string(),
1702 description: "Another test group".to_string(),
1703 admin_pubkeys: BTreeSet::new(),
1704 last_message_id: None,
1705 last_message_at: None,
1706 last_message_processed_at: None,
1707 epoch: 0,
1708 state: GroupState::Active,
1709 image_hash,
1710 image_key,
1711 image_nonce,
1712 self_update_state: SelfUpdateState::Required,
1713 };
1714 nostr_storage.save_group(group.clone()).unwrap();
1715 let relay_url1 = RelayUrl::parse("wss://relay1.example.com").unwrap();
1716 let relay_url2 = RelayUrl::parse("wss://relay2.example.com").unwrap();
1717 let relays = BTreeSet::from([relay_url1, relay_url2]);
1718 nostr_storage
1719 .replace_group_relays(&mls_group_id, relays)
1720 .unwrap();
1721 let found_relays = nostr_storage.group_relays(&mls_group_id).unwrap();
1722 assert_eq!(found_relays.len(), 2);
1723
1724 {
1726 let inner = nostr_storage.inner.read();
1727 let cache = &inner.group_relays_cache;
1728 assert!(cache.contains(&mls_group_id));
1729 if let Some(relays) = cache.peek(&mls_group_id) {
1730 assert_eq!(relays.len(), 2);
1731 } else {
1732 panic!("Group relays not found in cache");
1733 }
1734 }
1735 }
1736
1737 #[test]
1738 fn test_group_exporter_secret_cache() {
1739 let nostr_storage = MdkMemoryStorage::default();
1740 let mls_group_id = create_test_group_id();
1741 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
1742 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
1743 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
1744 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
1745 let group = Group {
1746 mls_group_id: mls_group_id.clone(),
1747 nostr_group_id,
1748 name: "Test Group".to_string(),
1749 description: "A test group".to_string(),
1750 admin_pubkeys: BTreeSet::new(),
1751 last_message_id: None,
1752 last_message_at: None,
1753 last_message_processed_at: None,
1754 epoch: 0,
1755 state: GroupState::Active,
1756 image_hash,
1757 image_key,
1758 image_nonce,
1759 self_update_state: SelfUpdateState::Required,
1760 };
1761 nostr_storage.save_group(group.clone()).unwrap();
1762 let group_exporter_secret_0 = GroupExporterSecret {
1763 mls_group_id: mls_group_id.clone(),
1764 epoch: 0,
1765 secret: Secret::new([0u8; 32]),
1766 };
1767 let group_exporter_secret_1 = GroupExporterSecret {
1768 mls_group_id: mls_group_id.clone(),
1769 epoch: 1,
1770 secret: Secret::new([0u8; 32]),
1771 };
1772 nostr_storage
1773 .save_group_exporter_secret(group_exporter_secret_0.clone())
1774 .unwrap();
1775 nostr_storage
1776 .save_group_exporter_secret(group_exporter_secret_1.clone())
1777 .unwrap();
1778 let found_secret_0 = nostr_storage
1779 .get_group_exporter_secret(&mls_group_id, 0)
1780 .unwrap()
1781 .unwrap();
1782 assert_eq!(found_secret_0, group_exporter_secret_0);
1783 let found_secret_1 = nostr_storage
1784 .get_group_exporter_secret(&mls_group_id, 1)
1785 .unwrap()
1786 .unwrap();
1787 assert_eq!(found_secret_1, group_exporter_secret_1);
1788 let non_existent_secret = nostr_storage
1789 .get_group_exporter_secret(&mls_group_id, 999)
1790 .unwrap();
1791 assert!(non_existent_secret.is_none());
1792
1793 {
1795 let inner = nostr_storage.inner.read();
1796 let cache = &inner.group_exporter_secrets_cache;
1797 assert!(cache.contains(&(mls_group_id.clone(), 0)));
1798 assert!(cache.contains(&(mls_group_id.clone(), 1)));
1799 assert!(!cache.contains(&(mls_group_id.clone(), 999)));
1800 }
1801 }
1802
1803 #[test]
1804 fn test_welcome_cache() {
1805 let nostr_storage = MdkMemoryStorage::default();
1806
1807 let event_id = EventId::all_zeros();
1809 let wrapper_id = EventId::all_zeros();
1810
1811 let pubkey =
1813 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
1814 .unwrap();
1815
1816 let mls_group_id = create_test_group_id();
1818 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
1819 let welcome = Welcome {
1820 id: event_id,
1821 event: UnsignedEvent::new(
1822 pubkey,
1823 Timestamp::now(),
1824 Kind::MlsWelcome,
1825 Tags::new(),
1826 "test".to_string(),
1827 ),
1828 mls_group_id: mls_group_id.clone(),
1829 nostr_group_id,
1830 group_name: "Test Welcome Group".to_string(),
1831 group_description: "A test welcome group".to_string(),
1832 group_image_key: None,
1833 group_image_hash: None,
1834 group_image_nonce: None,
1835 group_admin_pubkeys: BTreeSet::from([pubkey]),
1836 group_relays: BTreeSet::from([RelayUrl::parse("wss://relay.example.com").unwrap()]),
1837 welcomer: pubkey,
1838 member_count: 2,
1839 state: WelcomeState::Pending,
1840 wrapper_event_id: wrapper_id,
1841 };
1842
1843 let result = nostr_storage.save_welcome(welcome.clone());
1845 assert!(result.is_ok());
1846
1847 let found_welcome = nostr_storage.find_welcome_by_event_id(&event_id);
1849 assert!(found_welcome.is_ok());
1850 let found_welcome = found_welcome.unwrap().unwrap();
1851 assert_eq!(found_welcome.id, event_id);
1852 assert_eq!(found_welcome.mls_group_id, mls_group_id);
1853
1854 {
1856 let inner = nostr_storage.inner.read();
1857 let cache = &inner.welcomes_cache;
1858 assert!(cache.contains(&event_id));
1859 }
1860
1861 let processed_welcome = ProcessedWelcome {
1863 wrapper_event_id: wrapper_id,
1864 welcome_event_id: Some(event_id),
1865 processed_at: Timestamp::now(),
1866 state: ProcessedWelcomeState::Processed,
1867 failure_reason: None,
1868 };
1869
1870 let result = nostr_storage.save_processed_welcome(processed_welcome.clone());
1872 assert!(result.is_ok());
1873
1874 let found_processed_welcome = nostr_storage.find_processed_welcome_by_event_id(&wrapper_id);
1876 assert!(found_processed_welcome.is_ok());
1877 let found_processed_welcome = found_processed_welcome.unwrap().unwrap();
1878 assert_eq!(found_processed_welcome.wrapper_event_id, wrapper_id);
1879 assert_eq!(found_processed_welcome.welcome_event_id, Some(event_id));
1880
1881 {
1883 let inner = nostr_storage.inner.read();
1884 let cache = &inner.processed_welcomes_cache;
1885 assert!(cache.contains(&wrapper_id));
1886 }
1887 }
1888
1889 #[test]
1890 fn test_message_cache() {
1891 let nostr_storage = MdkMemoryStorage::default();
1892 let mls_group_id = create_test_group_id();
1893 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
1894 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
1895 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
1896 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
1897 let group = Group {
1898 mls_group_id: mls_group_id.clone(),
1899 nostr_group_id,
1900 name: "Message Test Group".to_string(),
1901 description: "A group for testing messages".to_string(),
1902 admin_pubkeys: BTreeSet::new(),
1903 last_message_id: None,
1904 last_message_at: None,
1905 last_message_processed_at: None,
1906 epoch: 0,
1907 state: GroupState::Active,
1908 image_hash,
1909 image_key,
1910 image_nonce,
1911 self_update_state: SelfUpdateState::Required,
1912 };
1913 nostr_storage.save_group(group.clone()).unwrap();
1914 let event_id = EventId::all_zeros();
1915 let wrapper_id = EventId::all_zeros();
1916 let pubkey =
1917 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
1918 .unwrap();
1919 let now = Timestamp::now();
1920 let message = Message {
1921 id: event_id,
1922 pubkey,
1923 kind: Kind::MlsGroupMessage,
1924 mls_group_id: mls_group_id.clone(),
1925 created_at: now,
1926 processed_at: now,
1927 content: "Hello, world!".to_string(),
1928 tags: Tags::new(),
1929 event: UnsignedEvent::new(
1930 pubkey,
1931 now,
1932 Kind::MlsGroupMessage,
1933 Tags::new(),
1934 "Hello, world!".to_string(),
1935 ),
1936 wrapper_event_id: wrapper_id,
1937 state: MessageState::Created,
1938 epoch: None,
1939 };
1940 nostr_storage.save_message(message.clone()).unwrap();
1941 let found_message = nostr_storage
1942 .find_message_by_event_id(&mls_group_id, &event_id)
1943 .unwrap()
1944 .unwrap();
1945 assert_eq!(found_message.id, event_id);
1946 assert_eq!(found_message.mls_group_id, mls_group_id);
1947
1948 {
1950 let inner = nostr_storage.inner.read();
1951 let cache = &inner.messages_cache;
1952 assert!(cache.contains(&event_id));
1953 }
1954 {
1955 let inner = nostr_storage.inner.read();
1957 let cache = &inner.messages_by_group_cache;
1958 assert!(cache.contains(&mls_group_id));
1959 if let Some(msgs) = cache.peek(&mls_group_id) {
1960 assert_eq!(msgs.len(), 1);
1961 assert!(msgs.contains_key(&event_id));
1962 assert_eq!(msgs.get(&event_id).unwrap().id, event_id);
1963 } else {
1964 panic!("Messages not found in group cache");
1965 }
1966 }
1967 let processed_message = ProcessedMessage {
1968 wrapper_event_id: wrapper_id,
1969 message_event_id: Some(event_id),
1970 processed_at: Timestamp::now(),
1971 epoch: None,
1972 mls_group_id: None,
1973 state: ProcessedMessageState::Processed,
1974 failure_reason: None,
1975 };
1976 nostr_storage
1977 .save_processed_message(processed_message.clone())
1978 .unwrap();
1979 let found_processed = nostr_storage
1980 .find_processed_message_by_event_id(&wrapper_id)
1981 .unwrap()
1982 .unwrap();
1983 assert_eq!(found_processed.wrapper_event_id, wrapper_id);
1984 {
1985 let inner = nostr_storage.inner.read();
1986 let cache = &inner.processed_messages_cache;
1987 assert!(cache.contains(&wrapper_id));
1988 }
1989 }
1990
1991 #[test]
1992 fn test_save_message_for_nonexistent_group() {
1993 let nostr_storage = MdkMemoryStorage::default();
1994 let nonexistent_group_id = create_test_group_id();
1995 let event_id = EventId::all_zeros();
1996 let wrapper_id = EventId::all_zeros();
1997 let pubkey =
1998 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
1999 .unwrap();
2000 let now = Timestamp::now();
2001 let message = Message {
2002 id: event_id,
2003 pubkey,
2004 kind: Kind::MlsGroupMessage,
2005 mls_group_id: nonexistent_group_id.clone(),
2006 created_at: now,
2007 processed_at: now,
2008 content: "Hello, world!".to_string(),
2009 tags: Tags::new(),
2010 event: UnsignedEvent::new(
2011 pubkey,
2012 now,
2013 Kind::MlsGroupMessage,
2014 Tags::new(),
2015 "Hello, world!".to_string(),
2016 ),
2017 wrapper_event_id: wrapper_id,
2018 state: MessageState::Created,
2019 epoch: None,
2020 };
2021
2022 let result = nostr_storage.save_message(message);
2024 assert!(result.is_err());
2025 match result.unwrap_err() {
2026 MessageError::InvalidParameters(msg) => {
2027 assert!(msg.contains("not found"));
2028 }
2029 _ => panic!("Expected InvalidParameters error"),
2030 }
2031
2032 {
2034 let inner = nostr_storage.inner.read();
2035 let cache = &inner.messages_by_group_cache;
2036 assert!(!cache.contains(&nonexistent_group_id));
2037 }
2038 }
2039
2040 #[test]
2041 fn test_update_existing_message() {
2042 let nostr_storage = MdkMemoryStorage::default();
2043 let mls_group_id = create_test_group_id();
2044 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2045 let group = Group {
2046 mls_group_id: mls_group_id.clone(),
2047 nostr_group_id,
2048 name: "Update Test Group".to_string(),
2049 description: "A group for testing message updates".to_string(),
2050 admin_pubkeys: BTreeSet::new(),
2051 last_message_id: None,
2052 last_message_at: None,
2053 last_message_processed_at: None,
2054 epoch: 0,
2055 state: GroupState::Active,
2056 image_hash: None,
2057 image_key: None,
2058 image_nonce: None,
2059 self_update_state: SelfUpdateState::Required,
2060 };
2061 nostr_storage.save_group(group).unwrap();
2062
2063 let event_id = EventId::all_zeros();
2064 let wrapper_id = EventId::all_zeros();
2065 let pubkey =
2066 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
2067 .unwrap();
2068 let now = Timestamp::now();
2069 let original_message = Message {
2070 id: event_id,
2071 pubkey,
2072 kind: Kind::MlsGroupMessage,
2073 mls_group_id: mls_group_id.clone(),
2074 created_at: now,
2075 processed_at: now,
2076 content: "Original message".to_string(),
2077 tags: Tags::new(),
2078 event: UnsignedEvent::new(
2079 pubkey,
2080 now,
2081 Kind::MlsGroupMessage,
2082 Tags::new(),
2083 "Original message".to_string(),
2084 ),
2085 wrapper_event_id: wrapper_id,
2086 state: MessageState::Created,
2087 epoch: None,
2088 };
2089
2090 nostr_storage
2092 .save_message(original_message.clone())
2093 .unwrap();
2094
2095 let found_message = nostr_storage
2097 .find_message_by_event_id(&mls_group_id, &event_id)
2098 .unwrap()
2099 .unwrap();
2100 assert_eq!(found_message.content, "Original message");
2101
2102 let updated_message = Message {
2104 content: "Updated message".to_string(),
2105 event: UnsignedEvent::new(
2106 pubkey,
2107 Timestamp::now(),
2108 Kind::MlsGroupMessage,
2109 Tags::new(),
2110 "Updated message".to_string(),
2111 ),
2112 ..original_message.clone()
2113 };
2114
2115 nostr_storage.save_message(updated_message.clone()).unwrap();
2117
2118 let found_message = nostr_storage
2120 .find_message_by_event_id(&mls_group_id, &event_id)
2121 .unwrap()
2122 .unwrap();
2123 assert_eq!(found_message.content, "Updated message");
2124
2125 {
2127 let inner = nostr_storage.inner.read();
2128 let cache = &inner.messages_by_group_cache;
2129 let group_messages = cache.peek(&mls_group_id).unwrap();
2130 assert_eq!(group_messages.len(), 1);
2131 let msg = group_messages.get(&event_id).unwrap();
2132 assert_eq!(msg.content, "Updated message");
2133 assert_eq!(msg.id, event_id);
2134 }
2135 }
2136
2137 #[test]
2138 fn test_save_multiple_messages_for_same_group() {
2139 let nostr_storage = MdkMemoryStorage::default();
2140 let mls_group_id = create_test_group_id();
2141 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2142 let group = Group {
2143 mls_group_id: mls_group_id.clone(),
2144 nostr_group_id,
2145 name: "Multiple Messages Group".to_string(),
2146 description: "A group for testing multiple messages".to_string(),
2147 admin_pubkeys: BTreeSet::new(),
2148 last_message_id: None,
2149 last_message_at: None,
2150 last_message_processed_at: None,
2151 epoch: 0,
2152 state: GroupState::Active,
2153 image_hash: None,
2154 image_key: None,
2155 image_nonce: None,
2156 self_update_state: SelfUpdateState::Required,
2157 };
2158 nostr_storage.save_group(group).unwrap();
2159
2160 let pubkey =
2161 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
2162 .unwrap();
2163
2164 let now = Timestamp::now();
2166 let event_id_1 =
2167 EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000001")
2168 .unwrap();
2169 let wrapper_id_1 = EventId::all_zeros();
2170 let message_1 = Message {
2171 id: event_id_1,
2172 pubkey,
2173 kind: Kind::MlsGroupMessage,
2174 mls_group_id: mls_group_id.clone(),
2175 created_at: now,
2176 processed_at: now,
2177 content: "First message".to_string(),
2178 tags: Tags::new(),
2179 event: UnsignedEvent::new(
2180 pubkey,
2181 now,
2182 Kind::MlsGroupMessage,
2183 Tags::new(),
2184 "First message".to_string(),
2185 ),
2186 wrapper_event_id: wrapper_id_1,
2187 state: MessageState::Created,
2188 epoch: None,
2189 };
2190 nostr_storage.save_message(message_1.clone()).unwrap();
2191
2192 let event_id_2 =
2194 EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000002")
2195 .unwrap();
2196 let wrapper_id_2 = EventId::all_zeros();
2197 let message_2 = Message {
2198 id: event_id_2,
2199 pubkey,
2200 kind: Kind::MlsGroupMessage,
2201 mls_group_id: mls_group_id.clone(),
2202 created_at: now,
2203 processed_at: now,
2204 content: "Second message".to_string(),
2205 tags: Tags::new(),
2206 event: UnsignedEvent::new(
2207 pubkey,
2208 now,
2209 Kind::MlsGroupMessage,
2210 Tags::new(),
2211 "Second message".to_string(),
2212 ),
2213 wrapper_event_id: wrapper_id_2,
2214 state: MessageState::Created,
2215 epoch: None,
2216 };
2217 nostr_storage.save_message(message_2.clone()).unwrap();
2218
2219 let found_message_1 = nostr_storage
2221 .find_message_by_event_id(&mls_group_id, &event_id_1)
2222 .unwrap()
2223 .unwrap();
2224 assert_eq!(found_message_1.content, "First message");
2225
2226 let found_message_2 = nostr_storage
2227 .find_message_by_event_id(&mls_group_id, &event_id_2)
2228 .unwrap()
2229 .unwrap();
2230 assert_eq!(found_message_2.content, "Second message");
2231
2232 {
2234 let inner = nostr_storage.inner.read();
2235 let cache = &inner.messages_by_group_cache;
2236 let group_messages = cache.peek(&mls_group_id).unwrap();
2237 assert_eq!(group_messages.len(), 2);
2238 assert_eq!(
2239 group_messages.get(&event_id_1).unwrap().content,
2240 "First message"
2241 );
2242 assert_eq!(
2243 group_messages.get(&event_id_2).unwrap().content,
2244 "Second message"
2245 );
2246 }
2247 }
2248
2249 #[test]
2250 fn test_save_message_verifies_group_existence_before_cache_insertion() {
2251 let nostr_storage = MdkMemoryStorage::default();
2252 let mls_group_id = create_test_group_id();
2253 let nonexistent_group_id = GroupId::from_slice(&[9, 9, 9, 9]);
2254 let event_id = EventId::all_zeros();
2255 let wrapper_id = EventId::all_zeros();
2256 let pubkey =
2257 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
2258 .unwrap();
2259
2260 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2262 let group = Group {
2263 mls_group_id: mls_group_id.clone(),
2264 nostr_group_id,
2265 name: "Test Group".to_string(),
2266 description: "A test group".to_string(),
2267 admin_pubkeys: BTreeSet::new(),
2268 last_message_id: None,
2269 last_message_at: None,
2270 last_message_processed_at: None,
2271 epoch: 0,
2272 state: GroupState::Active,
2273 image_hash: None,
2274 image_key: None,
2275 image_nonce: None,
2276 self_update_state: SelfUpdateState::Required,
2277 };
2278 nostr_storage.save_group(group).unwrap();
2279
2280 let now = Timestamp::now();
2282 let message = Message {
2283 id: event_id,
2284 pubkey,
2285 kind: Kind::MlsGroupMessage,
2286 mls_group_id: nonexistent_group_id.clone(),
2287 created_at: now,
2288 processed_at: now,
2289 content: "Hello, world!".to_string(),
2290 tags: Tags::new(),
2291 event: UnsignedEvent::new(
2292 pubkey,
2293 now,
2294 Kind::MlsGroupMessage,
2295 Tags::new(),
2296 "Hello, world!".to_string(),
2297 ),
2298 wrapper_event_id: wrapper_id,
2299 state: MessageState::Created,
2300 epoch: None,
2301 };
2302
2303 let result = nostr_storage.save_message(message);
2304 assert!(result.is_err());
2305
2306 {
2308 let inner = nostr_storage.inner.read();
2309 let cache = &inner.messages_cache;
2310 assert!(!cache.contains(&event_id));
2311 }
2312 {
2313 let inner = nostr_storage.inner.read();
2314 let cache = &inner.messages_by_group_cache;
2315 assert!(!cache.contains(&nonexistent_group_id));
2316 }
2317
2318 {
2320 let inner = nostr_storage.inner.read();
2321 let cache = &inner.messages_by_group_cache;
2322 if let Some(messages) = cache.peek(&mls_group_id) {
2323 assert!(messages.is_empty());
2324 }
2325 }
2326 }
2327
2328 #[test]
2329 fn test_with_custom_cache_size() {
2330 let custom_size = NonZeroUsize::new(50).unwrap();
2331 let nostr_storage = MdkMemoryStorage::with_cache_size(custom_size);
2332
2333 assert_eq!(nostr_storage.limits().cache_size, 50);
2335
2336 let mls_group_id = create_test_group_id();
2338 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2339 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
2340 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
2341 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
2342 let group = Group {
2343 mls_group_id: mls_group_id.clone(),
2344 nostr_group_id,
2345 name: "Custom Cache Group".to_string(),
2346 description: "A group for testing custom cache size".to_string(),
2347 admin_pubkeys: BTreeSet::new(),
2348 last_message_id: None,
2349 last_message_at: None,
2350 last_message_processed_at: None,
2351 epoch: 0,
2352 state: GroupState::Active,
2353 image_hash,
2354 image_key,
2355 image_nonce,
2356 self_update_state: SelfUpdateState::Required,
2357 };
2358
2359 nostr_storage.save_group(group.clone()).unwrap();
2361
2362 let found_group = nostr_storage.find_group_by_mls_group_id(&mls_group_id);
2364 assert!(found_group.is_ok());
2365 let found_group = found_group.unwrap().unwrap();
2366 assert_eq!(found_group.mls_group_id, mls_group_id);
2367 }
2368
2369 #[test]
2370 fn test_default_implementation() {
2371 let nostr_storage = MdkMemoryStorage::default();
2372
2373 let mls_group_id = create_test_group_id();
2375 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2376 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
2377 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
2378 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
2379
2380 let group = Group {
2381 mls_group_id: mls_group_id.clone(),
2382 nostr_group_id,
2383 name: "Default Implementation Group".to_string(),
2384 description: "A group for testing default implementation".to_string(),
2385 admin_pubkeys: BTreeSet::new(),
2386 last_message_id: None,
2387 last_message_at: None,
2388 last_message_processed_at: None,
2389 epoch: 0,
2390 state: GroupState::Active,
2391 image_hash,
2392 image_key,
2393 image_nonce,
2394 self_update_state: SelfUpdateState::Required,
2395 };
2396
2397 nostr_storage.save_group(group.clone()).unwrap();
2399
2400 let found_group = nostr_storage.find_group_by_mls_group_id(&mls_group_id);
2402 assert!(found_group.is_ok());
2403 let found_group = found_group.unwrap().unwrap();
2404 assert_eq!(found_group.mls_group_id, mls_group_id);
2405 }
2406
2407 #[test]
2408 fn test_snapshot_and_restore() {
2409 let storage = MdkMemoryStorage::default();
2410
2411 let mls_group_id = create_test_group_id();
2413 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2414 let group = Group {
2415 mls_group_id: mls_group_id.clone(),
2416 nostr_group_id,
2417 name: "Snapshot Test Group".to_string(),
2418 description: "A group for testing snapshots".to_string(),
2419 admin_pubkeys: BTreeSet::new(),
2420 last_message_id: None,
2421 last_message_at: None,
2422 last_message_processed_at: None,
2423 epoch: 0,
2424 state: GroupState::Active,
2425 image_hash: None,
2426 image_key: None,
2427 image_nonce: None,
2428 self_update_state: SelfUpdateState::Required,
2429 };
2430 storage.save_group(group.clone()).unwrap();
2431
2432 let snapshot = storage.create_snapshot();
2434
2435 let modified_group = Group {
2437 name: "Modified Group Name".to_string(),
2438 epoch: 5,
2439 ..group.clone()
2440 };
2441 storage.save_group(modified_group.clone()).unwrap();
2442
2443 let found_group = storage
2445 .find_group_by_mls_group_id(&mls_group_id)
2446 .unwrap()
2447 .unwrap();
2448 assert_eq!(found_group.name, "Modified Group Name");
2449 assert_eq!(found_group.epoch, 5);
2450
2451 storage.restore_snapshot(snapshot);
2453
2454 let restored_group = storage
2456 .find_group_by_mls_group_id(&mls_group_id)
2457 .unwrap()
2458 .unwrap();
2459 assert_eq!(restored_group.name, "Snapshot Test Group");
2460 assert_eq!(restored_group.epoch, 0);
2461 }
2462
2463 #[test]
2464 fn test_snapshot_with_messages() {
2465 let storage = MdkMemoryStorage::default();
2466
2467 let mls_group_id = create_test_group_id();
2469 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2470 let group = Group {
2471 mls_group_id: mls_group_id.clone(),
2472 nostr_group_id,
2473 name: "Message Snapshot Group".to_string(),
2474 description: "A group for testing message snapshots".to_string(),
2475 admin_pubkeys: BTreeSet::new(),
2476 last_message_id: None,
2477 last_message_at: None,
2478 last_message_processed_at: None,
2479 epoch: 0,
2480 state: GroupState::Active,
2481 image_hash: None,
2482 image_key: None,
2483 image_nonce: None,
2484 self_update_state: SelfUpdateState::Required,
2485 };
2486 storage.save_group(group).unwrap();
2487
2488 let pubkey =
2490 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
2491 .unwrap();
2492 let event_id =
2493 EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000001")
2494 .unwrap();
2495 let now = Timestamp::now();
2496 let message = Message {
2497 id: event_id,
2498 pubkey,
2499 kind: Kind::MlsGroupMessage,
2500 mls_group_id: mls_group_id.clone(),
2501 created_at: now,
2502 processed_at: now,
2503 content: "Original message".to_string(),
2504 tags: Tags::new(),
2505 event: UnsignedEvent::new(
2506 pubkey,
2507 now,
2508 Kind::MlsGroupMessage,
2509 Tags::new(),
2510 "Original message".to_string(),
2511 ),
2512 wrapper_event_id: EventId::all_zeros(),
2513 state: MessageState::Created,
2514 epoch: None,
2515 };
2516 storage.save_message(message).unwrap();
2517
2518 let snapshot = storage.create_snapshot();
2520
2521 let event_id_2 =
2523 EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000002")
2524 .unwrap();
2525 let message_2 = Message {
2526 id: event_id_2,
2527 pubkey,
2528 kind: Kind::MlsGroupMessage,
2529 mls_group_id: mls_group_id.clone(),
2530 created_at: now,
2531 processed_at: now,
2532 content: "Second message".to_string(),
2533 tags: Tags::new(),
2534 event: UnsignedEvent::new(
2535 pubkey,
2536 now,
2537 Kind::MlsGroupMessage,
2538 Tags::new(),
2539 "Second message".to_string(),
2540 ),
2541 wrapper_event_id: EventId::all_zeros(),
2542 state: MessageState::Created,
2543 epoch: None,
2544 };
2545 storage.save_message(message_2).unwrap();
2546
2547 let messages = storage.messages(&mls_group_id, None).unwrap();
2549 assert_eq!(messages.len(), 2);
2550
2551 storage.restore_snapshot(snapshot);
2553
2554 let messages_after = storage.messages(&mls_group_id, None).unwrap();
2556 assert_eq!(messages_after.len(), 1);
2557 assert_eq!(messages_after[0].content, "Original message");
2558 }
2559
2560 #[test]
2565 fn test_snapshot_with_new_group_rollback() {
2566 let storage = MdkMemoryStorage::default();
2567
2568 let mls_group_id = GroupId::from_slice(&[13, 14, 15, 16]);
2570 let before = storage.find_group_by_mls_group_id(&mls_group_id).unwrap();
2571 assert!(before.is_none());
2572
2573 let snapshot = storage.create_snapshot();
2575
2576 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2578 let group = Group {
2579 mls_group_id: mls_group_id.clone(),
2580 nostr_group_id,
2581 name: "New Group".to_string(),
2582 description: "A new group".to_string(),
2583 admin_pubkeys: BTreeSet::new(),
2584 last_message_id: None,
2585 last_message_at: None,
2586 last_message_processed_at: None,
2587 epoch: 0,
2588 state: GroupState::Active,
2589 image_hash: None,
2590 image_key: None,
2591 image_nonce: None,
2592 self_update_state: SelfUpdateState::Required,
2593 };
2594 storage.save_group(group).unwrap();
2595
2596 let after_insert = storage.find_group_by_mls_group_id(&mls_group_id).unwrap();
2598 assert!(after_insert.is_some());
2599
2600 storage.restore_snapshot(snapshot);
2602
2603 let after_rollback = storage.find_group_by_mls_group_id(&mls_group_id).unwrap();
2605 assert!(after_rollback.is_none());
2606 }
2607
2608 #[test]
2609 fn test_snapshot_with_multiple_modifications_rollback() {
2610 let storage = MdkMemoryStorage::default();
2611
2612 let mls_group_id = GroupId::from_slice(&[17, 18, 19, 20]);
2614 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2615 let group = Group {
2616 mls_group_id: mls_group_id.clone(),
2617 nostr_group_id,
2618 name: "Original Name".to_string(),
2619 description: "A group for testing modification rollback".to_string(),
2620 admin_pubkeys: BTreeSet::new(),
2621 last_message_id: None,
2622 last_message_at: None,
2623 last_message_processed_at: None,
2624 epoch: 1,
2625 state: GroupState::Active,
2626 image_hash: None,
2627 image_key: None,
2628 image_nonce: None,
2629 self_update_state: SelfUpdateState::Required,
2630 };
2631 storage.save_group(group.clone()).unwrap();
2632
2633 let exists = storage
2635 .find_group_by_mls_group_id(&mls_group_id)
2636 .unwrap()
2637 .unwrap();
2638 assert_eq!(exists.name, "Original Name");
2639 assert_eq!(exists.epoch, 1);
2640
2641 let snapshot = storage.create_snapshot();
2643
2644 let modified1 = Group {
2646 name: "Modified Once".to_string(),
2647 epoch: 10,
2648 ..group.clone()
2649 };
2650 storage.save_group(modified1).unwrap();
2651
2652 let modified2 = Group {
2653 name: "Modified Twice".to_string(),
2654 epoch: 20,
2655 ..group.clone()
2656 };
2657 storage.save_group(modified2).unwrap();
2658
2659 let after_mods = storage
2661 .find_group_by_mls_group_id(&mls_group_id)
2662 .unwrap()
2663 .unwrap();
2664 assert_eq!(after_mods.name, "Modified Twice");
2665 assert_eq!(after_mods.epoch, 20);
2666
2667 storage.restore_snapshot(snapshot);
2669
2670 let after_rollback = storage
2672 .find_group_by_mls_group_id(&mls_group_id)
2673 .unwrap()
2674 .unwrap();
2675 assert_eq!(after_rollback.name, "Original Name");
2676 assert_eq!(after_rollback.epoch, 1);
2677 }
2678
2679 #[test]
2680 fn test_snapshot_with_relays_rollback() {
2681 let storage = MdkMemoryStorage::default();
2682
2683 let mls_group_id = GroupId::from_slice(&[21, 22, 23, 24]);
2685 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2686 let group = Group {
2687 mls_group_id: mls_group_id.clone(),
2688 nostr_group_id,
2689 name: "Relay Test Group".to_string(),
2690 description: "A group for testing relay rollback".to_string(),
2691 admin_pubkeys: BTreeSet::new(),
2692 last_message_id: None,
2693 last_message_at: None,
2694 last_message_processed_at: None,
2695 epoch: 0,
2696 state: GroupState::Active,
2697 image_hash: None,
2698 image_key: None,
2699 image_nonce: None,
2700 self_update_state: SelfUpdateState::Required,
2701 };
2702 storage.save_group(group).unwrap();
2703
2704 let relay1 = RelayUrl::parse("wss://relay1.example.com").unwrap();
2706 storage
2707 .replace_group_relays(&mls_group_id, BTreeSet::from([relay1.clone()]))
2708 .unwrap();
2709
2710 let snapshot = storage.create_snapshot();
2712
2713 let relay2 = RelayUrl::parse("wss://relay2.example.com").unwrap();
2715 storage
2716 .replace_group_relays(&mls_group_id, BTreeSet::from([relay1.clone(), relay2]))
2717 .unwrap();
2718
2719 let relays_before = storage.group_relays(&mls_group_id).unwrap();
2721 assert_eq!(relays_before.len(), 2);
2722
2723 storage.restore_snapshot(snapshot);
2725
2726 let relays_after = storage.group_relays(&mls_group_id).unwrap();
2728 assert_eq!(relays_after.len(), 1);
2729 }
2730
2731 #[test]
2732 fn test_snapshot_with_exporter_secrets_rollback() {
2733 let storage = MdkMemoryStorage::default();
2734
2735 let mls_group_id = GroupId::from_slice(&[25, 26, 27, 28]);
2737 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2738 let group = Group {
2739 mls_group_id: mls_group_id.clone(),
2740 nostr_group_id,
2741 name: "Secret Test Group".to_string(),
2742 description: "A group for testing secret rollback".to_string(),
2743 admin_pubkeys: BTreeSet::new(),
2744 last_message_id: None,
2745 last_message_at: None,
2746 last_message_processed_at: None,
2747 epoch: 0,
2748 state: GroupState::Active,
2749 image_hash: None,
2750 image_key: None,
2751 image_nonce: None,
2752 self_update_state: SelfUpdateState::Required,
2753 };
2754 storage.save_group(group).unwrap();
2755
2756 let secret_0 = GroupExporterSecret {
2758 mls_group_id: mls_group_id.clone(),
2759 epoch: 0,
2760 secret: Secret::new([1u8; 32]),
2761 };
2762 storage
2763 .save_group_exporter_secret(secret_0.clone())
2764 .unwrap();
2765
2766 let snapshot = storage.create_snapshot();
2768
2769 let secret_1 = GroupExporterSecret {
2771 mls_group_id: mls_group_id.clone(),
2772 epoch: 1,
2773 secret: Secret::new([2u8; 32]),
2774 };
2775 storage
2776 .save_group_exporter_secret(secret_1.clone())
2777 .unwrap();
2778
2779 let found_1 = storage.get_group_exporter_secret(&mls_group_id, 1).unwrap();
2781 assert!(found_1.is_some());
2782
2783 storage.restore_snapshot(snapshot);
2785
2786 let after_rollback = storage.get_group_exporter_secret(&mls_group_id, 1).unwrap();
2788 assert!(after_rollback.is_none());
2789
2790 let epoch_0_exists = storage.get_group_exporter_secret(&mls_group_id, 0).unwrap();
2792 assert!(epoch_0_exists.is_some());
2793 }
2794
2795 #[test]
2796 fn test_snapshot_with_welcomes_rollback() {
2797 let storage = MdkMemoryStorage::default();
2798
2799 let snapshot = storage.create_snapshot();
2801
2802 let event_id = EventId::all_zeros();
2804 let wrapper_id = EventId::all_zeros();
2805 let mls_group_id = GroupId::from_slice(&[29, 30, 31, 32]);
2806 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2807 let pubkey =
2808 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
2809 .unwrap();
2810
2811 let welcome = Welcome {
2812 id: event_id,
2813 event: UnsignedEvent::new(
2814 pubkey,
2815 Timestamp::now(),
2816 Kind::MlsWelcome,
2817 Tags::new(),
2818 "test".to_string(),
2819 ),
2820 mls_group_id: mls_group_id.clone(),
2821 nostr_group_id,
2822 group_name: "Welcome Test Group".to_string(),
2823 group_description: "A test welcome group".to_string(),
2824 group_image_key: None,
2825 group_image_hash: None,
2826 group_image_nonce: None,
2827 group_admin_pubkeys: BTreeSet::from([pubkey]),
2828 group_relays: BTreeSet::from([RelayUrl::parse("wss://relay.example.com").unwrap()]),
2829 welcomer: pubkey,
2830 member_count: 2,
2831 state: WelcomeState::Pending,
2832 wrapper_event_id: wrapper_id,
2833 };
2834 storage.save_welcome(welcome).unwrap();
2835
2836 let found = storage.find_welcome_by_event_id(&event_id).unwrap();
2838 assert!(found.is_some());
2839
2840 storage.restore_snapshot(snapshot);
2842
2843 let after_rollback = storage.find_welcome_by_event_id(&event_id).unwrap();
2845 assert!(after_rollback.is_none());
2846 }
2847
2848 #[test]
2849 fn test_snapshot_multiple_operations_rollback() {
2850 let storage = MdkMemoryStorage::default();
2851
2852 let mls_group_id_1 = GroupId::from_slice(&[33, 34, 35, 36]);
2854 let nostr_group_id_1 = generate_random_bytes(32).try_into().unwrap();
2855 let group1 = Group {
2856 mls_group_id: mls_group_id_1.clone(),
2857 nostr_group_id: nostr_group_id_1,
2858 name: "Group 1".to_string(),
2859 description: "First group".to_string(),
2860 admin_pubkeys: BTreeSet::new(),
2861 last_message_id: None,
2862 last_message_at: None,
2863 last_message_processed_at: None,
2864 epoch: 0,
2865 state: GroupState::Active,
2866 image_hash: None,
2867 image_key: None,
2868 image_nonce: None,
2869 self_update_state: SelfUpdateState::Required,
2870 };
2871 storage.save_group(group1).unwrap();
2872
2873 let snapshot = storage.create_snapshot();
2875
2876 let mls_group_id_2 = GroupId::from_slice(&[37, 38, 39, 40]);
2879 let nostr_group_id_2 = generate_random_bytes(32).try_into().unwrap();
2880 let group2 = Group {
2881 mls_group_id: mls_group_id_2.clone(),
2882 nostr_group_id: nostr_group_id_2,
2883 name: "Group 2".to_string(),
2884 description: "Second group".to_string(),
2885 admin_pubkeys: BTreeSet::new(),
2886 last_message_id: None,
2887 last_message_at: None,
2888 last_message_processed_at: None,
2889 epoch: 0,
2890 state: GroupState::Active,
2891 image_hash: None,
2892 image_key: None,
2893 image_nonce: None,
2894 self_update_state: SelfUpdateState::Required,
2895 };
2896 storage.save_group(group2).unwrap();
2897
2898 let pubkey =
2900 PublicKey::from_hex("aabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabbccddeeffaabb")
2901 .unwrap();
2902 let event_id =
2903 EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000099")
2904 .unwrap();
2905 let now = Timestamp::now();
2906 let message = Message {
2907 id: event_id,
2908 pubkey,
2909 kind: Kind::MlsGroupMessage,
2910 mls_group_id: mls_group_id_1.clone(),
2911 created_at: now,
2912 processed_at: now,
2913 content: "Test message".to_string(),
2914 tags: Tags::new(),
2915 event: UnsignedEvent::new(
2916 pubkey,
2917 now,
2918 Kind::MlsGroupMessage,
2919 Tags::new(),
2920 "Test message".to_string(),
2921 ),
2922 wrapper_event_id: EventId::all_zeros(),
2923 state: MessageState::Created,
2924 epoch: None,
2925 };
2926 storage.save_message(message).unwrap();
2927
2928 let modified_group1 = Group {
2930 mls_group_id: mls_group_id_1.clone(),
2931 nostr_group_id: nostr_group_id_1,
2932 name: "Modified Group 1".to_string(),
2933 description: "First group".to_string(),
2934 admin_pubkeys: BTreeSet::new(),
2935 last_message_id: None,
2936 last_message_at: None,
2937 last_message_processed_at: None,
2938 epoch: 5,
2939 state: GroupState::Active,
2940 image_hash: None,
2941 image_key: None,
2942 image_nonce: None,
2943 self_update_state: SelfUpdateState::Required,
2944 };
2945 storage.save_group(modified_group1).unwrap();
2946
2947 let groups = storage.all_groups().unwrap();
2949 assert_eq!(groups.len(), 2);
2950 let messages = storage.messages(&mls_group_id_1, None).unwrap();
2951 assert_eq!(messages.len(), 1);
2952 let g1 = storage
2953 .find_group_by_mls_group_id(&mls_group_id_1)
2954 .unwrap()
2955 .unwrap();
2956 assert_eq!(g1.name, "Modified Group 1");
2957 assert_eq!(g1.epoch, 5);
2958
2959 storage.restore_snapshot(snapshot);
2961
2962 let groups_after = storage.all_groups().unwrap();
2964 assert_eq!(groups_after.len(), 1);
2965 let g2_gone = storage.find_group_by_mls_group_id(&mls_group_id_2).unwrap();
2966 assert!(g2_gone.is_none());
2967 let messages_after = storage.messages(&mls_group_id_1, None).unwrap();
2968 assert_eq!(messages_after.len(), 0);
2969 let g1_restored = storage
2970 .find_group_by_mls_group_id(&mls_group_id_1)
2971 .unwrap()
2972 .unwrap();
2973 assert_eq!(g1_restored.name, "Group 1");
2974 assert_eq!(g1_restored.epoch, 0);
2975 }
2976
2977 #[test]
2978 fn test_snapshot_preserves_snapshot_independence() {
2979 let storage = MdkMemoryStorage::default();
2980
2981 let mls_group_id = GroupId::from_slice(&[41, 42, 43, 44]);
2983 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
2984 let group = Group {
2985 mls_group_id: mls_group_id.clone(),
2986 nostr_group_id,
2987 name: "State A".to_string(),
2988 description: "Initial state".to_string(),
2989 admin_pubkeys: BTreeSet::new(),
2990 last_message_id: None,
2991 last_message_at: None,
2992 last_message_processed_at: None,
2993 epoch: 0,
2994 state: GroupState::Active,
2995 image_hash: None,
2996 image_key: None,
2997 image_nonce: None,
2998 self_update_state: SelfUpdateState::Required,
2999 };
3000 storage.save_group(group.clone()).unwrap();
3001
3002 let snapshot_a = storage.create_snapshot();
3004
3005 let group_b = Group {
3007 name: "State B".to_string(),
3008 epoch: 1,
3009 ..group.clone()
3010 };
3011 storage.save_group(group_b.clone()).unwrap();
3012
3013 let snapshot_b = storage.create_snapshot();
3015
3016 let group_c = Group {
3018 name: "State C".to_string(),
3019 epoch: 2,
3020 ..group.clone()
3021 };
3022 storage.save_group(group_c).unwrap();
3023
3024 let current = storage
3026 .find_group_by_mls_group_id(&mls_group_id)
3027 .unwrap()
3028 .unwrap();
3029 assert_eq!(current.name, "State C");
3030
3031 storage.restore_snapshot(snapshot_a.clone());
3033 let after_a = storage
3034 .find_group_by_mls_group_id(&mls_group_id)
3035 .unwrap()
3036 .unwrap();
3037 assert_eq!(after_a.name, "State A");
3038
3039 storage.restore_snapshot(snapshot_b);
3041 let after_b = storage
3042 .find_group_by_mls_group_id(&mls_group_id)
3043 .unwrap()
3044 .unwrap();
3045 assert_eq!(after_b.name, "State B");
3046
3047 storage.restore_snapshot(snapshot_a);
3049 let final_state = storage
3050 .find_group_by_mls_group_id(&mls_group_id)
3051 .unwrap()
3052 .unwrap();
3053 assert_eq!(final_state.name, "State A");
3054 }
3055
3056 #[test]
3061 fn test_snapshot_isolation_between_groups() {
3062 let storage = MdkMemoryStorage::default();
3063
3064 let group1_id = GroupId::from_slice(&[1; 32]);
3066 let group2_id = GroupId::from_slice(&[2; 32]);
3067 let nostr_group_id_1: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3068 let nostr_group_id_2: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3069
3070 let group1 = Group {
3071 mls_group_id: group1_id.clone(),
3072 nostr_group_id: nostr_group_id_1,
3073 name: "Group 1 Original".to_string(),
3074 description: "First group".to_string(),
3075 admin_pubkeys: BTreeSet::new(),
3076 last_message_id: None,
3077 last_message_at: None,
3078 last_message_processed_at: None,
3079 epoch: 5,
3080 state: GroupState::Active,
3081 image_hash: None,
3082 image_key: None,
3083 image_nonce: None,
3084 self_update_state: SelfUpdateState::Required,
3085 };
3086
3087 let group2 = Group {
3088 mls_group_id: group2_id.clone(),
3089 nostr_group_id: nostr_group_id_2,
3090 name: "Group 2 Original".to_string(),
3091 description: "Second group".to_string(),
3092 admin_pubkeys: BTreeSet::new(),
3093 last_message_id: None,
3094 last_message_at: None,
3095 last_message_processed_at: None,
3096 epoch: 10,
3097 state: GroupState::Active,
3098 image_hash: None,
3099 image_key: None,
3100 image_nonce: None,
3101 self_update_state: SelfUpdateState::Required,
3102 };
3103
3104 storage.save_group(group1.clone()).unwrap();
3105 storage.save_group(group2.clone()).unwrap();
3106
3107 storage
3109 .create_group_snapshot(&group1_id, "group1_snap")
3110 .unwrap();
3111
3112 let modified_group1 = Group {
3114 name: "Group 1 Modified".to_string(),
3115 epoch: 6,
3116 ..group1.clone()
3117 };
3118 let modified_group2 = Group {
3119 name: "Group 2 Modified".to_string(),
3120 epoch: 11,
3121 ..group2.clone()
3122 };
3123
3124 storage.save_group(modified_group1).unwrap();
3125 storage.save_group(modified_group2).unwrap();
3126
3127 let found1 = storage
3129 .find_group_by_mls_group_id(&group1_id)
3130 .unwrap()
3131 .unwrap();
3132 let found2 = storage
3133 .find_group_by_mls_group_id(&group2_id)
3134 .unwrap()
3135 .unwrap();
3136 assert_eq!(found1.name, "Group 1 Modified");
3137 assert_eq!(found1.epoch, 6);
3138 assert_eq!(found2.name, "Group 2 Modified");
3139 assert_eq!(found2.epoch, 11);
3140
3141 storage
3143 .rollback_group_to_snapshot(&group1_id, "group1_snap")
3144 .unwrap();
3145
3146 let final1 = storage
3148 .find_group_by_mls_group_id(&group1_id)
3149 .unwrap()
3150 .unwrap();
3151 assert_eq!(final1.name, "Group 1 Original");
3152 assert_eq!(final1.epoch, 5);
3153
3154 let final2 = storage
3156 .find_group_by_mls_group_id(&group2_id)
3157 .unwrap()
3158 .unwrap();
3159 assert_eq!(
3160 final2.name, "Group 2 Modified",
3161 "Group 2 should NOT be affected by Group 1's rollback"
3162 );
3163 assert_eq!(
3164 final2.epoch, 11,
3165 "Group 2's epoch should NOT be affected by Group 1's rollback"
3166 );
3167 }
3168
3169 #[test]
3171 fn test_snapshot_isolation_with_exporter_secrets() {
3172 let storage = MdkMemoryStorage::default();
3173
3174 let group1_id = GroupId::from_slice(&[11; 32]);
3176 let group2_id = GroupId::from_slice(&[22; 32]);
3177 let nostr_group_id_1: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3178 let nostr_group_id_2: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3179
3180 let group1 = Group {
3181 mls_group_id: group1_id.clone(),
3182 nostr_group_id: nostr_group_id_1,
3183 name: "Group 1".to_string(),
3184 description: "".to_string(),
3185 admin_pubkeys: BTreeSet::new(),
3186 last_message_id: None,
3187 last_message_at: None,
3188 last_message_processed_at: None,
3189 epoch: 1,
3190 state: GroupState::Active,
3191 image_hash: None,
3192 image_key: None,
3193 image_nonce: None,
3194 self_update_state: SelfUpdateState::Required,
3195 };
3196
3197 let group2 = Group {
3198 mls_group_id: group2_id.clone(),
3199 nostr_group_id: nostr_group_id_2,
3200 name: "Group 2".to_string(),
3201 description: "".to_string(),
3202 admin_pubkeys: BTreeSet::new(),
3203 last_message_id: None,
3204 last_message_at: None,
3205 last_message_processed_at: None,
3206 epoch: 1,
3207 state: GroupState::Active,
3208 image_hash: None,
3209 image_key: None,
3210 image_nonce: None,
3211 self_update_state: SelfUpdateState::Required,
3212 };
3213
3214 storage.save_group(group1).unwrap();
3215 storage.save_group(group2).unwrap();
3216
3217 let secret1_epoch0 = GroupExporterSecret {
3219 mls_group_id: group1_id.clone(),
3220 epoch: 0,
3221 secret: Secret::new([1u8; 32]),
3222 };
3223 let secret2_epoch0 = GroupExporterSecret {
3224 mls_group_id: group2_id.clone(),
3225 epoch: 0,
3226 secret: Secret::new([2u8; 32]),
3227 };
3228
3229 storage
3230 .save_group_exporter_secret(secret1_epoch0.clone())
3231 .unwrap();
3232 storage
3233 .save_group_exporter_secret(secret2_epoch0.clone())
3234 .unwrap();
3235
3236 storage
3238 .create_group_snapshot(&group1_id, "group1_secrets_snap")
3239 .unwrap();
3240
3241 let secret1_epoch1 = GroupExporterSecret {
3243 mls_group_id: group1_id.clone(),
3244 epoch: 1,
3245 secret: Secret::new([11u8; 32]),
3246 };
3247 let secret2_epoch1 = GroupExporterSecret {
3248 mls_group_id: group2_id.clone(),
3249 epoch: 1,
3250 secret: Secret::new([22u8; 32]),
3251 };
3252
3253 storage.save_group_exporter_secret(secret1_epoch1).unwrap();
3254 storage.save_group_exporter_secret(secret2_epoch1).unwrap();
3255
3256 assert!(
3258 storage
3259 .get_group_exporter_secret(&group1_id, 1)
3260 .unwrap()
3261 .is_some()
3262 );
3263 assert!(
3264 storage
3265 .get_group_exporter_secret(&group2_id, 1)
3266 .unwrap()
3267 .is_some()
3268 );
3269
3270 storage
3272 .rollback_group_to_snapshot(&group1_id, "group1_secrets_snap")
3273 .unwrap();
3274
3275 assert!(
3277 storage
3278 .get_group_exporter_secret(&group1_id, 1)
3279 .unwrap()
3280 .is_none(),
3281 "Group 1's epoch 1 secret should be rolled back"
3282 );
3283
3284 assert!(
3286 storage
3287 .get_group_exporter_secret(&group2_id, 1)
3288 .unwrap()
3289 .is_some(),
3290 "Group 2's epoch 1 secret should NOT be affected by Group 1's rollback"
3291 );
3292 }
3293
3294 #[test]
3296 fn test_snapshot_isolation_with_mip04_exporter_secrets() {
3297 let storage = MdkMemoryStorage::default();
3298
3299 let group1_id = GroupId::from_slice(&[33; 32]);
3301 let group2_id = GroupId::from_slice(&[44; 32]);
3302 let nostr_group_id_1: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3303 let nostr_group_id_2: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3304
3305 let group1 = Group {
3306 mls_group_id: group1_id.clone(),
3307 nostr_group_id: nostr_group_id_1,
3308 name: "Group 1".to_string(),
3309 description: "".to_string(),
3310 admin_pubkeys: BTreeSet::new(),
3311 last_message_id: None,
3312 last_message_at: None,
3313 last_message_processed_at: None,
3314 epoch: 1,
3315 state: GroupState::Active,
3316 image_hash: None,
3317 image_key: None,
3318 image_nonce: None,
3319 self_update_state: SelfUpdateState::Required,
3320 };
3321
3322 let group2 = Group {
3323 mls_group_id: group2_id.clone(),
3324 nostr_group_id: nostr_group_id_2,
3325 name: "Group 2".to_string(),
3326 description: "".to_string(),
3327 admin_pubkeys: BTreeSet::new(),
3328 last_message_id: None,
3329 last_message_at: None,
3330 last_message_processed_at: None,
3331 epoch: 1,
3332 state: GroupState::Active,
3333 image_hash: None,
3334 image_key: None,
3335 image_nonce: None,
3336 self_update_state: SelfUpdateState::Required,
3337 };
3338
3339 storage.save_group(group1).unwrap();
3340 storage.save_group(group2).unwrap();
3341
3342 let secret1_epoch0 = GroupExporterSecret {
3344 mls_group_id: group1_id.clone(),
3345 epoch: 0,
3346 secret: Secret::new([3u8; 32]),
3347 };
3348 let secret2_epoch0 = GroupExporterSecret {
3349 mls_group_id: group2_id.clone(),
3350 epoch: 0,
3351 secret: Secret::new([4u8; 32]),
3352 };
3353
3354 storage
3355 .save_group_mip04_exporter_secret(secret1_epoch0.clone())
3356 .unwrap();
3357 storage
3358 .save_group_mip04_exporter_secret(secret2_epoch0.clone())
3359 .unwrap();
3360
3361 storage
3363 .create_group_snapshot(&group1_id, "group1_mip04_secrets_snap")
3364 .unwrap();
3365
3366 let secret1_epoch1 = GroupExporterSecret {
3368 mls_group_id: group1_id.clone(),
3369 epoch: 1,
3370 secret: Secret::new([33u8; 32]),
3371 };
3372 let secret2_epoch1 = GroupExporterSecret {
3373 mls_group_id: group2_id.clone(),
3374 epoch: 1,
3375 secret: Secret::new([44u8; 32]),
3376 };
3377
3378 storage
3379 .save_group_mip04_exporter_secret(secret1_epoch1)
3380 .unwrap();
3381 storage
3382 .save_group_mip04_exporter_secret(secret2_epoch1)
3383 .unwrap();
3384
3385 assert!(
3387 storage
3388 .get_group_mip04_exporter_secret(&group1_id, 1)
3389 .unwrap()
3390 .is_some()
3391 );
3392 assert!(
3393 storage
3394 .get_group_mip04_exporter_secret(&group2_id, 1)
3395 .unwrap()
3396 .is_some()
3397 );
3398
3399 storage
3401 .rollback_group_to_snapshot(&group1_id, "group1_mip04_secrets_snap")
3402 .unwrap();
3403
3404 assert!(
3406 storage
3407 .get_group_mip04_exporter_secret(&group1_id, 1)
3408 .unwrap()
3409 .is_none(),
3410 "Group 1's MIP-04 epoch 1 secret should be rolled back"
3411 );
3412
3413 assert!(
3415 storage
3416 .get_group_mip04_exporter_secret(&group2_id, 1)
3417 .unwrap()
3418 .is_some(),
3419 "Group 2's MIP-04 epoch 1 secret should NOT be affected by Group 1's rollback"
3420 );
3421 }
3422
3423 #[test]
3425 fn test_rollback_nonexistent_snapshot_returns_error() {
3426 let storage = MdkMemoryStorage::default();
3427
3428 let group_id = GroupId::from_slice(&[99; 32]);
3429 let nostr_group_id: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3430
3431 let group = Group {
3432 mls_group_id: group_id.clone(),
3433 nostr_group_id,
3434 name: "Test Group".to_string(),
3435 description: "".to_string(),
3436 admin_pubkeys: BTreeSet::new(),
3437 last_message_id: None,
3438 last_message_at: None,
3439 last_message_processed_at: None,
3440 epoch: 1,
3441 state: GroupState::Active,
3442 image_hash: None,
3443 image_key: None,
3444 image_nonce: None,
3445 self_update_state: SelfUpdateState::Required,
3446 };
3447
3448 storage.save_group(group).unwrap();
3449
3450 let result = storage.rollback_group_to_snapshot(&group_id, "nonexistent_snapshot");
3452
3453 assert!(
3454 result.is_err(),
3455 "Should return error for nonexistent snapshot"
3456 );
3457 match result {
3458 Err(MdkStorageError::NotFound(msg)) => {
3459 assert!(
3460 msg.contains("Snapshot not found"),
3461 "Error should indicate snapshot not found"
3462 );
3463 }
3464 _ => panic!("Expected NotFound error"),
3465 }
3466 }
3467
3468 #[test]
3469 fn test_list_group_snapshots_empty() {
3470 let storage = MdkMemoryStorage::default();
3471 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
3472
3473 let snapshots = storage.list_group_snapshots(&group_id).unwrap();
3474 assert!(
3475 snapshots.is_empty(),
3476 "Should return empty list for no snapshots"
3477 );
3478 }
3479
3480 #[test]
3481 fn test_list_group_snapshots_returns_snapshots_sorted_by_created_at() {
3482 let storage = MdkMemoryStorage::default();
3483 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
3484 let nostr_group_id: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3485
3486 let group = Group {
3488 mls_group_id: group_id.clone(),
3489 nostr_group_id,
3490 name: "Test Group".to_string(),
3491 description: "".to_string(),
3492 admin_pubkeys: BTreeSet::new(),
3493 last_message_id: None,
3494 last_message_at: None,
3495 last_message_processed_at: None,
3496 epoch: 1,
3497 state: GroupState::Active,
3498 image_hash: None,
3499 image_key: None,
3500 image_nonce: None,
3501 self_update_state: SelfUpdateState::Required,
3502 };
3503 storage.save_group(group).unwrap();
3504
3505 {
3508 let mut snapshots = storage.group_snapshots.write();
3509
3510 let snap1 = crate::snapshot::GroupScopedSnapshot {
3512 group_id: group_id.clone(),
3513 created_at: 1000,
3514 mls_group_data: std::collections::HashMap::new(),
3515 mls_own_leaf_nodes: vec![],
3516 mls_proposals: std::collections::HashMap::new(),
3517 mls_epoch_key_pairs: std::collections::HashMap::new(),
3518 group: None,
3519 group_relays: std::collections::BTreeSet::new(),
3520 group_exporter_secrets: std::collections::HashMap::new(),
3521 group_mip04_exporter_secrets: std::collections::HashMap::new(),
3522 };
3523 let snap2 = crate::snapshot::GroupScopedSnapshot {
3524 group_id: group_id.clone(),
3525 created_at: 3000, ..snap1.clone()
3527 };
3528 let snap3 = crate::snapshot::GroupScopedSnapshot {
3529 group_id: group_id.clone(),
3530 created_at: 2000, ..snap1.clone()
3532 };
3533
3534 snapshots.insert((group_id.clone(), "snap_oldest".to_string()), snap1);
3535 snapshots.insert((group_id.clone(), "snap_newest".to_string()), snap2);
3536 snapshots.insert((group_id.clone(), "snap_middle".to_string()), snap3);
3537 }
3538
3539 let result = storage.list_group_snapshots(&group_id).unwrap();
3540
3541 assert_eq!(result.len(), 3);
3542 assert_eq!(result[0].0, "snap_oldest");
3544 assert_eq!(result[0].1, 1000);
3545 assert_eq!(result[1].0, "snap_middle");
3546 assert_eq!(result[1].1, 2000);
3547 assert_eq!(result[2].0, "snap_newest");
3548 assert_eq!(result[2].1, 3000);
3549 }
3550
3551 #[test]
3552 fn test_list_group_snapshots_only_returns_matching_group() {
3553 let storage = MdkMemoryStorage::default();
3554 let group1 = GroupId::from_slice(&[1, 1, 1, 1]);
3555 let group2 = GroupId::from_slice(&[2, 2, 2, 2]);
3556
3557 {
3558 let mut snapshots = storage.group_snapshots.write();
3559
3560 let snap1 = crate::snapshot::GroupScopedSnapshot {
3561 group_id: group1.clone(),
3562 created_at: 1000,
3563 mls_group_data: std::collections::HashMap::new(),
3564 mls_own_leaf_nodes: vec![],
3565 mls_proposals: std::collections::HashMap::new(),
3566 mls_epoch_key_pairs: std::collections::HashMap::new(),
3567 group: None,
3568 group_relays: std::collections::BTreeSet::new(),
3569 group_exporter_secrets: std::collections::HashMap::new(),
3570 group_mip04_exporter_secrets: std::collections::HashMap::new(),
3571 };
3572 let snap2 = crate::snapshot::GroupScopedSnapshot {
3573 group_id: group2.clone(),
3574 created_at: 2000,
3575 ..snap1.clone()
3576 };
3577
3578 snapshots.insert((group1.clone(), "snap_group1".to_string()), snap1);
3579 snapshots.insert((group2.clone(), "snap_group2".to_string()), snap2);
3580 }
3581
3582 let result1 = storage.list_group_snapshots(&group1).unwrap();
3583 let result2 = storage.list_group_snapshots(&group2).unwrap();
3584
3585 assert_eq!(result1.len(), 1);
3586 assert_eq!(result1[0].0, "snap_group1");
3587
3588 assert_eq!(result2.len(), 1);
3589 assert_eq!(result2[0].0, "snap_group2");
3590 }
3591
3592 #[test]
3593 fn test_prune_expired_snapshots_removes_old_snapshots() {
3594 let storage = MdkMemoryStorage::default();
3595 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
3596
3597 {
3598 let mut snapshots = storage.group_snapshots.write();
3599
3600 let base_snap = crate::snapshot::GroupScopedSnapshot {
3601 group_id: group_id.clone(),
3602 created_at: 0,
3603 mls_group_data: std::collections::HashMap::new(),
3604 mls_own_leaf_nodes: vec![],
3605 mls_proposals: std::collections::HashMap::new(),
3606 mls_epoch_key_pairs: std::collections::HashMap::new(),
3607 group: None,
3608 group_relays: std::collections::BTreeSet::new(),
3609 group_exporter_secrets: std::collections::HashMap::new(),
3610 group_mip04_exporter_secrets: std::collections::HashMap::new(),
3611 };
3612
3613 let old_snap = crate::snapshot::GroupScopedSnapshot {
3615 created_at: 1000,
3616 ..base_snap.clone()
3617 };
3618 let new_snap = crate::snapshot::GroupScopedSnapshot {
3620 created_at: 5000,
3621 ..base_snap.clone()
3622 };
3623
3624 snapshots.insert((group_id.clone(), "old_snap".to_string()), old_snap);
3625 snapshots.insert((group_id.clone(), "new_snap".to_string()), new_snap);
3626 }
3627
3628 let pruned = storage.prune_expired_snapshots(3000).unwrap();
3630
3631 assert_eq!(pruned, 1, "Should have pruned 1 snapshot");
3632
3633 let remaining = storage.list_group_snapshots(&group_id).unwrap();
3634 assert_eq!(remaining.len(), 1);
3635 assert_eq!(remaining[0].0, "new_snap");
3636 assert_eq!(remaining[0].1, 5000);
3637 }
3638
3639 #[test]
3640 fn test_prune_expired_snapshots_returns_zero_when_nothing_to_prune() {
3641 let storage = MdkMemoryStorage::default();
3642 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
3643
3644 {
3645 let mut snapshots = storage.group_snapshots.write();
3646
3647 let snap = crate::snapshot::GroupScopedSnapshot {
3648 group_id: group_id.clone(),
3649 created_at: 5000, mls_group_data: std::collections::HashMap::new(),
3651 mls_own_leaf_nodes: vec![],
3652 mls_proposals: std::collections::HashMap::new(),
3653 mls_epoch_key_pairs: std::collections::HashMap::new(),
3654 group: None,
3655 group_relays: std::collections::BTreeSet::new(),
3656 group_exporter_secrets: std::collections::HashMap::new(),
3657 group_mip04_exporter_secrets: std::collections::HashMap::new(),
3658 };
3659
3660 snapshots.insert((group_id.clone(), "recent_snap".to_string()), snap);
3661 }
3662
3663 let pruned = storage.prune_expired_snapshots(1000).unwrap();
3665
3666 assert_eq!(pruned, 0, "Should have pruned 0 snapshots");
3667
3668 let remaining = storage.list_group_snapshots(&group_id).unwrap();
3669 assert_eq!(remaining.len(), 1);
3670 }
3671
3672 #[test]
3673 fn test_prune_expired_snapshots_across_multiple_groups() {
3674 let storage = MdkMemoryStorage::default();
3675 let group1 = GroupId::from_slice(&[1, 1, 1, 1]);
3676 let group2 = GroupId::from_slice(&[2, 2, 2, 2]);
3677
3678 {
3679 let mut snapshots = storage.group_snapshots.write();
3680
3681 let base_snap1 = crate::snapshot::GroupScopedSnapshot {
3682 group_id: group1.clone(),
3683 created_at: 1000, mls_group_data: std::collections::HashMap::new(),
3685 mls_own_leaf_nodes: vec![],
3686 mls_proposals: std::collections::HashMap::new(),
3687 mls_epoch_key_pairs: std::collections::HashMap::new(),
3688 group: None,
3689 group_relays: std::collections::BTreeSet::new(),
3690 group_exporter_secrets: std::collections::HashMap::new(),
3691 group_mip04_exporter_secrets: std::collections::HashMap::new(),
3692 };
3693 let base_snap2 = crate::snapshot::GroupScopedSnapshot {
3694 group_id: group2.clone(),
3695 created_at: 2000, ..base_snap1.clone()
3697 };
3698 let new_snap1 = crate::snapshot::GroupScopedSnapshot {
3699 group_id: group1.clone(),
3700 created_at: 5000, ..base_snap1.clone()
3702 };
3703
3704 snapshots.insert((group1.clone(), "old_snap_g1".to_string()), base_snap1);
3705 snapshots.insert((group2.clone(), "old_snap_g2".to_string()), base_snap2);
3706 snapshots.insert((group1.clone(), "new_snap_g1".to_string()), new_snap1);
3707 }
3708
3709 let pruned = storage.prune_expired_snapshots(3000).unwrap();
3711
3712 assert_eq!(pruned, 2, "Should have pruned 2 snapshots across groups");
3713
3714 let remaining1 = storage.list_group_snapshots(&group1).unwrap();
3715 let remaining2 = storage.list_group_snapshots(&group2).unwrap();
3716
3717 assert_eq!(remaining1.len(), 1);
3718 assert_eq!(remaining1[0].0, "new_snap_g1");
3719 assert!(remaining2.is_empty());
3720 }
3721
3722 #[test]
3734 fn test_snapshot_captures_mls_group_data() {
3735 let storage = MdkMemoryStorage::default();
3736
3737 let group_id = GroupId::from_slice(&[1; 32]);
3738 let nostr_group_id: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3739 let group = Group {
3740 mls_group_id: group_id.clone(),
3741 nostr_group_id,
3742 name: "MLS Data Test".to_string(),
3743 description: "".to_string(),
3744 admin_pubkeys: BTreeSet::new(),
3745 last_message_id: None,
3746 last_message_at: None,
3747 last_message_processed_at: None,
3748 epoch: 5,
3749 state: GroupState::Active,
3750 image_hash: None,
3751 image_key: None,
3752 image_nonce: None,
3753 self_update_state: SelfUpdateState::Required,
3754 };
3755 storage.save_group(group).unwrap();
3756
3757 {
3759 let mut inner = storage.inner.write();
3760 inner
3761 .mls_group_data
3762 .write(
3763 &group_id,
3764 mls_storage::GroupDataType::GroupState,
3765 &"epoch5_state".to_string(),
3766 )
3767 .unwrap();
3768 inner
3769 .mls_group_data
3770 .write(
3771 &group_id,
3772 mls_storage::GroupDataType::Tree,
3773 &"epoch5_tree".to_string(),
3774 )
3775 .unwrap();
3776 }
3777
3778 storage
3780 .create_group_snapshot(&group_id, "snap_mls")
3781 .unwrap();
3782
3783 {
3785 let mut inner = storage.inner.write();
3786 inner
3787 .mls_group_data
3788 .write(
3789 &group_id,
3790 mls_storage::GroupDataType::GroupState,
3791 &"epoch6_state".to_string(),
3792 )
3793 .unwrap();
3794 inner
3795 .mls_group_data
3796 .write(
3797 &group_id,
3798 mls_storage::GroupDataType::Tree,
3799 &"epoch6_tree".to_string(),
3800 )
3801 .unwrap();
3802 }
3803
3804 {
3806 let inner = storage.inner.read();
3807 let state: Option<String> = inner
3808 .mls_group_data
3809 .read(&group_id, mls_storage::GroupDataType::GroupState)
3810 .unwrap();
3811 assert_eq!(state.as_deref(), Some("epoch6_state"));
3812 }
3813
3814 storage
3816 .rollback_group_to_snapshot(&group_id, "snap_mls")
3817 .unwrap();
3818
3819 {
3821 let inner = storage.inner.read();
3822 let state: Option<String> = inner
3823 .mls_group_data
3824 .read(&group_id, mls_storage::GroupDataType::GroupState)
3825 .unwrap();
3826 assert_eq!(
3827 state.as_deref(),
3828 Some("epoch5_state"),
3829 "MLS group_data (GroupState) must be restored to snapshot state"
3830 );
3831
3832 let tree: Option<String> = inner
3833 .mls_group_data
3834 .read(&group_id, mls_storage::GroupDataType::Tree)
3835 .unwrap();
3836 assert_eq!(
3837 tree.as_deref(),
3838 Some("epoch5_tree"),
3839 "MLS group_data (Tree) must be restored to snapshot state"
3840 );
3841 }
3842 }
3843
3844 #[test]
3846 fn test_snapshot_captures_mls_own_leaf_nodes() {
3847 let storage = MdkMemoryStorage::default();
3848
3849 let group_id = GroupId::from_slice(&[2; 32]);
3850 let nostr_group_id: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3851 let group = Group {
3852 mls_group_id: group_id.clone(),
3853 nostr_group_id,
3854 name: "Leaf Node Test".to_string(),
3855 description: "".to_string(),
3856 admin_pubkeys: BTreeSet::new(),
3857 last_message_id: None,
3858 last_message_at: None,
3859 last_message_processed_at: None,
3860 epoch: 0,
3861 state: GroupState::Active,
3862 image_hash: None,
3863 image_key: None,
3864 image_nonce: None,
3865 self_update_state: SelfUpdateState::Required,
3866 };
3867 storage.save_group(group).unwrap();
3868
3869 {
3871 let mut inner = storage.inner.write();
3872 inner
3873 .mls_own_leaf_nodes
3874 .append(&group_id, &"original_leaf".to_string())
3875 .unwrap();
3876 }
3877
3878 storage
3880 .create_group_snapshot(&group_id, "snap_leaf")
3881 .unwrap();
3882
3883 {
3885 let mut inner = storage.inner.write();
3886 inner
3887 .mls_own_leaf_nodes
3888 .append(&group_id, &"added_after_snapshot".to_string())
3889 .unwrap();
3890 }
3891
3892 {
3894 let inner = storage.inner.read();
3895 let leaves: Vec<String> = inner.mls_own_leaf_nodes.read(&group_id).unwrap();
3896 assert_eq!(leaves.len(), 2);
3897 }
3898
3899 storage
3901 .rollback_group_to_snapshot(&group_id, "snap_leaf")
3902 .unwrap();
3903
3904 {
3906 let inner = storage.inner.read();
3907 let leaves: Vec<String> = inner.mls_own_leaf_nodes.read(&group_id).unwrap();
3908 assert_eq!(
3909 leaves.len(),
3910 1,
3911 "Rollback must restore own_leaf_nodes to snapshot state (1 leaf, not 2)"
3912 );
3913 assert_eq!(leaves[0], "original_leaf");
3914 }
3915 }
3916
3917 #[test]
3919 fn test_snapshot_captures_mls_epoch_key_pairs() {
3920 let storage = MdkMemoryStorage::default();
3921
3922 let group_id = GroupId::from_slice(&[3; 32]);
3923 let nostr_group_id: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
3924 let group = Group {
3925 mls_group_id: group_id.clone(),
3926 nostr_group_id,
3927 name: "Epoch Keys Test".to_string(),
3928 description: "".to_string(),
3929 admin_pubkeys: BTreeSet::new(),
3930 last_message_id: None,
3931 last_message_at: None,
3932 last_message_processed_at: None,
3933 epoch: 5,
3934 state: GroupState::Active,
3935 image_hash: None,
3936 image_key: None,
3937 image_nonce: None,
3938 self_update_state: SelfUpdateState::Required,
3939 };
3940 storage.save_group(group).unwrap();
3941
3942 let epoch_id = 5u64;
3944 let leaf_index = 0u32;
3945 {
3946 let mut inner = storage.inner.write();
3947 inner
3948 .mls_epoch_key_pairs
3949 .write(
3950 &group_id,
3951 &epoch_id,
3952 leaf_index,
3953 &["epoch5_key_pair".to_string()],
3954 )
3955 .unwrap();
3956 }
3957
3958 storage
3960 .create_group_snapshot(&group_id, "snap_keys")
3961 .unwrap();
3962
3963 {
3965 let mut inner = storage.inner.write();
3966 inner
3967 .mls_epoch_key_pairs
3968 .write(
3969 &group_id,
3970 &epoch_id,
3971 leaf_index,
3972 &["epoch6_key_pair".to_string()],
3973 )
3974 .unwrap();
3975 }
3976
3977 storage
3979 .rollback_group_to_snapshot(&group_id, "snap_keys")
3980 .unwrap();
3981
3982 {
3984 let inner = storage.inner.read();
3985 let key_pairs: Vec<String> = inner
3986 .mls_epoch_key_pairs
3987 .read(&group_id, &epoch_id, leaf_index)
3988 .unwrap();
3989 assert_eq!(
3990 key_pairs,
3991 vec!["epoch5_key_pair"],
3992 "Rollback must restore epoch_key_pairs to snapshot state"
3993 );
3994 }
3995 }
3996
3997 #[test]
4004 fn test_rollback_metadata_crypto_consistency() {
4005 let storage = MdkMemoryStorage::default();
4006
4007 let group_id = GroupId::from_slice(&[4; 32]);
4008 let nostr_group_id: [u8; 32] = generate_random_bytes(32).try_into().unwrap();
4009 let group = Group {
4010 mls_group_id: group_id.clone(),
4011 nostr_group_id,
4012 name: "Consistency Test".to_string(),
4013 description: "".to_string(),
4014 admin_pubkeys: BTreeSet::new(),
4015 last_message_id: None,
4016 last_message_at: None,
4017 last_message_processed_at: None,
4018 epoch: 5,
4019 state: GroupState::Active,
4020 image_hash: None,
4021 image_key: None,
4022 image_nonce: None,
4023 self_update_state: SelfUpdateState::Required,
4024 };
4025 storage.save_group(group).unwrap();
4026
4027 {
4029 let mut inner = storage.inner.write();
4030 inner
4031 .mls_group_data
4032 .write(
4033 &group_id,
4034 mls_storage::GroupDataType::GroupState,
4035 &"epoch5_state".to_string(),
4036 )
4037 .unwrap();
4038 }
4039
4040 storage
4042 .create_group_snapshot(&group_id, "snap_epoch5")
4043 .unwrap();
4044
4045 {
4047 let mut g = storage
4048 .find_group_by_mls_group_id(&group_id)
4049 .unwrap()
4050 .unwrap();
4051 g.epoch = 6;
4052 storage.save_group(g).unwrap();
4053 }
4054 {
4055 let mut inner = storage.inner.write();
4056 inner
4057 .mls_group_data
4058 .write(
4059 &group_id,
4060 mls_storage::GroupDataType::GroupState,
4061 &"epoch6_state".to_string(),
4062 )
4063 .unwrap();
4064 }
4065
4066 storage
4068 .rollback_group_to_snapshot(&group_id, "snap_epoch5")
4069 .unwrap();
4070
4071 let group_after = storage
4073 .find_group_by_mls_group_id(&group_id)
4074 .unwrap()
4075 .unwrap();
4076 assert_eq!(group_after.epoch, 5, "MDK epoch should be 5 after rollback");
4077
4078 let crypto_after: Option<String> = {
4080 let inner = storage.inner.read();
4081 inner
4082 .mls_group_data
4083 .read(&group_id, mls_storage::GroupDataType::GroupState)
4084 .unwrap()
4085 };
4086 assert_eq!(
4087 crypto_after.as_deref(),
4088 Some("epoch5_state"),
4089 "MLS crypto state must match MDK metadata epoch after rollback. \
4090 groups.epoch=5 but crypto state is epoch6 means split-brain: \
4091 MDK thinks epoch 5, MLS engine has epoch 6 keys."
4092 );
4093 }
4094}