Skip to main content

mdk_storage_traits/
lib.rs

1//! MDK storage - A set of storage provider traits and types for implementing MLS storage
2//! It is designed to be used in conjunction with the `openmls` crate.
3
4#![deny(unsafe_code)]
5#![warn(missing_docs)]
6#![warn(rustdoc::bare_urls)]
7
8use openmls_traits::storage::StorageProvider;
9
10macro_rules! string_enum {
11    (
12        $(#[$enum_meta:meta])*
13        $vis:vis enum $name:ident => $error_ty:ty, $invalid_message:literal {
14            $(
15                $(#[$variant_meta:meta])*
16                $variant:ident => $value:literal
17            ),+ $(,)?
18        }
19    ) => {
20        $(#[$enum_meta])*
21        #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
22        $vis enum $name {
23            $(
24                $(#[$variant_meta])*
25                $variant,
26            )+
27        }
28
29        impl $name {
30            /// Get as `&str`
31            pub fn as_str(&self) -> &str {
32                match self {
33                    $(Self::$variant => $value,)+
34                }
35            }
36        }
37
38        impl std::fmt::Display for $name {
39            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40                write!(f, "{}", self.as_str())
41            }
42        }
43
44        impl std::str::FromStr for $name {
45            type Err = $error_ty;
46
47            fn from_str(s: &str) -> Result<Self, Self::Err> {
48                match s {
49                    $($value => Ok(Self::$variant),)+
50                    _ => Err(<$error_ty>::InvalidParameters(format!($invalid_message, s))),
51                }
52            }
53        }
54
55        impl serde::Serialize for $name {
56            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
57            where
58                S: serde::Serializer,
59            {
60                serializer.serialize_str(self.as_str())
61            }
62        }
63
64        impl<'de> serde::Deserialize<'de> for $name {
65            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
66            where
67                D: serde::Deserializer<'de>,
68            {
69                let s: String = String::deserialize(deserializer)?;
70                Self::from_str(&s).map_err(serde::de::Error::custom)
71            }
72        }
73
74        pastey::paste! {
75            #[cfg(test)]
76            mod [<$name:snake _string_enum_tests>] {
77                use super::*;
78                use std::str::FromStr;
79
80                #[test]
81                fn from_str_valid() {
82                    $(
83                        assert_eq!(
84                            $name::from_str($value).unwrap(),
85                            $name::$variant,
86                        );
87                    )+
88                }
89
90                #[test]
91                fn from_str_invalid() {
92                    assert!($name::from_str("__invalid_test_value__").is_err());
93                }
94
95                #[test]
96                fn to_string_matches_value() {
97                    $(
98                        assert_eq!($name::$variant.to_string(), $value);
99                    )+
100                }
101
102                #[test]
103                fn serde_roundtrip() {
104                    $(
105                        let serialized = serde_json::to_string(&$name::$variant).unwrap();
106                        assert_eq!(serialized, format!("\"{}\"", $value));
107                        let deserialized: $name = serde_json::from_str(&serialized).unwrap();
108                        assert_eq!(deserialized, $name::$variant);
109                    )+
110                }
111
112                #[test]
113                fn serde_invalid() {
114                    let result = serde_json::from_str::<$name>(r#""__invalid_test_value__""#);
115                    assert!(result.is_err());
116                }
117            }
118        }
119    };
120}
121
122/// Generate a `Pagination` struct with bounded limit/offset and a validation function.
123///
124/// Produces:
125/// - `pub struct Pagination` with `limit`, `offset`, and any extra fields
126/// - `new(limit, offset)` constructor
127/// - `limit()` / `offset()` accessors with defaults
128/// - `Default` impl
129/// - A public `$validate` function that checks `1..=$max`
130///
131/// Extra fields are wrapped in `Option` and default to `None`.
132macro_rules! bounded_pagination {
133    (
134        $(#[$struct_meta:meta])*
135        default_limit: $default:expr,
136        max_limit: $max:expr,
137        error_type: $err:ty,
138        validate_fn: $validate:ident
139        $(, extra {
140            $( $(#[$field_meta:meta])* $field:ident : $field_ty:ty ),+ $(,)?
141        })?
142    ) => {
143        $(#[$struct_meta])*
144        #[derive(Debug, Clone, Copy)]
145        pub struct Pagination {
146            /// Maximum number of items to return
147            pub limit: Option<usize>,
148            /// Number of items to skip
149            pub offset: Option<usize>,
150            $( $(
151                $(#[$field_meta])*
152                pub $field: Option<$field_ty>,
153            )+ )?
154        }
155
156        impl Pagination {
157            /// Create a new Pagination with specified limit and offset
158            pub fn new(limit: Option<usize>, offset: Option<usize>) -> Self {
159                Self {
160                    limit,
161                    offset,
162                    $( $( $field: None, )+ )?
163                }
164            }
165
166            /// Get the limit value, using default if not specified
167            pub fn limit(&self) -> usize {
168                self.limit.unwrap_or($default)
169            }
170
171            /// Get the offset value, using 0 if not specified
172            pub fn offset(&self) -> usize {
173                self.offset.unwrap_or(0)
174            }
175        }
176
177        impl Default for Pagination {
178            fn default() -> Self {
179                Self {
180                    limit: Some($default),
181                    offset: Some(0),
182                    $( $( $field: None, )+ )?
183                }
184            }
185        }
186
187        /// Validate that a limit is within the allowed range.
188        ///
189        /// Returns `Ok(())` if `limit` is between 1 and the maximum (inclusive),
190        /// or an error otherwise.
191        #[inline]
192        pub fn $validate(limit: usize) -> Result<(), $err> {
193            if (1..=$max).contains(&limit) {
194                Ok(())
195            } else {
196                Err(<$err>::InvalidParameters(format!(
197                    "Limit must be between 1 and {}, got {}",
198                    $max, limit
199                )))
200            }
201        }
202    };
203}
204
205/// Generate a storage error enum with common `InvalidParameters(String)` and
206/// `DatabaseError(String)` variants, plus any module-specific extras.
207macro_rules! storage_error {
208    (
209        $(#[$enum_meta:meta])*
210        $vis:vis enum $name:ident {
211            $(
212                $(#[$extra_meta:meta])*
213                $extra_variant:ident $( ($extra_inner:ty) )?
214            ),* $(,)?
215        }
216    ) => {
217        $(#[$enum_meta])*
218        #[derive(Debug, thiserror::Error)]
219        $vis enum $name {
220            /// Invalid parameters
221            #[error("Invalid parameters: {0}")]
222            InvalidParameters(String),
223            /// Database error
224            #[error("Database error: {0}")]
225            DatabaseError(String),
226            $(
227                $(#[$extra_meta])*
228                $extra_variant $( ($extra_inner) )?,
229            )*
230        }
231    };
232}
233
234pub mod error;
235pub mod group_id;
236pub mod groups;
237pub mod messages;
238pub mod mls_codec;
239/// Secret wrapper for zeroization
240pub mod secret;
241#[cfg(feature = "test-utils")]
242pub mod test_utils;
243
244pub mod welcomes;
245
246// Re-export GroupId for convenience
247pub use error::MdkStorageError;
248pub use group_id::GroupId;
249pub use secret::{Secret, Zeroize};
250
251use self::groups::GroupStorage;
252use self::messages::MessageStorage;
253use self::welcomes::WelcomeStorage;
254
255const CURRENT_VERSION: u16 = 1;
256
257/// Backend
258#[derive(Debug, Clone, PartialEq, Eq)]
259pub enum Backend {
260    /// Memory
261    Memory,
262    /// SQLite
263    SQLite,
264}
265
266impl Backend {
267    /// Check if it's a persistent backend
268    ///
269    /// All values different from [`Backend::Memory`] are considered persistent
270    pub fn is_persistent(&self) -> bool {
271        !matches!(self, Self::Memory)
272    }
273}
274
275/// Storage provider for MDK.
276///
277/// This trait combines all MDK storage requirements with the OpenMLS
278/// `StorageProvider` trait, enabling unified storage implementations
279/// that can atomically manage both MLS state and MDK-specific data.
280///
281/// Implementors must provide:
282/// - Group storage for MLS group metadata and relays
283/// - Message storage for encrypted messages
284/// - Welcome storage for pending welcome messages
285/// - Full OpenMLS `StorageProvider<1>` implementation for MLS cryptographic state
286pub trait MdkStorageProvider:
287    GroupStorage + MessageStorage + WelcomeStorage + StorageProvider<CURRENT_VERSION>
288{
289    /// Returns the backend type.
290    ///
291    /// # Returns
292    ///
293    /// The storage backend type (e.g., [`Backend::Memory`] or [`Backend::SQLite`]).
294    fn backend(&self) -> Backend;
295
296    /// Create a snapshot of a group's state before applying a commit.
297    ///
298    /// This captures all MLS and MDK state for the specified group,
299    /// enabling rollback if a better commit arrives later (MIP-03).
300    ///
301    /// The snapshot is stored persistently (in SQLite) or in memory,
302    /// keyed by both the group ID and snapshot name.
303    fn create_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), MdkStorageError>;
304
305    /// Rollback a group's state to a previously created snapshot.
306    ///
307    /// This restores all MLS and MDK state for the group to what it was
308    /// when the snapshot was created. The snapshot is consumed (deleted) after use.
309    fn rollback_group_to_snapshot(
310        &self,
311        group_id: &GroupId,
312        name: &str,
313    ) -> Result<(), MdkStorageError>;
314
315    /// Release a snapshot that is no longer needed.
316    ///
317    /// Call this to free resources when a snapshot won't be used for rollback.
318    fn release_group_snapshot(&self, group_id: &GroupId, name: &str)
319    -> Result<(), MdkStorageError>;
320
321    /// List all snapshots for a specific group with their creation timestamps.
322    ///
323    /// Returns a list of (snapshot_name, created_at_unix_timestamp) tuples
324    /// ordered by creation time (oldest first). This is used for:
325    /// - Hydrating the EpochSnapshotManager after restart
326    /// - Auditing existing snapshots
327    ///
328    /// # Arguments
329    ///
330    /// * `group_id` - The group to list snapshots for
331    ///
332    /// # Returns
333    ///
334    /// A vector of (snapshot_name, created_at) tuples, or an error.
335    fn list_group_snapshots(
336        &self,
337        group_id: &GroupId,
338    ) -> Result<Vec<(String, u64)>, MdkStorageError>;
339
340    /// Prune all snapshots created before the given Unix timestamp.
341    ///
342    /// This is used for TTL-based cleanup of old snapshots to prevent
343    /// indefinite storage growth and ensure cryptographic key material
344    /// doesn't persist longer than necessary.
345    ///
346    /// # Arguments
347    ///
348    /// * `min_timestamp` - Unix timestamp cutoff; snapshots with `created_at < min_timestamp` are deleted
349    ///
350    /// # Returns
351    ///
352    /// The number of snapshots deleted, or an error.
353    fn prune_expired_snapshots(&self, min_timestamp: u64) -> Result<usize, MdkStorageError>;
354
355    /// Delete all local state for a group.
356    ///
357    /// Removes the group, its messages, processed message records, MLS tree state,
358    /// epoch secrets, key material, proposals, and snapshots from local storage.
359    ///
360    /// This is irreversible. After deletion, the group cannot receive or decrypt
361    /// new messages. Call `leave_group()` before this method to notify other members.
362    ///
363    /// Idempotent: deleting a nonexistent group returns `Ok(())`.
364    ///
365    /// This is a local-only operation with no protocol-level side effects.
366    fn delete_group(&self, group_id: &GroupId) -> Result<(), MdkStorageError>;
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_backend_is_persistent() {
375        assert!(!Backend::Memory.is_persistent());
376        assert!(Backend::SQLite.is_persistent());
377    }
378}