use openmls_traits::{signatures::Signer, types::Ciphersuite};
use crate::{framing::*, group::*, treesync::LeafNodeParameters};
use crate::group::tests_and_kats::utils::{
generate_credential_with_key, generate_key_package, CredentialWithKeyAndSigner,
};
fn create_group(
ciphersuite: Ciphersuite,
provider: &impl crate::storage::OpenMlsProvider,
wire_format_policy: WireFormatPolicy,
) -> (MlsGroup, CredentialWithKeyAndSigner) {
let group_id = GroupId::random(provider.rand());
let credential_with_key_and_signer =
generate_credential_with_key("Alice".into(), ciphersuite.signature_algorithm(), provider);
let mls_group_config = MlsGroupCreateConfig::builder()
.wire_format_policy(wire_format_policy)
.use_ratchet_tree_extension(true)
.ciphersuite(ciphersuite)
.build();
(
MlsGroup::new_with_group_id(
provider,
&credential_with_key_and_signer.signer,
&mls_group_config,
group_id,
credential_with_key_and_signer.credential_with_key.clone(),
)
.expect("An unexpected error occurred."),
credential_with_key_and_signer,
)
}
fn receive_message(
ciphersuite: Ciphersuite,
alice_provider: &impl crate::storage::OpenMlsProvider,
bob_provider: &impl crate::storage::OpenMlsProvider,
alice_group: &mut MlsGroup,
alice_signer: &impl Signer,
) -> MlsMessageIn {
let bob_credential_with_key_and_signer = generate_credential_with_key(
"Bob".into(),
ciphersuite.signature_algorithm(),
bob_provider,
);
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential_with_key_and_signer.clone(),
);
let (_message, welcome, _group_info) = alice_group
.add_members(
alice_provider,
alice_signer,
core::slice::from_ref(bob_key_package.key_package()),
)
.expect("Could not add member.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let mls_group_config = MlsGroupJoinConfig::builder()
.wire_format_policy(alice_group.configuration().wire_format_policy())
.build();
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");
let mut bob_group =
StagedWelcome::new_from_welcome(bob_provider, &mls_group_config, welcome, None)
.expect("error creating bob's staged join from welcome")
.into_group(bob_provider)
.expect("error creating bob's group from staged join");
let (message, _welcome, _group_info) = bob_group
.self_update(
bob_provider,
&bob_credential_with_key_and_signer.signer,
LeafNodeParameters::default(),
)
.expect("An unexpected error occurred.")
.into_contents();
message.into()
}
#[openmls_test::openmls_test]
fn test_wire_policy_positive() {
for wire_format_policy in WIRE_FORMAT_POLICIES.iter() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let (mut alice_group, alice_credential_with_key_and_signer) =
create_group(ciphersuite, alice_provider, *wire_format_policy);
let message = receive_message(
ciphersuite,
alice_provider,
bob_provider,
&mut alice_group,
&alice_credential_with_key_and_signer.signer,
);
alice_group
.process_message(alice_provider, message.try_into_protocol_message().unwrap())
.expect("An unexpected error occurred.");
}
}
#[openmls_test::openmls_test]
fn test_wire_policy_negative() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let incompatible_policies = vec![
WireFormatPolicy::new(
OutgoingWireFormatPolicy::AlwaysPlaintext,
IncomingWireFormatPolicy::AlwaysCiphertext,
),
WireFormatPolicy::new(
OutgoingWireFormatPolicy::AlwaysCiphertext,
IncomingWireFormatPolicy::AlwaysPlaintext,
),
];
for wire_format_policy in incompatible_policies.into_iter() {
let (mut alice_group, alice_credential_with_key_and_signer) =
create_group(ciphersuite, alice_provider, wire_format_policy);
let message = receive_message(
ciphersuite,
alice_provider,
bob_provider,
&mut alice_group,
&alice_credential_with_key_and_signer.signer,
);
let err = alice_group
.process_message(alice_provider, message.try_into_protocol_message().unwrap())
.expect_err("An unexpected error occurred.");
assert!(matches!(err, ProcessMessageError::IncompatibleWireFormat));
}
}