use crate::error::WSError;
use sha2::{Digest, Sha256};
const LEAF_PREFIX: u8 = 0x00;
const NODE_PREFIX: u8 = 0x01;
pub fn compute_leaf_hash(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update([LEAF_PREFIX]);
hasher.update(data);
hasher.finalize().into()
}
pub fn compute_node_hash(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update([NODE_PREFIX]);
hasher.update(left);
hasher.update(right);
hasher.finalize().into()
}
pub fn verify_inclusion_proof(
leaf_index: u64,
tree_size: u64,
leaf_hash: &[u8; 32],
proof_hashes: &[[u8; 32]],
expected_root: &[u8; 32],
) -> Result<(), WSError> {
if leaf_index >= tree_size {
return Err(WSError::RekorError(format!(
"Leaf index {} is out of range for tree size {}",
leaf_index, tree_size
)));
}
if tree_size == 0 {
return Err(WSError::RekorError("Tree size cannot be zero".to_string()));
}
if tree_size == 1 {
if leaf_index != 0 {
return Err(WSError::RekorError(
"Leaf index must be 0 for single-leaf tree".to_string(),
));
}
if !proof_hashes.is_empty() {
return Err(WSError::RekorError(
"Proof should be empty for single-leaf tree".to_string(),
));
}
if leaf_hash != expected_root {
return Err(WSError::RekorError(
"Leaf hash does not match root for single-leaf tree".to_string(),
));
}
return Ok(());
}
let mut current_hash = *leaf_hash;
let mut current_index = leaf_index;
let mut current_tree_size = tree_size;
#[cfg(test)]
{
println!(" Starting with leaf hash: {}", hex::encode(current_hash));
println!(
" Leaf index: {}, Tree size: {}",
current_index, current_tree_size
);
}
#[allow(clippy::unused_enumerate_index)]
for (_i, proof_hash) in proof_hashes.iter().enumerate() {
let left_subtree_size = largest_power_of_two_less_than(current_tree_size);
let is_left_child = current_index < left_subtree_size;
#[cfg(test)]
let (left_hex, right_hex) = if is_left_child {
(hex::encode(current_hash), hex::encode(proof_hash))
} else {
(hex::encode(proof_hash), hex::encode(current_hash))
};
let (left, right) = if is_left_child {
(¤t_hash, proof_hash)
} else {
(proof_hash, ¤t_hash)
};
current_hash = compute_node_hash(left, right);
#[cfg(test)]
{
println!(
"\n Step {}: {} child",
_i + 1,
if is_left_child { "LEFT" } else { "RIGHT" }
);
println!(" Left: {}", left_hex);
println!(" Right: {}", right_hex);
println!(" Result: {}", hex::encode(current_hash));
println!(
" Index: {} -> {}, Tree size: {} -> {}",
current_index,
if current_index >= left_subtree_size {
current_index - left_subtree_size
} else {
current_index
},
current_tree_size,
if current_index >= left_subtree_size {
current_tree_size - left_subtree_size
} else {
left_subtree_size
}
);
}
if current_index >= left_subtree_size {
current_index -= left_subtree_size;
current_tree_size -= left_subtree_size;
} else {
current_tree_size = left_subtree_size;
}
}
#[cfg(test)]
println!("\n Final computed root: {}", hex::encode(current_hash));
if ¤t_hash != expected_root {
return Err(WSError::RekorError(format!(
"Computed root hash does not match expected root. Computed: {}, Expected: {}",
hex::encode(current_hash),
hex::encode(expected_root)
)));
}
Ok(())
}
fn largest_power_of_two_less_than(n: u64) -> u64 {
if n <= 1 {
return n;
}
let mut power = 1u64;
while power * 2 < n {
power *= 2;
}
power
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leaf_hash_computation() {
let data = b"test data";
let hash = compute_leaf_hash(data);
let mut expected = Sha256::new();
expected.update([0x00]);
expected.update(data);
let expected_hash: [u8; 32] = expected.finalize().into();
assert_eq!(hash, expected_hash);
}
#[test]
fn test_node_hash_computation() {
let left = [1u8; 32];
let right = [2u8; 32];
let hash = compute_node_hash(&left, &right);
let mut expected = Sha256::new();
expected.update([0x01]);
expected.update(left);
expected.update(right);
let expected_hash: [u8; 32] = expected.finalize().into();
assert_eq!(hash, expected_hash);
}
#[test]
fn test_single_leaf_tree() {
let leaf_hash = [0x42u8; 32];
let result = verify_inclusion_proof(0, 1, &leaf_hash, &[], &leaf_hash);
assert!(result.is_ok());
}
#[test]
fn test_single_leaf_tree_wrong_root() {
let leaf_hash = [0x42u8; 32];
let wrong_root = [0x43u8; 32];
let result = verify_inclusion_proof(0, 1, &leaf_hash, &[], &wrong_root);
assert!(result.is_err());
}
#[test]
fn test_invalid_leaf_index() {
let leaf_hash = [0x42u8; 32];
let result = verify_inclusion_proof(
5, 3,
&leaf_hash,
&[],
&leaf_hash,
);
assert!(result.is_err());
}
#[test]
fn test_largest_power_of_two() {
assert_eq!(largest_power_of_two_less_than(1), 1);
assert_eq!(largest_power_of_two_less_than(2), 1);
assert_eq!(largest_power_of_two_less_than(3), 2);
assert_eq!(largest_power_of_two_less_than(4), 2);
assert_eq!(largest_power_of_two_less_than(5), 4);
assert_eq!(largest_power_of_two_less_than(7), 4);
assert_eq!(largest_power_of_two_less_than(8), 4);
assert_eq!(largest_power_of_two_less_than(9), 8);
assert_eq!(largest_power_of_two_less_than(15), 8);
assert_eq!(largest_power_of_two_less_than(16), 8);
assert_eq!(largest_power_of_two_less_than(17), 16);
}
#[test]
fn test_two_leaf_tree() {
let leaf0_data = b"leaf0";
let leaf1_data = b"leaf1";
let leaf0_hash = compute_leaf_hash(leaf0_data);
let leaf1_hash = compute_leaf_hash(leaf1_data);
let root = compute_node_hash(&leaf0_hash, &leaf1_hash);
let result = verify_inclusion_proof(0, 2, &leaf0_hash, &[leaf1_hash], &root);
assert!(result.is_ok(), "Failed to verify leaf 0");
let result = verify_inclusion_proof(1, 2, &leaf1_hash, &[leaf0_hash], &root);
assert!(result.is_ok(), "Failed to verify leaf 1");
}
#[test]
fn test_google_ct_test_vectors() {
let inputs: Vec<&[u8]> = vec![
&[], &[0x00], &[0x10], &[0x20, 0x21], &[0x30, 0x31], &[0x40, 0x41, 0x42, 0x43], &[0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57], &[
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
],
];
let expected_roots = [
"6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d", "fac54203e7cc696cf0dfcb42c92a1d9dbaf70ad9e621f4bd8d98662f00e3c125", "aeb6bcfe274b70a14fb067a5e5578264db0fa9b51af5e0ba159158f329e06e77", "d37ee418976dd95753c1c73862b9398fa2a2cf9b4ff0fdfe8b30cd95209614b7", "4e3bbb1f7b478dcfe71fb631631519a3bca12c9aefca1612bfce4c13a86264d4", "76e67dadbcdf1e10e1b74ddc608abd2f98dfb16fbce75277b5232a127f2087ef", "ddb89be403809e325750d3d263cd78929c2942b7942a34b77e122c9594a74c8c", "5dc9da79a70659a9ad559cb701ded9a2ab9d823aad2f4960cfe370eff4604328", ];
let leaf0_hash = compute_leaf_hash(inputs[0]);
let expected_root0 = hex::decode(expected_roots[0]).unwrap();
assert_eq!(
&leaf0_hash[..],
&expected_root0[..],
"Single-leaf root mismatch"
);
let leaf0_hash = compute_leaf_hash(inputs[0]);
let leaf1_hash = compute_leaf_hash(inputs[1]);
let computed_root = compute_node_hash(&leaf0_hash, &leaf1_hash);
let expected_root1 = hex::decode(expected_roots[1]).unwrap();
assert_eq!(
&computed_root[..],
&expected_root1[..],
"Two-leaf root mismatch"
);
let leaf0_hash = compute_leaf_hash(inputs[0]);
let leaf1_hash = compute_leaf_hash(inputs[1]);
let leaf2_hash = compute_leaf_hash(inputs[2]);
let h01 = compute_node_hash(&leaf0_hash, &leaf1_hash);
let root3 = compute_node_hash(&h01, &leaf2_hash);
let expected_root2 = hex::decode(expected_roots[2]).unwrap();
assert_eq!(&root3[..], &expected_root2[..], "Three-leaf root mismatch");
let proof = vec![leaf1_hash, leaf2_hash];
let result = verify_inclusion_proof(0, 3, &leaf0_hash, &proof, &root3);
assert!(
result.is_ok(),
"Failed to verify leaf 0 in 3-leaf tree: {:?}",
result.err()
);
}
#[test]
fn test_empty_leaf_hash() {
let empty_leaf_hash = compute_leaf_hash(&[]);
let expected =
hex::decode("6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d")
.unwrap();
assert_eq!(
&empty_leaf_hash[..],
&expected[..],
"Empty leaf hash mismatch. Expected SHA-256(0x00 || '')"
);
}
}
#[cfg(kani)]
mod proofs {
use super::*;
#[kani::proof]
#[kani::unwind(65)]
fn proof_largest_power_of_two_properties() {
let n: u64 = kani::any();
kani::assume(n > 1);
kani::assume(n <= 1024);
let result = largest_power_of_two_less_than(n);
assert!(result.is_power_of_two(), "Result {} is not a power of 2", result);
assert!(result < n, "Result {} is not less than n={}", result, n);
assert!(result * 2 >= n, "Result {} is not the largest power < n={}", result, n);
}
#[kani::proof]
fn proof_largest_power_of_two_edge_cases() {
assert_eq!(largest_power_of_two_less_than(0), 0);
assert_eq!(largest_power_of_two_less_than(1), 1);
assert_eq!(largest_power_of_two_less_than(2), 1);
assert_eq!(largest_power_of_two_less_than(3), 2);
}
#[kani::proof]
fn proof_leaf_node_domain_separation() {
let mut data = [0u8; 64];
data[0] = kani::any();
data[1] = kani::any();
let leaf_hash = compute_leaf_hash(&data);
let left: [u8; 32] = data[0..32].try_into().unwrap();
let right: [u8; 32] = data[32..64].try_into().unwrap();
let node_hash = compute_node_hash(&left, &right);
assert_ne!(leaf_hash, node_hash,
"Leaf hash and node hash collided — domain separation broken");
}
#[kani::proof]
fn proof_leaf_hash_deterministic() {
let b0: u8 = kani::any();
let b1: u8 = kani::any();
let data = [b0, b1];
let hash1 = compute_leaf_hash(&data);
let hash2 = compute_leaf_hash(&data);
assert_eq!(hash1, hash2);
}
#[kani::proof]
fn proof_node_hash_deterministic() {
let l: [u8; 32] = [kani::any(); 32];
let r: [u8; 32] = [kani::any(); 32];
let hash1 = compute_node_hash(&l, &r);
let hash2 = compute_node_hash(&l, &r);
assert_eq!(hash1, hash2);
}
}