#[cfg(test)]
use crate::test_utils::*;
use crate::tree::{index::*, treemath::*};
use serde::{self, Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TreeMathTestVector {
n_leaves: u32,
n_nodes: u32,
root: Vec<u32>,
left: Vec<Option<u32>>,
right: Vec<Option<u32>>,
parent: Vec<Option<u32>>,
sibling: Vec<Option<u32>>,
}
macro_rules! convert {
($r:expr) => {
match $r {
Ok(i) => Some(i.as_u32()),
Err(_) => None,
}
};
}
#[cfg(any(feature = "test-utils", test))]
pub fn generate_test_vector(n_leaves: u32) -> TreeMathTestVector {
let leaves = LeafIndex::from(n_leaves);
let n_nodes = node_width(leaves.as_usize()) as u32;
let mut test_vector = TreeMathTestVector {
n_leaves,
n_nodes,
root: Vec::new(),
left: Vec::new(),
right: Vec::new(),
parent: Vec::new(),
sibling: Vec::new(),
};
for i in 0..n_leaves {
test_vector.root.push(root(LeafIndex::from(i + 1)).as_u32());
}
for i in 0..n_nodes {
test_vector.left.push(convert!(left(NodeIndex::from(i))));
test_vector
.right
.push(convert!(right(NodeIndex::from(i), leaves)));
test_vector
.parent
.push(convert!(parent(NodeIndex::from(i), leaves)));
test_vector
.sibling
.push(convert!(sibling(NodeIndex::from(i), leaves)));
}
test_vector
}
#[test]
fn write_test_vectors() {
let mut tests = Vec::new();
for n_leaves in 1..99 {
let test_vector = generate_test_vector(n_leaves);
tests.push(test_vector);
}
write("test_vectors/kat_treemath_openmls-new.json", &tests);
}
#[cfg(any(feature = "test-utils", test))]
pub fn run_test_vector(test_vector: TreeMathTestVector) -> Result<(), TmTestVectorError> {
let n_leaves = test_vector.n_leaves as usize;
let n_nodes = node_width(n_leaves);
let leaves = LeafIndex::from(n_leaves);
if test_vector.n_nodes != node_width(leaves.as_usize()) as u32 {
return Err(TmTestVectorError::TreeSizeMismatch);
}
for i in 0..n_leaves {
if test_vector.root[i] != root(LeafIndex::from(i + 1)).as_u32() {
return Err(TmTestVectorError::RootIndexMismatch);
}
}
for i in 0..n_nodes {
if test_vector.left[i] != convert!(left(NodeIndex::from(i))) {
return Err(TmTestVectorError::LeftIndexMismatch);
}
if test_vector.right[i] != convert!(right(NodeIndex::from(i), leaves)) {
return Err(TmTestVectorError::RightIndexMismatch);
}
if test_vector.parent[i] != convert!(parent(NodeIndex::from(i), leaves)) {
return Err(TmTestVectorError::ParentIndexMismatch);
}
if test_vector.sibling[i] != convert!(sibling(NodeIndex::from(i), leaves)) {
return Err(TmTestVectorError::SiblingIndexMismatch);
}
}
Ok(())
}
#[test]
fn read_test_vectors() {
let tests: Vec<TreeMathTestVector> = read("test_vectors/kat_treemath_openmls.json");
for test_vector in tests {
match run_test_vector(test_vector) {
Ok(_) => {}
Err(e) => panic!("Error while checking tree math test vector.\n{:?}", e),
}
}
let tv: TreeMathTestVector = read("test_vectors/mlspp/mlspp_treemath.json");
run_test_vector(tv).expect("Error while checking key schedule test vector.");
}
#[cfg(any(feature = "test-utils", test))]
#[derive(Error, Debug, PartialEq, Clone)]
pub enum TmTestVectorError {
#[error("The computed tree size doesn't match the one in the test vector.")]
TreeSizeMismatch,
#[error("The computed root index doesn't match the one in the test vector.")]
RootIndexMismatch,
#[error("A computed left child index doesn't match the one in the test vector.")]
LeftIndexMismatch,
#[error("A computed right child index doesn't match the one in the test vector.")]
RightIndexMismatch,
#[error("A computed parent index doesn't match the one in the test vector.")]
ParentIndexMismatch,
#[error("A computed sibling index doesn't match the one in the test vector.")]
SiblingIndexMismatch,
}