use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
use num_bigint::BigUint;
use crate::{
instructions::{hash::HashCPU, map::MapCPU},
utils::util::big_to_fe,
CircuitField,
};
pub(crate) const TREE_HEIGHT: u8 = 128;
#[derive(Clone, Debug)]
pub struct MapMt<F: CircuitField, H: HashCPU<F, F>> {
pub(crate) root: F,
nodes: HashMap<(u8, u128), F>,
map: HashMap<BigUint, F>,
default_nodes: [F; TREE_HEIGHT as usize],
_marker: PhantomData<H>,
}
impl<F: CircuitField, H: HashCPU<F, F>> PartialEq for MapMt<F, H> {
fn eq(&self, other: &Self) -> bool {
self.root == other.root
&& self.nodes == other.nodes
&& self.default_nodes == other.default_nodes
&& self.map == other.map
}
}
impl<F: CircuitField, H: HashCPU<F, F>> Eq for MapMt<F, H> {}
impl<F, H> MapCPU<F, F, F> for MapMt<F, H>
where
F: CircuitField,
H: HashCPU<F, F>,
{
fn new(default: &F) -> Self {
let nodes = HashMap::new();
let map = HashMap::new();
let mut default_nodes = [*default; TREE_HEIGHT as usize];
for i in 1..TREE_HEIGHT as usize {
default_nodes[i] =
<H as HashCPU<F, F>>::hash(&[default_nodes[i - 1], default_nodes[i - 1]]);
}
let root = <H as HashCPU<F, F>>::hash(&[default_nodes[127], default_nodes[127]]);
Self {
root,
nodes,
map,
default_nodes,
_marker: PhantomData,
}
}
fn succinct_repr(&self) -> F {
self.root
}
fn insert(&mut self, key: &F, value: &F) {
self.map.insert(key.to_biguint(), *value);
let mut child = *value;
let mut node_index = Self::compute_node_index(key);
for height in 0..TREE_HEIGHT {
self.nodes.insert((height, node_index), child);
let sibling = self.get_sibling(node_index, height);
let (x, y) = conditional_swap(node_index & 1 == 1, &child, &sibling);
child = <H as HashCPU<F, F>>::hash(&[x, y]);
node_index >>= 1;
}
self.root = child;
}
fn get(&self, key: &F) -> F {
self.map.get(&key.to_biguint()).copied().unwrap_or(self.default_nodes[0])
}
}
impl<F, H> IntoIterator for MapMt<F, H>
where
F: CircuitField,
H: HashCPU<F, F>,
{
type Item = (F, F);
type IntoIter = std::iter::Map<
std::collections::hash_map::IntoIter<BigUint, F>,
fn((BigUint, F)) -> (F, F),
>;
fn into_iter(self) -> Self::IntoIter {
self.map.into_iter().map(|(key, value)| (big_to_fe(key), value))
}
}
impl<F, H> MapMt<F, H>
where
F: CircuitField,
H: HashCPU<F, F>,
{
pub(crate) fn get_path(&self, key: &F) -> [F; TREE_HEIGHT as usize] {
let mut node_index = Self::compute_node_index(key);
let mut nodes = [F::ZERO; TREE_HEIGHT as usize];
for (i, val) in nodes.iter_mut().enumerate() {
*val = self.get_sibling(node_index, i as u8);
node_index >>= 1;
}
nodes
}
#[cfg(test)]
fn verify_mem_proof(root: &F, path: &[F; TREE_HEIGHT as usize], key: &F, value: F) -> bool {
let mut node_index = Self::compute_node_index(key);
let mut child = value;
for node in path {
let (x, y) = conditional_swap(node_index & 1 == 1, &child, node);
child = <H as HashCPU<F, F>>::hash(&[x, y]);
node_index >>= 1;
}
*root == child
}
fn get_sibling(&self, node_index: u128, height: u8) -> F {
assert!(height == 0 || (node_index < 1 << (TREE_HEIGHT - height)));
let sibling_index = node_index + 1 - 2 * (node_index & 1);
*self
.nodes
.get(&(height, sibling_index))
.unwrap_or(&self.default_nodes[height as usize])
}
fn compute_node_index(element: &F) -> u128 {
let hashed_value = <H as HashCPU<F, F>>::hash(&[*element, F::ZERO]);
let bytes = hashed_value.to_bytes_le();
u128::from_le_bytes(bytes[..TREE_HEIGHT as usize / 8].try_into().unwrap())
}
}
fn conditional_swap<F: CircuitField>(cond: bool, left_input: &F, right_input: &F) -> (F, F) {
if cond {
(*right_input, *left_input)
} else {
(*left_input, *right_input)
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use super::*;
use crate::hash::poseidon::{constants::PoseidonField, PoseidonChip};
fn test_map<F, H>()
where
F: CircuitField,
H: HashCPU<F, F> + Debug,
{
let mut rng = ChaCha8Rng::seed_from_u64(0xc0ffee);
let mut mt = MapMt::<F, H>::new(&F::ZERO);
for _ in 0..100 {
mt.insert(&F::random(&mut rng), &F::ONE);
}
mt.insert(&F::ONE, &F::ONE);
assert_eq!(mt.get(&F::ONE), F::ONE);
let old_mt = mt.clone();
mt.insert(&F::ONE, &F::ONE);
assert_eq!(old_mt, mt);
let one_path = mt.get_path(&F::ONE);
let member = MapMt::<F, H>::verify_mem_proof(&mt.root, &one_path, &F::ONE, F::ONE);
assert!(member);
let non_member = MapMt::<F, H>::verify_mem_proof(&mt.root, &one_path, &F::ONE, F::ZERO);
assert!(!non_member);
let new_value = F::random(&mut rng);
let path = mt.get_path(&new_value);
let non_member = MapMt::<F, H>::verify_mem_proof(&mt.root, &path, &new_value, F::ZERO);
assert!(non_member);
}
fn run_poseidon_test<F: PoseidonField>() {
test_map::<F, PoseidonChip<F>>();
}
#[test]
fn test_map_poseidon() {
run_poseidon_test::<midnight_curves::Fq>();
}
}