use crate::hash::Hash;
use crate::utils;
use anyhow::{Result, anyhow};
use ff::{Field, PrimeField};
use starkom_bluesky::Scalar;
use starkom_poly;
use std::marker::PhantomData;
use std::sync::LazyLock;
type Polynomial = starkom_poly::Polynomial<Scalar>;
static LEAF_DST: LazyLock<Scalar> = LazyLock::new(|| utils::hash_to_scalar(b"starkom/fri/leaf"));
static TREE_DST: LazyLock<Scalar> = LazyLock::new(|| utils::hash_to_scalar(b"starkom/fri/tree"));
static FOLD_DST: LazyLock<Scalar> = LazyLock::new(|| utils::hash_to_scalar(b"starkom/fri/fold"));
static GENERATOR_INV: LazyLock<Scalar> =
LazyLock::new(|| Scalar::MULTIPLICATIVE_GENERATOR.invert().unwrap());
fn hash_leaf<H: Hash<Scalar>>(values: &[Scalar]) -> Scalar {
H::hash_many(
std::iter::once(*LEAF_DST)
.chain(std::iter::once(Scalar::from(values.len() as u64)))
.chain(values.iter().cloned())
.collect::<Vec<Scalar>>()
.as_slice(),
)
}
pub fn merklify<H: Hash<Scalar>>(mut values: &mut [Scalar], mut n: usize) {
assert!(n.is_power_of_two());
while n > 1 {
let m = n / 2;
for j in 0..m {
values[n + j] = H::hash_raw(*TREE_DST, values[j * 2], values[j * 2 + 1]);
}
values = &mut values[n..];
n = m;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Commitment {
roots: Vec<Scalar>,
}
impl Commitment {
pub fn len(&self) -> usize {
self.roots.len()
}
pub fn roots(&self) -> &[Scalar] {
self.roots.as_slice()
}
pub fn root(&self) -> Scalar {
*self.roots.first().unwrap()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeafProof<H: Hash<Scalar>> {
leaf: Vec<Scalar>,
path: Vec<Scalar>,
_data: PhantomData<H>,
}
impl<H: Hash<Scalar>> LeafProof<H> {
pub fn leaf(&self) -> &[Scalar] {
self.leaf.as_slice()
}
pub fn check_leaf(&self, expected: &[Scalar]) -> Result<()> {
if expected.len() != self.leaf.len()
|| self
.leaf
.iter()
.zip(expected.iter())
.any(|(&value1, &value2)| value1 != value2)
{
return Err(anyhow!("leaf value mismatch"));
}
Ok(())
}
pub fn len(&self) -> usize {
self.path.len()
}
pub fn verify(&self, mut index: usize, root_hash: Scalar) -> Result<()> {
let mut hash = hash_leaf::<H>(self.leaf.as_slice());
for sibling in &self.path {
hash = if index & 1 != 0 {
H::hash_raw(*TREE_DST, *sibling, hash)
} else {
H::hash_raw(*TREE_DST, hash, *sibling)
};
index >>= 1;
}
if index != 0 {
return Err(anyhow!("invalid index"));
}
if hash != root_hash {
return Err(anyhow!(
"root hash mismatch (got {}, want {})",
hash,
root_hash
));
}
Ok(())
}
pub fn is_constant(&self) -> bool {
let mut hash = hash_leaf::<H>(self.leaf.as_slice());
for &sibling in &self.path {
if sibling != hash {
return false;
}
hash = H::hash_raw(*TREE_DST, hash, hash);
}
true
}
}
#[derive(Debug, Clone)]
pub struct Tree<H: Hash<Scalar>> {
num_polys: usize,
leaves: Vec<Vec<Scalar>>,
hashes: Vec<Scalar>,
_data: PhantomData<H>,
}
impl<H: Hash<Scalar>> Tree<H> {
pub fn from_leaves(leaves: Vec<Vec<Scalar>>) -> Self {
let num_polys = leaves[0].len();
assert!(num_polys > 0);
let n = leaves.len();
assert!(n.is_power_of_two());
let mut hashes = vec![Scalar::ZERO; n * 2 - 1];
for i in 0..n {
let leaf = leaves[i].as_slice();
assert_eq!(leaf.len(), num_polys);
hashes[i] = hash_leaf::<H>(leaf);
}
merklify::<H>(hashes.as_mut_slice(), n);
Self {
num_polys,
leaves,
hashes,
_data: Default::default(),
}
}
pub fn new(values: Vec<Vec<Scalar>>) -> Self {
let k = values.len();
assert!(k > 0);
let n = values[0].len();
let leaves: Vec<Vec<Scalar>> = (0..n)
.map(|i| {
(0..k)
.map(|j| {
assert_eq!(n, values[j].len());
values[j][i]
})
.collect()
})
.collect();
Self::from_leaves(leaves)
}
pub fn num_polys(&self) -> usize {
self.num_polys
}
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
pub fn root_hash(&self) -> Scalar {
let n = self.leaves.len();
self.hashes[(n - 1) * 2]
}
pub fn leaf(&self, index: usize) -> &[Scalar] {
self.leaves[index].as_slice()
}
pub fn query(&self, mut index: usize) -> LeafProof<H> {
let mut n = self.leaves.len();
assert!(n.is_power_of_two());
assert!(index < n);
let leaf = self.leaves[index].clone();
let mut path = Vec::with_capacity(n.trailing_zeros() as usize);
let mut hashes = self.hashes.as_slice();
while n > 1 {
path.push(hashes[index ^ 1]);
hashes = &hashes[n..];
n /= 2;
index >>= 1;
}
LeafProof {
leaf,
path,
_data: Default::default(),
}
}
fn fold(&self) -> Self {
let n = self.leaves.len();
assert!(n.is_power_of_two());
let alpha = H::hash_raw(*FOLD_DST, self.hashes[(n - 1) * 2], Scalar::ZERO) * *GENERATOR_INV;
let k = n.trailing_zeros();
let omega_inv = Scalar::ROOT_OF_UNITY_INV.pow_vartime([1u64 << (Scalar::S - k), 0, 0, 0]);
let m = n / 2;
let mut omega_inv_i = Scalar::ONE;
let mut leaves = Vec::with_capacity(m);
for i in 0..m {
let pos = self.leaves[i].as_slice();
let neg = self.leaves[i + m].as_slice();
leaves.push(
pos.iter()
.cloned()
.zip(neg.iter().cloned())
.map(|(pos, neg)| {
(pos + neg + alpha * omega_inv_i * (pos - neg)) * Scalar::TWO_INV
})
.collect::<Vec<Scalar>>(),
);
omega_inv_i *= omega_inv;
}
Self::from_leaves(leaves)
}
fn fold_all(self, times: usize) -> Vec<Self> {
let mut trees = Vec::with_capacity(times + 1);
let mut tree = self;
for _ in 0..times {
let folded = tree.fold();
trees.push(tree);
tree = folded;
}
trees.push(tree);
trees
}
}
#[derive(Debug, Clone)]
pub struct Query<H: Hash<Scalar>> {
n: usize,
index: usize,
folds: Vec<(LeafProof<H>, LeafProof<H>)>,
_data: PhantomData<H>,
}
impl<H: Hash<Scalar>> Query<H> {
pub fn indices(&self) -> (usize, usize) {
(self.index, (self.index + self.n / 2) % self.n)
}
pub fn x(&self) -> Scalar {
Polynomial::coset_element2(self.index, self.n)
}
pub fn values(&self) -> (&[Scalar], &[Scalar]) {
(self.folds[0].0.leaf(), self.folds[0].1.leaf())
}
pub fn len(&self) -> usize {
return self.folds.len();
}
pub fn verify(&self, commitment: &Commitment) -> Result<()> {
assert!(self.n.is_power_of_two());
assert!(self.index < self.n);
let k = self.n.trailing_zeros();
let folds = self.folds.as_slice();
let h = folds.len();
if h > k as usize + 1 {
return Err(anyhow!("invalid proof size"));
}
if commitment.len() != h {
return Err(anyhow!("wrong number of folding rounds"));
}
let mut m = self.n;
let mut index = self.index;
let mut pos = self.folds[0].0.leaf().to_vec();
let mut step = Scalar::ROOT_OF_UNITY_INV.pow_vartime([1u64 << (Scalar::S - k), 0, 0, 0]);
for r in 0..h {
let (left, right) = &folds[r];
let root_hash = commitment.roots()[r];
let alpha = H::hash_raw(*FOLD_DST, root_hash, Scalar::ZERO) * *GENERATOR_INV;
let neg = right.leaf();
if 1usize << left.len() != m {
return Err(anyhow!(
"invalid left-hand side Merkle proof height (got {}, want {})",
left.len(),
m.trailing_zeros()
));
}
if 1usize << right.len() != m {
return Err(anyhow!(
"invalid right-hand side Merkle proof height (got {}, want {})",
right.len(),
m.trailing_zeros()
));
}
left.check_leaf(pos.as_slice())?;
left.verify(index, root_hash)?;
right.verify((index + m / 2) % m, root_hash)?;
let omega_inv_i = step.pow_vartime([index as u64, 0, 0, 0]);
m /= 2;
index %= m;
for i in 0..pos.len() {
pos[i] =
(pos[i] + neg[i] + alpha * omega_inv_i * (pos[i] - neg[i])) * Scalar::TWO_INV;
}
step = step.square();
}
let (left, right) = folds.last().unwrap();
if !left.is_constant() || !right.is_constant() {
return Err(anyhow!("final folded polynomial is not constant"));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Prover<H: Hash<Scalar>> {
degree_bound: usize,
blowup_log2: usize,
trees: Vec<Tree<H>>,
}
impl<H: Hash<Scalar>> Prover<H> {
pub fn new(polynomials: Vec<Polynomial>, degree_bound: usize, blowup_log2: usize) -> Self {
assert!(degree_bound.is_power_of_two());
assert!(
polynomials
.iter()
.all(|polynomial| degree_bound >= polynomial.degree_bound())
);
let n = degree_bound << blowup_log2;
assert!(n as u64 <= 1u64 << Scalar::S);
let main_tree = Tree::<H>::new(
polynomials
.into_iter()
.map(|polynomial| polynomial.shifted_lde2(n))
.collect(),
);
let trees = main_tree.fold_all(degree_bound.trailing_zeros() as usize);
Self {
degree_bound,
blowup_log2,
trees,
}
}
pub fn degree_bound(&self) -> usize {
self.degree_bound
}
pub fn extended_domain_size(&self) -> usize {
self.degree_bound << self.blowup_log2
}
pub fn size(&self) -> usize {
self.degree_bound << self.blowup_log2
}
pub fn root_hash(&self) -> Scalar {
self.trees[0].root_hash()
}
pub fn commit(&self) -> Commitment {
Commitment {
roots: self.trees.iter().map(|tree| tree.root_hash()).collect(),
}
}
pub fn query(&self, index: usize) -> Query<H> {
let d = self.degree_bound;
assert!(d.is_power_of_two());
let n = self.degree_bound << self.blowup_log2;
assert!(index < n);
let mut m = n;
let mut i = index;
let mut folds = vec![];
for tree in &self.trees {
folds.push((tree.query(i), tree.query((i + m / 2) % m)));
m /= 2;
i %= m;
}
{
let (left, right) = folds.last().unwrap();
assert!(left.is_constant());
assert!(right.is_constant());
}
Query {
n,
index,
folds,
_data: Default::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash;
use crate::utils::parse_scalar;
type Poseidon2Hash = hash::Poseidon2Hash<Scalar>;
type Sha2Hash = hash::Sha2Hash<Scalar>;
#[test]
fn test_merklify_one_sha2() {
let mut values = vec![12.into()];
merklify::<Sha2Hash>(&mut values, 1);
assert_eq!(values, vec![12.into()]);
}
#[test]
fn test_merklify_one_poseidon2() {
let mut values = vec![12.into()];
merklify::<Poseidon2Hash>(&mut values, 1);
assert_eq!(values, vec![12.into()]);
}
#[test]
fn test_merklify_two_sha2() {
let mut values = vec![34.into(), 56.into()];
values.resize(3, 0.into());
merklify::<Sha2Hash>(&mut values, 2);
assert_eq!(
values,
vec![
34.into(),
56.into(),
parse_scalar("0x4e92e96500d26aa6e159670815c01b01c89f3385627027e52b20c3be995d9cb4")
]
);
}
#[test]
fn test_merklify_two_poseidon2() {
let mut values = vec![34.into(), 56.into()];
values.resize(3, 0.into());
merklify::<Poseidon2Hash>(&mut values, 2);
assert_eq!(
values,
vec![
34.into(),
56.into(),
parse_scalar("0x460d694c3fc49199a27c631df8a837d5b64566c40075981ff5cb0396cf52a80b")
]
);
}
#[test]
fn test_merklify_four_sha2() {
let mut values = vec![78.into(), 90.into(), 12.into(), 34.into()];
values.resize(7, 0.into());
merklify::<Sha2Hash>(&mut values, 4);
assert_eq!(
values,
vec![
78.into(),
90.into(),
12.into(),
34.into(),
parse_scalar("0x4ba5bb5405a8d200b4c1b2fe1240daa6be892eb58048020e0a03f5fb6e009dec"),
parse_scalar("0x1f4cbe6657a61b9852cb8c219f5bf3a42d6404902560ef5dd14f91a414fff307"),
parse_scalar("0x58d1fe70cb37a8e745391c570e3cda9e0c24e74464fb5119a29d01f2b64af357"),
]
);
}
#[test]
fn test_merklify_four_poseidon2() {
let mut values = vec![78.into(), 90.into(), 12.into(), 34.into()];
values.resize(7, 0.into());
merklify::<Poseidon2Hash>(&mut values, 4);
assert_eq!(
values,
vec![
78.into(),
90.into(),
12.into(),
34.into(),
parse_scalar("0x09c10aba0c59772b51adb65ac6780471b94bf18f63aa121901fb3a428f171064"),
parse_scalar("0x2786758795737449218651c7a13e09e40159eb361293bbd7526c26a110f4b733"),
parse_scalar("0x183ba165b9bd525fddf2be1420f9087f9ebfbf0bedbdb9e3bf4ec7a785325b13"),
]
);
}
fn test_merkle_tree<H: Hash<Scalar>>(leaves: Vec<Vec<Scalar>>, expected_root_hash: Scalar) {
let tree = Tree::<H>::from_leaves(leaves.clone());
assert_eq!(tree.num_polys(), leaves[0].len());
assert_eq!(tree.num_leaves(), leaves.len());
assert_eq!(tree.root_hash(), expected_root_hash);
for i in 0..leaves.len() {
let leaf = &leaves[i];
let proof = tree.query(i);
assert!(proof.verify(i, expected_root_hash).is_ok());
assert_eq!(proof.leaf().len(), leaf.len());
assert!(
proof
.leaf()
.iter()
.zip(leaf.iter())
.all(|(&lhs, &rhs)| lhs == rhs)
);
}
}
#[test]
fn test_merkle_tree_one_leaf_1() {
test_merkle_tree::<Sha2Hash>(
vec![vec![12.into()]],
parse_scalar("0x563171d1d29fc71a8e64c1996982ba9391b948c0f8e53c06f49dd50a935195bd"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![12.into()]],
parse_scalar("0x2cdc51a32dac2ed86403822d494776d4512920a3790544b4be3ebf2cbde92171"),
);
}
#[test]
fn test_merkle_tree_one_leaf_2() {
test_merkle_tree::<Sha2Hash>(
vec![vec![34.into()]],
parse_scalar("0x0a6461fb4b46a4cbf7855d0f8b2221b476c8fa54d510d03c5e0f1b7add3720d6"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![34.into()]],
parse_scalar("0x7bd38f78c7b116426eb1c3ce88929882f997ba95fd38128105171586b32a8db0"),
);
}
#[test]
fn test_merkle_tree_one_leaf_two_polynomials_1() {
test_merkle_tree::<Sha2Hash>(
vec![vec![12.into(), 34.into()]],
parse_scalar("0x12bca773e3d548e97bc3c09698887d8aa79a2a224741aca93ac1f748bf9d0a76"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![12.into(), 34.into()]],
parse_scalar("0x7266dbf17f81908d1abfcc37f1ac92cbdbbc0ddf8ed65dd8c56c6a7d4d6d23cf"),
);
}
#[test]
fn test_merkle_tree_one_leaf_two_polynomials_2() {
test_merkle_tree::<Sha2Hash>(
vec![vec![34.into(), 12.into()]],
parse_scalar("0x54ee34331f32bd339abb4fc82eb2779e696b593bf4c186ca2e993eb0cd3711c3"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![34.into(), 12.into()]],
parse_scalar("0x2b81b71d5988e82d27fb48f0b4b2c8f8ff3ba99751f2449997d840d7518dc11b"),
);
}
#[test]
fn test_merkle_tree_one_leaf_three_polynomials_1() {
test_merkle_tree::<Sha2Hash>(
vec![vec![12.into(), 34.into(), 56.into()]],
parse_scalar("0x285ebc787db855722846ffd14565aa60215c39953648ea0555d493d6d998c634"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![12.into(), 34.into(), 56.into()]],
parse_scalar("0x0c3b7d987cb1e3d95e6e0f95fcc37f64ebe76006c7d060cc88d2f6e77ac8ee9c"),
);
}
#[test]
fn test_merkle_tree_one_leaf_three_polynomials_2() {
test_merkle_tree::<Sha2Hash>(
vec![vec![34.into(), 12.into(), 78.into()]],
parse_scalar("0x54b1edfda64b33dacfd52f04514907c6692cceb22c85737851b192e1b6fc230e"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![34.into(), 12.into(), 78.into()]],
parse_scalar("0x06e1ad1622dcc06212ca8b23ca3fd56cdc4d8b065d987412a9cb9fdf1d0e155d"),
);
}
#[test]
fn test_merkle_tree_two_leaves_1() {
test_merkle_tree::<Sha2Hash>(
vec![vec![12.into()], vec![34.into()]],
parse_scalar("0x20e65b4345db52cd8249ed9c1797f859c4f3dff7c5d9374eb4b89118cd39b643"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![12.into()], vec![34.into()]],
parse_scalar("0x2f0c2ee238a5c8f3f9380fa9cdd59d4c1774ef7659554bf37d2e40b1bfda0f3d"),
);
}
#[test]
fn test_merkle_tree_two_leaves_2() {
test_merkle_tree::<Sha2Hash>(
vec![vec![34.into()], vec![56.into()]],
parse_scalar("0x1cc4e046101296f69bed2fc83482ce4056cd50d36b18cb4b08920225144bcaa6"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![34.into()], vec![56.into()]],
parse_scalar("0x14ff951575b5892afaf39760090ee44f2de980a45528a69700570aa0321338ab"),
);
}
#[test]
fn test_merkle_tree_two_leaves_two_polynomials_1() {
test_merkle_tree::<Sha2Hash>(
vec![vec![12.into(), 34.into()], vec![56.into(), 78.into()]],
parse_scalar("0x65a2f27eccdf81249652273e3df595841ac0d7398b1a866fce9bd6fe3c891dbc"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![12.into(), 34.into()], vec![56.into(), 78.into()]],
parse_scalar("0x3233da25a14a69d02937ebb8f7ca4831e1aa53a069c14bdbddb138d0488b3827"),
);
}
#[test]
fn test_merkle_tree_two_leaves_two_polynomials_2() {
test_merkle_tree::<Sha2Hash>(
vec![vec![78.into(), 56.into()], vec![34.into(), 12.into()]],
parse_scalar("0x40383f40f001699d6bec77a7ef72289ed6461d28659b94c1156d3d8cad226141"),
);
test_merkle_tree::<Poseidon2Hash>(
vec![vec![78.into(), 56.into()], vec![34.into(), 12.into()]],
parse_scalar("0x73b93dac865870e7515af085294a34ebf1bb2f1d86faf619013a9855df6caebb"),
);
}
fn test_prover_impl<H: Hash<Scalar>>(
polynomials: Vec<Polynomial>,
degree_bound: usize,
blowup_log2: usize,
) {
let prover = Prover::<H>::new(polynomials, degree_bound, blowup_log2);
assert_eq!(prover.degree_bound(), degree_bound);
let n = degree_bound << blowup_log2;
assert_eq!(prover.extended_domain_size(), n);
let commitment = prover.commit();
for i in 0..n {
let query = prover.query(i);
assert_eq!(query.indices(), (i, (i + n / 2) % n));
assert_eq!(query.len(), degree_bound.trailing_zeros() as usize + 1);
assert!(query.verify(&commitment).is_ok());
}
}
fn test_prover(polynomials: Vec<Polynomial>, degree_bound: usize) {
test_prover_impl::<Sha2Hash>(polynomials.clone(), degree_bound, 1);
test_prover_impl::<Poseidon2Hash>(polynomials.clone(), degree_bound, 1);
test_prover_impl::<Sha2Hash>(polynomials.clone(), degree_bound, 2);
test_prover_impl::<Poseidon2Hash>(polynomials.clone(), degree_bound, 2);
test_prover_impl::<Sha2Hash>(polynomials.clone(), degree_bound, 3);
test_prover_impl::<Poseidon2Hash>(polynomials.clone(), degree_bound, 3);
}
#[test]
fn test_one_constant_polynomial() {
test_prover(vec![Polynomial::with_coefficients(vec![12.into()])], 1);
test_prover(vec![Polynomial::with_coefficients(vec![34.into()])], 1);
}
#[test]
fn test_two_constant_polynomials() {
test_prover(
vec![
Polynomial::with_coefficients(vec![12.into()]),
Polynomial::with_coefficients(vec![34.into()]),
],
1,
);
}
#[test]
fn test_three_constant_polynomials() {
test_prover(
vec![
Polynomial::with_coefficients(vec![34.into()]),
Polynomial::with_coefficients(vec![56.into()]),
Polynomial::with_coefficients(vec![78.into()]),
],
1,
);
}
#[test]
fn test_one_polynomial_degree_one() {
test_prover(
vec![Polynomial::with_coefficients(vec![12.into(), 34.into()])],
2,
);
test_prover(
vec![Polynomial::with_coefficients(vec![56.into(), 78.into()])],
2,
);
}
#[test]
fn test_two_polynomials_degree_one() {
test_prover(
vec![
Polynomial::with_coefficients(vec![12.into(), 34.into()]),
Polynomial::with_coefficients(vec![56.into(), 78.into()]),
],
2,
);
}
#[test]
fn test_three_polynomials_degree_one() {
test_prover(
vec![
Polynomial::with_coefficients(vec![34.into(), 56.into()]),
Polynomial::with_coefficients(vec![56.into(), 78.into()]),
Polynomial::with_coefficients(vec![78.into(), 90.into()]),
],
2,
);
}
#[test]
fn test_one_polynomial_degree_three() {
test_prover(
vec![Polynomial::with_coefficients(vec![
12.into(),
34.into(),
56.into(),
78.into(),
])],
4,
);
test_prover(
vec![Polynomial::with_coefficients(vec![
42.into(),
43.into(),
44.into(),
45.into(),
])],
4,
);
}
#[test]
fn test_two_polynomials_degree_three() {
test_prover(
vec![
Polynomial::with_coefficients(vec![12.into(), 34.into(), 56.into(), 78.into()]),
Polynomial::with_coefficients(vec![42.into(), 43.into(), 44.into(), 45.into()]),
],
4,
);
}
#[test]
fn test_three_polynomials_degree_three() {
test_prover(
vec![
Polynomial::with_coefficients(vec![42.into(), 43.into(), 44.into(), 45.into()]),
Polynomial::with_coefficients(vec![12.into(), 34.into(), 56.into(), 78.into()]),
Polynomial::with_coefficients(vec![34.into(), 56.into(), 78.into(), 90.into()]),
],
4,
);
}
}