Skip to main content

mdk_memory_storage/
snapshot.rs

1//! Snapshot and rollback support for memory storage.
2//!
3//! This module provides the ability to create snapshots of all in-memory state
4//! and restore them later. This provides functionality analogous to SQLite
5//! savepoints for testing and rollback scenarios.
6//!
7//! # Concurrency
8//!
9//! Snapshot creation and restoration are **atomic** operations:
10//!
11//! - `create_snapshot()` acquires a global read lock on the storage state,
12//!   ensuring a consistent snapshot even with concurrent reads.
13//! - `restore_snapshot()` acquires a global write lock on the storage state,
14//!   ensuring the restore is consistent and blocks all other operations.
15
16use std::collections::{BTreeSet, HashMap};
17
18use lru::LruCache;
19use mdk_storage_traits::GroupId;
20use mdk_storage_traits::groups::types::{Group, GroupExporterSecret, GroupRelay};
21use mdk_storage_traits::messages::types::{Message, ProcessedMessage};
22use mdk_storage_traits::welcomes::types::{ProcessedWelcome, Welcome};
23use nostr::EventId;
24
25use crate::mls_storage::GroupDataType;
26
27/// A group-scoped snapshot that only contains data for a single group.
28///
29/// Unlike [`MemoryStorageSnapshot`] which captures all data in the storage,
30/// this snapshot only captures data relevant to a specific group. This enables
31/// proper rollback isolation where rolling back Group A doesn't affect Group B.
32///
33/// This matches the behavior of SQLite's group-scoped snapshots where:
34/// - `snapshot_group_state()` only copies rows WHERE `group_id = ?`
35/// - `restore_group_from_snapshot()` only deletes/restores rows for that group
36///
37/// # Group-Scoped Data
38///
39/// The following data is captured per group:
40/// - MLS group data (tree state, join config, etc.)
41/// - MLS own leaf nodes for this group
42/// - MLS proposals for this group
43/// - MLS epoch key pairs for this group
44/// - MDK group record
45/// - MDK group relays
46/// - MDK group exporter secrets
47///
48/// The following data is NOT captured (not group-scoped):
49/// - MLS key packages (identity-scoped, not group-scoped)
50/// - MLS PSKs (identity-scoped, not group-scoped)
51/// - MLS signature keys (identity-scoped, not group-scoped)
52/// - MLS encryption keys (identity-scoped, not group-scoped)
53/// - Messages (handled separately via `invalidate_messages_after_epoch`)
54/// - Welcomes (keyed by EventId, not group-scoped)
55#[derive(Clone)]
56pub struct GroupScopedSnapshot {
57    /// The group ID this snapshot is for
58    pub(crate) group_id: GroupId,
59
60    /// Unix timestamp when this snapshot was created
61    pub(crate) created_at: u64,
62
63    // MLS data (filtered by group_id)
64    /// MLS group data: (group_id, data_type) -> data
65    pub(crate) mls_group_data: HashMap<(Vec<u8>, GroupDataType), Vec<u8>>,
66    /// MLS own leaf nodes for this group
67    pub(crate) mls_own_leaf_nodes: Vec<Vec<u8>>,
68    /// MLS proposals: proposal_ref -> proposal (group_id is implicit)
69    pub(crate) mls_proposals: HashMap<Vec<u8>, Vec<u8>>,
70    /// MLS epoch key pairs: (epoch_id, leaf_index) -> key_pairs (group_id is implicit)
71    pub(crate) mls_epoch_key_pairs: HashMap<(Vec<u8>, u32), Vec<u8>>,
72
73    // MDK data
74    /// The group record itself
75    pub(crate) group: Option<Group>,
76    /// Group relays
77    pub(crate) group_relays: BTreeSet<GroupRelay>,
78    /// Group exporter secrets (MIP-03 group-event): epoch -> secret
79    pub(crate) group_exporter_secrets: HashMap<u64, GroupExporterSecret>,
80    /// Group MIP-04 encrypted-media exporter secrets: epoch -> secret
81    pub(crate) group_mip04_exporter_secrets: HashMap<u64, GroupExporterSecret>,
82}
83
84/// A snapshot of all in-memory state that can be restored later.
85///
86/// This enables rollback functionality similar to SQLite savepoints,
87/// allowing you to:
88/// 1. Create a snapshot before an operation
89/// 2. Attempt the operation
90/// 3. Restore the snapshot if the operation fails or needs to be undone
91///
92/// # Concurrency
93///
94/// Snapshot creation and restoration are **atomic**. `create_snapshot()` acquires
95/// a global read lock and `restore_snapshot()` acquires a global write lock,
96/// ensuring consistency in multi-threaded environments.
97///
98/// # Example
99///
100/// ```ignore
101/// let storage = MdkMemoryStorage::default();
102///
103/// // Make some changes
104/// storage.save_group(group)?;
105///
106/// // Create a snapshot (ensure no concurrent operations)
107/// let snapshot = storage.create_snapshot();
108///
109/// // Try an operation that might need rollback
110/// storage.save_message(message)?;
111///
112/// // If we need to undo (ensure no concurrent operations):
113/// storage.restore_snapshot(snapshot);
114/// ```
115#[derive(Clone)]
116pub struct MemoryStorageSnapshot {
117    // MLS data
118    pub(crate) mls_group_data: HashMap<(Vec<u8>, GroupDataType), Vec<u8>>,
119    pub(crate) mls_own_leaf_nodes: HashMap<Vec<u8>, Vec<Vec<u8>>>,
120    pub(crate) mls_proposals: HashMap<(Vec<u8>, Vec<u8>), Vec<u8>>,
121    pub(crate) mls_key_packages: HashMap<Vec<u8>, Vec<u8>>,
122    pub(crate) mls_psks: HashMap<Vec<u8>, Vec<u8>>,
123    pub(crate) mls_signature_keys: HashMap<Vec<u8>, Vec<u8>>,
124    pub(crate) mls_encryption_keys: HashMap<Vec<u8>, Vec<u8>>,
125    pub(crate) mls_epoch_key_pairs: HashMap<(Vec<u8>, Vec<u8>, u32), Vec<u8>>,
126
127    // MDK data - cloned from LRU caches
128    pub(crate) groups: HashMap<GroupId, Group>,
129    pub(crate) groups_by_nostr_id: HashMap<[u8; 32], Group>,
130    pub(crate) group_relays: HashMap<GroupId, BTreeSet<GroupRelay>>,
131    pub(crate) group_exporter_secrets: HashMap<(GroupId, u64), GroupExporterSecret>,
132    pub(crate) group_mip04_exporter_secrets: HashMap<(GroupId, u64), GroupExporterSecret>,
133    pub(crate) welcomes: HashMap<EventId, Welcome>,
134    pub(crate) processed_welcomes: HashMap<EventId, ProcessedWelcome>,
135    pub(crate) messages: HashMap<EventId, Message>,
136    pub(crate) messages_by_group: HashMap<GroupId, HashMap<EventId, Message>>,
137    pub(crate) processed_messages: HashMap<EventId, ProcessedMessage>,
138}
139
140#[cfg(test)]
141impl MemoryStorageSnapshot {
142    /// Create a new empty snapshot for testing.
143    pub(crate) fn new() -> Self {
144        Self {
145            mls_group_data: HashMap::new(),
146            mls_own_leaf_nodes: HashMap::new(),
147            mls_proposals: HashMap::new(),
148            mls_key_packages: HashMap::new(),
149            mls_psks: HashMap::new(),
150            mls_signature_keys: HashMap::new(),
151            mls_encryption_keys: HashMap::new(),
152            mls_epoch_key_pairs: HashMap::new(),
153            groups: HashMap::new(),
154            groups_by_nostr_id: HashMap::new(),
155            group_relays: HashMap::new(),
156            group_exporter_secrets: HashMap::new(),
157            group_mip04_exporter_secrets: HashMap::new(),
158            welcomes: HashMap::new(),
159            processed_welcomes: HashMap::new(),
160            messages: HashMap::new(),
161            messages_by_group: HashMap::new(),
162            processed_messages: HashMap::new(),
163        }
164    }
165}
166
167/// Helper trait to clone LRU cache contents into a HashMap.
168pub(crate) trait LruCacheExt<K, V> {
169    /// Clone all entries from the LRU cache into a HashMap.
170    fn clone_to_hashmap(&self) -> HashMap<K, V>
171    where
172        K: Clone + std::hash::Hash + Eq,
173        V: Clone;
174}
175
176impl<K, V> LruCacheExt<K, V> for LruCache<K, V> {
177    fn clone_to_hashmap(&self) -> HashMap<K, V>
178    where
179        K: Clone + std::hash::Hash + Eq,
180        V: Clone,
181    {
182        self.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
183    }
184}
185
186/// Helper trait to restore HashMap contents back into an LRU cache.
187pub(crate) trait HashMapToLruExt<K, V> {
188    /// Restore entries from a HashMap into an LRU cache.
189    fn restore_to_lru(&self, cache: &mut LruCache<K, V>)
190    where
191        K: Clone + std::hash::Hash + Eq,
192        V: Clone;
193}
194
195impl<K, V> HashMapToLruExt<K, V> for HashMap<K, V> {
196    fn restore_to_lru(&self, cache: &mut LruCache<K, V>)
197    where
198        K: Clone + std::hash::Hash + Eq,
199        V: Clone,
200    {
201        cache.clear();
202        for (k, v) in self.iter() {
203            cache.put(k.clone(), v.clone());
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::num::NonZeroUsize;
211
212    use super::*;
213
214    #[test]
215    fn test_lru_cache_clone_to_hashmap() {
216        let mut cache: LruCache<String, i32> = LruCache::new(NonZeroUsize::new(10).unwrap());
217        cache.put("a".to_string(), 1);
218        cache.put("b".to_string(), 2);
219        cache.put("c".to_string(), 3);
220
221        let map = cache.clone_to_hashmap();
222        assert_eq!(map.len(), 3);
223        assert_eq!(map.get("a"), Some(&1));
224        assert_eq!(map.get("b"), Some(&2));
225        assert_eq!(map.get("c"), Some(&3));
226    }
227
228    #[test]
229    fn test_hashmap_restore_to_lru() {
230        let mut map = HashMap::new();
231        map.insert("x".to_string(), 10);
232        map.insert("y".to_string(), 20);
233
234        let mut cache: LruCache<String, i32> = LruCache::new(NonZeroUsize::new(10).unwrap());
235        cache.put("old".to_string(), 999);
236
237        map.restore_to_lru(&mut cache);
238
239        assert_eq!(cache.len(), 2);
240        assert_eq!(cache.get(&"x".to_string()), Some(&10));
241        assert_eq!(cache.get(&"y".to_string()), Some(&20));
242        assert!(cache.get(&"old".to_string()).is_none());
243    }
244
245    #[test]
246    fn test_empty_snapshot() {
247        let snapshot = MemoryStorageSnapshot::new();
248        assert!(snapshot.mls_group_data.is_empty());
249        assert!(snapshot.groups.is_empty());
250        assert!(snapshot.messages.is_empty());
251    }
252}