use mdk_storage_traits::groups::types as group_types;
use mdk_storage_traits::messages::types as message_types;
use mdk_storage_traits::{GroupId, MdkStorageProvider};
use nostr::Event;
use openmls::prelude::{MlsGroup, Sender, StagedCommit};
use crate::MDK;
use crate::error::Error;
use super::Result;
impl<Storage> MDK<Storage>
where
Storage: MdkStorageProvider,
{
pub(super) fn process_commit(
&self,
mls_group: &mut MlsGroup,
event: &Event,
staged_commit: StagedCommit,
commit_sender: &Sender,
) -> Result<()> {
self.validate_commit_authorization(mls_group, &staged_commit, commit_sender)?;
self.validate_commit_identities(mls_group, &staged_commit, commit_sender)?;
let group_id: GroupId = mls_group.group_id().into();
let current_epoch = mls_group.epoch().as_u64();
let content_hash = super::content_hash(&event.content);
if let Err(_e) = self.epoch_snapshots.create_snapshot(
self.storage(),
&group_id,
current_epoch,
&event.id,
event.created_at.as_secs(),
&content_hash,
) {
tracing::warn!(
target: "mdk_core::messages::process_commit",
"Failed to create snapshot for epoch {}",
current_epoch
);
return Err(Error::SnapshotCreationFailed(
"snapshot creation failed".to_string(),
));
}
mls_group
.merge_staged_commit(&self.provider, staged_commit)
.map_err(|_e| Error::Message("Failed to merge staged commit".to_string()))?;
if mls_group.own_leaf().is_none() {
return self.handle_local_member_eviction(&group_id, event);
}
self.exporter_secret(&group_id)?;
#[cfg(feature = "mip04")]
{
let mip04_secret = self.mip04_exporter_secret(&group_id)?;
self.storage()
.save_group_mip04_exporter_secret(mip04_secret)
.map_err(|_| {
Error::Group("Failed to persist MIP-04 exporter secret".to_string())
})?;
}
let min_epoch_to_keep = mls_group
.epoch()
.as_u64()
.saturating_sub(self.config.max_past_epochs as u64);
self.storage()
.prune_group_exporter_secrets_before_epoch(&group_id, min_epoch_to_keep)
.map_err(|_| Error::Group("Failed to prune exporter secrets".to_string()))?;
self.sync_group_metadata_from_mls(&group_id)?;
let processed_message = super::create_processed_message_record(
event.id,
None,
Some(mls_group.epoch().as_u64()),
Some(group_id.clone()),
message_types::ProcessedMessageState::ProcessedCommit,
None,
);
self.save_processed_message_record(processed_message)?;
Ok(())
}
pub(super) fn handle_local_member_eviction(
&self,
group_id: &GroupId,
event: &Event,
) -> Result<()> {
tracing::info!(
target: "mdk_core::messages::process_commit",
"Local member was removed from group, setting group state to Inactive"
);
let group_epoch = match self.get_group(group_id)? {
Some(mut group) => {
let epoch = group.epoch;
group.state = group_types::GroupState::Inactive;
self.save_group_record(group)?;
Some(epoch)
}
None => {
tracing::warn!(
target: "mdk_core::messages::process_commit",
"Group not found in storage while handling eviction"
);
None
}
};
let processed_message = super::create_processed_message_record(
event.id,
None,
group_epoch,
Some(group_id.clone()),
message_types::ProcessedMessageState::Processed,
None,
);
self.save_processed_message_record(processed_message)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::fmt;
use mdk_storage_traits::GroupId;
#[cfg(feature = "mip04")]
use mdk_storage_traits::groups::GroupStorage;
use mdk_storage_traits::groups::types as group_types;
use nostr::{EventBuilder, EventId, Keys, Kind};
use openmls::prelude::{Extension, UnknownExtension};
use tls_codec::Serialize as TlsSerialize;
use crate::extension::NostrGroupDataExtension;
use crate::messages::MessageProcessingResult;
use crate::test_util::*;
use crate::tests::create_test_mdk;
#[test]
fn test_member_addition_commit() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let charlie_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let charlie_mdk = create_test_mdk();
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let charlie_key_package = create_key_package_event(&charlie_mdk, &charlie_keys);
let admin_pubkeys = vec![alice_keys.public_key()];
let config = create_nostr_group_config_data(admin_pubkeys);
let create_result = alice_mdk
.create_group(&alice_keys.public_key(), vec![bob_key_package], config)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let initial_epoch = alice_mdk
.get_group(&group_id)
.expect("Should get group")
.expect("Group should exist")
.epoch;
let alice_add_result = alice_mdk.add_members(&group_id, &[charlie_key_package]);
assert!(
alice_add_result.is_ok(),
"Alice should create pending commit"
);
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge commit");
let alice_epoch_after = alice_mdk
.get_group(&group_id)
.expect("Should get group")
.expect("Group should exist")
.epoch;
assert!(
alice_epoch_after > initial_epoch,
"Alice's epoch should advance after commit"
);
let alice_members = alice_mdk
.get_members(&group_id)
.expect("Alice should get members");
assert_eq!(
alice_members.len(),
3,
"Alice should see 3 members after adding Charlie"
);
}
#[cfg(feature = "mip04")]
#[test]
fn test_incoming_commit_persists_mip04_exporter_secret() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let charlie_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let charlie_mdk = create_test_mdk();
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let charlie_key_package = create_key_package_event(&charlie_mdk, &charlie_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(vec![alice_keys.public_key()]),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge create commit");
let bob_welcome = bob_mdk
.process_welcome(
&nostr::EventId::all_zeros(),
&create_result.welcome_rumors[0],
)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let add_result = alice_mdk
.add_members(&group_id, &[charlie_key_package])
.expect("Alice should create add-member commit");
let process_result = bob_mdk
.process_message(&add_result.evolution_event)
.expect("Bob should process incoming commit");
assert!(
matches!(
process_result,
MessageProcessingResult::Commit {
mls_group_id: ref gid
} if *gid == group_id
),
"Expected commit processing result for the same group"
);
let bob_epoch = bob_mdk
.get_group(&group_id)
.expect("Bob group lookup should succeed")
.expect("Bob group should exist")
.epoch;
let stored_group_event_secret = bob_mdk
.storage()
.get_group_exporter_secret(&group_id, bob_epoch)
.expect("Lookup of group-event exporter secret should succeed")
.expect("group-event exporter secret should be persisted");
let stored_mip04_secret = bob_mdk
.storage()
.get_group_mip04_exporter_secret(&group_id, bob_epoch)
.expect("Lookup of encrypted-media exporter secret should succeed")
.expect("encrypted-media exporter secret should be persisted");
assert_ne!(
stored_group_event_secret.secret, stored_mip04_secret.secret,
"MIP-03 and MIP-04 exporter secrets must differ for the same epoch"
);
}
#[test]
fn test_concurrent_commit_race_conditions() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key(), bob_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should be able to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge Alice's create commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should be able to process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should be able to accept welcome");
assert_eq!(
group_id, bob_welcome.mls_group_id,
"Alice and Bob should have the same group ID"
);
let alice_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get Alice's group")
.expect("Alice's group should exist")
.epoch;
let bob_epoch = bob_mdk
.get_group(&bob_welcome.mls_group_id)
.expect("Failed to get Bob's group")
.expect("Bob's group should exist")
.epoch;
assert_eq!(
alice_epoch, bob_epoch,
"Alice and Bob should be in same epoch"
);
let charlie_keys = Keys::generate();
let dave_keys = Keys::generate();
let charlie_key_package = create_key_package_event(&alice_mdk, &charlie_keys);
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let alice_commit_result = alice_mdk
.add_members(&group_id, std::slice::from_ref(&charlie_key_package))
.expect("Alice should be able to create commit");
let bob_commit_result = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should be able to create commit");
assert_eq!(
alice_commit_result.evolution_event.kind,
Kind::MlsGroupMessage
);
assert_eq!(
bob_commit_result.evolution_event.kind,
Kind::MlsGroupMessage
);
let _bob_process_result = bob_mdk
.process_message(&alice_commit_result.evolution_event)
.expect("Bob should be able to process Alice's commit");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge her commit");
let bob_process_own = bob_mdk.process_message(&bob_commit_result.evolution_event);
let is_handled = bob_process_own.is_err()
|| matches!(
&bob_process_own,
Ok(MessageProcessingResult::Unprocessable { .. })
)
|| bob_mdk.get_group(&group_id).unwrap().unwrap().epoch > bob_epoch;
assert!(
is_handled,
"Bob's stale commit should be rejected, unprocessable, or epoch should have advanced, got: {:?}",
bob_process_own
);
let final_alice_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get Alice's group")
.expect("Alice's group should exist")
.epoch;
assert!(
final_alice_epoch > alice_epoch,
"Epoch should have advanced after Alice's commit"
);
}
#[test]
fn test_add_member_commit_from_non_admin_is_rejected() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let charlie_keys = Keys::generate();
let dave_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let charlie_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key(), bob_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let charlie_key_package = create_key_package_event(&charlie_mdk, &charlie_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, charlie_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let charlie_welcome_rumor = &create_result.welcome_rumors[1];
let charlie_welcome = charlie_mdk
.process_welcome(&nostr::EventId::all_zeros(), charlie_welcome_rumor)
.expect("Charlie should process welcome");
charlie_mdk
.accept_welcome(&charlie_welcome)
.expect("Charlie should accept welcome");
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let update =
crate::groups::NostrGroupDataUpdate::new().admins(vec![alice_keys.public_key()]);
let alice_demote_result = alice_mdk
.update_group_data(&group_id, update)
.expect("Alice should demote Bob");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge demote commit");
let bob_add_result = bob_mdk
.add_members(&group_id, &[dave_key_package])
.expect("Bob (admin) can create add commit");
let mut bob_add_commit_event = bob_add_result.evolution_event;
if bob_add_commit_event.created_at <= alice_demote_result.evolution_event.created_at {
let new_ts = alice_demote_result.evolution_event.created_at + 1;
let builder = EventBuilder::new(
bob_add_commit_event.kind,
bob_add_commit_event.content.clone(),
)
.tags(bob_add_commit_event.tags.iter().cloned())
.custom_created_at(new_ts);
bob_add_commit_event = builder
.sign_with_keys(&bob_keys)
.expect("Failed to re-sign Bob's event");
}
charlie_mdk
.process_message(&alice_demote_result.evolution_event)
.expect("Charlie should process Alice's demote commit");
let result = charlie_mdk.process_message(&bob_add_commit_event);
match result {
Ok(MessageProcessingResult::Unprocessable { .. }) => {
}
Err(crate::Error::CommitFromNonAdmin) => {
}
Ok(MessageProcessingResult::Commit { .. }) => {
panic!("Add-member commit from demoted admin should have been rejected");
}
_ => {
panic!("Unexpected result for add-member commit from demoted admin");
}
}
}
#[test]
fn test_admin_add_member_commit_is_processed_successfully() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let charlie_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let charlie_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let group_state = bob_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist");
assert!(
!group_state.admin_pubkeys.contains(&bob_keys.public_key()),
"Bob should NOT be an admin"
);
let charlie_key_package = create_key_package_event(&charlie_mdk, &charlie_keys);
let alice_add_result = alice_mdk
.add_members(&group_id, &[charlie_key_package])
.expect("Alice (admin) can create add commit");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge add commit");
let result = bob_mdk.process_message(&alice_add_result.evolution_event);
assert!(
result.is_ok(),
"Admin add-member commit should be processed successfully, got error: {:?}",
result.err()
);
assert!(
matches!(result.unwrap(), MessageProcessingResult::Commit { .. }),
"Result should be a Commit"
);
let members = bob_mdk
.get_members(&group_id)
.expect("Failed to get members");
assert_eq!(members.len(), 3, "Group should have 3 members");
}
#[test]
fn test_admin_extension_update_commit_is_processed_successfully() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let update = crate::groups::NostrGroupDataUpdate::new()
.name("Updated Group Name".to_string())
.description("Updated description".to_string());
let alice_update_result = alice_mdk
.update_group_data(&group_id, update)
.expect("Alice (admin) can update group data");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge update commit");
let result = bob_mdk.process_message(&alice_update_result.evolution_event);
assert!(
result.is_ok(),
"Admin extension update commit should be processed successfully, got error: {:?}",
result.err()
);
assert!(
matches!(result.unwrap(), MessageProcessingResult::Commit { .. }),
"Result should be a Commit"
);
let group_state = bob_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist");
assert_eq!(
group_state.name, "Updated Group Name",
"Group name should be updated"
);
}
#[test]
fn test_admin_extension_update_cannot_deplete_all_admins() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let bob_epoch_before = bob_mdk
.get_group(&group_id)
.expect("Failed to get Bob's group before malicious commit")
.expect("Bob's group should exist before malicious commit")
.epoch;
let mut alice_mls_group = alice_mdk
.load_mls_group(&group_id)
.expect("Failed to load Alice MLS group")
.expect("Alice MLS group should exist");
let mut group_data = NostrGroupDataExtension::from_group(&alice_mls_group)
.expect("Alice MLS group should have group data");
group_data.admins.clear();
let serialized_group_data = group_data
.as_raw()
.tls_serialize_detached()
.expect("Failed to serialize group-data extension");
let extension = Extension::Unknown(
group_data.extension_type(),
UnknownExtension(serialized_group_data),
);
let mut extensions = alice_mls_group.extensions().clone();
extensions
.add_or_replace(extension)
.expect("Failed to replace group-data extension");
let signature_keypair = alice_mdk
.load_mls_signer(&alice_mls_group)
.expect("Failed to load Alice signer");
let (message_out, _, _) = alice_mls_group
.update_group_context_extensions(&alice_mdk.provider, extensions, &signature_keypair)
.expect("OpenMLS should create the extension update commit");
let malicious_commit_event = alice_mdk
.build_message_event(
&group_id,
message_out
.tls_serialize_detached()
.expect("Failed to serialize malicious commit"),
None,
)
.expect("Failed to build malicious commit event");
let result = bob_mdk.process_message(&malicious_commit_event);
assert!(
matches!(result, Ok(MessageProcessingResult::Unprocessable { .. })),
"Admin-depleting commit should be rejected as unprocessable"
);
let bob_group = bob_mdk
.get_group(&group_id)
.expect("Failed to get Bob's group")
.expect("Bob's group should exist");
assert_eq!(
bob_group.epoch, bob_epoch_before,
"Rejected commit should not advance Bob's stored epoch"
);
assert_eq!(
bob_group.admin_pubkeys,
admins.into_iter().collect(),
"Rejected commit should not alter Bob's stored admin set"
);
}
#[test]
fn test_admin_remove_member_commit_is_processed_successfully() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let charlie_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let charlie_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let charlie_key_package = create_key_package_event(&charlie_mdk, &charlie_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, charlie_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let charlie_welcome_rumor = &create_result.welcome_rumors[1];
let charlie_welcome = charlie_mdk
.process_welcome(&nostr::EventId::all_zeros(), charlie_welcome_rumor)
.expect("Charlie should process welcome");
charlie_mdk
.accept_welcome(&charlie_welcome)
.expect("Charlie should accept welcome");
let members = bob_mdk
.get_members(&group_id)
.expect("Failed to get members");
assert_eq!(members.len(), 3, "Group should have 3 members initially");
let alice_remove_result = alice_mdk
.remove_members(&group_id, &[charlie_keys.public_key()])
.expect("Alice (admin) can remove members");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge remove commit");
let result = bob_mdk.process_message(&alice_remove_result.evolution_event);
assert!(
result.is_ok(),
"Admin remove-member commit should be processed successfully, got error: {:?}",
result.err()
);
assert!(
matches!(result.unwrap(), MessageProcessingResult::Commit { .. }),
"Result should be a Commit"
);
let members = bob_mdk
.get_members(&group_id)
.expect("Failed to get members");
assert_eq!(
members.len(),
2,
"Group should have 2 members after removal"
);
assert!(
!members.contains(&charlie_keys.public_key()),
"Charlie should be removed"
);
}
#[test]
fn test_removed_member_processes_own_removal_commit() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let bob_group_before = bob_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist");
assert_eq!(
bob_group_before.state,
group_types::GroupState::Active,
"Bob's group should be Active before removal"
);
let alice_remove_result = alice_mdk
.remove_members(&group_id, &[bob_keys.public_key()])
.expect("Alice (admin) can remove members");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge remove commit");
let result = bob_mdk.process_message(&alice_remove_result.evolution_event);
assert!(
result.is_ok(),
"Removed member should process their removal commit successfully, got error: {:?}",
result.err()
);
assert!(
matches!(result.unwrap(), MessageProcessingResult::Commit { .. }),
"Result should be a Commit"
);
let bob_group_after = bob_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist");
assert_eq!(
bob_group_after.state,
group_types::GroupState::Inactive,
"Bob's group should be Inactive after being removed"
);
}
struct TestCallback {
rollbacks: std::sync::Mutex<Vec<crate::RollbackInfo>>,
}
impl fmt::Debug for TestCallback {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count = self.rollbacks.lock().map(|g| g.len()).unwrap_or(0);
f.debug_struct("TestCallback")
.field("rollback_count", &count)
.finish()
}
}
impl TestCallback {
fn new() -> Self {
Self {
rollbacks: std::sync::Mutex::new(Vec::new()),
}
}
fn rollback_count(&self) -> usize {
self.rollbacks.lock().unwrap().len()
}
fn get_rollbacks(&self) -> Vec<(GroupId, u64, EventId)> {
self.rollbacks
.lock()
.unwrap()
.iter()
.map(|info| {
(
info.group_id.clone(),
info.target_epoch,
info.new_head_event,
)
})
.collect()
}
#[allow(dead_code)]
fn get_rollback_infos(&self) -> Vec<crate::RollbackInfo> {
self.rollbacks.lock().unwrap().clone()
}
}
impl crate::callback::MdkCallback for TestCallback {
fn on_rollback(&self, info: &crate::RollbackInfo) {
self.rollbacks.lock().unwrap().push(info.clone());
}
}
fn order_events_by_mip03<'a>(
event_a: &'a nostr::Event,
event_b: &'a nostr::Event,
) -> (&'a nostr::Event, &'a nostr::Event) {
if event_a.created_at < event_b.created_at {
(event_a, event_b)
} else if event_b.created_at < event_a.created_at {
(event_b, event_a)
} else {
if event_a.id.to_hex() < event_b.id.to_hex() {
(event_a, event_b)
} else {
(event_b, event_a)
}
}
}
#[test]
fn test_commit_race_simple_better_commit_wins() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let carol_keys = Keys::generate();
let callback = std::sync::Arc::new(TestCallback::new());
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_callback(callback.clone())
.build();
let bob_mdk = create_test_mdk();
let carol_mdk = create_test_mdk();
let admins = vec![
alice_keys.public_key(),
bob_keys.public_key(),
carol_keys.public_key(),
];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let carol_key_package = create_key_package_event(&carol_mdk, &carol_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, carol_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should be able to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge Alice's create commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should be able to process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should be able to accept welcome");
let carol_welcome_rumor = &create_result.welcome_rumors[1];
let carol_welcome = carol_mdk
.process_welcome(&nostr::EventId::all_zeros(), carol_welcome_rumor)
.expect("Carol should be able to process welcome");
carol_mdk
.accept_welcome(&carol_welcome)
.expect("Carol should be able to accept welcome");
let initial_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
let dave_keys = Keys::generate();
let eve_keys = Keys::generate();
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let eve_key_package = create_key_package_event(&carol_mdk, &eve_keys);
let bob_commit = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should create commit");
let carol_commit = carol_mdk
.add_members(&group_id, std::slice::from_ref(&eve_key_package))
.expect("Carol should create commit");
let (better_commit, worse_commit) =
order_events_by_mip03(&bob_commit.evolution_event, &carol_commit.evolution_event);
let worse_result = alice_mdk.process_message(worse_commit);
assert!(
worse_result.is_ok(),
"Processing worse commit should succeed: {:?}",
worse_result.err()
);
let epoch_after_worse = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(
epoch_after_worse,
initial_epoch + 1,
"Epoch should advance after processing commit"
);
let better_result = alice_mdk.process_message(better_commit);
assert!(
better_result.is_ok(),
"Processing better commit should succeed via rollback: {:?}",
better_result.err()
);
assert_eq!(
callback.rollback_count(),
1,
"Should have triggered exactly one rollback"
);
let rollbacks = callback.get_rollbacks();
assert!(
rollbacks[0].0 == group_id,
"Rollback should be for our group"
);
assert_eq!(
rollbacks[0].1, initial_epoch,
"Rollback should target the epoch before the competing commits"
);
assert_eq!(
rollbacks[0].2, better_commit.id,
"Rollback should identify the better commit as the new head"
);
let final_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(
final_epoch,
initial_epoch + 1,
"Final epoch should be initial + 1 (better commit applied)"
);
let members = alice_mdk
.get_members(&group_id)
.expect("Should be able to get group members");
assert!(
members.contains(&alice_keys.public_key()),
"Alice should still be a member"
);
assert!(
members.contains(&bob_keys.public_key()),
"Bob should still be a member"
);
assert!(
members.contains(&carol_keys.public_key()),
"Carol should still be a member"
);
let bob_commit_was_better = better_commit.id == bob_commit.evolution_event.id;
if bob_commit_was_better {
assert!(
members.contains(&dave_keys.public_key()),
"Dave should be a member (Bob's better commit added Dave)"
);
assert!(
!members.contains(&eve_keys.public_key()),
"Eve should NOT be a member (Carol's worse commit was rolled back)"
);
} else {
assert!(
members.contains(&eve_keys.public_key()),
"Eve should be a member (Carol's better commit added Eve)"
);
assert!(
!members.contains(&dave_keys.public_key()),
"Dave should NOT be a member (Bob's worse commit was rolled back)"
);
}
assert_eq!(
members.len(),
4,
"Group should have exactly 4 members after rollback"
);
}
#[test]
fn test_commit_race_worse_late_commit_rejected() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let carol_keys = Keys::generate();
let callback = std::sync::Arc::new(TestCallback::new());
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_callback(callback.clone())
.build();
let bob_mdk = create_test_mdk();
let carol_mdk = create_test_mdk();
let admins = vec![
alice_keys.public_key(),
bob_keys.public_key(),
carol_keys.public_key(),
];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let carol_key_package = create_key_package_event(&carol_mdk, &carol_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, carol_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let carol_welcome_rumor = &create_result.welcome_rumors[1];
let carol_welcome = carol_mdk
.process_welcome(&nostr::EventId::all_zeros(), carol_welcome_rumor)
.expect("Carol should process welcome");
carol_mdk
.accept_welcome(&carol_welcome)
.expect("Carol should accept welcome");
let initial_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
let dave_keys = Keys::generate();
let eve_keys = Keys::generate();
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let eve_key_package = create_key_package_event(&carol_mdk, &eve_keys);
let bob_commit = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should create commit");
let carol_commit = carol_mdk
.add_members(&group_id, std::slice::from_ref(&eve_key_package))
.expect("Carol should create commit");
let (better_commit, worse_commit) =
order_events_by_mip03(&bob_commit.evolution_event, &carol_commit.evolution_event);
let better_result = alice_mdk.process_message(better_commit);
assert!(
better_result.is_ok(),
"Processing better commit should succeed: {:?}",
better_result.err()
);
let epoch_after_better = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(
epoch_after_better,
initial_epoch + 1,
"Epoch should advance"
);
let worse_result = alice_mdk.process_message(worse_commit);
match worse_result {
Ok(MessageProcessingResult::Unprocessable { .. }) => {
}
Ok(MessageProcessingResult::Commit { .. }) => {
}
Ok(other) => {
panic!(
"Unexpected result type for worse commit: {:?}",
std::mem::discriminant(&other)
);
}
Err(_) => {
}
}
assert_eq!(
callback.rollback_count(),
0,
"Should NOT have triggered any rollback"
);
let final_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(
final_epoch,
initial_epoch + 1,
"Epoch should remain at initial + 1 (better commit preserved)"
);
}
#[test]
fn test_commit_race_simple_rollback() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let carol_keys = Keys::generate();
let callback = std::sync::Arc::new(TestCallback::new());
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_callback(callback.clone())
.build();
let bob_mdk = create_test_mdk();
let carol_mdk = create_test_mdk();
let admins = vec![
alice_keys.public_key(),
bob_keys.public_key(),
carol_keys.public_key(),
];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let carol_key_package = create_key_package_event(&carol_mdk, &carol_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, carol_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let carol_welcome_rumor = &create_result.welcome_rumors[1];
let carol_welcome = carol_mdk
.process_welcome(&nostr::EventId::all_zeros(), carol_welcome_rumor)
.expect("Carol should process welcome");
carol_mdk
.accept_welcome(&carol_welcome)
.expect("Carol should accept welcome");
let initial_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
let dave_keys = Keys::generate();
let eve_keys = Keys::generate();
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let eve_key_package = create_key_package_event(&carol_mdk, &eve_keys);
let commit_a = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should create commit A");
let commit_a_prime = carol_mdk
.add_members(&group_id, std::slice::from_ref(&eve_key_package))
.expect("Carol should create commit A'");
let (better_a, worse_a) =
order_events_by_mip03(&commit_a.evolution_event, &commit_a_prime.evolution_event);
let result_a = alice_mdk.process_message(worse_a);
assert!(result_a.is_ok(), "Processing worse A should succeed");
let epoch_after_a = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(epoch_after_a, initial_epoch + 1, "Epoch should be at +1");
if worse_a.id == commit_a.evolution_event.id {
bob_mdk
.merge_pending_commit(&group_id)
.expect("Bob should merge his commit A");
} else {
}
let _result_a_prime = alice_mdk.process_message(better_a);
let rollback_count = callback.rollback_count();
assert!(rollback_count > 0, "Rollback should have happened");
let rollbacks = callback.get_rollbacks();
assert_eq!(
rollbacks[0].1, initial_epoch,
"Rollback should target the epoch before competing commits"
);
let final_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(
final_epoch,
initial_epoch + 1,
"After rollback and applying better A', epoch should be initial + 1"
);
}
#[test]
fn test_epoch_snapshot_retention_pruning() {
let alice_keys = Keys::generate();
let config = crate::MdkConfig {
epoch_snapshot_retention: 3,
..Default::default()
};
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_config(config)
.build();
let admins = vec![alice_keys.public_key()];
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge");
for i in 0..5 {
let update = alice_mdk
.self_update(&group_id)
.unwrap_or_else(|e| panic!("Self-update {} should succeed: {:?}", i, e));
alice_mdk
.merge_pending_commit(&group_id)
.unwrap_or_else(|e| panic!("Merge {} should succeed: {:?}", i, e));
let _ = update;
}
let final_epoch = alice_mdk
.get_group(&group_id)
.expect("Failed to get group")
.expect("Group should exist")
.epoch;
assert_eq!(final_epoch, 5, "Should have advanced through 5 epochs");
}
#[test]
fn test_commit_race_event_id_tiebreaker() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let carol_keys = Keys::generate();
let callback = std::sync::Arc::new(TestCallback::new());
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_callback(callback.clone())
.build();
let bob_mdk = create_test_mdk();
let carol_mdk = create_test_mdk();
let admins = vec![
alice_keys.public_key(),
bob_keys.public_key(),
carol_keys.public_key(),
];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let carol_key_package = create_key_package_event(&carol_mdk, &carol_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, carol_key_package],
create_nostr_group_config_data(admins),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let carol_welcome_rumor = &create_result.welcome_rumors[1];
let carol_welcome = carol_mdk
.process_welcome(&nostr::EventId::all_zeros(), carol_welcome_rumor)
.expect("Carol should process welcome");
carol_mdk
.accept_welcome(&carol_welcome)
.expect("Carol should accept welcome");
let dave_keys = Keys::generate();
let eve_keys = Keys::generate();
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let eve_key_package = create_key_package_event(&carol_mdk, &eve_keys);
let bob_commit = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should create commit");
let mut carol_commit = carol_mdk
.add_members(&group_id, std::slice::from_ref(&eve_key_package))
.expect("Carol should create commit");
if carol_commit.evolution_event.created_at != bob_commit.evolution_event.created_at {
let target_ts = bob_commit.evolution_event.created_at;
let builder = nostr::EventBuilder::new(
carol_commit.evolution_event.kind,
carol_commit.evolution_event.content.clone(),
)
.tags(carol_commit.evolution_event.tags.clone())
.custom_created_at(target_ts);
carol_commit.evolution_event = builder
.sign_with_keys(&carol_keys)
.expect("Failed to re-sign Carol's event");
}
let bob_id = bob_commit.evolution_event.id.to_hex();
let carol_id = carol_commit.evolution_event.id.to_hex();
let (smaller_id_event, larger_id_event) = if bob_id < carol_id {
(&bob_commit.evolution_event, &carol_commit.evolution_event)
} else {
(&carol_commit.evolution_event, &bob_commit.evolution_event)
};
let result1 = alice_mdk.process_message(larger_id_event);
assert!(result1.is_ok(), "First commit should process successfully");
let result2 = alice_mdk.process_message(smaller_id_event);
assert!(
result2.is_ok(),
"Second commit should process successfully: {:?}",
result2.err()
);
if bob_commit.evolution_event.created_at == carol_commit.evolution_event.created_at {
assert!(
callback.rollback_count() >= 1,
"Should trigger rollback when using event ID tiebreaker"
);
}
}
mod epoch_snapshot_manager_tests {
use mdk_storage_traits::GroupId;
use nostr::EventId;
use crate::epoch_snapshots::EpochSnapshotManager;
#[test]
fn test_is_better_candidate_earlier_timestamp_wins() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
let applied_commit_id = EventId::all_zeros();
let applied_ts = 1000u64;
let _ = manager.create_snapshot(
&storage,
&group_id,
0,
&applied_commit_id,
applied_ts,
&[1u8; 32],
);
let candidate_id = EventId::from_slice(&[1u8; 32]).unwrap();
let earlier_ts = 999u64;
assert!(
manager.is_better_candidate(
&storage,
&group_id,
0,
earlier_ts,
&candidate_id,
&[2u8; 32]
),
"Earlier timestamp should be better"
);
let later_ts = 1001u64;
assert!(
!manager.is_better_candidate(
&storage,
&group_id,
0,
later_ts,
&candidate_id,
&[2u8; 32]
),
"Later timestamp should not be better"
);
}
#[test]
fn test_is_better_candidate_smaller_id_wins_on_same_timestamp() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
let applied_commit_id = EventId::from_slice(&[0x80u8; 32]).unwrap();
let ts = 1000u64;
let _ =
manager.create_snapshot(&storage, &group_id, 0, &applied_commit_id, ts, &[1u8; 32]);
let smaller_id = EventId::from_slice(&[0x70u8; 32]).unwrap();
assert!(
manager.is_better_candidate(&storage, &group_id, 0, ts, &smaller_id, &[2u8; 32]),
"Smaller ID should be better when timestamps are equal"
);
let larger_id = EventId::from_slice(&[0x90u8; 32]).unwrap();
assert!(
!manager.is_better_candidate(&storage, &group_id, 0, ts, &larger_id, &[2u8; 32]),
"Larger ID should not be better when timestamps are equal"
);
}
#[test]
fn test_is_better_candidate_wrong_epoch_returns_false() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
let applied_commit_id = EventId::all_zeros();
let ts = 1000u64;
let _ =
manager.create_snapshot(&storage, &group_id, 0, &applied_commit_id, ts, &[1u8; 32]);
let candidate_id = EventId::from_slice(&[1u8; 32]).unwrap();
assert!(
!manager.is_better_candidate(
&storage,
&group_id,
1,
999,
&candidate_id,
&[2u8; 32]
),
"Should return false for epoch with no snapshot"
);
}
#[test]
fn test_rollback_removes_subsequent_snapshots() {
let manager = EpochSnapshotManager::new(10);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
for epoch in 0..3 {
let commit_id = EventId::from_slice(&[epoch as u8; 32]).unwrap();
let _ = manager.create_snapshot(
&storage,
&group_id,
epoch,
&commit_id,
1000 + epoch,
&[1u8; 32],
);
}
let result = manager.rollback_to_epoch(&storage, &group_id, 1);
assert!(result.is_ok(), "Rollback should succeed");
let candidate_id = EventId::from_slice(&[0xFFu8; 32]).unwrap();
assert!(
!manager.is_better_candidate(
&storage,
&group_id,
2,
999,
&candidate_id,
&[2u8; 32]
),
"Epoch 2 snapshot should have been removed after rollback to epoch 1"
);
}
#[test]
fn test_is_better_candidate_same_id_returns_false() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
let applied_commit_id = EventId::from_slice(&[0x50u8; 32]).unwrap();
let ts = 1000u64;
let _ =
manager.create_snapshot(&storage, &group_id, 0, &applied_commit_id, ts, &[1u8; 32]);
assert!(
!manager.is_better_candidate(
&storage,
&group_id,
0,
ts,
&applied_commit_id,
&[2u8; 32]
),
"Same event ID should not be considered better than itself"
);
}
#[test]
fn test_rollback_to_nonexistent_epoch_fails() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
let commit_id = EventId::all_zeros();
let _ = manager.create_snapshot(&storage, &group_id, 0, &commit_id, 1000, &[1u8; 32]);
let result = manager.rollback_to_epoch(&storage, &group_id, 5);
assert!(result.is_err(), "Rollback to nonexistent epoch should fail");
}
#[test]
fn test_rollback_to_unknown_group_fails() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
let unknown_group_id = GroupId::from_slice(&[9, 9, 9, 9]);
let commit_id = EventId::all_zeros();
let _ = manager.create_snapshot(&storage, &group_id, 0, &commit_id, 1000, &[1u8; 32]);
let result = manager.rollback_to_epoch(&storage, &unknown_group_id, 0);
assert!(result.is_err(), "Rollback for unknown group should fail");
}
#[test]
fn test_snapshots_isolated_per_group() {
let manager = EpochSnapshotManager::new(5);
let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_a = GroupId::from_slice(&[1, 1, 1, 1]);
let group_b = GroupId::from_slice(&[2, 2, 2, 2]);
let commit_id_a = EventId::from_slice(&[0x10u8; 32]).unwrap();
let commit_id_b = EventId::from_slice(&[0x20u8; 32]).unwrap();
let _ = manager.create_snapshot(&storage, &group_a, 0, &commit_id_a, 1000, &[1u8; 32]);
let _ = manager.create_snapshot(&storage, &group_b, 0, &commit_id_b, 2000, &[1u8; 32]);
let candidate = EventId::from_slice(&[0x05u8; 32]).unwrap();
assert!(
manager.is_better_candidate(&storage, &group_a, 0, 999, &candidate, &[2u8; 32]),
"Earlier timestamp (999) should be better for group A (ts=1000)"
);
assert!(
!manager.is_better_candidate(&storage, &group_a, 0, 1001, &candidate, &[2u8; 32]),
"Later timestamp (1001) should not be better for group A (ts=1000)"
);
assert!(
manager.is_better_candidate(&storage, &group_b, 0, 1999, &candidate, &[2u8; 32]),
"Earlier timestamp (1999) should be better for group B (ts=2000)"
);
assert!(
!manager.is_better_candidate(&storage, &group_b, 0, 2001, &candidate, &[2u8; 32]),
"Later timestamp (2001) should not be better for group B (ts=2000)"
);
}
#[test]
fn test_snapshot_retention_pruning() {
let manager = EpochSnapshotManager::new(3); let storage = mdk_memory_storage::MdkMemoryStorage::default();
let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
for epoch in 0..5u64 {
let commit_id = EventId::from_slice(&[epoch as u8; 32]).unwrap();
let _ = manager.create_snapshot(
&storage,
&group_id,
epoch,
&commit_id,
1000 + epoch,
&[1u8; 32],
);
}
let candidate = EventId::from_slice(&[0xFFu8; 32]).unwrap();
assert!(
!manager.is_better_candidate(&storage, &group_id, 0, 0, &candidate, &[2u8; 32]),
"Epoch 0 snapshot should have been pruned"
);
assert!(
!manager.is_better_candidate(&storage, &group_id, 1, 0, &candidate, &[2u8; 32]),
"Epoch 1 snapshot should have been pruned"
);
assert!(
manager.is_better_candidate(&storage, &group_id, 2, 1001, &candidate, &[2u8; 32]),
"Epoch 2 snapshot should still exist"
);
assert!(
manager.is_better_candidate(&storage, &group_id, 4, 1003, &candidate, &[2u8; 32]),
"Epoch 4 snapshot should still exist"
);
}
}
#[test]
fn test_removed_member_processed_message_saved_correctly() {
use mdk_storage_traits::messages::MessageStorage;
use mdk_storage_traits::messages::types::ProcessedMessageState;
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let alice_mdk = create_test_mdk();
let bob_mdk = create_test_mdk();
let admins = vec![alice_keys.public_key()];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Failed to create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge pending commit");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let alice_remove_result = alice_mdk
.remove_members(&group_id, &[bob_keys.public_key()])
.expect("Alice (admin) can remove members");
alice_mdk
.merge_pending_commit(&group_id)
.expect("Alice should merge remove commit");
let removal_event_id = alice_remove_result.evolution_event.id;
bob_mdk
.process_message(&alice_remove_result.evolution_event)
.expect("Bob should process removal commit");
let processed_message = bob_mdk
.storage()
.find_processed_message_by_event_id(&removal_event_id)
.expect("Failed to get processed message")
.expect("Processed message should exist");
assert_eq!(
processed_message.wrapper_event_id, removal_event_id,
"Wrapper event ID should match"
);
assert_eq!(
processed_message.state,
ProcessedMessageState::Processed,
"Processed message state should be Processed"
);
assert!(
processed_message.failure_reason.is_none(),
"There should be no failure reason for successful processing"
);
}
#[test]
fn test_group_membership_preserved_after_rollback() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let carol_keys = Keys::generate();
let callback = std::sync::Arc::new(TestCallback::new());
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_callback(callback.clone())
.build();
let bob_mdk = create_test_mdk();
let carol_mdk = create_test_mdk();
let admins = vec![
alice_keys.public_key(),
bob_keys.public_key(),
carol_keys.public_key(),
];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let carol_key_package = create_key_package_event(&carol_mdk, &carol_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, carol_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let carol_welcome_rumor = &create_result.welcome_rumors[1];
let carol_welcome = carol_mdk
.process_welcome(&nostr::EventId::all_zeros(), carol_welcome_rumor)
.expect("Carol should process welcome");
carol_mdk
.accept_welcome(&carol_welcome)
.expect("Carol should accept welcome");
let initial_members = alice_mdk
.get_members(&group_id)
.expect("Should get members");
let initial_group = alice_mdk
.get_group(&group_id)
.expect("Should get group")
.expect("Group should exist");
assert_eq!(initial_members.len(), 3, "Should have 3 members initially");
assert!(initial_members.contains(&alice_keys.public_key()));
assert!(initial_members.contains(&bob_keys.public_key()));
assert!(initial_members.contains(&carol_keys.public_key()));
let dave_keys = Keys::generate();
let eve_keys = Keys::generate();
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let eve_key_package = create_key_package_event(&carol_mdk, &eve_keys);
let commit_a = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should create commit");
let commit_a_prime = carol_mdk
.add_members(&group_id, std::slice::from_ref(&eve_key_package))
.expect("Carol should create commit");
let (better_commit, worse_commit) =
order_events_by_mip03(&commit_a.evolution_event, &commit_a_prime.evolution_event);
alice_mdk
.process_message(worse_commit)
.expect("Processing worse commit should succeed");
alice_mdk
.process_message(better_commit)
.expect("Processing better commit should succeed");
let rollback_count = callback.rollback_count();
assert!(
rollback_count > 0,
"Rollback should have occurred when better commit arrived"
);
let group_after_rollback = alice_mdk
.get_group(&group_id)
.expect("Should be able to get group after rollback")
.expect("Group MUST still exist after rollback");
assert_eq!(
group_after_rollback.admin_pubkeys, initial_group.admin_pubkeys,
"Admin pubkeys should be preserved after rollback"
);
let members_after_rollback = alice_mdk
.get_members(&group_id)
.expect("Should get members after rollback");
assert!(
members_after_rollback.contains(&alice_keys.public_key()),
"Alice should still be a member after rollback"
);
assert!(
members_after_rollback.contains(&bob_keys.public_key()),
"Bob should still be a member after rollback"
);
assert!(
members_after_rollback.contains(&carol_keys.public_key()),
"Carol should still be a member after rollback"
);
assert_eq!(
members_after_rollback.len(),
4,
"Should have 4 members after applying winning commit (3 original + 1 new)"
);
}
#[test]
fn test_message_invalidation_during_rollback() {
let alice_keys = Keys::generate();
let bob_keys = Keys::generate();
let carol_keys = Keys::generate();
let callback = std::sync::Arc::new(TestCallback::new());
let alice_mdk = crate::MDK::builder(mdk_memory_storage::MdkMemoryStorage::default())
.with_callback(callback.clone())
.build();
let bob_mdk = create_test_mdk();
let carol_mdk = create_test_mdk();
let admins = vec![
alice_keys.public_key(),
bob_keys.public_key(),
carol_keys.public_key(),
];
let bob_key_package = create_key_package_event(&bob_mdk, &bob_keys);
let carol_key_package = create_key_package_event(&carol_mdk, &carol_keys);
let create_result = alice_mdk
.create_group(
&alice_keys.public_key(),
vec![bob_key_package, carol_key_package],
create_nostr_group_config_data(admins.clone()),
)
.expect("Alice should create group");
let group_id = create_result.group.mls_group_id.clone();
alice_mdk
.merge_pending_commit(&group_id)
.expect("Failed to merge");
let bob_welcome_rumor = &create_result.welcome_rumors[0];
let bob_welcome = bob_mdk
.process_welcome(&nostr::EventId::all_zeros(), bob_welcome_rumor)
.expect("Bob should process welcome");
bob_mdk
.accept_welcome(&bob_welcome)
.expect("Bob should accept welcome");
let carol_welcome_rumor = &create_result.welcome_rumors[1];
let carol_welcome = carol_mdk
.process_welcome(&nostr::EventId::all_zeros(), carol_welcome_rumor)
.expect("Carol should process welcome");
carol_mdk
.accept_welcome(&carol_welcome)
.expect("Carol should accept welcome");
let dave_keys = Keys::generate();
let eve_keys = Keys::generate();
let dave_key_package = create_key_package_event(&bob_mdk, &dave_keys);
let eve_key_package = create_key_package_event(&carol_mdk, &eve_keys);
let bob_commit = bob_mdk
.add_members(&group_id, std::slice::from_ref(&dave_key_package))
.expect("Bob should create commit");
let carol_commit = carol_mdk
.add_members(&group_id, std::slice::from_ref(&eve_key_package))
.expect("Carol should create commit");
let (better_commit, worse_commit) =
order_events_by_mip03(&bob_commit.evolution_event, &carol_commit.evolution_event);
alice_mdk
.process_message(worse_commit)
.expect("Alice should process worse commit");
let mut rumor = create_test_rumor(&alice_keys, "Message at wrong epoch");
let rumor_id = rumor.id(); let _message_event = alice_mdk
.create_message(&group_id, rumor, None)
.expect("Alice should create message");
let message_id = rumor_id;
alice_mdk
.process_message(better_commit)
.expect("Alice should process better commit");
assert!(
callback.rollback_count() > 0,
"Rollback should have occurred when better commit arrived"
);
let rollback_infos = callback.get_rollback_infos();
assert!(!rollback_infos.is_empty(), "Should have rollback info");
let rollback_info = &rollback_infos[0];
assert!(
rollback_info.invalidated_messages.contains(&message_id)
|| rollback_info.messages_needing_refetch.contains(&message_id),
"Message sent at rolled-back epoch should be invalidated or need refetch. \
Message ID: {:?}, Invalidated: {:?}, Needing refetch: {:?}",
message_id,
rollback_info.invalidated_messages,
rollback_info.messages_needing_refetch
);
}
}