use crate::{
binary_tree::LeafNodeIndex,
framing::*,
group::{
tests_and_kats::utils::{generate_credential_with_key, generate_key_package},
*,
},
};
#[openmls_test::openmls_test]
fn test_add_member_with_aad() {
for wire_format_policy in [
PURE_PLAINTEXT_WIRE_FORMAT_POLICY,
PURE_CIPHERTEXT_WIRE_FORMAT_POLICY,
] {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let charlie_provider = &Provider::default();
let group_id = GroupId::random(alice_provider.rand());
let alice_credential_with_key_and_signer = generate_credential_with_key(
"Alice".into(),
ciphersuite.signature_algorithm(),
alice_provider,
);
let bob_credential_with_key_and_signer = generate_credential_with_key(
"Bob".into(),
ciphersuite.signature_algorithm(),
bob_provider,
);
let charlie_credential_with_key_and_signer = generate_credential_with_key(
"Charlie".into(),
ciphersuite.signature_algorithm(),
charlie_provider,
);
let bob_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
bob_provider,
bob_credential_with_key_and_signer.clone(),
);
let charlie_key_package = generate_key_package(
ciphersuite,
Extensions::empty(),
charlie_provider,
charlie_credential_with_key_and_signer,
);
let mls_group_create_config = MlsGroupCreateConfig::builder()
.ciphersuite(ciphersuite)
.wire_format_policy(wire_format_policy)
.build();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_credential_with_key_and_signer.signer,
&mls_group_create_config,
group_id,
alice_credential_with_key_and_signer
.credential_with_key
.clone(),
)
.expect("An unexpected error occurred.");
let aad = b"Test AAD".to_vec();
alice_group.set_aad(aad.clone());
assert_eq!(alice_group.aad(), &aad);
let (_message, welcome, _group_info) = alice_group
.add_members(
alice_provider,
&alice_credential_with_key_and_signer.signer,
core::slice::from_ref(bob_key_package.key_package()),
)
.expect("An unexpected error occurred.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let welcome: MlsMessageIn = welcome.into();
let welcome = welcome
.into_welcome()
.expect("expected message to be a welcome");
let mut bob_group = StagedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome.clone(),
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from Welcome")
.into_group(bob_provider)
.expect("Error creating group from staged join");
let message = b"Hello, World!".to_vec();
alice_group.set_aad(aad.clone());
let alice_message: MlsMessageIn = alice_group
.create_message(
alice_provider,
&alice_credential_with_key_and_signer.signer,
&message,
)
.expect("Error creating message")
.into();
assert_eq!(alice_group.aad().len(), 0);
let bob_message = bob_group
.process_message(
bob_provider,
alice_message.clone().into_protocol_message().unwrap(),
)
.expect("Error handling message");
assert_eq!(bob_message.aad(), &aad);
alice_group.set_aad(aad.clone());
let (commit, _welcome, _group_info) = alice_group
.add_members(
alice_provider,
&alice_credential_with_key_and_signer.signer,
core::slice::from_ref(charlie_key_package.key_package()),
)
.expect("An unexpected error occurred.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
assert_eq!(alice_group.aad().len(), 0);
let bob_processed_message = bob_group
.process_message(
bob_provider,
commit.clone().into_protocol_message().unwrap(),
)
.expect("Error handling message");
match bob_processed_message.into_content() {
ProcessedMessageContent::StagedCommitMessage(bob_staged_commit) => {
bob_group
.merge_staged_commit(bob_provider, *bob_staged_commit)
.unwrap();
}
_ => panic!("Expected a StagedCommitMessage"),
}
assert_eq!(bob_message.aad(), &aad);
alice_group.set_aad(aad.clone());
let (commit, _welcome, _group_info) = alice_group
.remove_members(
alice_provider,
&alice_credential_with_key_and_signer.signer,
&[LeafNodeIndex::new(2)],
)
.expect("An unexpected error occurred.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
assert_eq!(alice_group.aad().len(), 0);
let bob_processed_message = bob_group
.process_message(
bob_provider,
commit.clone().into_protocol_message().unwrap(),
)
.expect("Error handling message");
assert_eq!(bob_processed_message.aad(), &aad);
}
}