use crate::error::StorageError;
use crate::types::{FrameID, Hash};
use blake3::Hasher;
use rs_merkle::{Hasher as RsHasher, MerkleTree};
use std::collections::BTreeSet;
#[derive(Clone, Debug)]
pub struct Blake3Hasher;
impl RsHasher for Blake3Hasher {
type Hash = [u8; 32];
fn hash(data: &[u8]) -> Self::Hash {
let mut hasher = Hasher::new();
hasher.update(data);
*hasher.finalize().as_bytes()
}
}
pub struct FrameMerkleSet {
frames: BTreeSet<FrameID>,
root: Option<Hash>,
}
impl Default for FrameMerkleSet {
fn default() -> Self {
Self::new()
}
}
impl FrameMerkleSet {
pub fn new() -> Self {
FrameMerkleSet {
frames: BTreeSet::new(),
root: Some(compute_empty_set_hash()),
}
}
pub fn add_frame(&mut self, frame_id: FrameID) -> Result<Hash, StorageError> {
if self.frames.contains(&frame_id) {
return Ok(self.root.expect("Root should exist if frames exist"));
}
self.frames.insert(frame_id);
self.rebuild_tree()
}
pub fn remove_frame(&mut self, frame_id: FrameID) -> Result<Hash, StorageError> {
if !self.frames.contains(&frame_id) {
return Ok(self.root.expect("Root should exist"));
}
self.frames.remove(&frame_id);
self.rebuild_tree()
}
pub fn root(&self) -> Option<Hash> {
self.root
}
pub fn contains(&self, frame_id: &FrameID) -> bool {
self.frames.contains(frame_id)
}
pub fn len(&self) -> usize {
self.frames.len()
}
pub fn is_empty(&self) -> bool {
self.frames.is_empty()
}
pub fn frame_ids(&self) -> impl Iterator<Item = &FrameID> {
self.frames.iter()
}
fn rebuild_tree(&mut self) -> Result<Hash, StorageError> {
if self.frames.is_empty() {
let empty_hash = compute_empty_set_hash();
self.root = Some(empty_hash);
return Ok(empty_hash);
}
let leaves: Vec<[u8; 32]> = self
.frames
.iter()
.map(|frame_id| {
let mut hasher = Hasher::new();
hasher.update(b"frame_leaf");
hasher.update(frame_id);
*hasher.finalize().as_bytes()
})
.collect();
let tree = MerkleTree::<Blake3Hasher>::from_leaves(&leaves);
let root_opt = tree.root();
let root = root_opt.ok_or_else(|| {
StorageError::IoError(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to compute Merkle tree root",
))
})?;
self.root = Some(root);
Ok(root)
}
pub fn from_frame_ids<I>(frame_ids: I) -> Result<Self, StorageError>
where
I: IntoIterator<Item = FrameID>,
{
let mut set = FrameMerkleSet {
frames: BTreeSet::new(),
root: None,
};
for frame_id in frame_ids {
set.frames.insert(frame_id);
}
set.rebuild_tree()?;
Ok(set)
}
}
fn compute_empty_set_hash() -> Hash {
let mut hasher = Hasher::new();
hasher.update(b"empty_frame_set");
*hasher.finalize().as_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_set() {
let set = FrameMerkleSet::new();
assert!(set.is_empty());
assert_eq!(set.len(), 0);
assert!(set.root().is_some());
}
#[test]
fn test_add_frame() {
let mut set = FrameMerkleSet::new();
let frame_id: FrameID = [1u8; 32];
let root1 = set.add_frame(frame_id).unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&frame_id));
assert!(root1 != compute_empty_set_hash());
let root2 = set.add_frame(frame_id).unwrap();
assert_eq!(root1, root2);
assert_eq!(set.len(), 1);
}
#[test]
fn test_deterministic_root() {
let mut set1 = FrameMerkleSet::new();
let mut set2 = FrameMerkleSet::new();
let frame_id1: FrameID = [1u8; 32];
let frame_id2: FrameID = [2u8; 32];
let frame_id3: FrameID = [3u8; 32];
set1.add_frame(frame_id1).unwrap();
set1.add_frame(frame_id2).unwrap();
set1.add_frame(frame_id3).unwrap();
set2.add_frame(frame_id3).unwrap();
set2.add_frame(frame_id1).unwrap();
set2.add_frame(frame_id2).unwrap();
assert_eq!(set1.root(), set2.root());
}
#[test]
fn test_remove_frame() {
let mut set = FrameMerkleSet::new();
let frame_id1: FrameID = [1u8; 32];
let frame_id2: FrameID = [2u8; 32];
set.add_frame(frame_id1).unwrap();
set.add_frame(frame_id2).unwrap();
assert_eq!(set.len(), 2);
let root_before = set.root().unwrap();
let root_after = set.remove_frame(frame_id1).unwrap();
assert_eq!(set.len(), 1);
assert!(!set.contains(&frame_id1));
assert!(set.contains(&frame_id2));
assert_ne!(root_before, root_after);
let root_unchanged = set.remove_frame(frame_id1).unwrap();
assert_eq!(root_after, root_unchanged);
}
#[test]
fn test_from_frame_ids() {
let frame_ids = vec![[1u8; 32], [2u8; 32], [3u8; 32]];
let set = FrameMerkleSet::from_frame_ids(frame_ids).unwrap();
assert_eq!(set.len(), 3);
assert!(set.root().is_some());
}
#[test]
fn test_empty_set_root_stable() {
let set1 = FrameMerkleSet::new();
let set2 = FrameMerkleSet::new();
assert_eq!(set1.root(), set2.root());
}
}