use std::slice::from_ref;
use openmls_basic_credential::SignatureKeyPair;
use tls_codec::{Deserialize, Serialize};
use crate::{
binary_tree::LeafNodeIndex,
ciphersuite::{
hash_ref::KeyPackageRef, hpke, signable::Signable, AeadKey, AeadNonce, Mac, Secret,
},
extensions::Extensions,
group::{
errors::WelcomeError, mls_group::tests_and_kats::utils::setup_client, GroupContext,
GroupId, MlsGroup, MlsGroupCreateConfig, ProcessedWelcome, StagedWelcome,
},
messages::{
group_info::{GroupInfoTBS, VerifiableGroupInfo},
ConfirmationTag, EncryptedGroupSecrets, GroupSecrets, GroupSecretsError, Welcome,
},
prelude::ExtensionType,
schedule::{
psk::{load_psks, store::ResumptionPskStore, PskSecret},
KeySchedule,
},
treesync::node::encryption_keys::EncryptionKeyPair,
};
#[openmls_test::openmls_test]
fn test_welcome_context_mismatch() {
let alice_provider = Provider::default();
let bob_provider = Provider::default();
let mismatched_ciphersuite = match ciphersuite {
Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 => {
Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519
}
_ => Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519,
};
let group_id = GroupId::random(alice_provider.rand());
let mls_group_create_config = MlsGroupCreateConfig::builder()
.ciphersuite(ciphersuite)
.build();
let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_signature_key) =
setup_client("Alice", ciphersuite, &alice_provider);
let (_bob_credential, bob_kpb, _bob_signer, _bob_signature_key) =
setup_client("Bob", ciphersuite, &bob_provider);
let bob_kp = bob_kpb.key_package();
let bob_private_key = bob_kpb.init_private_key();
let mut alice_group = MlsGroup::new_with_group_id(
&alice_provider,
&alice_signer,
&mls_group_create_config,
group_id,
alice_credential_with_key,
)
.expect("An unexpected error occurred.");
let (_queued_message, welcome, _group_info) = alice_group
.add_members(&alice_provider, &alice_signer, from_ref(bob_kp))
.expect("Could not add member to group.");
alice_group
.merge_pending_commit(&alice_provider)
.expect("error merging pending commit");
let mut welcome = welcome.into_welcome().expect("Unexpected message type.");
let original_welcome = welcome.clone();
let egs = welcome.secrets[0].clone();
let group_secrets_bytes = hpke::decrypt_with_label(
bob_private_key,
"Welcome",
welcome.encrypted_group_info(),
egs.encrypted_group_secrets(),
ciphersuite,
bob_provider.crypto(),
)
.expect("Could not decrypt group secrets.");
let group_secrets = GroupSecrets::tls_deserialize(&mut group_secrets_bytes.as_slice())
.expect("Could not deserialize group secrets.");
let joiner_secret = group_secrets.joiner_secret;
let psk_secret = {
let resumption_psk_store = ResumptionPskStore::new(1024);
let psks = load_psks(bob_provider.storage(), &resumption_psk_store, &[]).unwrap();
PskSecret::new(bob_provider.crypto(), ciphersuite, psks).unwrap()
};
let key_schedule = KeySchedule::init(
ciphersuite,
bob_provider.crypto(),
&joiner_secret,
psk_secret,
)
.expect("Could not create KeySchedule.");
let (welcome_key, welcome_nonce) = key_schedule
.welcome(bob_provider.crypto(), ciphersuite)
.expect("Using the key schedule in the wrong state")
.derive_welcome_key_nonce(bob_provider.crypto(), ciphersuite)
.expect("Could not derive welcome key and nonce.");
let group_info_bytes = welcome_key
.aead_open(
bob_provider.crypto(),
welcome.encrypted_group_info(),
&[],
&welcome_nonce,
)
.expect("Could not decrypt GroupInfo.");
let mut verifiable_group_info =
VerifiableGroupInfo::tls_deserialize(&mut group_info_bytes.as_slice()).unwrap();
verifiable_group_info
.payload_mut()
.group_context_mut()
.set_ciphersuite(mismatched_ciphersuite);
let verifiable_group_info_bytes = verifiable_group_info.tls_serialize_detached().unwrap();
let encrypted_verifiable_group_info = welcome_key
.aead_seal(
alice_provider.crypto(),
&verifiable_group_info_bytes,
&[],
&welcome_nonce,
)
.unwrap();
welcome.encrypted_group_info = encrypted_verifiable_group_info.into();
let encryption_keypair = EncryptionKeyPair::from((
bob_kpb.key_package().leaf_node().encryption_key().clone(),
bob_kpb.private_encryption_key.clone(),
));
let err = StagedWelcome::new_from_welcome(
&bob_provider,
mls_group_create_config.join_config(),
welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect_err("Created a staged join from an invalid Welcome.");
assert!(matches!(
err,
WelcomeError::GroupSecrets(GroupSecretsError::DecryptionFailed)
));
bob_provider
.storage()
.write_key_package(&bob_kp.hash_ref(bob_provider.crypto()).unwrap(), &bob_kpb)
.unwrap();
encryption_keypair.write(bob_provider.storage()).unwrap();
let _group = StagedWelcome::new_from_welcome(
&bob_provider,
mls_group_create_config.join_config(),
original_welcome,
Some(alice_group.export_ratchet_tree().into()),
)
.expect("Error creating staged join from a valid Welcome.")
.into_group(&bob_provider)
.expect("Error creating group from a valid staged join.");
}
#[openmls_test::openmls_test]
fn test_welcome_message() {
let provider = &Provider::default();
let group_info_tbs = {
let group_context = GroupContext::new(
ciphersuite,
GroupId::random(provider.rand()),
123,
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
vec![1, 1, 1],
Extensions::empty(),
);
GroupInfoTBS::new(
group_context,
Extensions::empty(),
ConfirmationTag(Mac {
mac_value: vec![1, 2, 3, 4, 5].into(),
}),
LeafNodeIndex::new(1),
)
.unwrap()
};
let signer = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap();
let group_info = group_info_tbs
.sign(&signer)
.expect("Error signing GroupInfo");
let welcome_key = AeadKey::random(ciphersuite, provider.rand());
let welcome_nonce = AeadNonce::random(provider.rand());
let receiver_key_pair = provider
.crypto()
.derive_hpke_keypair(
ciphersuite.hpke_config(),
Secret::random(ciphersuite, provider.rand())
.expect("Not enough randomness.")
.as_slice(),
)
.expect("Error deriving receiver key pair");
let hpke_context = b"group info welcome test info";
let group_secrets = b"these should be the group secrets";
let new_member = KeyPackageRef::from_slice(&[0u8; 16]);
let secrets = vec![EncryptedGroupSecrets {
new_member: new_member.clone(),
encrypted_group_secrets: hpke::encrypt_with_label(
receiver_key_pair.public.as_slice(),
"Welcome",
hpke_context,
group_secrets,
ciphersuite,
provider.crypto(),
)
.unwrap(),
}];
let encrypted_group_info = welcome_key
.aead_seal(
provider.crypto(),
&group_info
.tls_serialize_detached()
.expect("An unexpected error occurred."),
&[],
&welcome_nonce,
)
.expect("An unexpected error occurred.");
let msg = Welcome::new(ciphersuite, secrets, encrypted_group_info.clone());
let msg_encoded = msg
.tls_serialize_detached()
.expect("An unexpected error occurred.");
println!("encoded msg: {msg_encoded:?}");
let msg_decoded = Welcome::tls_deserialize(&mut msg_encoded.as_slice())
.expect("An unexpected error occurred.");
assert_eq!(msg_decoded.cipher_suite, ciphersuite);
for secret in msg_decoded.secrets.iter() {
assert_eq!(new_member.as_slice(), secret.new_member.as_slice());
let ptxt = hpke::decrypt_with_label(
&receiver_key_pair.private,
"Welcome",
hpke_context,
&secret.encrypted_group_secrets,
ciphersuite,
provider.crypto(),
)
.expect("Error decrypting valid ciphertext in Welcome message test.");
assert_eq!(&group_secrets[..], &ptxt[..]);
}
assert_eq!(
msg_decoded.encrypted_group_info.as_slice(),
encrypted_group_info.as_slice()
);
}
#[openmls_test::openmls_test]
fn test_welcome_processing() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let group_id = GroupId::random(alice_provider.rand());
let mls_group_create_config = MlsGroupCreateConfig::builder()
.ciphersuite(ciphersuite)
.build();
let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_signature_key) =
setup_client("Alice", ciphersuite, alice_provider);
let (_bob_credential, bob_kpb, _bob_signer, _bob_signature_key) =
setup_client("Bob", ciphersuite, bob_provider);
let bob_kp = bob_kpb.key_package();
let mut alice_group = MlsGroup::new_with_group_id(
alice_provider,
&alice_signer,
&mls_group_create_config,
group_id,
alice_credential_with_key,
)
.expect("An unexpected error occurred.");
let (_queued_message, welcome, _group_info) = alice_group
.add_members(alice_provider, &alice_signer, from_ref(bob_kp))
.expect("Could not add member to group.");
alice_group
.merge_pending_commit(alice_provider)
.expect("error merging pending commit");
let welcome = welcome.into_welcome().expect("Unexpected message type.");
let processed_welcome = ProcessedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
)
.unwrap();
let unverified_group_info = processed_welcome.unverified_group_info();
let group_id = unverified_group_info.group_id();
assert_eq!(group_id, alice_group.group_id());
let alice_group_info = alice_group
.export_group_info(alice_provider.crypto(), &alice_signer, false)
.unwrap()
.into_verifiable_group_info()
.unwrap();
let mut group_info_extensions = alice_group_info.extensions().clone();
group_info_extensions.remove(ExtensionType::ExternalPub);
assert_eq!(unverified_group_info.extensions(), &group_info_extensions);
let staged_welcome = processed_welcome
.into_staged_welcome(bob_provider, Some(alice_group.export_ratchet_tree().into()))
.unwrap();
let _group = staged_welcome
.into_group(bob_provider)
.expect("Error creating group from a valid staged join.");
}
#[openmls_test::openmls_test]
fn no_external_pub_in_welcome() {
let alice_provider = &Provider::default();
let bob_provider = &Provider::default();
let mls_group_create_config = MlsGroupCreateConfig::builder()
.ciphersuite(ciphersuite)
.build();
let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_signature_key) =
setup_client("Alice", ciphersuite, alice_provider);
let (_bob_credential, bob_kpb, _bob_signer, _bob_signature_key) =
setup_client("Bob", ciphersuite, bob_provider);
let bob_kp = bob_kpb.key_package();
let mut alice_group = MlsGroup::new(
alice_provider,
&alice_signer,
&mls_group_create_config,
alice_credential_with_key,
)
.expect("An unexpected error occurred.");
let (_queued_message, welcome, _group_info) = alice_group
.add_members(alice_provider, &alice_signer, from_ref(bob_kp))
.expect("Could not add member to group.");
let welcome = welcome.into_welcome().expect("Unexpected message type.");
let processed_welcome = ProcessedWelcome::new_from_welcome(
bob_provider,
mls_group_create_config.join_config(),
welcome,
)
.unwrap();
let unverified_group_info = processed_welcome.unverified_group_info();
assert!(!unverified_group_info
.extensions()
.contains(ExtensionType::ExternalPub));
}
#[test]
fn invalid_welcomes() {
let mut bytes = &[
2u8, 0, 2, 0, 0, 0, 90, 4, 0, 0, 0, 0, 0, 32, 183, 76, 159, 248, 180, 5, 79, 86, 242, 165,
206, 103, 47, 8, 110, 250, 81, 48, 206, 185, 186, 104, 220, 181, 245, 106, 134, 32, 97,
233, 141, 26, 0, 49, 13, 203, 68, 119, 97, 90, 172, 36, 170, 239, 80, 191, 63, 146, 177,
211, 151, 152, 93, 117, 192, 136, 96, 22, 168, 213, 67, 165, 244, 165, 183, 228, 88, 62,
232, 36, 220, 224, 93, 216, 155, 210, 167, 34, 112, 7, 73, 42, 2, 0, 0, 0, 71, 254, 148,
190, 32, 30, 92, 51, 15, 16, 11, 46, 196, 65, 132, 142, 111, 177, 115, 21, 218, 71, 51,
118, 228, 188, 12, 134, 23, 216, 51, 20, 138, 215, 232, 62, 216, 119, 242, 93, 164, 250,
100, 223, 214, 94, 85, 139, 159, 205, 193, 153, 181, 243, 139, 12, 78, 253, 200, 47, 207,
79, 86, 82, 63, 217, 126, 204, 178, 24, 199, 49,
] as &[u8];
let msg = Welcome::tls_deserialize(&mut bytes);
assert!(msg.is_err());
}