use crate::trie::error::{Result, TrieError};
use crate::trie::nibbles::Nibbles;
use crate::trie::node::{Node, NodeHandle};
use sha3::{Digest, Keccak256};
use std::io::{Read, Write};
use std::path::Path;
const MAGIC: &[u8; 4] = b"MPTC";
const VERSION: u8 = 1;
const CHECKSUM_LEN: usize = 8;
pub(crate) fn save(
root: &NodeHandle,
sync_tag: u64,
root_hash: &[u8],
w: &mut impl Write,
) -> Result<()> {
let mut payload = Vec::new();
payload.extend_from_slice(MAGIC);
payload.push(VERSION);
payload.extend_from_slice(&sync_tag.to_le_bytes());
let hash_len = root_hash.len() as u32;
payload.extend_from_slice(&hash_len.to_le_bytes());
payload.extend_from_slice(root_hash);
let tree_data = serialize_handle(root);
payload.extend_from_slice(&tree_data);
let checksum = compute_checksum(&payload);
w.write_all(&payload).map_err(io_err)?;
w.write_all(&checksum).map_err(io_err)?;
Ok(())
}
pub(crate) fn load(r: &mut impl Read) -> Result<(NodeHandle, u64, Vec<u8>)> {
let mut all_data = Vec::new();
r.read_to_end(&mut all_data).map_err(io_err)?;
if all_data.len() < CHECKSUM_LEN {
return Err(TrieError::InvalidState("cache file too short".into()));
}
let (payload, stored_checksum) = all_data.split_at(all_data.len() - CHECKSUM_LEN);
let expected = compute_checksum(payload);
if stored_checksum != expected {
return Err(TrieError::InvalidState("cache checksum mismatch".into()));
}
if payload.len() < 5 {
return Err(TrieError::InvalidState("cache header too short".into()));
}
if &payload[0..4] != MAGIC {
return Err(TrieError::InvalidState("invalid cache magic".into()));
}
if payload[4] != VERSION {
return Err(TrieError::InvalidState(format!(
"unsupported cache version {}",
payload[4]
)));
}
let mut cursor = 5;
if cursor + 8 > payload.len() {
return Err(TrieError::InvalidState(
"unexpected EOF reading sync_tag".into(),
));
}
let sync_tag = u64::from_le_bytes(payload[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
if cursor + 4 > payload.len() {
return Err(TrieError::InvalidState(
"unexpected EOF reading hash_len".into(),
));
}
let hash_len =
u32::from_le_bytes(payload[cursor..cursor + 4].try_into().unwrap()) as usize;
cursor += 4;
if cursor + hash_len > payload.len() {
return Err(TrieError::InvalidState(
"unexpected EOF reading root_hash".into(),
));
}
let root_hash = payload[cursor..cursor + hash_len].to_vec();
cursor += hash_len;
let root = deserialize_handle(payload, &mut cursor)?;
Ok((root, sync_tag, root_hash))
}
pub(crate) fn save_to_file(
root: &NodeHandle,
sync_tag: u64,
root_hash: &[u8],
path: &Path,
) -> Result<()> {
let mut f = std::fs::File::create(path).map_err(io_err)?;
save(root, sync_tag, root_hash, &mut f)
}
pub(crate) fn load_from_file(path: &Path) -> Result<(NodeHandle, u64, Vec<u8>)> {
let mut f = std::fs::File::open(path).map_err(io_err)?;
load(&mut f)
}
const HANDLE_INMEMORY: u8 = 0x00;
const HANDLE_CACHED: u8 = 0x01;
const NODE_NULL: u8 = 0x00;
const NODE_LEAF: u8 = 0x01;
const NODE_EXT: u8 = 0x02;
const NODE_BRANCH: u8 = 0x03;
fn serialize_handle(handle: &NodeHandle) -> Vec<u8> {
let mut buf = Vec::new();
match handle {
NodeHandle::InMemory(node) => {
buf.push(HANDLE_INMEMORY);
serialize_node(&mut buf, node);
}
NodeHandle::Cached(hash, node) => {
buf.push(HANDLE_CACHED);
write_bytes(&mut buf, hash);
serialize_node(&mut buf, node);
}
}
buf
}
fn serialize_node(buf: &mut Vec<u8>, node: &Node) {
match node {
Node::Null => buf.push(NODE_NULL),
Node::Leaf { path, value } => {
buf.push(NODE_LEAF);
write_nibbles(buf, path);
write_bytes(buf, value);
}
Node::Extension { path, child } => {
buf.push(NODE_EXT);
write_nibbles(buf, path);
let child_data = serialize_handle(child);
buf.extend_from_slice(&child_data);
}
Node::Branch { children, value } => {
buf.push(NODE_BRANCH);
let mut bitmap: u16 = 0;
for (i, c) in children.iter().enumerate() {
if c.is_some() {
bitmap |= 1 << i;
}
}
buf.extend_from_slice(&bitmap.to_le_bytes());
match value {
Some(v) => {
buf.push(1);
write_bytes(buf, v);
}
None => buf.push(0),
}
for child in children.iter().flatten() {
let child_data = serialize_handle(child);
buf.extend_from_slice(&child_data);
}
}
}
}
fn deserialize_handle(data: &[u8], cursor: &mut usize) -> Result<NodeHandle> {
let tag = read_u8(data, cursor)?;
match tag {
HANDLE_INMEMORY => {
let node = deserialize_node(data, cursor)?;
Ok(NodeHandle::InMemory(Box::new(node)))
}
HANDLE_CACHED => {
let hash = read_bytes(data, cursor)?;
let node = deserialize_node(data, cursor)?;
Ok(NodeHandle::Cached(hash, Box::new(node)))
}
_ => Err(TrieError::InvalidState(format!(
"invalid handle tag: {tag}"
))),
}
}
fn deserialize_node(data: &[u8], cursor: &mut usize) -> Result<Node> {
let tag = read_u8(data, cursor)?;
match tag {
NODE_NULL => Ok(Node::Null),
NODE_LEAF => {
let path = read_nibbles(data, cursor)?;
let value = read_bytes(data, cursor)?;
Ok(Node::Leaf { path, value })
}
NODE_EXT => {
let path = read_nibbles(data, cursor)?;
let child = deserialize_handle(data, cursor)?;
Ok(Node::Extension { path, child })
}
NODE_BRANCH => {
if *cursor + 2 > data.len() {
return Err(TrieError::InvalidState(
"unexpected EOF in branch bitmap".into(),
));
}
let bitmap = u16::from_le_bytes([data[*cursor], data[*cursor + 1]]);
*cursor += 2;
let has_value = read_u8(data, cursor)?;
let value = if has_value == 1 {
Some(read_bytes(data, cursor)?)
} else {
None
};
let mut children: Box<[Option<NodeHandle>; 16]> = Box::new([
None, None, None, None, None, None, None, None, None, None, None, None,
None, None, None, None,
]);
for i in 0..16 {
if (bitmap & (1 << i)) != 0 {
children[i] = Some(deserialize_handle(data, cursor)?);
}
}
Ok(Node::Branch { children, value })
}
_ => Err(TrieError::InvalidState(format!("invalid node tag: {tag}"))),
}
}
fn write_varint(buf: &mut Vec<u8>, mut n: usize) {
while n >= 0x80 {
buf.push(((n as u8) & 0x7F) | 0x80);
n >>= 7;
}
buf.push(n as u8);
}
fn read_varint(data: &[u8], cursor: &mut usize) -> Result<usize> {
let mut n: usize = 0;
let mut shift: u32 = 0;
loop {
if *cursor >= data.len() {
return Err(TrieError::InvalidState("varint unexpected EOF".into()));
}
let b = data[*cursor];
*cursor += 1;
if shift >= usize::BITS {
return Err(TrieError::InvalidState("varint overflow".into()));
}
let val = ((b & 0x7F) as usize)
.checked_shl(shift)
.ok_or_else(|| TrieError::InvalidState("varint overflow".into()))?;
n |= val;
if b & 0x80 == 0 {
break;
}
shift += 7;
}
Ok(n)
}
fn write_bytes(buf: &mut Vec<u8>, bytes: &[u8]) {
write_varint(buf, bytes.len());
buf.extend_from_slice(bytes);
}
fn read_bytes(data: &[u8], cursor: &mut usize) -> Result<Vec<u8>> {
let len = read_varint(data, cursor)?;
if *cursor + len > data.len() {
return Err(TrieError::InvalidState("bytes unexpected EOF".into()));
}
let bytes = data[*cursor..*cursor + len].to_vec();
*cursor += len;
Ok(bytes)
}
fn write_nibbles(buf: &mut Vec<u8>, nibbles: &Nibbles) {
let raw = nibbles.as_slice();
write_varint(buf, raw.len());
buf.extend_from_slice(raw);
}
fn read_nibbles(data: &[u8], cursor: &mut usize) -> Result<Nibbles> {
let len = read_varint(data, cursor)?;
if *cursor + len > data.len() {
return Err(TrieError::InvalidState("nibbles unexpected EOF".into()));
}
let raw = data[*cursor..*cursor + len].to_vec();
*cursor += len;
Ok(Nibbles::from_nibbles_unsafe(raw))
}
fn read_u8(data: &[u8], cursor: &mut usize) -> Result<u8> {
if *cursor >= data.len() {
return Err(TrieError::InvalidState("unexpected EOF".into()));
}
let v = data[*cursor];
*cursor += 1;
Ok(v)
}
fn io_err(e: std::io::Error) -> TrieError {
TrieError::InvalidState(format!("I/O error: {e}"))
}
fn compute_checksum(data: &[u8]) -> [u8; CHECKSUM_LEN] {
let hash = Keccak256::digest(data);
let mut out = [0u8; CHECKSUM_LEN];
out.copy_from_slice(&hash[..CHECKSUM_LEN]);
out
}