mdk_storage_traits/lib.rs
1//! MDK storage - A set of storage provider traits and types for implementing MLS storage
2//! It is designed to be used in conjunction with the `openmls` crate.
3
4#![deny(unsafe_code)]
5#![warn(missing_docs)]
6#![warn(rustdoc::bare_urls)]
7
8use openmls_traits::storage::StorageProvider;
9
10macro_rules! string_enum {
11 (
12 $(#[$enum_meta:meta])*
13 $vis:vis enum $name:ident => $error_ty:ty, $invalid_message:literal {
14 $(
15 $(#[$variant_meta:meta])*
16 $variant:ident => $value:literal
17 ),+ $(,)?
18 }
19 ) => {
20 $(#[$enum_meta])*
21 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
22 $vis enum $name {
23 $(
24 $(#[$variant_meta])*
25 $variant,
26 )+
27 }
28
29 impl $name {
30 /// Get as `&str`
31 pub fn as_str(&self) -> &str {
32 match self {
33 $(Self::$variant => $value,)+
34 }
35 }
36 }
37
38 impl std::fmt::Display for $name {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.as_str())
41 }
42 }
43
44 impl std::str::FromStr for $name {
45 type Err = $error_ty;
46
47 fn from_str(s: &str) -> Result<Self, Self::Err> {
48 match s {
49 $($value => Ok(Self::$variant),)+
50 _ => Err(<$error_ty>::InvalidParameters(format!($invalid_message, s))),
51 }
52 }
53 }
54
55 impl serde::Serialize for $name {
56 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
57 where
58 S: serde::Serializer,
59 {
60 serializer.serialize_str(self.as_str())
61 }
62 }
63
64 impl<'de> serde::Deserialize<'de> for $name {
65 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
66 where
67 D: serde::Deserializer<'de>,
68 {
69 let s: String = String::deserialize(deserializer)?;
70 Self::from_str(&s).map_err(serde::de::Error::custom)
71 }
72 }
73
74 pastey::paste! {
75 #[cfg(test)]
76 mod [<$name:snake _string_enum_tests>] {
77 use super::*;
78 use std::str::FromStr;
79
80 #[test]
81 fn from_str_valid() {
82 $(
83 assert_eq!(
84 $name::from_str($value).unwrap(),
85 $name::$variant,
86 );
87 )+
88 }
89
90 #[test]
91 fn from_str_invalid() {
92 assert!($name::from_str("__invalid_test_value__").is_err());
93 }
94
95 #[test]
96 fn to_string_matches_value() {
97 $(
98 assert_eq!($name::$variant.to_string(), $value);
99 )+
100 }
101
102 #[test]
103 fn serde_roundtrip() {
104 $(
105 let serialized = serde_json::to_string(&$name::$variant).unwrap();
106 assert_eq!(serialized, format!("\"{}\"", $value));
107 let deserialized: $name = serde_json::from_str(&serialized).unwrap();
108 assert_eq!(deserialized, $name::$variant);
109 )+
110 }
111
112 #[test]
113 fn serde_invalid() {
114 let result = serde_json::from_str::<$name>(r#""__invalid_test_value__""#);
115 assert!(result.is_err());
116 }
117 }
118 }
119 };
120}
121
122/// Generate a `Pagination` struct with bounded limit/offset and a validation function.
123///
124/// Produces:
125/// - `pub struct Pagination` with `limit`, `offset`, and any extra fields
126/// - `new(limit, offset)` constructor
127/// - `limit()` / `offset()` accessors with defaults
128/// - `Default` impl
129/// - A public `$validate` function that checks `1..=$max`
130///
131/// Extra fields are wrapped in `Option` and default to `None`.
132macro_rules! bounded_pagination {
133 (
134 $(#[$struct_meta:meta])*
135 default_limit: $default:expr,
136 max_limit: $max:expr,
137 error_type: $err:ty,
138 validate_fn: $validate:ident
139 $(, extra {
140 $( $(#[$field_meta:meta])* $field:ident : $field_ty:ty ),+ $(,)?
141 })?
142 ) => {
143 $(#[$struct_meta])*
144 #[derive(Debug, Clone, Copy)]
145 pub struct Pagination {
146 /// Maximum number of items to return
147 pub limit: Option<usize>,
148 /// Number of items to skip
149 pub offset: Option<usize>,
150 $( $(
151 $(#[$field_meta])*
152 pub $field: Option<$field_ty>,
153 )+ )?
154 }
155
156 impl Pagination {
157 /// Create a new Pagination with specified limit and offset
158 pub fn new(limit: Option<usize>, offset: Option<usize>) -> Self {
159 Self {
160 limit,
161 offset,
162 $( $( $field: None, )+ )?
163 }
164 }
165
166 /// Get the limit value, using default if not specified
167 pub fn limit(&self) -> usize {
168 self.limit.unwrap_or($default)
169 }
170
171 /// Get the offset value, using 0 if not specified
172 pub fn offset(&self) -> usize {
173 self.offset.unwrap_or(0)
174 }
175 }
176
177 impl Default for Pagination {
178 fn default() -> Self {
179 Self {
180 limit: Some($default),
181 offset: Some(0),
182 $( $( $field: None, )+ )?
183 }
184 }
185 }
186
187 /// Validate that a limit is within the allowed range.
188 ///
189 /// Returns `Ok(())` if `limit` is between 1 and the maximum (inclusive),
190 /// or an error otherwise.
191 #[inline]
192 pub fn $validate(limit: usize) -> Result<(), $err> {
193 if (1..=$max).contains(&limit) {
194 Ok(())
195 } else {
196 Err(<$err>::InvalidParameters(format!(
197 "Limit must be between 1 and {}, got {}",
198 $max, limit
199 )))
200 }
201 }
202 };
203}
204
205/// Generate a storage error enum with common `InvalidParameters(String)` and
206/// `DatabaseError(String)` variants, plus any module-specific extras.
207macro_rules! storage_error {
208 (
209 $(#[$enum_meta:meta])*
210 $vis:vis enum $name:ident {
211 $(
212 $(#[$extra_meta:meta])*
213 $extra_variant:ident $( ($extra_inner:ty) )?
214 ),* $(,)?
215 }
216 ) => {
217 $(#[$enum_meta])*
218 #[derive(Debug, thiserror::Error)]
219 $vis enum $name {
220 /// Invalid parameters
221 #[error("Invalid parameters: {0}")]
222 InvalidParameters(String),
223 /// Database error
224 #[error("Database error: {0}")]
225 DatabaseError(String),
226 $(
227 $(#[$extra_meta])*
228 $extra_variant $( ($extra_inner) )?,
229 )*
230 }
231 };
232}
233
234pub mod error;
235pub mod group_id;
236pub mod groups;
237pub mod messages;
238pub mod mls_codec;
239/// Secret wrapper for zeroization
240pub mod secret;
241#[cfg(feature = "test-utils")]
242pub mod test_utils;
243
244pub mod welcomes;
245
246// Re-export GroupId for convenience
247pub use error::MdkStorageError;
248pub use group_id::GroupId;
249pub use secret::{Secret, Zeroize};
250
251use self::groups::GroupStorage;
252use self::messages::MessageStorage;
253use self::welcomes::WelcomeStorage;
254
255const CURRENT_VERSION: u16 = 1;
256
257/// Backend
258#[derive(Debug, Clone, PartialEq, Eq)]
259pub enum Backend {
260 /// Memory
261 Memory,
262 /// SQLite
263 SQLite,
264}
265
266impl Backend {
267 /// Check if it's a persistent backend
268 ///
269 /// All values different from [`Backend::Memory`] are considered persistent
270 pub fn is_persistent(&self) -> bool {
271 !matches!(self, Self::Memory)
272 }
273}
274
275/// Storage provider for MDK.
276///
277/// This trait combines all MDK storage requirements with the OpenMLS
278/// `StorageProvider` trait, enabling unified storage implementations
279/// that can atomically manage both MLS state and MDK-specific data.
280///
281/// Implementors must provide:
282/// - Group storage for MLS group metadata and relays
283/// - Message storage for encrypted messages
284/// - Welcome storage for pending welcome messages
285/// - Full OpenMLS `StorageProvider<1>` implementation for MLS cryptographic state
286pub trait MdkStorageProvider:
287 GroupStorage + MessageStorage + WelcomeStorage + StorageProvider<CURRENT_VERSION>
288{
289 /// Returns the backend type.
290 ///
291 /// # Returns
292 ///
293 /// The storage backend type (e.g., [`Backend::Memory`] or [`Backend::SQLite`]).
294 fn backend(&self) -> Backend;
295
296 /// Create a snapshot of a group's state before applying a commit.
297 ///
298 /// This captures all MLS and MDK state for the specified group,
299 /// enabling rollback if a better commit arrives later (MIP-03).
300 ///
301 /// The snapshot is stored persistently (in SQLite) or in memory,
302 /// keyed by both the group ID and snapshot name.
303 fn create_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), MdkStorageError>;
304
305 /// Rollback a group's state to a previously created snapshot.
306 ///
307 /// This restores all MLS and MDK state for the group to what it was
308 /// when the snapshot was created. The snapshot is consumed (deleted) after use.
309 fn rollback_group_to_snapshot(
310 &self,
311 group_id: &GroupId,
312 name: &str,
313 ) -> Result<(), MdkStorageError>;
314
315 /// Release a snapshot that is no longer needed.
316 ///
317 /// Call this to free resources when a snapshot won't be used for rollback.
318 fn release_group_snapshot(&self, group_id: &GroupId, name: &str)
319 -> Result<(), MdkStorageError>;
320
321 /// List all snapshots for a specific group with their creation timestamps.
322 ///
323 /// Returns a list of (snapshot_name, created_at_unix_timestamp) tuples
324 /// ordered by creation time (oldest first). This is used for:
325 /// - Hydrating the EpochSnapshotManager after restart
326 /// - Auditing existing snapshots
327 ///
328 /// # Arguments
329 ///
330 /// * `group_id` - The group to list snapshots for
331 ///
332 /// # Returns
333 ///
334 /// A vector of (snapshot_name, created_at) tuples, or an error.
335 fn list_group_snapshots(
336 &self,
337 group_id: &GroupId,
338 ) -> Result<Vec<(String, u64)>, MdkStorageError>;
339
340 /// Prune all snapshots created before the given Unix timestamp.
341 ///
342 /// This is used for TTL-based cleanup of old snapshots to prevent
343 /// indefinite storage growth and ensure cryptographic key material
344 /// doesn't persist longer than necessary.
345 ///
346 /// # Arguments
347 ///
348 /// * `min_timestamp` - Unix timestamp cutoff; snapshots with `created_at < min_timestamp` are deleted
349 ///
350 /// # Returns
351 ///
352 /// The number of snapshots deleted, or an error.
353 fn prune_expired_snapshots(&self, min_timestamp: u64) -> Result<usize, MdkStorageError>;
354
355 /// Delete all local state for a group.
356 ///
357 /// Removes the group, its messages, processed message records, MLS tree state,
358 /// epoch secrets, key material, proposals, and snapshots from local storage.
359 ///
360 /// This is irreversible. After deletion, the group cannot receive or decrypt
361 /// new messages. Call `leave_group()` before this method to notify other members.
362 ///
363 /// Idempotent: deleting a nonexistent group returns `Ok(())`.
364 ///
365 /// This is a local-only operation with no protocol-level side effects.
366 fn delete_group(&self, group_id: &GroupId) -> Result<(), MdkStorageError>;
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_backend_is_persistent() {
375 assert!(!Backend::Memory.is_persistent());
376 assert!(Backend::SQLite.is_persistent());
377 }
378}