use std::{borrow::Borrow, fmt::Debug};
use ark_ff::{FftField, Field};
use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Valid};
use derive_more::{AsRef, Deref, From, Into};
use rs_merkle::{Hasher, MerkleTree};
pub(crate) struct AssertPowerOfTwo<const N: usize>;
impl<const N: usize> AssertPowerOfTwo<N> {
pub const OK: () = assert!(N.is_power_of_two(), "`N` must be a power of two");
}
#[derive(From, Into, AsRef, Deref)]
#[repr(transparent)]
pub struct MerkleProof<H: Hasher>(rs_merkle::MerkleProof<H>);
impl<H: Hasher> MerkleProof<H> {
pub fn new(hashes: Vec<H::Hash>) -> Self {
rs_merkle::MerkleProof::new(hashes).into()
}
}
impl<H: Hasher> Clone for MerkleProof<H> {
fn clone(&self) -> Self {
Self(rs_merkle::MerkleProof::new(self.proof_hashes().to_vec()))
}
}
impl<H: Hasher> Debug for MerkleProof<H>
where
H::Hash: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut dbg = f.debug_tuple("MerkleProof");
for hash in self.proof_hashes() {
dbg.field(hash);
}
dbg.finish()
}
}
impl<H: Hasher> PartialEq for MerkleProof<H> {
fn eq(&self, other: &Self) -> bool {
self.proof_hashes() == other.proof_hashes()
}
}
impl<H: Hasher> CanonicalSerialize for MerkleProof<H>
where
H::Hash: CanonicalSerialize,
{
fn serialize_with_mode<W: ark_serialize::Write>(
&self,
writer: W,
compress: Compress,
) -> Result<(), ark_serialize::SerializationError> {
self.proof_hashes().serialize_with_mode(writer, compress)
}
fn serialized_size(&self, compress: Compress) -> usize {
self.proof_hashes().serialized_size(compress)
}
}
impl<H: Hasher> Valid for MerkleProof<H>
where
H::Hash: Valid,
{
fn check(&self) -> Result<(), ark_serialize::SerializationError> {
H::Hash::batch_check(self.proof_hashes().iter())
}
}
impl<H: Hasher> CanonicalDeserialize for MerkleProof<H>
where
H::Hash: CanonicalDeserialize,
{
fn deserialize_with_mode<R: ark_serialize::Read>(
reader: R,
compress: Compress,
validate: ark_serialize::Validate,
) -> Result<Self, ark_serialize::SerializationError> {
<Vec<H::Hash>>::deserialize_with_mode(reader, compress, validate).map(Self::new)
}
}
impl<H: Hasher> Borrow<rs_merkle::MerkleProof<H>> for MerkleProof<H> {
fn borrow(&self) -> &rs_merkle::MerkleProof<H> {
self
}
}
pub trait HasherExt: Hasher {
fn hash_item_with<S: CanonicalSerialize + ?Sized>(
value: &S,
buffer: &mut Vec<u8>,
) -> Self::Hash;
fn hash_item<S: CanonicalSerialize + ?Sized>(value: &S) -> Self::Hash {
Self::hash_item_with(value, &mut Vec::with_capacity(value.compressed_size()))
}
fn hash_many<S: CanonicalSerialize>(values: &[S]) -> Vec<Self::Hash> {
let mut hashes = Vec::with_capacity(values.len());
let mut bytes = Vec::with_capacity(values.first().map_or(0, S::compressed_size));
for evaluation in values {
hashes.push(Self::hash_item_with(evaluation, &mut bytes));
}
hashes
}
}
impl<H: Hasher> HasherExt for H {
fn hash_item_with<S: CanonicalSerialize + ?Sized>(
value: &S,
buffer: &mut Vec<u8>,
) -> Self::Hash {
buffer.clear();
value
.serialize_compressed(&mut *buffer)
.expect("Serialization failed");
H::hash(buffer)
}
}
pub(crate) trait MerkleTreeExt {
fn from_evaluations<S: CanonicalSerialize>(evaluations: &[S]) -> Self;
}
impl<H: Hasher> MerkleTreeExt for MerkleTree<H> {
fn from_evaluations<S: CanonicalSerialize>(evaluations: &[S]) -> Self {
let hashes = H::hash_many(evaluations);
Self::from_leaves(&hashes)
}
}
#[inline]
pub fn to_evaluations<F: FftField>(mut polynomial: Vec<F>, domain_size: usize) -> Vec<F> {
debug_assert!(
domain_size.is_power_of_two(),
"Domain size must be a power of two"
);
let domain = GeneralEvaluationDomain::<F>::new(domain_size).unwrap();
domain.fft_in_place(&mut polynomial);
polynomial
}
#[inline]
pub fn to_polynomial<F: FftField>(mut evaluations: Vec<F>, degree_bound: usize) -> Vec<F> {
debug_assert!(
evaluations.len().is_power_of_two(),
"Domain size must be a power of two"
);
let domain = GeneralEvaluationDomain::<F>::new(evaluations.len()).unwrap();
domain.ifft_in_place(&mut evaluations);
debug_assert!(
evaluations[degree_bound..].iter().all(|c| *c == F::ZERO),
"Degree of polynomial is not bound by {degree_bound}"
);
evaluations.truncate(degree_bound);
evaluations
}
#[inline]
pub fn horner_evaluate<F: Field, I: IntoIterator>(coeffs: I, alpha: F) -> F
where
I::IntoIter: DoubleEndedIterator,
I::Item: Borrow<F>,
{
coeffs
.into_iter()
.rfold(F::ZERO, |result, eval| result * alpha + eval.borrow())
}