use pasta_curves::Fp;
use crate::fp_utils::{binary_search_records, read_fp, validate_all_fp_chunks};
use crate::{TIER2_LAYERS, TIER2_LEAF_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES};
pub struct Tier2Row<'a> {
data: &'a [u8],
}
impl<'a> Tier2Row<'a> {
pub fn from_bytes(data: &'a [u8]) -> anyhow::Result<Self> {
anyhow::ensure!(
data.len() == TIER2_ROW_BYTES,
"Tier 2 row size mismatch: got {} bytes, expected {}",
data.len(),
TIER2_ROW_BYTES
);
validate_all_fp_chunks(data, "Tier 2 row")?;
Ok(Self { data })
}
pub fn leaf_record(&self, i: usize) -> (Fp, Fp, Fp) {
debug_assert!(i < TIER2_LEAVES);
let base = i * TIER2_LEAF_BYTES;
let nf_lo = read_fp(&self.data[base..base + 32]);
let nf_mid = read_fp(&self.data[base + 32..base + 64]);
let nf_hi = read_fp(&self.data[base + 64..base + 96]);
(nf_lo, nf_mid, nf_hi)
}
pub fn find_leaf(&self, value: Fp, valid_leaves: usize) -> Option<usize> {
debug_assert!(valid_leaves <= TIER2_LEAVES);
if valid_leaves == 0 {
return None;
}
let idx = binary_search_records(self.data, 0, valid_leaves, TIER2_LEAF_BYTES, 0, value)?;
let (nf_lo, nf_mid, nf_hi) = self.leaf_record(idx);
let offset = value - nf_lo;
let span = nf_hi - nf_lo;
if offset == Fp::zero() || offset >= span {
return None;
}
if value == nf_mid {
return None;
}
Some(idx)
}
pub fn extract_siblings(
&self,
leaf_idx: usize,
valid_leaves: usize,
hasher: &imt_tree::hasher::PoseidonHasher,
) -> [Fp; TIER2_LAYERS] {
debug_assert!(valid_leaves <= TIER2_LEAVES);
let empty_leaf = hasher.hash3(Fp::zero(), Fp::zero(), Fp::zero());
let mut current_level: Vec<Fp> = (0..TIER2_LEAVES)
.map(|i| {
if i < valid_leaves {
let (lo, mid, hi) = self.leaf_record(i);
hasher.hash3(lo, mid, hi)
} else {
empty_leaf
}
})
.collect();
let mut siblings = [Fp::default(); TIER2_LAYERS];
let mut pos = leaf_idx;
for level in 0..TIER2_LAYERS {
siblings[level] = current_level[pos ^ 1];
if level < TIER2_LAYERS - 1 {
let next_len = current_level.len() / 2;
let mut next_level = Vec::with_capacity(next_len);
for j in 0..next_len {
next_level.push(hasher.hash(current_level[2 * j], current_level[2 * j + 1]));
}
current_level = next_level;
}
pos >>= 1;
}
siblings
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fp_utils::write_fp;
use imt_tree::hasher::PoseidonHasher;
#[test]
fn from_bytes_rejects_non_canonical_field_element() {
let mut row = vec![0u8; TIER2_ROW_BYTES];
row[0..32].fill(0xFF);
let err = Tier2Row::from_bytes(&row)
.err()
.expect("row should be rejected");
assert!(
err.to_string().contains("invalid field element"),
"unexpected error: {err}"
);
}
#[test]
fn punctured_leaf_record_round_trip() {
let mut row = vec![0u8; TIER2_ROW_BYTES];
let hasher = PoseidonHasher::new();
write_fp(&mut row[0..], Fp::from(10u64));
write_fp(&mut row[32..], Fp::from(20u64));
write_fp(&mut row[64..], Fp::from(30u64));
let tier2 = Tier2Row::from_bytes(&row).unwrap();
let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(0);
assert_eq!(nf_lo, Fp::from(10u64));
assert_eq!(nf_mid, Fp::from(20u64));
assert_eq!(nf_hi, Fp::from(30u64));
assert!(tier2.find_leaf(Fp::from(15u64), 1).is_some());
assert!(tier2.find_leaf(Fp::from(25u64), 1).is_some());
assert!(tier2.find_leaf(Fp::from(10u64), 1).is_none());
assert!(tier2.find_leaf(Fp::from(20u64), 1).is_none());
assert!(tier2.find_leaf(Fp::from(30u64), 1).is_none());
let sibs = tier2.extract_siblings(0, 1, &hasher);
assert_eq!(sibs[0], hasher.hash3(Fp::zero(), Fp::zero(), Fp::zero()));
}
}