use openmls_traits::{prelude::*, signatures::Signer, types::Ciphersuite};
use proposal_store::QueuedProposal;
use tls_codec::{Deserialize, Serialize};
use crate::{
binary_tree::LeafNodeIndex,
ciphersuite::signable::Signable,
extensions::Extensions,
framing::*,
group::{
tests_and_kats::utils::{
generate_credential_with_key, generate_key_package, resign_message,
CredentialWithKeyAndSigner,
},
*,
},
messages::proposals::*,
schedule::{ExternalPsk, PreSharedKeyId, Psk},
treesync::{
errors::ApplyUpdatePathError, node::parent_node::PlainUpdatePathNode, treekem::UpdatePath,
LeafNodeParameters,
},
};
struct CommitValidationTestSetup {
alice_group: MlsGroup,
alice_credential: CredentialWithKeyAndSigner,
bob_group: MlsGroup,
charlie_group: MlsGroup,
}
fn validation_test_setup(
wire_format_policy: WireFormatPolicy,
ciphersuite: Ciphersuite,
alice_provider: &impl crate::storage::OpenMlsProvider,
bob_provider: &impl crate::storage::OpenMlsProvider,
charlie_provider: &impl crate::storage::OpenMlsProvider,
) -> CommitValidationTestSetup {
let group_id = GroupId::from_slice(b"Test Group");
let alice_credential = generate_credential_with_key(
"Alice".into(),
ciphersuite.signature_algorithm(),
alice_provider,
);
let bob_credential = generate_credential_with_key(
"Bob".into(),
ciphersuite.signature_algorithm(),
bob_provider,
);
let charlie_credential = generate_credential_with_key(
"Charlie".into(),
ciphersuite.signature_algorithm(),
charlie_provider,
);
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential,
);
let charlie_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
charlie_credential,
);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(wire_format_policy)
.ciphersuite(ciphersuite)
.build();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_credential.signer,
&mls_group_create_config,
group_id,
alice_credential.credential_with_key.clone(),
)
.expect("An unexpected error occurred.");
let (_message, welcome, _group_info) = alice_group
.add_members(
alice_provider,
&alice_credential.signer,
&[
bob_key_package.key_package().clone(),
charlie_key_package.key_package().clone(),
],
)
.expect("error adding Bob to group");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");
let bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome.clone(),
Some(alice_group.export_ratchet_tree().into()),
)
.expect("error creating staged join from welcome")
.into_group(bob_provider)
.expect("error creating group from staged join");
let charlie_group = StagedWelcome::new_from_welcome(
charlie_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("error creating staged join from welcome")
.into_group(charlie_provider)
.expect("error creating group from staged join");
CommitValidationTestSetup {
alice_group,
alice_credential,
bob_group,
charlie_group,
}
}
#[openmls_test::openmls_test]
fn test_valsem200() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
..
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let serialized_proposal_message = alice_group
.propose_remove_member(
alice_provider,
&alice_credential.signer,
alice_group.own_leaf_index(),
)
.expect("error creating commit")
.tls_serialize_detached()
.expect("serialization error");
let proposal_message =
MlsMessageIn::tls_deserialize(&mut serialized_proposal_message.as_slice())
.expect("Could not deserialize message.")
.into_plaintext()
.expect("Message was not a plaintext.");
let proposal = if let FramedContentBody::Proposal(proposal) = proposal_message.content() {
proposal.clone()
} else {
panic!("Unexpected content type.");
};
alice_group
.clear_pending_proposals(alice_provider.storage())
.unwrap();
let serialized_message = alice_group
.self_update(
alice_provider,
&alice_credential.signer,
LeafNodeParameters::default(),
)
.expect("Error creating self-update")
.into_messages()
.tls_serialize_detached()
.expect("Could not serialize message.");
let mut plaintext = MlsMessageIn::tls_deserialize(&mut serialized_message.as_slice())
.expect("Could not deserialize message.")
.into_plaintext()
.expect("Message was not a plaintext.");
let original_plaintext = plaintext.clone();
let mut commit_content = if let FramedContentBody::Commit(commit) = plaintext.content() {
commit.clone()
} else {
panic!("Unexpected content type.");
};
commit_content
.proposals
.push(ProposalOrRef::proposal(proposal));
plaintext.set_content(FramedContentBody::Commit(commit_content));
let serialized_context = alice_group
.export_group_context()
.tls_serialize_detached()
.expect("error serializing context");
let tbs: FramedContentTbs = plaintext.into();
let mut signed_plaintext: AuthenticatedContent = tbs
.with_context(serialized_context)
.sign(&alice_credential.signer)
.expect("Error signing modified payload.");
signed_plaintext.set_confirmation_tag(
original_plaintext
.confirmation_tag()
.expect("no confirmation tag on original message")
.clone(),
);
let mut signed_plaintext: PublicMessage = signed_plaintext.into();
let membership_key = alice_group.message_secrets().membership_key();
signed_plaintext
.set_membership_tag(
alice_provider.crypto(),
ciphersuite,
membership_key,
alice_group.message_secrets().serialized_context(),
)
.expect("error refreshing membership tag");
let message_in = ProtocolMessage::from(signed_plaintext);
let err = bob_group
.process_message(bob_provider, message_in)
.expect_err("Could process unverified message despite self remove.");
assert!(matches!(
err,
ProcessMessageError::InvalidCommit(StageCommitError::AttemptedSelfRemoval)
));
bob_group
.process_message(bob_provider, ProtocolMessage::from(original_plaintext))
.expect("Unexpected error.");
}
#[openmls_test::openmls_test]
fn test_valsem201() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let wire_format_policy = PURE_PLAINTEXT_WIRE_FORMAT_POLICY;
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
charlie_group,
..
} = validation_test_setup(
wire_format_policy,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let queued = |proposal: Proposal| {
QueuedProposal::from_proposal_and_sender(
ciphersuite,
alice_provider.crypto(),
proposal,
&Sender::Member(alice_group.own_leaf_index()),
)
.unwrap()
};
let dave_provider = &Provider::default();
let add_proposal = || {
let dave_credential = generate_credential_with_key(
"Dave".into(),
ciphersuite.signature_algorithm(),
dave_provider,
);
let dave_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
dave_provider,
dave_credential,
);
queued(Proposal::add(AddProposal {
key_package: dave_key_package.key_package().clone(),
}))
};
let psk_proposal = || {
let secret = Secret::random(ciphersuite, alice_provider.rand()).unwrap();
let rand = alice_provider
.rand()
.random_vec(ciphersuite.hash_length())
.unwrap();
let psk_id = PreSharedKeyId::new(
ciphersuite,
alice_provider.rand(),
Psk::External(ExternalPsk::new(rand)),
)
.unwrap();
psk_id.store(alice_provider, secret.as_slice()).unwrap();
psk_id.store(bob_provider, secret.as_slice()).unwrap();
queued(Proposal::psk(PreSharedKeyProposal::new(psk_id)))
};
let update_proposal = queued(Proposal::update(UpdateProposal {
leaf_node: alice_group
.own_leaf()
.expect("Unable to get own leaf")
.clone(),
}));
let remove_proposal = || {
queued(Proposal::remove(RemoveProposal {
removed: charlie_group.own_leaf_index(),
}))
};
let group_context_extensions: Extensions<GroupContext> =
alice_group.context().extensions().clone();
let gce_proposal = || {
queued(Proposal::group_context_extensions(
GroupContextExtensionProposal::new(group_context_extensions),
))
};
let cases = vec![
(vec![add_proposal()], false),
(vec![psk_proposal()], false),
(vec![update_proposal.clone()], true),
(vec![remove_proposal()], true),
(vec![gce_proposal()], true),
(vec![add_proposal(), psk_proposal()], false),
(vec![remove_proposal(), add_proposal()], true),
(vec![update_proposal, remove_proposal()], true),
];
for (proposal, is_path_required) in cases {
proposal.into_iter().for_each(|p| {
alice_group
.store_pending_proposal(alice_provider.storage(), p.clone())
.unwrap();
});
let commit = alice_group
.commit_builder()
.force_self_update(false)
.load_psks(alice_provider.storage())
.unwrap()
.build(
alice_provider.rand(),
alice_provider.crypto(),
&alice_credential.signer,
|_| true,
)
.unwrap()
.commit_result()
.commit;
if let FramedContentBody::Commit(commit) = commit.content() {
assert_eq!(commit.has_path(), is_path_required);
} else {
panic!()
};
let mut commit: PublicMessage = commit.into();
let membership_key = alice_group.message_secrets().membership_key();
commit
.set_membership_tag(
alice_provider.crypto(),
ciphersuite,
membership_key,
alice_group.message_secrets().serialized_context(),
)
.unwrap();
if is_path_required {
let commit_wo_path = erase_path(
alice_provider,
ciphersuite,
commit.clone(),
&alice_group,
&alice_credential.signer,
);
let processed_msg = bob_group.process_message(bob_provider, commit_wo_path);
assert!(matches!(
processed_msg.unwrap_err(),
ProcessMessageError::InvalidCommit(StageCommitError::RequiredPathNotFound)
));
}
let process_message_result = bob_group.process_message(bob_provider, commit);
assert!(process_message_result.is_ok(), "{process_message_result:?}");
alice_group
.clear_pending_proposals(alice_provider.storage())
.unwrap();
alice_group
.clear_pending_commit(alice_provider.storage())
.unwrap();
bob_group
.clear_pending_commit(bob_provider.storage())
.unwrap();
}
}
fn erase_path(
provider: &impl crate::storage::OpenMlsProvider,
ciphersuite: Ciphersuite,
mut plaintext: PublicMessage,
alice_group: &MlsGroup,
alice_signer: &impl Signer,
) -> ProtocolMessage {
let original_plaintext = plaintext.clone();
let mut commit_content = if let FramedContentBody::Commit(commit) = plaintext.content() {
commit.clone()
} else {
panic!("Unexpected content type.");
};
commit_content.path = None;
plaintext.set_content(FramedContentBody::Commit(commit_content));
let plaintext = resign_message(
alice_group,
plaintext,
&original_plaintext,
provider,
alice_signer,
ciphersuite,
);
plaintext.into()
}
#[openmls_test::openmls_test]
fn test_valsem202() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
..
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let serialized_update = alice_group
.self_update(
alice_provider,
&alice_credential.signer,
LeafNodeParameters::default(),
)
.expect("Error creating self-update")
.into_messages()
.tls_serialize_detached()
.expect("Could not serialize message.");
let mut plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.")
.into_plaintext()
.expect("Message was not a plaintext.");
let original_plaintext = plaintext.clone();
let mut commit_content = if let FramedContentBody::Commit(commit) = plaintext.content() {
commit.clone()
} else {
panic!("Unexpected content type.");
};
if let Some(ref mut path) = commit_content.path {
path.pop();
};
plaintext.set_content(FramedContentBody::Commit(commit_content));
let plaintext = resign_message(
&alice_group,
plaintext,
&original_plaintext,
alice_provider,
&alice_credential.signer,
ciphersuite,
);
let update_message_in = ProtocolMessage::from(plaintext);
let err = bob_group
.process_message(bob_provider, update_message_in)
.expect_err("Could process unverified message despite path length mismatch.");
assert!(matches!(
err,
ProcessMessageError::InvalidCommit(StageCommitError::UpdatePathError(
ApplyUpdatePathError::PathLengthMismatch
))
));
let original_update_plaintext =
MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.");
bob_group
.process_message(
bob_provider,
original_update_plaintext
.try_into_protocol_message()
.unwrap(),
)
.expect("Unexpected error.");
}
#[openmls_test::openmls_test]
fn test_valsem203() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
..
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let serialized_update = alice_group
.self_update(
alice_provider,
&alice_credential.signer,
LeafNodeParameters::default(),
)
.expect("Error creating self-update")
.into_messages()
.tls_serialize_detached()
.expect("Could not serialize message.");
let mut plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.")
.into_plaintext()
.expect("Message was not a plaintext.");
let original_plaintext = plaintext.clone();
let mut commit_content = if let FramedContentBody::Commit(commit) = plaintext.content() {
commit.clone()
} else {
panic!("Unexpected content type.");
};
if let Some(ref mut path) = commit_content.path {
path.flip_eps_bytes();
};
plaintext.set_content(FramedContentBody::Commit(commit_content));
let plaintext = resign_message(
&alice_group,
plaintext,
&original_plaintext,
alice_provider,
&alice_credential.signer,
ciphersuite,
);
let update_message_in = ProtocolMessage::from(plaintext);
let err = bob_group
.process_message(bob_provider, update_message_in)
.expect_err("Could process unverified message despite scrambled ciphertexts.");
assert!(matches!(
err,
ProcessMessageError::InvalidCommit(StageCommitError::UpdatePathError(
ApplyUpdatePathError::UnableToDecrypt
))
));
let original_update_plaintext =
MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.");
bob_group
.process_message(
bob_provider,
original_update_plaintext
.try_into_protocol_message()
.unwrap(),
)
.expect("Unexpected error.");
}
#[openmls_test::openmls_test]
fn test_valsem204() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
mut charlie_group,
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let serialized_update = alice_group
.self_update(
alice_provider,
&alice_credential.signer,
LeafNodeParameters::default(),
)
.expect("Error creating self-update")
.into_messages()
.tls_serialize_detached()
.expect("Could not serialize message.");
let mut plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.")
.into_plaintext()
.expect("Message was not a plaintext.");
let original_plaintext = plaintext.clone();
let mut commit_content = if let FramedContentBody::Commit(commit) = plaintext.content() {
commit.clone()
} else {
panic!("Unexpected content type.");
};
let message = charlie_group
.process_message(charlie_provider, original_plaintext.clone())
.unwrap();
match message.into_content() {
ProcessedMessageContent::StagedCommitMessage(staged_commit) => charlie_group
.merge_staged_commit(charlie_provider, *staged_commit)
.unwrap(),
_ => panic!("Unexpected message type."),
}
let mut encryption_context = alice_group.export_group_context().clone();
let post_merge_tree_hash = charlie_group.export_group_context().tree_hash().to_vec();
encryption_context.increment_epoch();
encryption_context.update_tree_hash(post_merge_tree_hash);
if let Some(ref mut path) = commit_content.path {
let new_plain_path: Vec<PlainUpdatePathNode> = path
.nodes()
.iter()
.map(|upn| {
PlainUpdatePathNode::new(
upn.encryption_key().clone(),
Secret::random(ciphersuite, alice_provider.rand())
.unwrap()
.into(),
)
})
.collect();
let new_nodes = alice_group
.public_group()
.encrypt_path(
alice_provider,
ciphersuite,
&new_plain_path,
&encryption_context.tls_serialize_detached().unwrap(),
&[].into(),
LeafNodeIndex::new(0),
)
.unwrap();
let new_path = UpdatePath::new(path.leaf_node().clone(), new_nodes);
commit_content.path = Some(new_path);
};
plaintext.set_content(FramedContentBody::Commit(commit_content));
let plaintext = resign_message(
&alice_group,
plaintext,
&original_plaintext,
alice_provider,
&alice_credential.signer,
ciphersuite,
);
let update_message_in = ProtocolMessage::from(plaintext);
let err = bob_group
.process_message(bob_provider, update_message_in)
.expect_err("Could process unverified message despite modified public key in path.");
assert!(matches!(
err,
ProcessMessageError::InvalidCommit(StageCommitError::UpdatePathError(
ApplyUpdatePathError::PathMismatch
))
));
let original_update_plaintext =
MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.");
bob_group
.process_message(
bob_provider,
original_update_plaintext
.try_into_protocol_message()
.unwrap(),
)
.expect("Unexpected error.");
}
#[openmls_test::openmls_test]
fn test_valsem205() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
..
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let serialized_update = alice_group
.self_update(
alice_provider,
&alice_credential.signer,
LeafNodeParameters::default(),
)
.expect("Error creating self-update")
.into_messages()
.tls_serialize_detached()
.expect("Could not serialize message.");
let mut plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice())
.expect("Could not deserialize message.")
.into_plaintext()
.expect("Message was not a plaintext.");
let original_plaintext = plaintext.clone();
let mut new_confirmation_tag = plaintext
.confirmation_tag()
.expect("no confirmation tag on commit")
.clone();
new_confirmation_tag.0.flip_last_byte();
plaintext.set_confirmation_tag(Some(new_confirmation_tag));
let membership_key = alice_group.message_secrets().membership_key();
plaintext
.set_membership_tag(
alice_provider.crypto(),
ciphersuite,
membership_key,
alice_group.message_secrets().serialized_context(),
)
.expect("error refreshing membership tag");
let update_message_in = ProtocolMessage::from(plaintext);
let err = bob_group
.process_message(bob_provider, update_message_in)
.expect_err("Could process unverified message despite confirmation tag mismatch.");
assert!(matches!(
err,
ProcessMessageError::InvalidCommit(StageCommitError::ConfirmationTagMismatch)
));
bob_group
.process_message(bob_provider, ProtocolMessage::from(original_plaintext))
.expect("Unexpected error.");
}
#[openmls_test::openmls_test]
fn test_partial_proposal_commit() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let CommitValidationTestSetup {
mut alice_group,
alice_credential,
mut bob_group,
..
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
charlie_provider,
);
let charlie_index = alice_group
.members()
.find(|m| m.credential.serialized_content() == b"Charlie")
.unwrap()
.index;
let proposal_1 = alice_group
.propose_remove_member(alice_provider, &alice_credential.signer, charlie_index)
.map(|(out, _)| MlsMessageIn::from(out))
.unwrap();
let proposal_1 = bob_group
.process_message(
bob_provider,
proposal_1.try_into_protocol_message().unwrap(),
)
.unwrap();
match proposal_1.into_content() {
ProcessedMessageContent::ProposalMessage(p) => bob_group
.store_pending_proposal(bob_provider.storage(), *p)
.unwrap(),
_ => unreachable!(),
}
let proposal_2 = alice_group
.propose_self_update(
alice_provider,
&alice_credential.signer,
LeafNodeParameters::default(),
)
.map(|(out, _)| MlsMessageIn::from(out))
.unwrap();
let proposal_2 = bob_group
.process_message(
bob_provider,
proposal_2.try_into_protocol_message().unwrap(),
)
.unwrap();
match proposal_2.into_content() {
ProcessedMessageContent::ProposalMessage(p) => bob_group
.store_pending_proposal(bob_provider.storage(), *p)
.unwrap(),
_ => unreachable!(),
}
let remaining_proposal = alice_group
.proposal_store()
.proposals()
.next()
.cloned()
.unwrap();
alice_group.proposal_store_mut().empty();
alice_group.proposal_store_mut().add(remaining_proposal);
let (commit, _, _) = alice_group
.commit_to_pending_proposals(alice_provider, &alice_credential.signer)
.unwrap();
alice_group
.merge_pending_commit(alice_provider)
.expect("Commits with partial proposals are not supported");
bob_group
.process_message(bob_provider, commit.into_protocol_message().unwrap())
.expect("Commits with partial proposals are not supported");
bob_group
.merge_pending_commit(bob_provider)
.expect("Commits with partial proposals are not supported");
}