use openmls_traits::{crypto::OpenMlsCrypto, random::OpenMlsRand, OpenMlsProvider};
use serde::{self, Deserialize, Serialize};
use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait};
#[cfg(test)]
use crate::test_utils::{read, write};
use crate::{
binary_tree::array_representation::LeafNodeIndex,
framing::{mls_auth_content::AuthenticatedContent, *},
group::{
tests_and_kats::utils::{generate_credential_with_key, randombytes},
*,
},
messages::*,
schedule::*,
test_utils::*,
};
const TEST_VECTOR_PATH_READ: &str = "test_vectors/transcript-hashes.json";
const TEST_VECTOR_PATH_WRITE: &str = "test_vectors/transcript-hashes-new.json";
const NUM_TESTS: usize = 100;
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct TranscriptTestVector {
pub cipher_suite: u16,
#[serde(with = "hex::serde")]
pub confirmation_key: Vec<u8>,
#[serde(with = "hex::serde")]
pub authenticated_content: Vec<u8>,
#[serde(with = "hex::serde")]
pub interim_transcript_hash_before: Vec<u8>,
#[serde(with = "hex::serde")]
pub confirmed_transcript_hash_after: Vec<u8>,
#[serde(with = "hex::serde")]
pub interim_transcript_hash_after: Vec<u8>,
}
#[test]
fn read_test_vectors_transcript() {
let tests: Vec<TranscriptTestVector> = read(TEST_VECTOR_PATH_READ);
for test_vector in tests {
run_test_vector(test_vector);
}
}
pub fn run_test_vector(test_vector: TranscriptTestVector) {
let provider = OpenMlsRustCrypto::default();
let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).unwrap();
if provider.crypto().supports(ciphersuite).is_err() {
log::debug!("Skipping unsupported ciphersuite `{ciphersuite:?}`.");
return;
}
let authenticated_content = AuthenticatedContent::from(
AuthenticatedContentIn::tls_deserialize_exact(test_vector.authenticated_content).unwrap(),
);
assert!(matches!(
authenticated_content.content(),
FramedContentBody::Commit(_)
));
let confirmation_key =
ConfirmationKey::from_secret(Secret::from_slice(&test_vector.confirmation_key));
let got_confirmation_tag = confirmation_key
.tag(
provider.crypto(),
ciphersuite,
&test_vector.confirmed_transcript_hash_after,
)
.unwrap();
assert_eq!(
got_confirmation_tag,
*authenticated_content.confirmation_tag().unwrap()
);
let got_confirmed_transcript_hash_after = {
let input = ConfirmedTranscriptHashInput::try_from(&authenticated_content).unwrap();
input
.calculate_confirmed_transcript_hash(
provider.crypto(),
ciphersuite,
&test_vector.interim_transcript_hash_before,
)
.unwrap()
};
assert_eq!(
test_vector.confirmed_transcript_hash_after,
got_confirmed_transcript_hash_after
);
let got_interim_transcript_hash_after = {
let input = InterimTranscriptHashInput::from(&got_confirmation_tag);
input
.calculate_interim_transcript_hash(
provider.crypto(),
ciphersuite,
&got_confirmed_transcript_hash_after,
)
.unwrap()
};
assert_eq!(
test_vector.interim_transcript_hash_after,
got_interim_transcript_hash_after
);
}
#[test]
fn write_test_vectors() {
let mut tests = Vec::new();
for &ciphersuite in OpenMlsRustCrypto::default()
.crypto()
.supported_ciphersuites()
.iter()
{
for _ in 0..NUM_TESTS {
let test = generate_test_vector(ciphersuite);
tests.push(test);
}
}
write(TEST_VECTOR_PATH_WRITE, &tests);
}
pub fn generate_test_vector(ciphersuite: Ciphersuite) -> TranscriptTestVector {
let provider = OpenMlsRustCrypto::default();
let confirmation_key = ConfirmationKey::random(ciphersuite, provider.rand());
let interim_transcript_hash_before = randombytes(ciphersuite.hash_length());
let mut authenticated_content = {
let aad = provider.rand().random_vec(48).unwrap();
let framing_parameters = FramingParameters::new(&aad, WireFormat::PublicMessage);
let sender = Sender::build_member(LeafNodeIndex::new(7));
let commit = Commit {
proposals: vec![],
path: None,
};
let group_context = {
let tree_hash_before = provider
.rand()
.random_vec(ciphersuite.hash_length())
.unwrap();
let confirmed_transcript_hash_before = provider
.rand()
.random_vec(ciphersuite.hash_length())
.unwrap();
GroupContext::new(
ciphersuite,
GroupId::random(provider.rand()),
random_u64(),
tree_hash_before,
confirmed_transcript_hash_before,
Extensions::empty(),
)
};
let signer = {
let credential_with_key_and_signer = generate_credential_with_key(
b"Alice".to_vec(),
ciphersuite.signature_algorithm(),
&provider,
);
credential_with_key_and_signer.signer
};
AuthenticatedContent::commit(framing_parameters, sender, commit, &group_context, &signer)
.unwrap()
};
let confirmed_transcript_hash_after = {
let input = ConfirmedTranscriptHashInput::try_from(&authenticated_content).unwrap();
input
.calculate_confirmed_transcript_hash(
provider.crypto(),
ciphersuite,
&interim_transcript_hash_before,
)
.unwrap()
};
let confirmation_tag = {
confirmation_key
.tag(
provider.crypto(),
ciphersuite,
&confirmed_transcript_hash_after,
)
.unwrap()
};
authenticated_content.set_confirmation_tag(confirmation_tag.clone());
let interim_transcript_hash_after = {
let input = InterimTranscriptHashInput::from(&confirmation_tag);
input
.calculate_interim_transcript_hash(
provider.crypto(),
ciphersuite,
&confirmed_transcript_hash_after,
)
.unwrap()
};
TranscriptTestVector {
cipher_suite: (&ciphersuite).into(),
confirmation_key: confirmation_key.as_slice().to_vec(),
authenticated_content: authenticated_content.tls_serialize_detached().unwrap(),
interim_transcript_hash_before,
confirmed_transcript_hash_after,
interim_transcript_hash_after,
}
}