use std::convert::TryFrom;
#[cfg(test)]
use crate::test_utils::{read, write};
use crate::{
ciphersuite::hash_ref::KeyPackageRef, ciphersuite::signable::*, credentials::*, framing::*,
group::*, messages::*, schedule::*, test_utils::*, versions::ProtocolVersion,
};
use openmls_rust_crypto::OpenMlsRustCrypto;
use openmls_traits::{random::OpenMlsRand, types::SignatureScheme, OpenMlsCryptoProvider};
use serde::{self, Deserialize, Serialize};
use thiserror::Error;
use tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerializeTrait};
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct TranscriptTestVector {
pub cipher_suite: u16,
group_id: String,
epoch: u64,
tree_hash_before: String,
confirmed_transcript_hash_before: String,
interim_transcript_hash_before: String,
credential: String,
membership_key: String,
confirmation_key: String,
commit: String, group_context: String,
confirmed_transcript_hash_after: String,
interim_transcript_hash_after: String,
}
pub fn generate_test_vector(ciphersuite: Ciphersuite) -> TranscriptTestVector {
let crypto = OpenMlsRustCrypto::default();
let group_id = GroupId::random(&crypto);
let epoch = random_u64();
let tree_hash_before = crypto
.rand()
.random_vec(ciphersuite.hash_length())
.expect("An unexpected error occurred.");
let confirmed_transcript_hash_before = crypto
.rand()
.random_vec(ciphersuite.hash_length())
.expect("An unexpected error occurred.");
let interim_transcript_hash_before = crypto
.rand()
.random_vec(ciphersuite.hash_length())
.expect("An unexpected error occurred.");
let membership_key = MembershipKey::from_secret(
Secret::random(ciphersuite, &crypto, None )
.expect("Not enough randomness."),
);
let confirmation_key = ConfirmationKey::from_secret(
Secret::random(ciphersuite, &crypto, None )
.expect("Not enough randomness."),
);
let credential_bundle = CredentialBundle::new(
b"client".to_vec(),
CredentialType::Basic,
SignatureScheme::from(ciphersuite),
&crypto,
)
.expect("An unexpected error occurred.");
let context = GroupContext::new(
group_id.clone(),
epoch,
tree_hash_before.clone(),
confirmed_transcript_hash_before.clone(),
&[], );
let aad = crypto
.rand()
.random_vec(48)
.expect("An unexpected error occurred.");
let framing_parameters = FramingParameters::new(&aad, WireFormat::MlsPlaintext);
let sender = Sender::build_member(&KeyPackageRef::from_slice(
&crypto
.rand()
.random_vec(16)
.expect("Error getting randomnes"),
));
let mut commit = MlsPlaintext::commit(
framing_parameters,
sender,
Commit {
proposals: vec![].into(),
path: None,
},
&credential_bundle,
&context,
&crypto,
)
.expect("An unexpected error occurred.");
let confirmed_transcript_hash_after = update_confirmed_transcript_hash(
ciphersuite,
&crypto,
&MlsPlaintextCommitContent::try_from(&commit).expect("An unexpected error occurred."),
&interim_transcript_hash_before,
)
.expect("Error updating confirmed transcript hash");
let confirmation_tag = confirmation_key
.tag(&crypto, &confirmed_transcript_hash_after)
.expect("Could not compute confirmation tag.");
commit.set_confirmation_tag(confirmation_tag);
let interim_transcript_hash_after = update_interim_transcript_hash(
ciphersuite,
&crypto,
&MlsPlaintextCommitAuthData::try_from(&commit).expect("An unexpected error occurred."),
&confirmed_transcript_hash_after,
)
.expect("Error updating interim transcript hash");
commit
.set_membership_tag(
&crypto,
&context
.tls_serialize_detached()
.expect("An unexpected error occurred."),
&membership_key,
)
.expect("Error adding membership tag");
let credential = credential_bundle
.credential()
.tls_serialize_detached()
.expect("An unexpected error occurred.");
TranscriptTestVector {
cipher_suite: ciphersuite as u16,
group_id: bytes_to_hex(group_id.as_slice()),
epoch,
tree_hash_before: bytes_to_hex(&tree_hash_before),
confirmed_transcript_hash_before: bytes_to_hex(&confirmed_transcript_hash_before),
interim_transcript_hash_before: bytes_to_hex(&interim_transcript_hash_before),
credential: bytes_to_hex(&credential),
membership_key: bytes_to_hex(membership_key.as_slice()),
confirmation_key: bytes_to_hex(confirmation_key.as_slice()),
commit: bytes_to_hex(
&commit
.tls_serialize_detached()
.expect("Error encoding commit"),
),
group_context: bytes_to_hex(
&context
.tls_serialize_detached()
.expect("An unexpected error occurred."),
),
confirmed_transcript_hash_after: bytes_to_hex(&confirmed_transcript_hash_after),
interim_transcript_hash_after: bytes_to_hex(&interim_transcript_hash_after),
}
}
#[test]
fn write_test_vectors() {
use openmls_traits::crypto::OpenMlsCrypto;
let mut tests = Vec::new();
const NUM_TESTS: usize = 100;
for &ciphersuite in OpenMlsRustCrypto::default()
.crypto()
.supported_ciphersuites()
.iter()
{
for _ in 0..NUM_TESTS {
let test = generate_test_vector(ciphersuite);
tests.push(test);
}
}
write("test_vectors/kat_transcripts-new.json", &tests);
}
pub fn run_test_vector(
test_vector: TranscriptTestVector,
backend: &impl OpenMlsCryptoProvider,
) -> Result<(), TranscriptTestVectorError> {
let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).expect("Invalid ciphersuite");
log::debug!("Testing test vector for ciphersuite {:?}", ciphersuite);
log::trace!(" {:?}", test_vector);
let group_id = GroupId {
value: hex_to_bytes(&test_vector.group_id).into(),
};
let epoch = test_vector.epoch;
let tree_hash_before = hex_to_bytes(&test_vector.tree_hash_before);
let confirmed_transcript_hash_before =
hex_to_bytes(&test_vector.confirmed_transcript_hash_before);
let interim_transcript_hash_before = hex_to_bytes(&test_vector.interim_transcript_hash_before);
let membership_key = MembershipKey::from_secret(Secret::from_slice(
&hex_to_bytes(&test_vector.membership_key),
ProtocolVersion::default(),
ciphersuite,
));
let confirmation_key = ConfirmationKey::from_secret(Secret::from_slice(
&hex_to_bytes(&test_vector.confirmation_key),
ProtocolVersion::default(),
ciphersuite,
));
let credential =
Credential::tls_deserialize(&mut hex_to_bytes(&test_vector.credential).as_slice())
.expect("An unexpected error occurred.");
let commit_bytes = hex_to_bytes(&test_vector.commit);
let mut commit = VerifiableMlsPlaintext::tls_deserialize(&mut commit_bytes.as_slice())
.expect("Error decoding commit");
let context = GroupContext::new(
group_id,
epoch,
tree_hash_before,
confirmed_transcript_hash_before,
&[], );
let expected_group_context = hex_to_bytes(&test_vector.group_context);
if context
.tls_serialize_detached()
.expect("An unexpected error occurred.")
!= expected_group_context
{
log::error!(" Group context mismatch");
log::debug!(
" Computed: {:x?}",
context
.tls_serialize_detached()
.expect("An unexpected error occurred.")
);
log::debug!(" Expected: {:x?}", expected_group_context);
if cfg!(test) {
panic!("Group context mismatch");
}
return Err(TranscriptTestVectorError::GroupContextMismatch);
}
commit.set_context(
context
.tls_serialize_detached()
.expect("An unexpected error occurred."),
);
if commit.verify_membership(backend, &membership_key).is_err() {
if cfg!(test) {
panic!("Invalid membership tag");
}
return Err(TranscriptTestVectorError::MembershipTagVerificationError);
}
let commit: MlsPlaintext = commit
.verify(backend, &credential)
.expect("Invalid signature on MlsPlaintext commit");
let confirmed_transcript_hash_after =
hex_to_bytes(&test_vector.confirmed_transcript_hash_after);
let my_confirmation_tag = confirmation_key
.tag(backend, &confirmed_transcript_hash_after)
.expect("Could not compute confirmation tag.");
if &my_confirmation_tag
!= commit
.confirmation_tag()
.expect("Confirmation tag is missing")
{
log::error!(" Confirmation tag mismatch");
log::debug!(" Computed: {:x?}", my_confirmation_tag);
log::debug!(
" Expected: {:x?}",
commit
.confirmation_tag()
.expect("An unexpected error occurred.")
);
if cfg!(test) {
panic!("Invalid confirmation tag");
}
return Err(TranscriptTestVectorError::ConfirmationTagMismatch);
}
let my_confirmed_transcript_hash_after = update_confirmed_transcript_hash(
ciphersuite,
backend,
&MlsPlaintextCommitContent::try_from(&commit).expect("An unexpected error occurred."),
&interim_transcript_hash_before,
)
.expect("Error updating confirmed transcript hash");
if my_confirmed_transcript_hash_after != confirmed_transcript_hash_after {
log::debug!(" Confirmed transcript hash mismatch");
log::debug!(" Got: {:x?}", my_confirmed_transcript_hash_after);
log::debug!(" Expected: {:x?}", confirmed_transcript_hash_after);
if cfg!(test) {
panic!("Confirmed transcript hash mismatch");
}
return Err(TranscriptTestVectorError::ConfirmedTranscriptHashMismatch);
}
let interim_transcript_hash_after = hex_to_bytes(&test_vector.interim_transcript_hash_after);
let my_interim_transcript_hash_after = update_interim_transcript_hash(
ciphersuite,
backend,
&MlsPlaintextCommitAuthData::try_from(&commit).expect("An unexpected error occurred."),
&my_confirmed_transcript_hash_after,
)
.expect("Error updating interim transcript hash");
if my_interim_transcript_hash_after != interim_transcript_hash_after {
log::debug!(" Interim transcript hash mismatch");
log::debug!(" Got: {:x?}", my_interim_transcript_hash_after);
log::debug!(" Expected: {:x?}", interim_transcript_hash_after);
if cfg!(test) {
panic!("Interim transcript hash mismatch");
}
return Err(TranscriptTestVectorError::InterimTranscriptHashMismatch);
}
log::debug!(" Finished transcript test vector verification");
Ok(())
}
#[apply(backends)]
fn read_test_vectors_transcript(backend: &impl OpenMlsCryptoProvider) {
let tests: Vec<TranscriptTestVector> = read("test_vectors/kat_transcripts.json");
for test_vector in tests {
match run_test_vector(test_vector, backend) {
Ok(_) => {}
Err(e) => panic!("Error while checking transcript test vector.\n{:?}", e),
}
}
}
#[derive(Error, Debug, PartialEq, Clone)]
pub enum TranscriptTestVectorError {
#[error("Membership tag could not be verified.")]
MembershipTagVerificationError,
#[error("The group context does not match")]
GroupContextMismatch,
#[error("The computed confirmation tag doesn't match the one in the test vector.")]
ConfirmationTagMismatch,
#[error("The computed transcript hash doesn't match the one in the test vector.")]
ConfirmedTranscriptHashMismatch,
#[error("The computed interim transcript hash doesn't match the one in the test vector.")]
InterimTranscriptHashMismatch,
}