use std::slice::from_ref;
use openmls::{
prelude::{test_utils::new_credential, *},
storage::OpenMlsProvider,
treesync::LeafNodeParameters,
};
use openmls_test::openmls_test;
use openmls_traits::signatures::Signer;
fn generate_key_package<Provider: OpenMlsProvider>(
ciphersuite: Ciphersuite,
extensions: Extensions<KeyPackage>,
provider: &Provider,
credential_with_key: CredentialWithKey,
signer: &impl Signer,
) -> KeyPackage {
KeyPackage::builder()
.key_package_extensions(extensions)
.build(ciphersuite, provider, signer, credential_with_key)
.unwrap()
.key_package()
.clone()
}
#[openmls_test]
fn mls_duplicate_signature_key_detection_same_key_package() {
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let group_id = GroupId::from_slice(b"Test Group");
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let (bob_credential, bob_signer) =
new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm());
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(*wire_format_policy)
.ciphersuite(ciphersuite)
.build();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id.clone(),
alice_credential.clone(),
)
.expect("An unexpected error occurred.");
let welcome =
match alice_group.add_members(alice_provider, &alice_signer, &[bob_key_package]) {
Ok((_, welcome, _)) => welcome,
Err(e) => panic!("Could not add member to group: {e:?}"),
};
if let Some(staged_commit) = alice_group.pending_commit() {
let add = staged_commit
.add_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(
add.add_proposal().key_package().leaf_node().credential(),
&bob_credential.credential
);
assert!(
matches!(add.sender(), Sender::Member(member) if *member == alice_group.own_leaf_index())
);
} else {
unreachable!("Expected a StagedCommit.");
}
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
let bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating StagedWelcome from Welcome")
.into_group(bob_provider)
.expect("Error creating group from StagedWelcome");
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
assert_eq!(alice_group.pending_proposals().count(), 0);
let bob_leaf_node_index = bob_group.own_leaf_index();
if let Err(e) =
alice_group.propose_remove_member(alice_provider, &alice_signer, bob_leaf_node_index)
{
panic!("Could not add member from group: {e:?}");
};
let bob_key_package_2 = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
for _ in 0..2 {
if let Err(e) =
alice_group.propose_add_member(alice_provider, &alice_signer, &bob_key_package_2)
{
panic!("Could not add member to group: {e:?}");
}
}
assert_eq!(alice_group.pending_proposals().count(), 3);
if let Err(e) = alice_group.commit_to_pending_proposals(alice_provider, &alice_signer) {
panic!("Could not commit proposals: {e:?}");
}
let pending_commit = match alice_group.pending_commit() {
Some(pending_commit) => pending_commit,
None => panic!("No pending commit was created"),
};
assert_eq!(pending_commit.queued_proposals().count(), 2);
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
}
}
#[openmls_test]
fn mls_duplicate_signature_key_detection_different_key_package() {
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let group_id = GroupId::from_slice(b"Test Group");
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let (bob_credential, bob_signer) =
new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm());
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(*wire_format_policy)
.ciphersuite(ciphersuite)
.build();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id.clone(),
alice_credential.clone(),
)
.expect("An unexpected error occurred.");
let welcome = match alice_group.add_members(
alice_provider,
&alice_signer,
from_ref(&bob_key_package),
) {
Ok((_, welcome, _)) => welcome,
Err(e) => panic!("Could not add member to group: {e:?}"),
};
if let Some(staged_commit) = alice_group.pending_commit() {
let add = staged_commit
.add_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(
add.add_proposal().key_package().leaf_node().credential(),
&bob_credential.credential
);
assert!(
matches!(add.sender(), Sender::Member(member) if *member == alice_group.own_leaf_index())
);
} else {
unreachable!("Expected a StagedCommit.");
}
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
let bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating StagedWelcome from Welcome")
.into_group(bob_provider)
.expect("Error creating group from StagedWelcome");
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
assert_eq!(alice_group.pending_proposals().count(), 0);
let bob_leaf_node_index = bob_group.own_leaf_index();
alice_group
.propose_remove_member(alice_provider, &alice_signer, bob_leaf_node_index)
.expect("Could not add member from group: {e:?}");
for _ in 0..2 {
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
if let Err(e) =
alice_group.propose_add_member(alice_provider, &alice_signer, &bob_key_package)
{
panic!("Could not add member to group: {e:?}");
}
}
assert_eq!(alice_group.pending_proposals().count(), 3);
use openmls::group::{
CommitToPendingProposalsError, CreateCommitError, ProposalValidationError,
};
match alice_group.commit_to_pending_proposals(alice_provider, &alice_signer) {
Err(CommitToPendingProposalsError::CreateCommitError(
CreateCommitError::ProposalValidationError(
ProposalValidationError::DuplicateSignatureKey,
),
)) => (),
Err(e) => panic!("Wrong error type returned: {e:?}."),
Ok(e) => panic!("Creating commit should fail: {e:?}"),
}
}
}
#[openmls_test]
fn mls_group_operations() {
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let group_id = GroupId::from_slice(b"Test Group");
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let (bob_credential, bob_signer) =
new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm());
let (charlie_credential, charlie_signer) = new_credential(
charlie_provider,
b"Charlie",
ciphersuite.signature_algorithm(),
);
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(*wire_format_policy)
.ciphersuite(ciphersuite)
.build();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id.clone(),
alice_credential.clone(),
)
.expect("An unexpected error occurred.");
let welcome =
match alice_group.add_members(alice_provider, &alice_signer, &[bob_key_package]) {
Ok((_, welcome, _)) => welcome,
Err(e) => panic!("Could not add member to group: {e:?}"),
};
if let Some(staged_commit) = alice_group.pending_commit() {
let add = staged_commit
.add_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(
add.add_proposal().key_package().leaf_node().credential(),
&bob_credential.credential
);
assert!(
matches!(add.sender(), Sender::Member(member) if *member == alice_group.own_leaf_index())
);
} else {
unreachable!("Expected a StagedCommit.");
}
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
let mut bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating StagedWelcome from Welcome")
.into_group(bob_provider)
.expect("Error creating group from StagedWelcome");
assert!(alice_group.members().eq(bob_group.members()));
assert_eq!(
alice_group.epoch_authenticator().as_slice(),
bob_group.epoch_authenticator().as_slice()
);
let message_alice = b"Hi, I'm Alice!";
let queued_message = alice_group
.create_message(alice_provider, &alice_signer, message_alice)
.expect("Error creating application message");
let processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
let sender = processed_message.credential().clone();
if let ProcessedMessageContent::ApplicationMessage(application_message) =
processed_message.into_content()
{
assert_eq!(application_message.into_bytes(), message_alice);
assert_eq!(
&sender,
alice_group
.credential()
.expect("An unexpected error occurred.")
);
} else {
unreachable!("Expected an ApplicationMessage.");
}
let (queued_message, welcome_option, _group_info) = bob_group
.self_update(bob_provider, &bob_signer, LeafNodeParameters::default())
.unwrap()
.into_contents();
let alice_processed_message = alice_group
.process_message(
alice_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
alice_processed_message.into_content()
{
alice_group
.merge_staged_commit(alice_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
bob_group
.merge_pending_commit(bob_provider)
.expect("error merging pending commit");
assert!(welcome_option.is_none());
assert_eq!(
alice_group
.export_secret(alice_provider.crypto(), "", &[], 32)
.unwrap(),
bob_group
.export_secret(bob_provider.crypto(), "", &[], 32)
.unwrap()
);
assert_eq!(
alice_group.export_ratchet_tree(),
bob_group.export_ratchet_tree()
);
let (queued_message, _) = alice_group
.propose_self_update(alice_provider, &alice_signer, LeafNodeParameters::default())
.unwrap();
let bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::ProposalMessage(staged_proposal) =
bob_processed_message.into_content()
{
if let Proposal::Update(ref update_proposal) = staged_proposal.proposal() {
assert_eq!(
update_proposal.leaf_node().credential(),
&alice_credential.credential
);
alice_group
.store_pending_proposal(alice_provider.storage(), *staged_proposal.clone())
.unwrap();
} else {
unreachable!("Expected a Proposal.");
}
assert!(matches!(
staged_proposal.sender(),
Sender::Member(member) if *member == alice_group.own_leaf_index()
));
bob_group
.store_pending_proposal(bob_provider.storage(), *staged_proposal)
.unwrap();
} else {
unreachable!("Expected a QueuedProposal.");
}
let (queued_message, _welcome_option, _group_info) = alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.unwrap();
let bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
bob_processed_message.into_content()
{
bob_group
.merge_staged_commit(bob_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
assert_eq!(
alice_group
.export_secret(alice_provider.crypto(), "", &[], 32)
.unwrap(),
bob_group
.export_secret(bob_provider.crypto(), "", &[], 32)
.unwrap()
);
assert_eq!(
alice_group.export_ratchet_tree(),
bob_group.export_ratchet_tree()
);
let charlie_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
charlie_credential,
&charlie_signer,
);
let (queued_message, welcome, _group_info) = bob_group
.add_members(bob_provider, &bob_signer, &[charlie_key_package])
.unwrap();
let alice_processed_message = alice_group
.process_message(
alice_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
bob_group
.merge_pending_commit(bob_provider)
.expect("error merging pending commit");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
alice_processed_message.into_content()
{
alice_group
.merge_staged_commit(alice_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
let mut charlie_group = StagedWelcome::new_from_welcome(
charlie_provider,
mls_group_create_config.join_config(),
welcome,
Some(bob_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from Welcome")
.into_group(charlie_provider)
.expect("Error creating group from staged join");
assert_eq!(
alice_group.export_ratchet_tree(),
bob_group.export_ratchet_tree(),
);
assert_eq!(
alice_group.export_ratchet_tree(),
charlie_group.export_ratchet_tree()
);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
let credential2 = members[2].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
assert_eq!(credential2, b"Charlie");
let message_charlie = b"Hi, I'm Charlie!";
let queued_message = charlie_group
.create_message(charlie_provider, &charlie_signer, message_charlie)
.expect("Error creating application message");
let _alice_processed_message = alice_group
.process_message(
alice_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
let _bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
let (queued_message, welcome_option, _group_info) = charlie_group
.self_update(
charlie_provider,
&charlie_signer,
LeafNodeParameters::default(),
)
.unwrap()
.into_contents();
let alice_processed_message = alice_group
.process_message(
alice_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
let bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
charlie_group
.merge_pending_commit(charlie_provider)
.expect("error merging pending commit");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
alice_processed_message.into_content()
{
alice_group
.merge_staged_commit(alice_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
bob_processed_message.into_content()
{
bob_group
.merge_staged_commit(bob_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
assert!(welcome_option.is_none());
assert_eq!(
alice_group
.export_secret(alice_provider.crypto(), "", &[], 32)
.unwrap(),
bob_group
.export_secret(bob_provider.crypto(), "", &[], 32)
.unwrap()
);
assert_eq!(
alice_group
.export_secret(alice_provider.crypto(), "", &[], 32)
.unwrap(),
charlie_group
.export_secret(charlie_provider.crypto(), "", &[], 32)
.unwrap()
);
assert_eq!(
alice_group.export_ratchet_tree(),
bob_group.export_ratchet_tree(),
);
assert_eq!(
alice_group.export_ratchet_tree(),
charlie_group.export_ratchet_tree()
);
println!(" >>> Charlie is removing bob");
let (queued_message, welcome_option, _group_info) = charlie_group
.remove_members(
charlie_provider,
&charlie_signer,
&[bob_group.own_leaf_index()],
)
.expect("Could not remove member from group.");
assert!(bob_group.is_active());
let alice_processed_message = alice_group
.process_message(
alice_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
let bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
charlie_group
.merge_pending_commit(charlie_provider)
.expect("error merging pending commit");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
alice_processed_message.into_content()
{
let remove = staged_commit
.remove_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(remove.remove_proposal().removed(), members[1].index);
assert!(
matches!(remove.sender(), Sender::Member(member) if *member == members[2].index)
);
alice_group
.merge_staged_commit(alice_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
bob_processed_message.into_content()
{
let remove = staged_commit
.remove_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(remove.remove_proposal().removed(), members[1].index);
assert!(
matches!(remove.sender(), Sender::Member(member) if *member == members[2].index)
);
bob_group
.merge_staged_commit(bob_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
assert!(welcome_option.is_none());
assert!(!bob_group.is_active());
assert_eq!(
alice_group.export_ratchet_tree(),
charlie_group.export_ratchet_tree()
);
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Charlie");
assert!(bob_group
.create_message(bob_provider, &bob_signer, b"Should not go through")
.is_err());
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
let (queued_message, _) = alice_group
.propose_remove_member(
alice_provider,
&alice_signer,
charlie_group.own_leaf_index(),
)
.expect("Could not create proposal to remove Charlie");
let charlie_processed_message = charlie_group
.process_message(
charlie_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::ProposalMessage(staged_proposal) =
charlie_processed_message.into_content()
{
if let Proposal::Remove(ref remove_proposal) = staged_proposal.proposal() {
assert_eq!(remove_proposal.removed(), members[1].index);
charlie_group
.store_pending_proposal(charlie_provider.storage(), *staged_proposal.clone())
.unwrap();
} else {
unreachable!("Expected a Proposal.");
}
assert!(matches!(
staged_proposal.sender(),
Sender::Member(member) if *member == members[0].index
));
} else {
unreachable!("Expected a QueuedProposal.");
}
let (queued_message, _) = alice_group
.propose_add_member(alice_provider, &alice_signer, &bob_key_package)
.expect("Could not create proposal to add Bob");
let charlie_processed_message = charlie_group
.process_message(
charlie_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::ProposalMessage(staged_proposal) =
charlie_processed_message.into_content()
{
if let Proposal::Add(add_proposal) = staged_proposal.proposal() {
assert_eq!(
add_proposal.key_package().leaf_node().credential(),
&bob_credential.credential
);
} else {
unreachable!("Expected an AddProposal.");
}
assert!(matches!(
staged_proposal.sender(),
Sender::Member(member) if *member == members[0].index
));
charlie_group
.store_pending_proposal(charlie_provider.storage(), *staged_proposal)
.unwrap();
} else {
unreachable!("Expected a QueuedProposal.");
}
let (queued_message, welcome_option, _group_info) = alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.expect("Could not flush proposals");
let charlie_processed_message = charlie_group
.process_message(
charlie_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
charlie_processed_message.into_content()
{
charlie_group
.merge_staged_commit(charlie_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
let welcome: MlsMessageIn = welcome_option.expect("Welcome was not returned").into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
bob_group.delete(bob_provider.storage()).unwrap();
let mut bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from Welcome")
.into_group(bob_provider)
.expect("Error creating group from staged join");
assert_eq!(alice_group.members().count(), 2);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
assert_eq!(bob_group.members().count(), 2);
let members = bob_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
let credential1 = members[1].credential.serialized_content();
assert_eq!(credential0, b"Alice");
assert_eq!(credential1, b"Bob");
let message_alice = b"Hi, I'm Alice!";
let queued_message = alice_group
.create_message(alice_provider, &alice_signer, message_alice)
.expect("Error creating application message");
let bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
let sender = bob_processed_message.credential().clone();
if let ProcessedMessageContent::ApplicationMessage(application_message) =
bob_processed_message.into_content()
{
assert_eq!(application_message.into_bytes(), message_alice);
assert_eq!(
&sender,
alice_group.credential().expect("Expected a credential")
);
} else {
unreachable!("Expected an ApplicationMessage.");
}
let queued_message = bob_group
.leave_group(bob_provider, &bob_signer)
.expect("Could not leave group");
let alice_processed_message = alice_group
.process_message(
alice_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::ProposalMessage(staged_proposal) =
alice_processed_message.into_content()
{
alice_group
.store_pending_proposal(alice_provider.storage(), *staged_proposal)
.unwrap();
} else {
unreachable!("Expected a QueuedProposal.");
}
assert!(matches!(
bob_group.commit_to_pending_proposals(bob_provider, &bob_signer),
Err(CommitToPendingProposalsError::CreateCommitError(
CreateCommitError::CannotRemoveSelf
))
));
let (queued_message, _welcome_option, _group_info) = alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.expect("Could not commit to proposals.");
assert!(bob_group.is_active());
let bob_leaf_index = bob_group.own_leaf_index();
if let Some(staged_commit) = alice_group.pending_commit() {
let remove = staged_commit
.remove_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(remove.remove_proposal().removed(), bob_leaf_index);
assert!(matches!(remove.sender(), Sender::Member(member) if *member == bob_leaf_index));
} else {
unreachable!("Expected a StagedCommit.");
}
alice_group
.merge_pending_commit(alice_provider)
.expect("Could not merge Commit.");
let bob_processed_message = bob_group
.process_message(
bob_provider,
queued_message
.clone()
.into_protocol_message()
.expect("Unexpected message type"),
)
.expect("Could not process message.");
if let ProcessedMessageContent::StagedCommitMessage(staged_commit) =
bob_processed_message.into_content()
{
let remove = staged_commit
.remove_proposals()
.next()
.expect("Expected a proposal.");
assert_eq!(remove.remove_proposal().removed(), bob_leaf_index);
assert!(matches!(remove.sender(), Sender::Member(member) if *member == bob_leaf_index));
assert!(staged_commit.self_removed());
bob_group
.merge_staged_commit(bob_provider, *staged_commit)
.unwrap();
} else {
unreachable!("Expected a StagedCommit.");
}
assert!(!bob_group.is_active());
assert_eq!(alice_group.members().count(), 1);
let members = alice_group.members().collect::<Vec<Member>>();
let credential0 = members[0].credential.serialized_content();
assert_eq!(credential0, b"Alice");
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential,
&bob_signer,
);
let (_queued_message, welcome, _group_info) = alice_group
.add_members(alice_provider, &alice_signer, &[bob_key_package])
.expect("Could not add Bob");
let _test_group = MlsGroup::load(alice_provider.storage(), &group_id)
.expect("Could not load the group state due to an error.")
.expect("Could not load the group state because the group does not exist.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
bob_group.delete(bob_provider.storage()).unwrap();
let mut bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Could not create staged join from Welcome")
.into_group(bob_provider)
.expect("Could not create group from staged join");
assert_eq!(
alice_group
.export_secret(alice_provider.crypto(), "before load", &[], 32)
.unwrap(),
bob_group
.export_secret(bob_provider.crypto(), "before load", &[], 32)
.unwrap()
);
bob_group = MlsGroup::load(bob_provider.storage(), &group_id)
.expect("Could not load group from file because of an error")
.expect("Could not load group from file because there is no group with given id");
assert_eq!(
alice_group
.export_secret(alice_provider.crypto(), "after load", &[], 32)
.unwrap(),
bob_group
.export_secret(bob_provider.crypto(), "after load", &[], 32)
.unwrap()
);
}
}
#[openmls_test]
fn addition_order() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let group_id = GroupId::random(alice_provider.rand());
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let (bob_credential, bob_signer) =
new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm());
let (charlie_credential, charlie_signer) = new_credential(
charlie_provider,
b"Charlie",
ciphersuite.signature_algorithm(),
);
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential.clone(),
&bob_signer,
);
let charlie_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
charlie_credential.clone(),
&charlie_signer,
);
let mls_group_config = MlsGroupCreateConfig::builder()
.wire_format_policy(*wire_format_policy)
.ciphersuite(ciphersuite)
.build();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_config,
group_id.clone(),
alice_credential.clone(),
)
.expect("An unexpected error occurred.");
let _welcome = match alice_group.add_members(
alice_provider,
&alice_signer,
&[bob_key_package, charlie_key_package],
) {
Ok((_, welcome, _)) => welcome,
Err(e) => panic!("Could not add member to group: {e:?}"),
};
if let Some(staged_commit) = alice_group.pending_commit() {
let mut add_proposals = staged_commit.add_proposals();
let add_bob = add_proposals.next().expect("Expected a proposal.");
assert_eq!(
add_bob
.add_proposal()
.key_package()
.leaf_node()
.credential(),
&bob_credential.credential
);
let add_charlie = add_proposals.next().expect("Expected a proposal.");
assert_eq!(
add_charlie
.add_proposal()
.key_package()
.leaf_node()
.credential(),
&charlie_credential.credential
);
} else {
unreachable!("Expected a StagedCommit.");
}
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let members = alice_group.members().collect::<Vec<Member>>();
let credential1 = members[1].credential.serialized_content();
let credential2 = members[2].credential.serialized_content();
assert_eq!(credential1, b"Bob");
assert_eq!(members[1].index, LeafNodeIndex::new(1));
assert_eq!(credential2, b"Charlie");
assert_eq!(members[2].index, LeafNodeIndex::new(2));
}
}
#[openmls_test]
fn more_remove_than_add_proposals_in_commit() {
let provider = &Provider::default();
more_remove_than_add_proposals_in_commit_inner::<6, 2, 1>(provider, ciphersuite);
more_remove_than_add_proposals_in_commit_inner::<10, 2, 1>(provider, ciphersuite);
more_remove_than_add_proposals_in_commit_inner::<22, 6, 5>(provider, ciphersuite);
fn more_remove_than_add_proposals_in_commit_inner<
const INITIAL_GROUP_SIZE: usize,
const REMOVE_PROPOSALS_COUNT: usize,
const ADD_PROPOSALS_COUNT: usize,
>(
provider: &Provider,
ciphersuite: Ciphersuite,
) {
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let ALL_MEMBERS_COUNT: usize = INITIAL_GROUP_SIZE + ADD_PROPOSALS_COUNT;
let REMAINING_MEMBERS_COUNT: usize = INITIAL_GROUP_SIZE - REMOVE_PROPOSALS_COUNT;
let ids = (0..ALL_MEMBERS_COUNT)
.map(|i| format!("member {i}").into_bytes())
.collect::<Vec<_>>();
let (credentials, signers): (Vec<_>, Vec<_>) = ids
.iter()
.map(|id| new_credential(provider, id, ciphersuite.signature_algorithm()))
.unzip();
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(*wire_format_policy)
.ciphersuite(ciphersuite)
.build();
let group_id = GroupId::random(provider.rand());
let mut alice_group = MlsGroup::new_with_group_id(
provider,
&signers[0],
&mls_group_create_config,
group_id.clone(),
credentials[0].clone(),
)
.expect("could not create group");
let key_packages = credentials[1..]
.iter()
.zip(signers[1..].iter())
.map(|(credential, signer)| {
generate_key_package(
ciphersuite,
Extensions::empty(),
provider,
credential.clone(),
signer,
)
})
.collect::<Vec<_>>();
alice_group
.add_members(
provider,
&signers[0],
&key_packages[..INITIAL_GROUP_SIZE - 1],
)
.expect("Could not add initial members to group");
alice_group
.merge_pending_commit(provider)
.expect("error merging pending commit");
assert_eq!(alice_group.members().count(), INITIAL_GROUP_SIZE);
let removed_indices = (REMAINING_MEMBERS_COUNT..INITIAL_GROUP_SIZE)
.map(|index| LeafNodeIndex::new(index as u32));
for index in removed_indices {
alice_group
.propose_remove_member_by_value(provider, &signers[0], index)
.expect("could not propose removing member");
}
alice_group
.add_members(
provider,
&signers[0],
&key_packages[INITIAL_GROUP_SIZE - 1..],
)
.expect("Could not add member to group");
alice_group
.merge_pending_commit(provider)
.expect("error merging pending commit");
assert_eq!(
alice_group.members().count(),
ALL_MEMBERS_COUNT - REMOVE_PROPOSALS_COUNT
);
}
}
}
#[openmls_test]
fn test_empty_input_errors() {
let alice_provider = &Provider::default();
let group_id = GroupId::from_slice(b"Test Group");
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let mls_group_create_config = MlsGroupCreateConfig::test_default(ciphersuite);
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id,
alice_credential,
)
.expect("An unexpected error occurred.");
assert!(matches!(
alice_group
.add_members(alice_provider, &alice_signer, &[])
.expect_err("No EmptyInputError when trying to pass an empty slice to `add_members`."),
AddMembersError::EmptyInput(EmptyInputError::AddMembers)
));
assert!(matches!(
alice_group
.remove_members(alice_provider, &alice_signer, &[])
.expect_err(
"No EmptyInputError when trying to pass an empty slice to `remove_members`."
),
RemoveMembersError::EmptyInput(EmptyInputError::RemoveMembers)
));
}
#[openmls_test]
fn mls_group_ratchet_tree_extension() {
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let (bob_credential, bob_signer) =
new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm());
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential,
&bob_signer,
);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(*wire_format_policy)
.use_ratchet_tree_extension(true)
.ciphersuite(ciphersuite)
.build();
let group_id = GroupId::random(alice_provider.rand());
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id,
alice_credential.clone(),
)
.expect("An unexpected error occurred.");
let (_queued_message, welcome, _group_info) = alice_group
.add_members(alice_provider, &alice_signer, from_ref(&bob_key_package))
.unwrap();
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
let _bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
None,
)
.expect("Error creating staged join from Welcome")
.into_group(bob_provider)
.expect("Error creating group from staged join");
let (alice_credential, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let (bob_credential, bob_signer) =
new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm());
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential,
&bob_signer,
);
let mls_group_create_config = MlsGroupCreateConfig::test_default(ciphersuite);
let group_id = GroupId::random(alice_provider.rand());
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id,
alice_credential.clone(),
)
.expect("An unexpected error occurred.");
let (_queued_message, welcome, _group_info) = alice_group
.add_members(alice_provider, &alice_signer, &[bob_key_package])
.unwrap();
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected the message to be a welcome message");
let error = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
None,
)
.expect_err("Could join a group without a ratchet tree");
assert!(matches!(error, WelcomeError::MissingRatchetTree));
}
}
#[openmls_test]
fn group_context_extensions_proposal() {
let alice_provider = &Provider::default();
let (alice_credential_with_key, alice_signer) =
new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm());
let mut alice_group = MlsGroup::builder()
.ciphersuite(ciphersuite)
.build(alice_provider, &alice_signer, alice_credential_with_key)
.expect("error creating group using builder");
assert!(alice_group.extensions().required_capabilities().is_none());
let group_context_before = alice_group.export_group_context().clone();
assert_eq!(group_context_before.extensions(), &Extensions::empty());
let new_extensions = Extensions::single(Extension::RequiredCapabilities(
RequiredCapabilitiesExtension::new(&[ExtensionType::RequiredCapabilities], &[], &[]),
))
.expect("failed to create single-element extensions list");
let new_extensions_2 = Extensions::single(Extension::RequiredCapabilities(
RequiredCapabilitiesExtension::new(&[ExtensionType::RatchetTree], &[], &[]),
))
.expect("failed to create single-element extensions list");
alice_group
.propose_group_context_extensions(alice_provider, new_extensions.clone(), &alice_signer)
.expect("failed to build group context extensions proposal");
assert_eq!(alice_group.pending_proposals().count(), 1);
alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.expect("failed to commit to pending proposals");
let group_context_staged = alice_group
.pending_commit()
.unwrap()
.group_context()
.clone();
assert_eq!(group_context_staged.extensions(), &new_extensions);
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let required_capabilities = alice_group
.extensions()
.required_capabilities()
.expect("couldn't get required_capabilities");
assert!(required_capabilities.extension_types() == [ExtensionType::RequiredCapabilities]);
alice_group
.propose_group_context_extensions(alice_provider, new_extensions, &alice_signer)
.expect("failed to build group context extensions proposal");
alice_group
.propose_group_context_extensions(alice_provider, new_extensions_2, &alice_signer)
.expect("failed to build group context extensions proposal");
assert_eq!(alice_group.pending_proposals().count(), 2);
alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.expect_err(
"expected error when committing to multiple group context extensions proposals",
);
let new_extensions = Extensions::single(Extension::RequiredCapabilities(
RequiredCapabilitiesExtension::new(&[ExtensionType::Unknown(0xf042)], &[], &[]),
))
.expect("failed to create single-element extensions list");
alice_group
.propose_group_context_extensions(alice_provider, new_extensions, &alice_signer)
.expect_err("expected an error building GCE proposal with bad required_capabilities");
}