#![allow(unused_qualifications)]
#[cfg(not(any(feature = "hashbrown")))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::convert::TryFrom;
use std::path::Path;
use crate::Array;
#[cfg(feature = "hashbrown")]
use hashbrown::HashMap;
use crate::traits::{
Branch, Data, Database, Decode, Encode, Exception, Hasher, Leaf, Node, NodeVariant,
};
use crate::utils::tree_cell::TreeCell;
use crate::utils::tree_ref::TreeRef;
use crate::utils::tree_utils::{
calc_min_split_index, check_descendants, choose_zero, generate_leaf_map,
generate_tree_ref_queue, split_pairs,
};
pub type BinaryMerkleTreeResult<T> = Result<T, Exception>;
pub trait MerkleTree<const N: usize> {
type Database: Database<N, Self::Node>;
type Branch: Branch<N>;
type Leaf: Leaf<N>;
type Data: Data;
type Node: Node<N, Branch = Self::Branch, Leaf = Self::Leaf, Data = Self::Data>;
type Hasher: Hasher<N>;
type Value: Decode + Encode;
}
pub struct MerkleBIT<M: MerkleTree<N>, const N: usize> {
db: M::Database,
depth: usize,
}
impl<M: MerkleTree<N>, const N: usize> MerkleBIT<M, N> {
#[inline]
pub fn new(path: &Path, depth: usize) -> BinaryMerkleTreeResult<Self> {
let db = Database::open(path)?;
Ok(Self { db, depth })
}
#[inline]
pub const fn from_db(db: M::Database, depth: usize) -> BinaryMerkleTreeResult<Self> {
Ok(Self { db, depth })
}
#[inline]
pub fn get(
&self,
root_hash: &Array<N>,
keys: &mut [Array<N>],
) -> BinaryMerkleTreeResult<HashMap<Array<N>, Option<M::Value>>> {
if keys.is_empty() {
return Ok(HashMap::new());
}
let mut leaf_map = generate_leaf_map(keys);
keys.sort_unstable();
let root_node = if let Some(n) = self.db.get_node(*root_hash)? {
n
} else {
return Ok(leaf_map);
};
let mut cell_queue = VecDeque::with_capacity(keys.len());
let root_cell =
TreeCell::new::<M::Branch, M::Leaf, M::Data>(*root_hash, keys, root_node, 0);
cell_queue.push_front(root_cell);
while let Some(tree_cell) = cell_queue.pop_front() {
if tree_cell.depth > self.depth {
return Err(Exception::new("Depth of merkle tree exceeded"));
}
let node = tree_cell.node;
match node.get_variant() {
NodeVariant::Branch(branch) => {
let (_, zero, one, branch_split_index, branch_key) = branch.decompose();
let min_split_index = calc_min_split_index(tree_cell.keys, &branch_key)?;
let descendants = check_descendants(
tree_cell.keys,
branch_split_index,
&branch_key,
min_split_index,
)?;
if descendants.is_empty() {
continue;
}
let (zeros, ones) = split_pairs(descendants, branch_split_index)?;
self.push_cell_if_node(&mut cell_queue, tree_cell.depth, one, ones)?;
self.push_cell_if_node(&mut cell_queue, tree_cell.depth, zero, zeros)?;
}
NodeVariant::Leaf(n) => {
if let Some(d) = self.db.get_node(*n.get_data())? {
if let NodeVariant::Data(data) = d.get_variant() {
let value = M::Value::decode(data.get_value())?;
if let Ok(index) = keys.binary_search(n.get_key()) {
leaf_map.insert(keys[index], Some(value));
}
} else {
return Err(Exception::new(
"Corrupt merkle tree: Found non data node after leaf",
));
}
} else {
return Err(Exception::new(
"Corrupt merkle tree: Failed to get leaf node from DB",
));
}
}
NodeVariant::Data(_) => {
return Err(Exception::new(
"Corrupt merkle tree: Found data node while traversing tree",
));
}
}
}
Ok(leaf_map)
}
fn push_cell_if_node<'keys>(
&self,
cell_queue: &mut VecDeque<TreeCell<'keys, M::Node, N>>,
depth: usize,
location: Array<N>,
locations: &'keys [Array<N>],
) -> BinaryMerkleTreeResult<()> {
if let Some(node) = self.db.get_node(location)? {
if !locations.is_empty() {
let new_cell = TreeCell::new::<M::Branch, M::Leaf, M::Data>(
location,
locations,
node,
depth + 1,
);
cell_queue.push_front(new_cell);
}
}
Ok(())
}
#[inline]
pub fn insert(
&mut self,
previous_root: Option<&Array<N>>,
keys: &mut [Array<N>],
values: &[M::Value],
) -> BinaryMerkleTreeResult<Array<N>> {
if keys.len() != values.len() {
return Err(Exception::new("Keys and values have different lengths"));
}
if keys.is_empty() || values.is_empty() {
return Err(Exception::new("Keys or values are empty"));
}
let mut value_map = HashMap::new();
for (&key, value) in keys.iter().zip(values.iter()) {
value_map.insert(key, value);
}
keys.sort_unstable();
let nodes = self.insert_leaves(keys, &value_map)?;
let mut tree_refs = Vec::with_capacity(keys.len());
let mut key_map = HashMap::new();
for (loc, &key) in nodes.into_iter().zip(keys.iter()) {
key_map.insert(key, loc);
let tree_ref = TreeRef::new(key, loc, 1, 1);
tree_refs.push(tree_ref);
}
if let Some(root) = previous_root {
let mut proof_nodes = self.generate_treerefs(root, keys, &key_map)?;
tree_refs.append(&mut proof_nodes);
}
let new_root = self.create_tree(tree_refs)?;
Ok(new_root)
}
fn generate_treerefs(
&mut self,
root: &Array<N>,
keys: &mut [Array<N>],
key_map: &HashMap<Array<N>, Array<N>>,
) -> BinaryMerkleTreeResult<Vec<TreeRef<N>>> {
let mut proof_nodes = Vec::with_capacity(keys.len());
let root_node = if let Some(m) = self.db.get_node(*root)? {
m
} else {
return Err(Exception::new("Could not find root"));
};
let mut cell_queue = VecDeque::with_capacity(keys.len());
let root_cell: TreeCell<M::Node, N> =
TreeCell::new::<M::Branch, M::Leaf, M::Data>(*root, keys, root_node, 0);
cell_queue.push_front(root_cell);
self.traverse_tree(key_map, &mut proof_nodes, &mut cell_queue)?;
Ok(proof_nodes)
}
fn traverse_tree(
&mut self,
key_map: &HashMap<Array<N>, Array<N>>,
proof_nodes: &mut Vec<TreeRef<N>>,
cell_queue: &mut VecDeque<TreeCell<M::Node, N>>,
) -> BinaryMerkleTreeResult<()> {
while let Some(tree_cell) = cell_queue.pop_front() {
if tree_cell.depth > self.depth {
return Err(Exception::new("Depth of merkle tree exceeded"));
}
let node = tree_cell.node;
let depth = tree_cell.depth;
let location = tree_cell.location;
let mut refs = node.get_references();
let branch = match node.get_variant() {
NodeVariant::Branch(n) => n,
NodeVariant::Leaf(n) => {
let key = n.get_key();
let mut update = false;
if let Some(loc) = key_map.get(key) {
update = loc == &location;
if !update {
continue;
}
}
self.insert_leaf(&location)?;
if update {
continue;
}
let tree_ref = TreeRef::new(*key, location, 1, 1);
proof_nodes.push(tree_ref);
continue;
}
NodeVariant::Data(_) => {
return Err(Exception::new(
"Corrupt merkle tree: Found data node while traversing tree",
));
}
};
let (branch_count, branch_zero, branch_one, branch_split_index, branch_key) =
branch.decompose();
let min_split_index = calc_min_split_index(tree_cell.keys, &branch_key)?;
let mut descendants = tree_cell.keys;
if min_split_index < branch_split_index {
descendants = check_descendants(
tree_cell.keys,
branch_split_index,
&branch_key,
min_split_index,
)?;
if descendants.is_empty() {
let mut new_branch = M::Branch::new();
new_branch.set_count(branch_count);
new_branch.set_zero(branch_zero);
new_branch.set_one(branch_one);
new_branch.set_split_index(branch_split_index);
new_branch.set_key(branch_key);
let tree_ref = TreeRef::new(branch_key, tree_cell.location, branch_count, 1);
refs += 1;
let mut new_node = M::Node::new(NodeVariant::Branch(new_branch));
new_node.set_references(refs);
self.db.insert(tree_ref.location, new_node)?;
proof_nodes.push(tree_ref);
continue;
}
}
let (zeros, ones) = split_pairs(descendants, branch_split_index)?;
{
match self.split_nodes(depth, branch_one, ones)? {
SplitNodeType::Ref(tree_ref) => proof_nodes.push(tree_ref),
SplitNodeType::Cell(cell) => cell_queue.push_front(cell),
}
}
{
match self.split_nodes(depth, branch_zero, zeros)? {
SplitNodeType::Ref(tree_ref) => proof_nodes.push(tree_ref),
SplitNodeType::Cell(cell) => cell_queue.push_front(cell),
}
}
}
Ok(())
}
fn insert_leaf(&mut self, location: &Array<N>) -> BinaryMerkleTreeResult<()> {
if let Some(mut l) = self.db.get_node(*location)? {
let leaf_refs = l.get_references() + 1;
l.set_references(leaf_refs);
self.db.insert(*location, l)?;
return Ok(());
}
Err(Exception::new(
"Corrupt merkle tree: Failed to update leaf references",
))
}
fn split_nodes<'node_list>(
&mut self,
depth: usize,
branch: Array<N>,
node_list: &'node_list [Array<N>],
) -> Result<SplitNodeType<'node_list, M::Node, N>, Exception> {
if let Some(node) = self.db.get_node(branch)? {
return if node_list.is_empty() {
let other_key;
let count;
let refs = node.get_references() + 1;
let mut new_node;
match node.get_variant() {
NodeVariant::Branch(b) => {
count = b.get_count();
other_key = *b.get_key();
new_node = M::Node::new(NodeVariant::Branch(b));
}
NodeVariant::Leaf(l) => {
count = 1;
other_key = *l.get_key();
new_node = M::Node::new(NodeVariant::Leaf(l));
}
NodeVariant::Data(_) => {
return Err(Exception::new(
"Corrupt merkle tree: Found data node while traversing tree",
));
}
}
new_node.set_references(refs);
self.db.insert(branch, new_node)?;
let tree_ref = TreeRef::new(other_key, branch, count, 1);
Ok(SplitNodeType::Ref(tree_ref))
} else {
let new_cell = TreeCell::new::<M::Branch, M::Leaf, M::Data>(
branch,
node_list,
node,
depth + 1,
);
Ok(SplitNodeType::Cell(new_cell))
};
}
Err(Exception::new("Failed to find node in database."))
}
fn insert_leaves(
&mut self,
keys: &[Array<N>],
values: &HashMap<Array<N>, &M::Value>,
) -> BinaryMerkleTreeResult<Vec<Array<N>>> {
let mut nodes = Vec::with_capacity(keys.len());
for k in keys.iter() {
let key = k.as_ref();
let mut data = M::Data::new();
data.set_value(&(values[k].encode()?));
let mut data_hasher = M::Hasher::new(key.len());
data_hasher.update(b"d");
data_hasher.update(key);
data_hasher.update(data.get_value());
let data_node_location = data_hasher.finalize();
let mut data_node = M::Node::new(NodeVariant::Data(data));
data_node.set_references(1);
let mut leaf = M::Leaf::new();
leaf.set_data(data_node_location);
leaf.set_key(*k);
let mut leaf_hasher = M::Hasher::new(key.len());
leaf_hasher.update(b"l");
leaf_hasher.update(key.as_ref());
leaf_hasher.update(leaf.get_data().as_ref());
let leaf_node_location = leaf_hasher.finalize();
let mut leaf_node = M::Node::new(NodeVariant::Leaf(leaf));
leaf_node.set_references(1);
if let Some(n) = self.db.get_node(data_node_location)? {
let references = n.get_references() + 1;
data_node.set_references(references);
}
if let Some(n) = self.db.get_node(leaf_node_location)? {
let references = n.get_references() + 1;
leaf_node.set_references(references);
}
self.db.insert(data_node_location, data_node)?;
self.db.insert(leaf_node_location, leaf_node)?;
nodes.push(leaf_node_location);
}
Ok(nodes)
}
fn create_tree(&mut self, mut tree_refs: Vec<TreeRef<N>>) -> BinaryMerkleTreeResult<Array<N>> {
if tree_refs.is_empty() {
return Err(Exception::new("tree_refs should not be empty!"));
}
if tree_refs.len() == 1 {
self.db.batch_write()?;
let node = tree_refs.remove(0);
return Ok(node.location);
}
tree_refs.sort();
let mut tree_ref_queue = HashMap::new();
let unique_split_bits = generate_tree_ref_queue(&mut tree_refs, &mut tree_ref_queue)?;
let mut indices = unique_split_bits.into_iter().collect::<Vec<_>>();
indices.sort_unstable();
let mut root = None;
for i in indices.into_iter().rev() {
if let Some(level) = tree_ref_queue.remove(&i) {
root = self.merge_nodes(&mut tree_refs, level)?;
} else {
return Err(Exception::new("Level should not be empty."));
}
}
root.map_or_else(|| Err(Exception::new("Failed to get root.")), Ok)
}
fn merge_nodes(
&mut self,
tree_refs: &mut [TreeRef<N>],
level: Vec<(usize, usize, usize)>,
) -> BinaryMerkleTreeResult<Option<Array<N>>> {
#[cfg(feature = "serde")]
let mut root = Array::default();
#[cfg(not(any(feature = "serde")))]
let mut root = [0; N];
for (split_index, tree_ref_pointer, next_tree_ref_pointer) in level {
let mut branch = M::Branch::new();
let tree_ref_key = tree_refs[tree_ref_pointer].key;
let tree_ref_location = tree_refs[tree_ref_pointer].location;
let tree_ref_count = tree_refs[tree_ref_pointer].node_count;
let mut lookahead_count;
let mut lookahead_tree_ref_pointer: usize;
{
let mut count_ = tree_refs[next_tree_ref_pointer].count;
if count_ > 1 {
lookahead_tree_ref_pointer = tree_ref_pointer + usize::try_from(count_)?;
lookahead_count = tree_refs[lookahead_tree_ref_pointer].count;
while lookahead_count > count_ {
count_ = lookahead_count;
lookahead_tree_ref_pointer = tree_ref_pointer + usize::try_from(count_)?;
lookahead_count = tree_refs[lookahead_tree_ref_pointer].count;
}
} else {
lookahead_count = count_;
lookahead_tree_ref_pointer = next_tree_ref_pointer;
}
}
let next_tree_ref_location = tree_refs[lookahead_tree_ref_pointer].location;
let count = tree_ref_count + tree_refs[lookahead_tree_ref_pointer].node_count;
let branch_node_location;
{
let mut branch_hasher = M::Hasher::new(root.len());
branch_hasher.update(b"b");
branch_hasher.update(&tree_ref_location[..]);
branch_hasher.update(&next_tree_ref_location[..]);
branch_node_location = branch_hasher.finalize();
branch.set_zero(tree_ref_location);
branch.set_one(next_tree_ref_location);
branch.set_count(count);
branch.set_split_index(split_index);
branch.set_key(tree_ref_key);
}
let mut branch_node = M::Node::new(NodeVariant::Branch(branch));
branch_node.set_references(1);
self.db.insert(branch_node_location, branch_node)?;
{
tree_refs[lookahead_tree_ref_pointer].key = tree_ref_key;
tree_refs[lookahead_tree_ref_pointer].location = branch_node_location;
tree_refs[lookahead_tree_ref_pointer].count =
lookahead_count + tree_refs[tree_ref_pointer].count;
tree_refs[lookahead_tree_ref_pointer].node_count = count;
tree_refs[tree_ref_pointer] = tree_refs[lookahead_tree_ref_pointer];
}
root = branch_node_location;
}
self.db.batch_write()?;
Ok(Some(root))
}
#[inline]
pub fn remove(&mut self, root_hash: &Array<N>) -> BinaryMerkleTreeResult<()> {
let mut nodes = VecDeque::with_capacity(128);
nodes.push_front(*root_hash);
while !nodes.is_empty() {
let node_location;
if let Some(location) = nodes.pop_front() {
node_location = location;
} else {
return Err(Exception::new("Nodes should not be empty."));
}
let node = if let Some(n) = self.db.get_node(node_location)? {
n
} else {
continue;
};
let mut refs = node.get_references();
refs = refs.saturating_sub(1);
let mut new_node;
match node.get_variant() {
NodeVariant::Branch(b) => {
if refs == 0 {
let zero = *b.get_zero();
let one = *b.get_one();
nodes.push_back(zero);
nodes.push_back(one);
self.db.remove(&node_location)?;
continue;
}
new_node = M::Node::new(NodeVariant::Branch(b));
}
NodeVariant::Leaf(l) => {
if refs == 0 {
let data = *l.get_data();
nodes.push_back(data);
self.db.remove(&node_location)?;
continue;
}
new_node = M::Node::new(NodeVariant::Leaf(l));
}
NodeVariant::Data(d) => {
if refs == 0 {
self.db.remove(&node_location)?;
continue;
}
new_node = M::Node::new(NodeVariant::Data(d));
}
}
new_node.set_references(refs);
self.db.insert(node_location, new_node)?;
}
self.db.batch_write()?;
Ok(())
}
#[inline]
pub fn generate_inclusion_proof(
&self,
root: &Array<N>,
key: Array<N>,
) -> BinaryMerkleTreeResult<Vec<(Array<N>, bool)>> {
let mut nodes = VecDeque::with_capacity(self.depth);
nodes.push_front(*root);
let mut proof = Vec::with_capacity(self.depth);
let mut found_leaf = false;
let mut depth = 0;
while let Some(location) = nodes.pop_front() {
if depth > self.depth {
return Err(Exception::new("Depth limit exceeded"));
}
depth += 1;
if let Some(node) = self.db.get_node(location)? {
match node.get_variant() {
NodeVariant::Branch(b) => {
if found_leaf {
return Err(Exception::new("Corrupt Merkle Tree"));
}
let index = b.get_split_index();
let b_key = b.get_key();
let min_split_index = calc_min_split_index(&[key], b_key)?;
let keys = &[key];
let descendants = check_descendants(keys, index, b_key, min_split_index)?;
if descendants.is_empty() {
return Err(Exception::new("Key not found in tree"));
}
if choose_zero(key, index)? {
proof.push((*b.get_one(), true));
nodes.push_back(*b.get_zero());
} else {
proof.push((*b.get_zero(), false));
nodes.push_back(*b.get_one());
}
}
NodeVariant::Leaf(l) => {
if found_leaf {
return Err(Exception::new("Corrupt Merkle Tree"));
}
if *l.get_key() != key {
return Err(Exception::new("Key not found in tree"));
}
let mut leaf_hasher = M::Hasher::new(location.len());
leaf_hasher.update(b"l");
leaf_hasher.update(&l.get_key()[..]);
leaf_hasher.update(&l.get_data()[..]);
let leaf_node_location = leaf_hasher.finalize();
proof.push((leaf_node_location, false));
nodes.push_back(*l.get_data());
found_leaf = true;
}
NodeVariant::Data(d) => {
if !found_leaf {
return Err(Exception::new("Corrupt Merkle Tree"));
}
let mut data_hasher = M::Hasher::new(location.len());
data_hasher.update(b"d");
data_hasher.update(&key[..]);
data_hasher.update(d.get_value());
let data_node_location = data_hasher.finalize();
proof.push((data_node_location, false));
}
}
} else {
return Err(Exception::new("Failed to find node"));
}
}
proof.reverse();
Ok(proof)
}
#[inline]
pub fn verify_inclusion_proof(
root: &Array<N>,
key: Array<N>,
value: &M::Value,
proof: &[(Array<N>, bool)],
) -> BinaryMerkleTreeResult<()> {
if proof.len() < 2 {
return Err(Exception::new("Proof is too short to be valid"));
}
let key_len = root.len();
let mut data_hasher = M::Hasher::new(key_len);
data_hasher.update(b"d");
data_hasher.update(&key[..]);
data_hasher.update(&value.encode()?);
let data_hash = data_hasher.finalize();
if data_hash != proof[0].0 {
return Err(Exception::new("Proof is invalid"));
}
let mut leaf_hasher = M::Hasher::new(key_len);
leaf_hasher.update(b"l");
leaf_hasher.update(&key[..]);
leaf_hasher.update(&data_hash[..]);
let leaf_hash = leaf_hasher.finalize();
if leaf_hash != proof[1].0 {
return Err(Exception::new("Proof is invalid"));
}
let mut current_hash = leaf_hash;
for item in proof.iter().skip(2) {
let mut branch_hasher = M::Hasher::new(key_len);
branch_hasher.update(b"b");
if item.1 {
branch_hasher.update(¤t_hash[..]);
branch_hasher.update(&item.0[..]);
} else {
branch_hasher.update(&item.0[..]);
branch_hasher.update(¤t_hash[..]);
}
let branch_hash = branch_hasher.finalize();
current_hash = branch_hash;
}
if *root != current_hash {
return Err(Exception::new("Proof is invalid"));
}
Ok(())
}
#[inline]
pub fn get_one(
&self,
root: &Array<N>,
key: &Array<N>,
) -> BinaryMerkleTreeResult<Option<M::Value>> {
let mut nodes = VecDeque::with_capacity(3);
nodes.push_front(*root);
let mut found_leaf = false;
let mut depth = 0;
while let Some(location) = nodes.pop_front() {
if depth > self.depth {
return Err(Exception::new("Depth limit exceeded"));
}
depth += 1;
if let Some(node) = self.db.get_node(location)? {
match node.get_variant() {
NodeVariant::Branch(b) => {
if found_leaf {
return Err(Exception::new("Corrupt Merkle Tree"));
}
let index = b.get_split_index();
let b_key = b.get_key();
let min_split_index = calc_min_split_index(&[*key], b_key)?;
let keys = &[*key];
let descendants = check_descendants(keys, index, b_key, min_split_index)?;
if descendants.is_empty() {
return Ok(None);
}
if choose_zero(*key, index)? {
nodes.push_back(*b.get_zero());
} else {
nodes.push_back(*b.get_one());
}
}
NodeVariant::Leaf(l) => {
if found_leaf {
return Err(Exception::new("Corrupt Merkle Tree"));
}
if l.get_key() != key {
return Ok(None);
}
found_leaf = true;
nodes.push_back(*l.get_data());
}
NodeVariant::Data(d) => {
if !found_leaf {
return Err(Exception::new("Corrupt Merkle Tree"));
}
let buffer = d.get_value();
let value = M::Value::decode(buffer)?;
return Ok(Some(value));
}
}
}
}
Ok(None)
}
#[inline]
pub fn insert_one(
&mut self,
previous_root: Option<&Array<N>>,
key: &Array<N>,
value: &M::Value,
) -> BinaryMerkleTreeResult<Array<N>> {
let mut value_map = HashMap::new();
value_map.insert(*key, value);
let leaf_location = self.insert_leaves(&[*key], &value_map)?[0];
let mut tree_refs = Vec::with_capacity(1);
let mut key_map = HashMap::new();
key_map.insert(*key, leaf_location);
let tree_ref = TreeRef::new(*key, leaf_location, 1, 1);
tree_refs.push(tree_ref);
if let Some(root) = previous_root {
let mut proof_nodes = self.generate_treerefs(root, &mut [*key], &key_map)?;
tree_refs.append(&mut proof_nodes);
}
let new_root = self.create_tree(tree_refs)?;
Ok(new_root)
}
#[inline]
pub fn decompose(self) -> (M::Database, usize) {
(self.db, self.depth)
}
}
enum SplitNodeType<'keys, NodeType: Node<N>, const N: usize> {
Ref(TreeRef<N>),
Cell(TreeCell<'keys, NodeType, N>),
}
#[allow(clippy::panic_in_result_fn)]
#[cfg(test)]
pub mod tests {
use crate::utils::tree_utils::choose_zero;
use super::*;
const KEY_LEN: usize = 32;
#[test]
fn it_chooses_the_right_branch_easy() -> Result<(), Exception> {
let key = [0x0F_u8; KEY_LEN];
for i in 0..8 {
let expected_branch = i < 4;
let branch = choose_zero(key.into(), i)?;
assert_eq!(branch, expected_branch);
}
Ok(())
}
#[test]
fn it_chooses_the_right_branch_medium() -> Result<(), Exception> {
{
let key = [0x55; KEY_LEN];
for i in 0..8 {
let expected_branch = i % 2 == 0;
let branch = choose_zero(key.into(), i)?;
assert_eq!(branch, expected_branch);
}
}
let key = [0xAA; KEY_LEN];
for i in 0..8 {
let expected_branch = i % 2 != 0;
let branch = choose_zero(key.into(), i)?;
assert_eq!(branch, expected_branch);
}
Ok(())
}
#[test]
fn it_chooses_the_right_branch_hard() -> Result<(), Exception> {
{
let key = [0x68; KEY_LEN];
for i in 0..8 {
let expected_branch = !(i == 1 || i == 2 || i == 4);
let branch = choose_zero(key.into(), i)?;
assert_eq!(branch, expected_branch);
}
}
let key = [0xAB; KEY_LEN];
for i in 0..8 {
let expected_branch = !(i == 0 || i == 2 || i == 4 || i == 6 || i == 7);
let branch = choose_zero(key.into(), i)?;
assert_eq!(branch, expected_branch);
}
Ok(())
}
#[test]
fn it_splits_an_all_zeros_sorted_list_of_pairs() -> Result<(), Exception> {
#[cfg(feature = "serde")]
let zero_key = Array([0x00_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let zero_key = [0x00_u8; KEY_LEN];
let key_vec = vec![
zero_key, zero_key, zero_key, zero_key, zero_key, zero_key, zero_key, zero_key,
zero_key, zero_key,
];
let keys = key_vec;
let result = split_pairs(&keys, 0)?;
assert_eq!(result.0.len(), 10);
assert_eq!(result.1.len(), 0);
for &res in result.0 {
#[cfg(feature = "serde")]
assert_eq!(res, [0x00_u8; KEY_LEN].into());
#[cfg(not(any(feature = "serde")))]
assert_eq!(res, [0x00_u8; KEY_LEN]);
}
Ok(())
}
#[test]
fn it_splits_an_all_ones_sorted_list_of_pairs() -> Result<(), Exception> {
#[cfg(feature = "serde")]
let one_key = Array([0xFF_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let one_key = [0xFF_u8; KEY_LEN];
let keys = vec![
one_key, one_key, one_key, one_key, one_key, one_key, one_key, one_key, one_key,
one_key,
];
let result = split_pairs(&keys, 0)?;
assert_eq!(result.0.len(), 0);
assert_eq!(result.1.len(), 10);
for &res in result.1 {
#[cfg(feature = "serde")]
assert_eq!(res, [0xFF_u8; KEY_LEN].into());
#[cfg(not(any(feature = "serde")))]
assert_eq!(res, [0xFF_u8; KEY_LEN]);
}
Ok(())
}
#[test]
fn it_splits_an_even_length_sorted_list_of_pairs() -> Result<(), Exception> {
#[cfg(feature = "serde")]
let zero_key = Array([0x00_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let zero_key = [0x00_u8; KEY_LEN];
#[cfg(feature = "serde")]
let one_key = Array([0xFF_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let one_key = [0xFF_u8; KEY_LEN];
let keys = vec![
zero_key, zero_key, zero_key, zero_key, zero_key, one_key, one_key, one_key, one_key,
one_key,
];
let result = split_pairs(&keys, 0)?;
assert_eq!(result.0.len(), 5);
assert_eq!(result.1.len(), 5);
for &res in result.0 {
#[cfg(feature = "serde")]
assert_eq!(res, [0x00_u8; KEY_LEN].into());
#[cfg(not(any(feature = "serde")))]
assert_eq!(res, [0x00_u8; KEY_LEN]);
}
for &res in result.1 {
#[cfg(feature = "serde")]
assert_eq!(res, [0xFF_u8; KEY_LEN].into());
#[cfg(not(any(feature = "serde")))]
assert_eq!(res, [0xFF_u8; KEY_LEN]);
}
Ok(())
}
#[test]
fn it_splits_an_odd_length_sorted_list_of_pairs_with_more_zeros() -> Result<(), Exception> {
#[cfg(feature = "serde")]
let zero_key = Array([0x00_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let zero_key = [0x00_u8; KEY_LEN];
#[cfg(feature = "serde")]
let one_key = Array([0xFF_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let one_key = [0xFF_u8; KEY_LEN];
let keys = vec![
zero_key, zero_key, zero_key, zero_key, zero_key, zero_key, one_key, one_key, one_key,
one_key, one_key,
];
let result = split_pairs(&keys, 0)?;
assert_eq!(result.0.len(), 6);
assert_eq!(result.1.len(), 5);
for &res in result.0 {
#[cfg(feature = "serde")]
assert_eq!(res, [0x00_u8; KEY_LEN].into());
}
for &res in result.1 {
#[cfg(feature = "serde")]
assert_eq!(res, [0xFF_u8; KEY_LEN].into());
}
Ok(())
}
#[test]
fn it_splits_an_odd_length_sorted_list_of_pairs_with_more_ones() -> Result<(), Exception> {
#[cfg(feature = "serde")]
let zero_key = Array([0x00_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let zero_key = [0x00_u8; KEY_LEN];
#[cfg(feature = "serde")]
let one_key = Array([0xFF_u8; KEY_LEN]);
#[cfg(not(any(feature = "serde")))]
let one_key = [0xFF_u8; KEY_LEN];
let keys = vec![
zero_key, zero_key, zero_key, zero_key, zero_key, one_key, one_key, one_key, one_key,
one_key, one_key,
];
let result = split_pairs(&keys, 0)?;
assert_eq!(result.0.len(), 5);
assert_eq!(result.1.len(), 6);
for &res in result.0 {
#[cfg(feature = "serde")]
assert_eq!(res, [0x00_u8; KEY_LEN].into());
}
for &res in result.1 {
#[cfg(feature = "serde")]
assert_eq!(res, [0xFF_u8; KEY_LEN].into());
}
Ok(())
}
}