use std::collections::HashSet;
use ::serde::Deserialize;
use openmls_traits::{crypto::OpenMlsCrypto, OpenMlsProvider};
use tls_codec::Deserialize as TlsDeserializeTrait;
use crate::{
binary_tree::array_representation::TreeNodeIndex,
group::GroupId,
test_utils::*,
treesync::{RatchetTreeIn, TreeSync},
};
#[derive(Deserialize)]
struct TreeHash(#[serde(with = "hex")] Vec<u8>);
#[derive(Deserialize)]
struct TestElement {
cipher_suite: u16,
#[serde(with = "hex")]
tree: Vec<u8>,
#[serde(with = "hex")]
group_id: Vec<u8>,
resolutions: Vec<Vec<u32>>,
tree_hashes: Vec<TreeHash>,
}
fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result<(), String> {
let ciphersuite = Ciphersuite::try_from(test.cipher_suite).unwrap();
if !provider
.crypto()
.supported_ciphersuites()
.contains(&ciphersuite)
{
log::debug!("Unsupported ciphersuite {0:?} ...", test.cipher_suite);
return Ok(());
}
let group_id = &GroupId::from_slice(test.group_id.as_slice());
let ratchet_tree = RatchetTreeIn::tls_deserialize_exact(test.tree)
.unwrap()
.into_verified(ciphersuite, provider.crypto(), group_id)
.unwrap();
let treesync =
TreeSync::from_ratchet_tree(provider.crypto(), ciphersuite, ratchet_tree.clone())
.map_err(|e| format!("Error while creating tree sync: {e:?}"))?;
let diff = treesync.empty_diff();
for index in 0..ratchet_tree.0.len() {
let tree_node_index = TreeNodeIndex::test_new(index as u32);
let resolution = diff
.resolution(tree_node_index, &HashSet::new())
.into_iter()
.map(|(index, _)| index.test_u32())
.collect::<Vec<_>>();
assert_eq!(resolution, test.resolutions[index]);
let tree_hash = diff
.compute_tree_hash(
provider.crypto(),
ciphersuite,
tree_node_index,
&HashSet::new(),
)
.unwrap();
assert_eq!(tree_hash, test.tree_hashes[index].0);
}
Ok(())
}
#[openmls_test::openmls_test]
fn read_test_vectors_tree_validation() {
let provider = &Provider::default();
let _ = pretty_env_logger::try_init();
log::debug!("Reading test vectors ...");
let tests: Vec<TestElement> = read_json!("../../../../test_vectors/tree-validation.json");
for test_vector in tests {
match run_test_vector(test_vector, provider) {
Ok(_) => {}
Err(e) => panic!("Error while checking tree validation test vector.\n{e:?}"),
}
}
log::trace!("Finished test vector verification");
}