use crate::{FieldElement, authenticator::AuthenticatorPublicKeySet, serde_utils::hex_u64};
use ark_bn254::Fr;
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as _};
mod array_serde {
use super::*;
pub fn serialize<S, T, const N: usize>(array: &[T; N], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
array.as_slice().serialize(serializer)
}
pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
let vec = Vec::<T>::deserialize(deserializer)?;
vec.try_into().map_err(|v: Vec<_>| {
D::Error::custom(format!("Expected array of size {}, got {}", N, v.len()))
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleInclusionProof<const TREE_DEPTH: usize> {
pub root: FieldElement,
#[serde(with = "hex_u64")]
pub leaf_index: u64,
#[serde(with = "array_serde")]
pub siblings: [FieldElement; TREE_DEPTH],
}
impl<const TREE_DEPTH: usize> MerkleInclusionProof<TREE_DEPTH> {
#[must_use]
pub const fn new(
root: FieldElement,
leaf_index: u64,
siblings: [FieldElement; TREE_DEPTH],
) -> Self {
Self {
root,
leaf_index,
siblings,
}
}
#[must_use]
pub fn is_valid(&self, leaf: FieldElement) -> bool {
let mut computed = leaf.0;
for (idx, sibling) in self.siblings.iter().enumerate() {
if (self.leaf_index >> idx) & 1 == 0 {
computed = poseidon2_compress(computed, sibling.0);
} else {
computed = poseidon2_compress(sibling.0, computed);
}
}
computed == self.root.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountInclusionProof<const TREE_DEPTH: usize> {
#[serde(flatten)]
pub inclusion_proof: MerkleInclusionProof<TREE_DEPTH>,
pub authenticator_pubkeys: AuthenticatorPublicKeySet,
}
impl<const TREE_DEPTH: usize> AccountInclusionProof<TREE_DEPTH> {
pub const fn new(
inclusion_proof: MerkleInclusionProof<TREE_DEPTH>,
authenticator_pubkeys: AuthenticatorPublicKeySet,
) -> Self {
Self {
inclusion_proof,
authenticator_pubkeys,
}
}
}
fn poseidon2_compress(left: Fr, right: Fr) -> Fr {
let mut state = poseidon2::bn254::t2::permutation(&[left, right]);
state[0] += left;
state[0]
}