use arecibo::provider::PallasEngine;
use arecibo::traits::Engine;
use ff::Field;
use serde::{Deserialize, Serialize};
use super::NovaPoRError;
use crate::commitment::domain_tags;
use crate::commitment::poseidon_hash_tagged;
use crate::config;
use crate::utils::bytes31_to_field_le;
pub type F = <PallasEngine as Engine>::Scalar;
pub fn poseidon_hash_pair(left: F, right: F) -> F {
crate::commitment::poseidon_hash2(left, right)
}
pub fn hash_node(left: F, right: F) -> F {
poseidon_hash_tagged(domain_tags::node(), left, right)
}
pub fn hash_leaf_data(left: F, right: F) -> F {
poseidon_hash_tagged(domain_tags::leaf(), left, right)
}
pub fn get_leaf_hash(data: &[u8]) -> Result<F, NovaPoRError> {
if data.len() > config::CHUNK_SIZE_BYTES {
return Err(NovaPoRError::InvalidInput(format!(
"Data chunk too large for secure PoR: {} bytes (max {}). \
Larger chunks would be hashed, allowing provers to store only hashes \
instead of retrievable data. Use 31-byte symbols from erasure encoding.",
data.len(),
config::CHUNK_SIZE_BYTES
)));
}
if data.is_empty() {
return Ok(F::ZERO);
}
let element = bytes31_to_field_le::<F>(data);
Ok(element)
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MerkleTree {
pub layers: Vec<Vec<F>>,
}
impl MerkleTree {
pub fn root(&self) -> F {
self.layers
.last()
.and_then(|layer| layer.first())
.copied()
.unwrap_or(F::ZERO) }
}
pub fn build_tree_from_leaves(leaves: &[F]) -> Result<MerkleTree, NovaPoRError> {
if leaves.is_empty() {
return Ok(MerkleTree {
layers: vec![vec![F::ZERO]],
});
}
let mut layers = vec![leaves.to_vec()];
while layers
.last()
.ok_or_else(|| NovaPoRError::MerkleTree("build_tree_from_leaves: Tree layers should never be empty - internal error in tree construction".to_string()))?
.len()
> 1
{
let current_layer = layers.last().ok_or_else(|| {
NovaPoRError::MerkleTree("build_tree_from_leaves: Tree layers should never be empty during layer construction - internal error".to_string())
})?;
let mut next_layer = Vec::new();
for pair in current_layer.chunks(2) {
if pair.len() == 2 {
next_layer.push(hash_node(pair[0], pair[1]));
} else {
next_layer.push(hash_node(pair[0], pair[0]));
}
}
layers.push(next_layer);
}
Ok(MerkleTree { layers })
}
pub fn build_tree(data_chunks: &[Vec<u8>]) -> Result<(MerkleTree, F), NovaPoRError> {
if data_chunks.is_empty() {
let leaf = F::ZERO;
let tree = MerkleTree {
layers: vec![vec![leaf]],
};
return Ok((tree, leaf));
}
let leaves: Vec<F> = data_chunks
.iter()
.map(|chunk| get_leaf_hash(chunk))
.collect::<Result<Vec<_>, _>>()?;
let tree = build_tree_from_leaves(&leaves)?;
let root = tree.root();
Ok((tree, root))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitMerkleProof {
pub leaf: F,
pub siblings: Vec<F>,
pub path_indices: Vec<bool>,
}
pub fn get_padded_proof_for_leaf(
tree: &MerkleTree,
leaf_index: usize,
depth: usize,
) -> Result<CircuitMerkleProof, NovaPoRError> {
let leaf = tree
.layers
.first()
.and_then(|layer| layer.get(leaf_index))
.copied()
.ok_or_else(|| NovaPoRError::IndexOutOfBounds {
index: leaf_index,
length: tree.layers.first().map(|l| l.len()).unwrap_or(0),
})?;
let mut siblings = Vec::new();
let mut path_indices = Vec::new();
let mut current_index = leaf_index;
for level in 0..tree.layers.len() - 1 {
let current_layer = &tree.layers[level];
let is_right_node = current_index % 2 == 1;
let sibling_index = if is_right_node {
current_index - 1
} else {
current_index + 1
};
let sibling = current_layer
.get(sibling_index)
.copied()
.unwrap_or(current_layer[current_index]);
siblings.push(sibling);
path_indices.push(is_right_node);
current_index /= 2;
}
while siblings.len() < depth {
siblings.push(F::ZERO);
path_indices.push(false);
}
siblings.truncate(depth);
path_indices.truncate(depth);
Ok(CircuitMerkleProof {
leaf,
siblings,
path_indices,
})
}
pub fn verify_merkle_proof_in_place(root: F, proof: &CircuitMerkleProof) -> bool {
let mut current_hash = proof.leaf;
for (i, sibling) in proof.siblings.iter().enumerate() {
let path_bit = proof.path_indices.get(i).copied().unwrap_or(false);
if path_bit {
current_hash = hash_node(*sibling, current_hash);
} else {
current_hash = hash_node(current_hash, *sibling);
}
}
current_hash == root
}