use openmls_basic_credential::SignatureKeyPair;
use openmls_test::openmls_test;
use crate::{
binary_tree::LeafNodeIndex,
framing::*,
group::*,
messages::{
external_proposals::*,
proposals::{Proposal, ProposalType},
},
test_utils::frankenstein::*,
treesync::LeafNodeParameters,
};
use openmls_traits::types::Ciphersuite;
use crate::group::tests_and_kats::utils::*;
struct ProposalValidationTestSetup {
alice_group: (MlsGroup, SignatureKeyPair),
bob_group: (MlsGroup, SignatureKeyPair),
}
fn new_test_group(
identity: &str,
wire_format_policy: WireFormatPolicy,
ciphersuite: Ciphersuite,
provider: &impl crate::storage::OpenMlsProvider,
) -> (MlsGroup, CredentialWithKeyAndSigner) {
let group_id = GroupId::random(provider.rand());
let credential_with_keys =
generate_credential_with_key(identity.into(), ciphersuite.signature_algorithm(), provider);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.wire_format_policy(wire_format_policy)
.ciphersuite(ciphersuite)
.build();
(
MlsGroup::new_with_group_id(
provider,
&credential_with_keys.signer,
&mls_group_create_config,
group_id,
credential_with_keys.credential_with_key.clone(),
)
.unwrap(),
credential_with_keys,
)
}
fn validation_test_setup(
wire_format_policy: WireFormatPolicy,
ciphersuite: Ciphersuite,
alice_provider: &impl crate::storage::OpenMlsProvider,
bob_provider: &impl crate::storage::OpenMlsProvider,
) -> ProposalValidationTestSetup {
let (mut alice_group, alice_signer_with_keys) =
new_test_group("Alice", wire_format_policy, ciphersuite, alice_provider);
let bob_credential_with_key = generate_credential_with_key(
"Bob".into(),
ciphersuite.signature_algorithm(),
bob_provider,
);
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential_with_key.clone(),
);
let (_message, welcome, _group_info) = alice_group
.add_members(
alice_provider,
&alice_signer_with_keys.signer,
core::slice::from_ref(bob_key_package.key_package()),
)
.expect("error adding Bob to group");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let mls_group_config = MlsGroupJoinConfig::builder()
.wire_format_policy(wire_format_policy)
.build();
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");
let bob_group = StagedWelcome::new_from_welcome(
bob_provider,
&mls_group_config,
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("error creating group from welcome")
.into_group(bob_provider)
.expect("error creating group from welcome");
ProposalValidationTestSetup {
alice_group: (alice_group, alice_signer_with_keys.signer),
bob_group: (bob_group, bob_credential_with_key.signer),
}
}
#[openmls_test]
fn external_join_add_proposal_should_succeed() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
for policy in WIRE_FORMAT_POLICIES {
let ProposalValidationTestSetup {
alice_group,
bob_group,
} = validation_test_setup(policy, ciphersuite, alice_provider, bob_provider);
let (mut alice_group, alice_signer) = alice_group;
let (mut bob_group, _bob_signer) = bob_group;
assert_eq!(alice_group.members().count(), 2);
assert_eq!(bob_group.members().count(), 2);
let charlie_provider = &Provider::default();
let charlie_credential = generate_credential_with_key(
"Charlie".into(),
ciphersuite.signature_algorithm(),
charlie_provider,
);
let charlie_kp = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
charlie_credential.clone(),
);
let proposal =
JoinProposal::new::<<Provider as openmls_traits::OpenMlsProvider>::StorageProvider>(
charlie_kp.key_package().clone(),
alice_group.group_id().clone(),
alice_group.epoch(),
&charlie_credential.signer,
)
.unwrap();
let verify_proposal = |msg: &PublicMessage| {
*msg.sender() == Sender::NewMemberProposal
&& msg.content_type() == ContentType::Proposal
&& matches!(msg.content(), FramedContentBody::Proposal(p) if p.proposal_type() == ProposalType::Add)
};
assert!(
matches!(proposal.body, MlsMessageBodyOut::PublicMessage(ref msg) if verify_proposal(msg))
);
let msg = alice_group
.process_message(
alice_provider,
proposal.clone().into_protocol_message().unwrap(),
)
.unwrap();
match msg.into_content() {
ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => {
assert!(matches!(proposal.sender(), Sender::NewMemberProposal));
let add_proposal = match proposal.proposal() {
Proposal::Add(kp) => kp,
_ => unreachable!("This shouldn't be reached"),
};
assert!(add_proposal.key_package() == charlie_kp.key_package());
alice_group
.store_pending_proposal(alice_provider.storage(), *proposal)
.unwrap()
}
_ => unreachable!(),
}
let msg = bob_group
.process_message(bob_provider, proposal.into_protocol_message().unwrap())
.unwrap();
match msg.into_content() {
ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => bob_group
.store_pending_proposal(bob_provider.storage(), *proposal)
.unwrap(),
_ => unreachable!(),
}
let (commit, welcome, _group_info) = alice_group
.commit_to_pending_proposals(alice_provider, &alice_signer)
.unwrap();
alice_group.merge_pending_commit(alice_provider).unwrap();
assert_eq!(alice_group.members().count(), 3);
let msg = bob_group
.process_message(bob_provider, commit.into_protocol_message().unwrap())
.unwrap();
match msg.into_content() {
ProcessedMessageContent::StagedCommitMessage(commit) => bob_group
.merge_staged_commit(bob_provider, *commit)
.unwrap(),
_ => unreachable!(),
}
assert_eq!(bob_group.members().count(), 3);
let welcome: MlsMessageIn = welcome.expect("expected a welcome").into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");
let mls_group_config = MlsGroupJoinConfig::builder()
.wire_format_policy(policy)
.build();
let charlie_group = StagedWelcome::new_from_welcome(
charlie_provider,
&mls_group_config,
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.unwrap()
.into_group(charlie_provider)
.unwrap();
assert_eq!(charlie_group.members().count(), 3);
}
}
#[openmls_test]
fn external_join_add_proposal_should_be_signed_by_key_package_it_references() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let attacker_provider = &Provider::default();
let ProposalValidationTestSetup { alice_group, .. } = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
);
let (mut alice_group, _alice_signer) = alice_group;
let attacker_credential = generate_credential_with_key(
"Attacker".into(),
ciphersuite.signature_algorithm(),
attacker_provider,
);
let charlie_credential = generate_credential_with_key(
"Charlie".into(),
ciphersuite.signature_algorithm(),
charlie_provider,
);
let charlie_kp = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
attacker_credential,
);
let invalid_proposal =
JoinProposal::new::<<Provider as openmls_traits::OpenMlsProvider>::StorageProvider>(
charlie_kp.key_package().clone(),
alice_group.group_id().clone(),
alice_group.epoch(),
&charlie_credential.signer,
)
.unwrap();
assert!(matches!(
alice_group
.process_message(
alice_provider,
invalid_proposal.into_protocol_message().unwrap()
)
.unwrap_err(),
ProcessMessageError::ValidationError(ValidationError::InvalidSignature)
));
}
#[openmls_test]
fn test_valn1504() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
for policy in WIRE_FORMAT_POLICIES {
let ProposalValidationTestSetup {
alice_group,
bob_group,
} = validation_test_setup(policy, ciphersuite, alice_provider, bob_provider);
let (mut alice_group, _alice_signer) = alice_group;
let (bob_group, _bob_signer) = bob_group;
assert_eq!(alice_group.members().count(), 2);
assert_eq!(bob_group.members().count(), 2);
let charlie_provider = &Provider::default();
let charlie_credential = generate_credential_with_key(
"Charlie".into(),
ciphersuite.signature_algorithm(),
charlie_provider,
);
let charlie_kp = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
charlie_credential.clone(),
);
let proposal =
JoinProposal::new::<<Provider as openmls_traits::OpenMlsProvider>::StorageProvider>(
charlie_kp.key_package().clone(),
alice_group.group_id().clone(),
alice_group.epoch(),
&charlie_credential.signer,
)
.unwrap();
let mut franken_message = FrankenMlsMessage::from(proposal);
match franken_message.body {
FrankenMlsMessageBody::PublicMessage(ref mut message) => {
let incorrect_proposal =
FrankenProposal::Remove(FrankenRemoveProposal { removed: 0 });
message.content.body = FrankenFramedContentBody::Proposal(incorrect_proposal);
}
_ => unreachable!(),
}
let proposal: MlsMessageOut = franken_message.into();
let err = alice_group
.process_message(
alice_provider,
proposal.clone().into_protocol_message().unwrap(),
)
.expect_err("Should return an error");
assert_eq!(
err,
ProcessMessageError::ValidationError(ValidationError::NotAnExternalAddProposal,)
);
}
}
#[openmls_test]
fn new_member_proposal_sender_should_be_reserved_for_join_proposals() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let any_provider = &Provider::default();
let ProposalValidationTestSetup {
alice_group,
bob_group,
} = validation_test_setup(
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
ciphersuite,
alice_provider,
bob_provider,
);
let (mut alice_group, alice_signer) = alice_group;
let (mut bob_group, _bob_signer) = bob_group;
let any_credential = generate_credential_with_key(
"Any".into(),
ciphersuite.signature_algorithm(),
any_provider,
);
let any_kp = generate_key_package(
ciphersuite,
Extensions::empty(),
any_provider,
any_credential.clone(),
);
let join_proposal =
JoinProposal::new::<<Provider as openmls_traits::OpenMlsProvider>::StorageProvider>(
any_kp.key_package().clone(),
alice_group.group_id().clone(),
alice_group.epoch(),
&any_credential.signer,
)
.unwrap();
if let MlsMessageBodyOut::PublicMessage(plaintext) = &join_proposal.body {
assert!(matches!(
plaintext.content(),
FramedContentBody::Proposal(Proposal::Add(_))
));
assert!(matches!(plaintext.sender(), Sender::NewMemberProposal));
assert!(bob_group
.process_message(bob_provider, join_proposal.into_protocol_message().unwrap())
.is_ok());
} else {
panic!()
};
alice_group
.clear_pending_proposals(alice_provider.storage())
.unwrap();
let remove_proposal = alice_group
.propose_remove_member(alice_provider, &alice_signer, LeafNodeIndex::new(1))
.map(|(out, _)| MlsMessageIn::from(out))
.unwrap();
if let MlsMessageBodyIn::PublicMessage(mut plaintext) = remove_proposal.body {
plaintext.set_sender(Sender::NewMemberProposal);
assert!(matches!(
bob_group
.process_message(bob_provider, plaintext)
.unwrap_err(),
ProcessMessageError::ValidationError(ValidationError::NotAnExternalAddProposal)
));
} else {
panic!()
};
alice_group
.clear_pending_proposals(alice_provider.storage())
.unwrap();
let update_proposal = alice_group
.propose_self_update(alice_provider, &alice_signer, LeafNodeParameters::default())
.map(|(out, _)| MlsMessageIn::from(out))
.unwrap();
if let MlsMessageBodyIn::PublicMessage(mut plaintext) = update_proposal.body {
plaintext.set_sender(Sender::NewMemberProposal);
assert!(matches!(
bob_group
.process_message(bob_provider, plaintext)
.unwrap_err(),
ProcessMessageError::ValidationError(ValidationError::NotAnExternalAddProposal)
));
} else {
panic!()
};
alice_group
.clear_pending_proposals(alice_provider.storage())
.unwrap();
}