#![forbid(unsafe_code)]
#![warn(missing_docs)]
#![warn(rustdoc::bare_urls)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![doc = include_str!("../README.md")]
use std::sync::Arc;
use mdk_storage_traits::MdkStorageProvider;
use openmls::prelude::*;
use openmls_rust_crypto::RustCrypto;
pub mod callback;
mod constant;
#[cfg(feature = "mip04")]
#[cfg_attr(docsrs, doc(cfg(feature = "mip04")))]
pub mod encrypted_media;
pub mod epoch_snapshots;
pub mod error;
pub mod extension;
pub mod groups;
pub mod key_packages;
pub mod media_processing;
pub mod messages;
#[cfg(feature = "mip05")]
#[cfg_attr(docsrs, doc(cfg(feature = "mip05")))]
pub mod mip05;
pub mod prelude;
mod state_validation;
#[cfg(any(test, feature = "test-utils"))]
#[cfg_attr(docsrs, doc(cfg(feature = "test-utils")))]
pub mod test_util;
mod util;
pub mod welcomes;
use self::callback::{MdkCallback, RollbackInfo};
use self::constant::{
DEFAULT_CIPHERSUITE, SUPPORTED_EXTENSIONS, SUPPORTED_PROPOSALS, TAG_PROPOSALS,
};
use self::epoch_snapshots::EpochSnapshotManager;
pub use self::error::Error;
use self::util::NostrTagFormat;
pub use mdk_storage_traits::GroupId;
#[derive(Debug, Clone)]
pub struct MdkConfig {
pub max_event_age_secs: u64,
pub max_future_skew_secs: u64,
pub out_of_order_tolerance: u32,
pub maximum_forward_distance: u32,
pub max_past_epochs: usize,
pub epoch_snapshot_retention: usize,
pub snapshot_ttl_seconds: u64,
}
impl Default for MdkConfig {
fn default() -> Self {
Self {
max_event_age_secs: 3888000, max_future_skew_secs: 300, out_of_order_tolerance: 100, maximum_forward_distance: 1000, max_past_epochs: 5, epoch_snapshot_retention: 5,
snapshot_ttl_seconds: 604800, }
}
}
impl MdkConfig {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Debug)]
pub struct MdkBuilder<Storage> {
storage: Storage,
config: MdkConfig,
callback: Option<Arc<dyn MdkCallback>>,
}
impl<Storage> MdkBuilder<Storage>
where
Storage: MdkStorageProvider,
{
pub fn new(storage: Storage) -> Self {
Self {
storage,
config: MdkConfig::default(),
callback: None,
}
}
mdk_macros::setters! {
with_config<direct> -> config: MdkConfig;
with_callback -> callback: Arc<dyn MdkCallback>;
}
pub fn build(self) -> MDK<Storage> {
if self.storage.backend().is_persistent() {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("System time before Unix epoch")
.as_secs();
let min_timestamp = current_time.saturating_sub(self.config.snapshot_ttl_seconds);
if let Ok(pruned_count) = self.storage.prune_expired_snapshots(min_timestamp)
&& pruned_count > 0
{
tracing::info!(
pruned = pruned_count,
ttl_seconds = self.config.snapshot_ttl_seconds,
"Pruned expired snapshots on startup"
);
}
}
let epoch_snapshots = Arc::new(EpochSnapshotManager::new(
self.config.epoch_snapshot_retention,
));
MDK {
ciphersuite: DEFAULT_CIPHERSUITE,
extensions: SUPPORTED_EXTENSIONS.to_vec(),
provider: MdkProvider {
crypto: RustCrypto::default(),
storage: self.storage,
},
config: self.config,
epoch_snapshots,
callback: self.callback,
}
}
}
#[derive(Debug)]
pub struct MDK<Storage>
where
Storage: MdkStorageProvider,
{
pub ciphersuite: Ciphersuite,
pub extensions: Vec<ExtensionType>,
pub provider: MdkProvider<Storage>,
pub config: MdkConfig,
epoch_snapshots: Arc<EpochSnapshotManager>,
callback: Option<Arc<dyn MdkCallback>>,
}
#[derive(Debug)]
pub struct MdkProvider<Storage>
where
Storage: MdkStorageProvider,
{
crypto: RustCrypto,
storage: Storage,
}
impl<Storage> OpenMlsProvider for MdkProvider<Storage>
where
Storage: MdkStorageProvider,
{
type CryptoProvider = RustCrypto;
type RandProvider = RustCrypto;
type StorageProvider = Storage;
fn storage(&self) -> &Self::StorageProvider {
&self.storage
}
fn crypto(&self) -> &Self::CryptoProvider {
&self.crypto
}
fn rand(&self) -> &Self::RandProvider {
&self.crypto
}
}
impl<Storage> MDK<Storage>
where
Storage: MdkStorageProvider,
{
pub fn builder(storage: Storage) -> MdkBuilder<Storage> {
MdkBuilder::new(storage)
}
pub fn new(storage: Storage) -> Self {
Self::builder(storage).build()
}
#[inline]
pub(crate) fn capabilities(&self) -> Capabilities {
Capabilities::new(
None,
Some(&[self.ciphersuite]),
Some(&self.extensions),
Some(&SUPPORTED_PROPOSALS),
None,
)
.with_grease(&self.provider.crypto)
}
pub(crate) fn ciphersuite_value(&self) -> String {
self.ciphersuite.to_nostr_tag()
}
pub(crate) fn extensions_value(&self) -> Vec<String> {
self.extensions.iter().map(|e| e.to_nostr_tag()).collect()
}
pub(crate) fn proposals_value(&self) -> Vec<String> {
TAG_PROPOSALS.iter().map(|p| p.to_nostr_tag()).collect()
}
pub(crate) fn storage(&self) -> &Storage {
&self.provider.storage
}
}
#[cfg(test)]
pub mod tests {
use mdk_memory_storage::MdkMemoryStorage;
use super::*;
pub fn create_test_mdk() -> MDK<MdkMemoryStorage> {
MDK::new(MdkMemoryStorage::default())
}
pub fn create_test_mdk_with_config(config: MdkConfig) -> MDK<MdkMemoryStorage> {
MDK::builder(MdkMemoryStorage::default())
.with_config(config)
.build()
}
mod grease_tests {
use openmls_traits::types::VerifiableCiphersuite;
use super::*;
#[test]
fn test_capabilities_include_grease_ciphersuites() {
let mdk = create_test_mdk();
let caps = mdk.capabilities();
let has_grease_ciphersuite = caps.ciphersuites().iter().any(|cs| cs.is_grease());
assert!(
has_grease_ciphersuite,
"Capabilities should include at least one GREASE ciphersuite"
);
}
#[test]
fn test_capabilities_include_grease_extensions() {
let mdk = create_test_mdk();
let caps = mdk.capabilities();
let has_grease_extension = caps.extensions().iter().any(|ext| ext.is_grease());
assert!(
has_grease_extension,
"Capabilities should include at least one GREASE extension"
);
}
#[test]
fn test_capabilities_include_grease_proposals() {
let mdk = create_test_mdk();
let caps = mdk.capabilities();
let has_grease_proposal = caps.proposals().iter().any(|prop| prop.is_grease());
assert!(
has_grease_proposal,
"Capabilities should include at least one GREASE proposal type"
);
}
#[test]
fn test_capabilities_include_grease_credentials() {
let mdk = create_test_mdk();
let caps = mdk.capabilities();
let has_grease_credential = caps.credentials().iter().any(|cred| cred.is_grease());
assert!(
has_grease_credential,
"Capabilities should include at least one GREASE credential type"
);
}
#[test]
fn test_capabilities_still_include_real_values() {
let mdk = create_test_mdk();
let caps = mdk.capabilities();
let expected_cs: VerifiableCiphersuite = DEFAULT_CIPHERSUITE.into();
let has_real_ciphersuite = caps.ciphersuites().contains(&expected_cs);
assert!(
has_real_ciphersuite,
"Capabilities should still include the real ciphersuite"
);
let has_last_resort = caps.extensions().contains(&ExtensionType::LastResort);
assert!(
has_last_resort,
"Capabilities should still include LastResort extension"
);
}
#[test]
fn test_different_mdk_instances_get_different_grease_values() {
let mdk1 = create_test_mdk();
let mdk2 = create_test_mdk();
let caps1 = mdk1.capabilities();
let caps2 = mdk2.capabilities();
let grease_cs1: Vec<_> = caps1
.ciphersuites()
.iter()
.filter(|cs| cs.is_grease())
.collect();
let grease_cs2: Vec<_> = caps2
.ciphersuites()
.iter()
.filter(|cs| cs.is_grease())
.collect();
assert!(
!grease_cs1.is_empty(),
"MDK1 should have GREASE ciphersuites"
);
assert!(
!grease_cs2.is_empty(),
"MDK2 should have GREASE ciphersuites"
);
}
}
mod sender_ratchet_tests {
use nostr::Keys;
use super::*;
use crate::messages::MessageProcessingResult;
use crate::test_util::{
create_key_package_event, create_nostr_group_config_data, create_test_rumor,
};
#[test]
fn test_custom_config_is_applied() {
let config = MdkConfig {
out_of_order_tolerance: 50,
maximum_forward_distance: 500,
max_event_age_secs: 86400,
max_future_skew_secs: 120,
max_past_epochs: 5,
epoch_snapshot_retention: 5,
snapshot_ttl_seconds: 604800,
};
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk_with_config(config.clone());
let bob_mdk = create_test_mdk_with_config(config.clone());
assert_eq!(alice_mdk.config.out_of_order_tolerance, 50);
assert_eq!(alice_mdk.config.maximum_forward_distance, 500);
assert_eq!(bob_mdk.config.out_of_order_tolerance, 50);
assert_eq!(bob_mdk.config.maximum_forward_distance, 500);
let admins = vec![alice_keys.public_key(), bob_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
assert_eq!(group_id, bob_welcome.mls_group_id);
}
#[test]
fn test_high_tolerance_allows_reordered_messages() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key(), bob_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let num_messages = 50;
let mut message_events = Vec::new();
for i in 0..num_messages {
let rumor = create_test_rumor(&alice_keys, &format!("Message {}", i));
let msg_event = alice_mdk
.create_message(&group_id, rumor, None)
.expect("Alice should send message");
message_events.push(msg_event);
}
let mut receive_order: Vec<usize> = Vec::new();
for i in 0..num_messages / 2 {
receive_order.push(num_messages - 1 - i); receive_order.push(i); }
for &idx in &receive_order {
let msg_event = &message_events[idx];
let result = bob_mdk
.process_message(msg_event)
.unwrap_or_else(|e| panic!("Bob should decrypt message {idx}: {e}"));
match result {
MessageProcessingResult::ApplicationMessage(msg) => {
assert_eq!(msg.content, format!("Message {}", idx));
}
other => panic!("Expected ApplicationMessage for message {idx}, got {other:?}"),
}
}
}
#[test]
fn test_low_tolerance_rejects_distant_messages() {
let config = MdkConfig {
out_of_order_tolerance: 5,
..Default::default()
};
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk_with_config(config.clone());
let bob_mdk = create_test_mdk_with_config(config);
let admins = vec![alice_keys.public_key(), bob_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let num_messages = 20;
let mut message_events = Vec::new();
for i in 0..num_messages {
let rumor = create_test_rumor(&alice_keys, &format!("Message {}", i));
let msg_event = alice_mdk
.create_message(&group_id, rumor, None)
.expect("Alice should send message");
message_events.push(msg_event);
}
let last_msg = &message_events[num_messages - 1];
let result = bob_mdk
.process_message(last_msg)
.expect("Bob should decrypt the latest message");
match result {
MessageProcessingResult::ApplicationMessage(msg) => {
assert_eq!(msg.content, format!("Message {}", num_messages - 1));
}
_ => panic!("Expected ApplicationMessage"),
}
let first_msg = &message_events[0];
let result = bob_mdk.process_message(first_msg);
match result {
Ok(MessageProcessingResult::Unprocessable { .. }) => {
}
Ok(MessageProcessingResult::ApplicationMessage(_)) => {
panic!(
"Message 0 should NOT decrypt after receiving message 19 with tolerance 5"
);
}
Err(_) => {
}
other => {
panic!("Unexpected result: {:?}", other);
}
}
for (i, msg_event) in message_events
.iter()
.enumerate()
.take(num_messages - 1)
.skip(num_messages - 5)
{
let result = bob_mdk.process_message(msg_event).unwrap_or_else(|e| {
panic!("Message {i} should decrypt (within tolerance): {e}")
});
match result {
MessageProcessingResult::ApplicationMessage(msg) => {
assert_eq!(msg.content, format!("Message {}", i));
}
other => panic!("Expected ApplicationMessage for message {i}, got {other:?}"),
}
}
}
}
}