use crate::errors::{Result, ShiaError};
use crate::utils::{read_varint, write_varint, double_sha256};
use byteorder::{ReadBytesExt, WriteBytesExt};
use std::io::{Read, Write};
use anyhow::anyhow;
#[derive(Debug, Clone)]
pub struct Leaf {
pub offset: u64,
pub flags: u8,
pub hash: Option<[u8; 32]>,
}
#[derive(Debug, Clone)]
pub struct Bump {
pub block_height: u64,
pub tree_height: u8,
pub levels: Vec<Vec<Leaf>>,
}
impl Bump {
pub fn deserialize(reader: &mut impl Read) -> Result<Self> {
let block_height = read_varint(reader)?;
let tree_height = reader.read_u8()?;
if tree_height > 64 {
return Err(ShiaError::InvalidTreeHeight(tree_height));
}
let mut levels = Vec::with_capacity(tree_height as usize);
for _ in 0..tree_height {
let n_leaves = read_varint(reader)? as usize;
let mut leaves = Vec::with_capacity(n_leaves);
for _ in 0..n_leaves {
let offset = read_varint(reader)?;
let flags = reader.read_u8()?;
let hash = if flags == 0 || flags == 2 {
let mut h = [0u8; 32];
reader.read_exact(&mut h)?;
Some(h)
} else if flags == 1 {
None
} else {
return Err(ShiaError::InvalidFlags(flags));
};
leaves.push(Leaf { offset, flags, hash });
}
levels.push(leaves);
}
Ok(Self { block_height, tree_height, levels })
}
pub fn serialize(&self, writer: &mut impl Write) -> Result<()> {
write_varint(writer, self.block_height)?;
writer.write_u8(self.tree_height)?;
for level in &self.levels {
write_varint(writer, level.len() as u64)?;
for leaf in level {
write_varint(writer, leaf.offset)?;
writer.write_u8(leaf.flags)?;
if let Some(h) = leaf.hash {
writer.write_all(&h)?;
}
}
}
Ok(())
}
pub fn compute_merkle_root_for_hash(&self, leaf_hash: [u8; 32]) -> Result<[u8; 32]> {
if self.levels.is_empty() {
return Ok(leaf_hash); }
let level0 = &self.levels[0];
let leaf = level0
.iter()
.find(|l| l.flags == 2 && l.hash == Some(leaf_hash))
.ok_or(ShiaError::LeafNotFound)?;
let mut current_offset = leaf.offset;
let mut working = leaf_hash;
for level_idx in 0..self.tree_height as usize {
let current_level = &self.levels[level_idx];
let sibling_offset = current_offset ^ 1;
let sibling_leaf = current_level.iter().find(|l| l.offset == sibling_offset);
let sibling_hash = match sibling_leaf {
Some(leaf) => match leaf.flags {
1 => working, 0 | 2 => leaf.hash.ok_or(anyhow!("Hash missing for non-duplicate"))?,
_ => return Err(ShiaError::InvalidFlags(leaf.flags)),
},
None => working, };
let concat = if current_offset % 2 == 0 {
[&working[..], &sibling_hash[..]].concat() } else {
[&sibling_hash[..], &working[..]].concat() };
working = double_sha256(&concat);
current_offset /= 2;
}
Ok(working)
}
pub fn merge(&mut self, other: &Bump) -> Result<()> {
if self.block_height != other.block_height || self.tree_height != other.tree_height {
return Err(ShiaError::MergeMismatch("Heights differ"));
}
for (self_level, other_level) in self.levels.iter_mut().zip(other.levels.iter()) {
for other_leaf in other_level {
if let Some(existing) = self_level.iter_mut().find(|l| l.offset == other_leaf.offset) {
if existing.flags != other_leaf.flags || existing.hash != other_leaf.hash {
return Err(ShiaError::MergeMismatch("Conflicting leaf"));
}
} else {
self_level.push(other_leaf.clone());
}
}
self_level.sort_by_key(|l| l.offset); }
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_bump_compute_merkle_root() {
let tx1_hash = [1u8; 32];
let tx2_hash = [2u8; 32];
let concat = [&tx1_hash[..], &tx2_hash[..]].concat();
let root = double_sha256(&concat);
let mut bump_bytes = Vec::new();
write_varint(&mut bump_bytes, 1u64).unwrap(); bump_bytes.write_u8(1).unwrap(); write_varint(&mut bump_bytes, 2u64).unwrap(); write_varint(&mut bump_bytes, 0u64).unwrap(); bump_bytes.write_u8(2).unwrap(); bump_bytes.extend_from_slice(&tx1_hash);
write_varint(&mut bump_bytes, 1u64).unwrap(); bump_bytes.write_u8(0).unwrap(); bump_bytes.extend_from_slice(&tx2_hash);
let mut cursor = Cursor::new(bump_bytes);
let bump = Bump::deserialize(&mut cursor).unwrap();
let computed_root = bump.compute_merkle_root_for_hash(tx1_hash).unwrap();
assert_eq!(computed_root, root);
}
#[test]
fn test_bump_duplicate_mirror() {
let leaf_hash = [1u8; 32];
let mirrored_concat = [&leaf_hash[..], &leaf_hash[..]].concat();
let root = double_sha256(&mirrored_concat);
let mut bump_bytes = Vec::new();
write_varint(&mut bump_bytes, 1u64).unwrap();
bump_bytes.write_u8(1).unwrap();
write_varint(&mut bump_bytes, 1u64).unwrap(); write_varint(&mut bump_bytes, 0u64).unwrap(); bump_bytes.write_u8(2).unwrap(); bump_bytes.extend_from_slice(&leaf_hash);
let mut cursor = Cursor::new(bump_bytes);
let bump = Bump::deserialize(&mut cursor).unwrap();
let computed_root = bump.compute_merkle_root_for_hash(leaf_hash).unwrap();
assert_eq!(computed_root, root);
}
#[test]
fn test_bump_merge() {
let leaf_hash = [1u8; 32];
let sibling_hash = [2u8; 32];
let bump1 = Bump {
block_height: 1,
tree_height: 1,
levels: vec![vec![
Leaf { offset: 0, flags: 2, hash: Some(leaf_hash) },
]],
};
let bump2 = Bump {
block_height: 1,
tree_height: 1,
levels: vec![vec![
Leaf { offset: 1, flags: 0, hash: Some(sibling_hash) },
]],
};
let mut merged = bump1.clone();
merged.merge(&bump2).unwrap();
assert_eq!(merged.levels[0].len(), 2);
assert_eq!(merged.levels[0][0].offset, 0);
assert_eq!(merged.levels[0][1].offset, 1);
let concat = [&leaf_hash[..], &sibling_hash[..]].concat();
let expected_root = double_sha256(&concat);
assert_eq!(merged.compute_merkle_root_for_hash(leaf_hash).unwrap(), expected_root);
}
#[test]
fn test_bump_merge_conflict() {
let bump1 = Bump {
block_height: 1,
tree_height: 1,
levels: vec![vec![
Leaf { offset: 0, flags: 2, hash: Some([1u8; 32]) },
]],
};
let bump2 = Bump {
block_height: 1,
tree_height: 1,
levels: vec![vec![
Leaf { offset: 0, flags: 2, hash: Some([3u8; 32]) }, ]],
};
let mut merged = bump1.clone();
assert!(merged.merge(&bump2).is_err());
}
}