use super::MerkleTreeOverlay;
use crate::error::{Error, Result};
use crate::field::{Composite, Node, Primitive};
use crate::tree_arithmetic::zeroed::{left_most_leaf, subtree_index_to_general};
use crate::tree_arithmetic::{log_base_two, next_power_of_two};
use crate::types::{FixedVector, VariableList};
use crate::{NodeIndex, Path, BYTES_PER_CHUNK};
use ethereum_types::U256;
use typenum::Unsigned;
macro_rules! impl_merkle_overlay_for_basic_type {
($type: ident, $bit_size: expr) => {
impl MerkleTreeOverlay for $type {
fn height() -> u64 {
0
}
fn min_repr_size() -> u64 {
($bit_size / 8) as u64
}
fn get_node(path: Vec<Path>) -> Result<Node> {
if path.len() == 0 {
Ok(Node::Primitive(vec![Primitive {
ident: "".to_string(),
index: 0,
size: ($bit_size / 32) as u8,
offset: 0,
}]))
} else {
Err(Error::InvalidPath(path[0].clone()))
}
}
}
};
}
impl_merkle_overlay_for_basic_type!(bool, 8);
impl_merkle_overlay_for_basic_type!(u8, 8);
impl_merkle_overlay_for_basic_type!(u16, 16);
impl_merkle_overlay_for_basic_type!(u32, 32);
impl_merkle_overlay_for_basic_type!(u64, 64);
impl_merkle_overlay_for_basic_type!(u128, 128);
impl_merkle_overlay_for_basic_type!(U256, 256);
impl_merkle_overlay_for_basic_type!(usize, std::mem::size_of::<usize>());
macro_rules! impl_merkle_overlay_for_collection_type {
($type: ident, $is_variable_length: expr) => {
impl<T: MerkleTreeOverlay, N: Unsigned> MerkleTreeOverlay for $type<T, N> {
fn height() -> u64 {
let items_per_chunk = BYTES_PER_CHUNK as u64 / T::min_repr_size();
let num_leaves = next_power_of_two(N::to_u64() / items_per_chunk);
let data_tree_height = log_base_two(num_leaves);
if $is_variable_length {
data_tree_height + 1
} else {
data_tree_height
}
}
fn min_repr_size() -> u64 {
if Self::height() > 0 {
32
} else {
T::min_repr_size() * N::to_u64()
}
}
fn get_node(path: Vec<Path>) -> Result<Node> {
match path.first() {
Some(Path::Index(position)) => {
if *position >= N::to_u64() {
return Err(Error::IndexOutOfBounds(*position));
}
let first_leaf = left_most_leaf(0, Self::height() as u64);
let items_per_chunk = (BYTES_PER_CHUNK as u64 / T::min_repr_size()) as u64;
let leaf_index = first_leaf + (position / items_per_chunk);
if path.len() == 1 {
Ok(generate_leaf::<Self, T>(leaf_index))
} else {
let node = T::get_node(path[1..].to_vec())?;
let index = subtree_index_to_general(leaf_index, node.get_index());
Ok(replace_index(node.clone(), index))
}
}
Some(Path::Ident(i)) => {
if $is_variable_length && i == "len" {
Ok(Node::Length(Primitive {
ident: "len".to_string(),
index: 2,
size: 32,
offset: 0,
}))
} else {
Err(Error::InvalidPath(path[0].clone()))
}
}
None => Err(Error::EmptyPath()),
}
}
}
};
}
impl_merkle_overlay_for_collection_type!(VariableList, true);
impl_merkle_overlay_for_collection_type!(FixedVector, false);
fn generate_leaf<S: MerkleTreeOverlay, T: MerkleTreeOverlay>(index: NodeIndex) -> Node {
let first_leaf = left_most_leaf(0, S::height() as u64);
match T::get_node(vec![]) {
Ok(_) => {
let item_size = std::mem::size_of::<T>() as u8;
let items_per_chunk = BYTES_PER_CHUNK as u8 / item_size;
let values = vec![Primitive::default(); items_per_chunk as usize]
.iter()
.enumerate()
.map(|(i, _)| Primitive {
ident: ((index - first_leaf) * items_per_chunk as u64 + i as u64).to_string(),
index: index,
size: item_size,
offset: i as u8 * item_size,
})
.collect();
Node::Primitive(values)
}
Err(_) => Node::Composite(Composite {
ident: (index - first_leaf).to_string(),
index,
height: T::height(),
}),
}
}
pub fn replace_index(node: Node, index: NodeIndex) -> Node {
match node {
Node::Composite(c) => Node::Composite(Composite {
ident: c.ident,
index: index,
height: c.height,
}),
Node::Primitive(b) => Node::Primitive(
b.iter()
.cloned()
.map(|mut x| {
x.index = index;
x
})
.collect(),
),
Node::Length(b) => Node::Length(Primitive {
ident: b.ident,
index: index,
size: 32,
offset: 0,
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
use typenum::{U1, U16, U2, U32, U4, U8};
fn build_node(ident: &str, index: u64) -> Node {
Node::Primitive(vec![Primitive {
ident: ident.to_string(),
index: index,
size: 32,
offset: 0,
}])
}
fn ident_path(ident: &str) -> Vec<Path> {
vec![Path::Ident(ident.to_string())]
}
fn index_path(index: u64) -> Vec<Path> {
vec![Path::Index(index)]
}
#[test]
fn variable_list_overlay() {
type T = VariableList<U256, U8>;
assert_eq!(
T::get_node(ident_path("len")),
Ok(Node::Length(Primitive {
ident: "len".to_string(),
index: 2,
size: 32,
offset: 0,
}))
);
assert_eq!(T::get_node(index_path(0)), Ok(build_node("0", 15)));
assert_eq!(T::get_node(index_path(3)), Ok(build_node("3", 18)));
assert_eq!(T::get_node(index_path(7)), Ok(build_node("7", 22)));
assert_eq!(T::get_node(index_path(9)), Err(Error::IndexOutOfBounds(9)));
}
#[test]
fn nested_variable_list_overlay() {
type T = VariableList<VariableList<VariableList<U256, U2>, U2>, U4>;
assert_eq!(
T::get_node(vec![Path::Ident("len".to_string())]),
Ok(Node::Length(Primitive {
ident: "len".to_string(),
index: 2,
size: 32,
offset: 0,
}))
);
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Ident("len".to_string())]),
Ok(Node::Length(Primitive {
ident: "len".to_string(),
index: 16,
size: 32,
offset: 0,
}))
);
assert_eq!(
T::get_node(vec![Path::Index(3), Path::Ident("len".to_string())]),
Ok(Node::Length(Primitive {
ident: "len".to_string(),
index: 22,
size: 32,
offset: 0,
}))
);
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Index(1), Path::Index(0)]),
Ok(build_node("0", 131))
);
assert_eq!(
T::get_node(vec![Path::Index(2), Path::Index(1), Path::Index(0)]),
Ok(build_node("0", 163))
);
assert_eq!(
T::get_node(vec![Path::Index(3), Path::Index(0), Path::Index(1)]),
Ok(build_node("1", 176))
);
assert_eq!(
T::get_node(vec![Path::Index(4)]),
Err(Error::IndexOutOfBounds(4))
);
assert_eq!(
T::get_node(vec![Path::Index(3), Path::Index(2)]),
Err(Error::IndexOutOfBounds(2))
);
assert_eq!(
T::get_node(vec![Path::Index(3), Path::Index(1), Path::Index(2)]),
Err(Error::IndexOutOfBounds(2))
);
}
#[test]
fn simple_fixed_vector() {
type T = FixedVector<U256, U8>;
assert_eq!(T::height(), 3);
for i in 7..=14 {
assert_eq!(
T::get_node(vec![Path::Index(i - 7)]),
Ok(build_node(&(i - 7).to_string(), i))
);
}
assert_eq!(
T::get_node(vec![Path::Index(8)]),
Err(Error::IndexOutOfBounds(8))
);
assert_eq!(
T::get_node(ident_path("len")),
Err(Error::InvalidPath(Path::Ident("len".to_string())))
);
}
#[test]
fn another_simple_fixed_vector() {
type T = FixedVector<u8, U32>;
assert_eq!(T::height(), 0);
let node = Node::Primitive(
vec![Primitive::default(); 32]
.iter()
.cloned()
.enumerate()
.map(|(i, mut p)| {
p.ident = i.to_string();
p.index = 0;
p.size = 1;
p.offset = i as u8;
p
})
.collect(),
);
for i in 0..32 {
assert_eq!(T::get_node(vec![Path::Index(i)]), Ok(node.clone()));
}
}
#[test]
fn nested_fixed_vector() {
type T = FixedVector<FixedVector<FixedVector<U256, U16>, U2>, U1>;
assert_eq!(T::height(), 0);
assert_eq!(
T::get_node(index_path(0)),
Ok(Node::Composite(Composite {
ident: 0.to_string(),
index: 0,
height: 1,
}))
);
for i in 0..2 {
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Index(i)]),
Ok(Node::Composite(Composite {
ident: i.to_string(),
index: i + 1,
height: 4,
}))
);
for j in 0..16 {
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Index(i), Path::Index(j)]),
Ok(Node::Primitive(vec![Primitive {
ident: j.to_string(),
index: j + 31 + (i * 16),
offset: 0,
size: 32,
}]))
);
}
}
assert_eq!(
T::get_node(vec![Path::Index(1)]),
Err(Error::IndexOutOfBounds(1))
);
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Index(2)]),
Err(Error::IndexOutOfBounds(2))
);
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Index(0), Path::Index(16)]),
Err(Error::IndexOutOfBounds(16))
);
assert_eq!(
T::get_node(vec![Path::Index(0), Path::Index(1), Path::Index(16)]),
Err(Error::IndexOutOfBounds(16))
);
}
}