#![deny(unsafe_code)]
#![warn(missing_docs)]
#![warn(rustdoc::bare_urls)]
use openmls_traits::storage::StorageProvider;
macro_rules! string_enum {
(
$(#[$enum_meta:meta])*
$vis:vis enum $name:ident => $error_ty:ty, $invalid_message:literal {
$(
$(#[$variant_meta:meta])*
$variant:ident => $value:literal
),+ $(,)?
}
) => {
$(#[$enum_meta])*
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
$vis enum $name {
$(
$(#[$variant_meta])*
$variant,
)+
}
impl $name {
pub fn as_str(&self) -> &str {
match self {
$(Self::$variant => $value,)+
}
}
}
impl std::fmt::Display for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl std::str::FromStr for $name {
type Err = $error_ty;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
$($value => Ok(Self::$variant),)+
_ => Err(<$error_ty>::InvalidParameters(format!($invalid_message, s))),
}
}
}
impl serde::Serialize for $name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<'de> serde::Deserialize<'de> for $name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = String::deserialize(deserializer)?;
Self::from_str(&s).map_err(serde::de::Error::custom)
}
}
pastey::paste! {
#[cfg(test)]
mod [<$name:snake _string_enum_tests>] {
use super::*;
use std::str::FromStr;
#[test]
fn from_str_valid() {
$(
assert_eq!(
$name::from_str($value).unwrap(),
$name::$variant,
);
)+
}
#[test]
fn from_str_invalid() {
assert!($name::from_str("__invalid_test_value__").is_err());
}
#[test]
fn to_string_matches_value() {
$(
assert_eq!($name::$variant.to_string(), $value);
)+
}
#[test]
fn serde_roundtrip() {
$(
let serialized = serde_json::to_string(&$name::$variant).unwrap();
assert_eq!(serialized, format!("\"{}\"", $value));
let deserialized: $name = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, $name::$variant);
)+
}
#[test]
fn serde_invalid() {
let result = serde_json::from_str::<$name>(r#""__invalid_test_value__""#);
assert!(result.is_err());
}
}
}
};
}
macro_rules! bounded_pagination {
(
$(#[$struct_meta:meta])*
default_limit: $default:expr,
max_limit: $max:expr,
error_type: $err:ty,
validate_fn: $validate:ident
$(, extra {
$( $(#[$field_meta:meta])* $field:ident : $field_ty:ty ),+ $(,)?
})?
) => {
$(#[$struct_meta])*
#[derive(Debug, Clone, Copy)]
pub struct Pagination {
pub limit: Option<usize>,
pub offset: Option<usize>,
$( $(
$(#[$field_meta])*
pub $field: Option<$field_ty>,
)+ )?
}
impl Pagination {
pub fn new(limit: Option<usize>, offset: Option<usize>) -> Self {
Self {
limit,
offset,
$( $( $field: None, )+ )?
}
}
pub fn limit(&self) -> usize {
self.limit.unwrap_or($default)
}
pub fn offset(&self) -> usize {
self.offset.unwrap_or(0)
}
}
impl Default for Pagination {
fn default() -> Self {
Self {
limit: Some($default),
offset: Some(0),
$( $( $field: None, )+ )?
}
}
}
#[inline]
pub fn $validate(limit: usize) -> Result<(), $err> {
if (1..=$max).contains(&limit) {
Ok(())
} else {
Err(<$err>::InvalidParameters(format!(
"Limit must be between 1 and {}, got {}",
$max, limit
)))
}
}
};
}
macro_rules! storage_error {
(
$(#[$enum_meta:meta])*
$vis:vis enum $name:ident {
$(
$(#[$extra_meta:meta])*
$extra_variant:ident $( ($extra_inner:ty) )?
),* $(,)?
}
) => {
$(#[$enum_meta])*
#[derive(Debug, thiserror::Error)]
$vis enum $name {
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Database error: {0}")]
DatabaseError(String),
$(
$(#[$extra_meta])*
$extra_variant $( ($extra_inner) )?,
)*
}
};
}
pub mod error;
pub mod group_id;
pub mod groups;
pub mod messages;
pub mod mls_codec;
pub mod secret;
#[cfg(feature = "test-utils")]
pub mod test_utils;
pub mod welcomes;
pub use error::MdkStorageError;
pub use group_id::GroupId;
pub use secret::{Secret, Zeroize};
use self::groups::GroupStorage;
use self::messages::MessageStorage;
use self::welcomes::WelcomeStorage;
const CURRENT_VERSION: u16 = 1;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Backend {
Memory,
SQLite,
}
impl Backend {
pub fn is_persistent(&self) -> bool {
!matches!(self, Self::Memory)
}
}
pub trait MdkStorageProvider:
GroupStorage + MessageStorage + WelcomeStorage + StorageProvider<CURRENT_VERSION>
{
fn backend(&self) -> Backend;
fn create_group_snapshot(&self, group_id: &GroupId, name: &str) -> Result<(), MdkStorageError>;
fn rollback_group_to_snapshot(
&self,
group_id: &GroupId,
name: &str,
) -> Result<(), MdkStorageError>;
fn release_group_snapshot(&self, group_id: &GroupId, name: &str)
-> Result<(), MdkStorageError>;
fn list_group_snapshots(
&self,
group_id: &GroupId,
) -> Result<Vec<(String, u64)>, MdkStorageError>;
fn prune_expired_snapshots(&self, min_timestamp: u64) -> Result<usize, MdkStorageError>;
fn delete_group(&self, group_id: &GroupId) -> Result<(), MdkStorageError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_is_persistent() {
assert!(!Backend::Memory.is_persistent());
assert!(Backend::SQLite.is_persistent());
}
}