use crate::{field::Field, prng::derive_many_mod_p};
use crate::{MultilinearPolynomial, StreamingPolynomial, Transcript};
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub fn f_demo(field: &Field, x1: u64, x2: u64) -> u64 {
let t1 = field.add(x1 % field.modulus(), x2 % field.modulus());
let two = 2 % field.modulus();
let t2 = field.mul(two, field.mul(x1 % field.modulus(), x2 % field.modulus()));
field.add(t1, t2)
}
pub fn true_sum_demo(field: &Field) -> u64 {
let mut s = 0;
for &x1 in &[0u64, 1] {
for &x2 in &[0u64, 1] {
s = field.add(s, f_demo(field, x1, x2));
}
}
s
}
#[derive(Debug, Clone)]
pub struct SumClaim {
pub p: u64,
pub claimed_sum: u64,
pub g1_a: u64,
pub g1_b: u64,
pub g2_a: u64,
pub g2_b: u64,
pub k: usize,
}
impl SumClaim {
pub fn prove_demo(field: &Field, k: usize) -> Self {
let p = field.modulus();
let s = true_sum_demo(field);
let g1_0 = f_demo(field, 0, 0).wrapping_add(f_demo(field, 0, 1)) % p;
let g1_1 = f_demo(field, 1, 0).wrapping_add(f_demo(field, 1, 1)) % p;
let g1_a = field.sub(g1_1, g1_0);
let g1_b = g1_0;
let base_transcript = [p, s, g1_a, g1_b, 0u64, 0u64, k as u64];
let r1_values = derive_many_mod_p(p, b"power_house:v1:sumcheck:r1", &base_transcript, 1);
let r1 = r1_values[0];
let _s1 = field.add(field.mul(g1_a, r1), g1_b);
let g2_0 = f_demo(field, r1, 0);
let g2_1 = f_demo(field, r1, 1);
let g2_a = field.sub(g2_1, g2_0);
let g2_b = g2_0;
SumClaim {
p,
claimed_sum: s,
g1_a,
g1_b,
g2_a,
g2_b,
k,
}
}
pub fn verify_demo(&self) -> bool {
let field = Field::new(self.p);
let g1_0 = self.g1_b;
let g1_1 = field.add(self.g1_a, self.g1_b);
let lhs1 = field.add(g1_0, g1_1);
if lhs1 != self.claimed_sum {
return false;
}
let base_transcript = [
self.p,
self.claimed_sum,
self.g1_a,
self.g1_b,
0u64,
0u64,
self.k as u64,
];
let r1_values =
derive_many_mod_p(self.p, b"power_house:v1:sumcheck:r1", &base_transcript, 1);
let r1 = r1_values[0];
let s1 = field.add(field.mul(self.g1_a, r1), self.g1_b);
let g2_0 = self.g2_b;
let g2_1 = field.add(self.g2_a, self.g2_b);
let lhs2 = field.add(g2_0, g2_1);
if lhs2 != s1 {
return false;
}
let transcript = [
self.p,
self.claimed_sum,
self.g1_a,
self.g1_b,
self.g2_a,
self.g2_b,
self.k as u64,
];
let r2s = derive_many_mod_p(self.p, b"power_house:v1:sumcheck:r2", &transcript, self.k);
for &r2 in &r2s {
let left = field.add(field.mul(self.g2_a, r2), self.g2_b);
let right = f_demo(&field, r1, r2);
if left != right {
return false;
}
}
true
}
}
pub(crate) const GENERAL_SUMCHECK_DOMAIN: &[u8] = b"power_house:v2:sumcheck";
const SEEDED_AFFINE_DOMAIN: &[u8] = b"power_house:v1:seeded-affine";
#[derive(Debug, Clone)]
pub struct GeneralSumClaim {
pub p: u64,
pub num_vars: usize,
pub claimed_sum: u64,
pub rounds: Vec<(u64, u64)>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GeneralSumTrace {
pub challenges: Vec<u64>,
pub round_sums: Vec<u64>,
pub final_evaluation: u64,
}
#[derive(Debug, Clone)]
pub struct GeneralSumProof {
pub claim: GeneralSumClaim,
pub challenges: Vec<u64>,
pub round_sums: Vec<u64>,
pub final_evaluation: u64,
}
#[derive(Debug, Clone)]
pub struct ProofStats {
pub total_duration: Duration,
pub round_durations: Vec<Duration>,
}
#[derive(Debug, Clone)]
pub struct ChainLink {
pub parent_final: Option<u64>,
pub proof: GeneralSumProof,
}
#[derive(Debug, Clone)]
pub struct ChainedSumProof {
links: Vec<ChainLink>,
}
impl GeneralSumClaim {
pub fn prove(poly: &MultilinearPolynomial, field: &Field) -> Self {
GeneralSumProof::prove(poly, field).claim
}
pub fn prove_streaming<F>(num_vars: usize, field: &Field, evaluator: F) -> Self
where
F: Fn(usize) -> u64 + Send + Sync + 'static,
{
GeneralSumProof::prove_streaming(num_vars, field, evaluator).claim
}
pub fn prove_streaming_poly(poly: &StreamingPolynomial, field: &Field) -> Self {
GeneralSumProof::prove_streaming_poly(poly, field).claim
}
pub fn prove_constant(num_vars: usize, field: &Field, value: u64) -> Self {
GeneralSumProof::prove_constant(num_vars, field, value).claim
}
pub fn prove_seeded_affine(num_vars: usize, field: &Field, seed: &[u8]) -> Self {
GeneralSumProof::prove_seeded_affine(num_vars, field, seed).claim
}
pub fn prove_with_trace(poly: &MultilinearPolynomial, field: &Field) -> GeneralSumProof {
GeneralSumProof::prove(poly, field)
}
pub fn verify(&self, poly: &MultilinearPolynomial, field: &Field) -> bool {
self.verify_with_trace(poly, field).is_some()
}
pub fn verify_streaming(&self, poly: &StreamingPolynomial, field: &Field) -> bool {
self.verify_streaming_with_trace(poly, field).is_some()
}
pub fn verify_with_trace(
&self,
poly: &MultilinearPolynomial,
field: &Field,
) -> Option<GeneralSumTrace> {
verify_general_sum(self, poly, field)
}
pub fn verify_streaming_with_trace(
&self,
poly: &StreamingPolynomial,
field: &Field,
) -> Option<GeneralSumTrace> {
verify_general_sum_streaming(self, poly, field)
}
pub fn verify_constant_with_trace(&self, field: &Field, value: u64) -> Option<GeneralSumTrace> {
verify_constant_sum(self, field, value)
}
pub fn verify_constant(&self, field: &Field, value: u64) -> bool {
self.verify_constant_with_trace(field, value).is_some()
}
pub fn verify_seeded_affine_with_trace(
&self,
field: &Field,
seed: &[u8],
) -> Option<GeneralSumTrace> {
verify_seeded_affine_sum(self, field, seed)
}
pub fn verify_seeded_affine(&self, field: &Field, seed: &[u8]) -> bool {
self.verify_seeded_affine_with_trace(field, seed).is_some()
}
}
impl GeneralSumProof {
pub fn prove(poly: &MultilinearPolynomial, field: &Field) -> Self {
Self::prove_with_stats(poly, field).0
}
pub fn prove_streaming<F>(num_vars: usize, field: &Field, evaluator: F) -> Self
where
F: Fn(usize) -> u64 + Send + Sync + 'static,
{
Self::prove_streaming_with_stats(num_vars, field, evaluator).0
}
pub fn prove_streaming_poly(poly: &StreamingPolynomial, field: &Field) -> Self {
Self::prove_streaming_with_stats_poly(poly, field).0
}
pub fn prove_constant(num_vars: usize, field: &Field, value: u64) -> Self {
prove_constant_inner(num_vars, field, value)
}
pub fn prove_seeded_affine(num_vars: usize, field: &Field, seed: &[u8]) -> Self {
prove_seeded_affine_inner(num_vars, field, seed)
}
pub fn prove_with_stats(poly: &MultilinearPolynomial, field: &Field) -> (Self, ProofStats) {
let p = field.modulus();
let num_vars = poly.num_vars();
let mut layer = poly.evaluations_mod_p(field);
let claimed_sum = poly.sum_over_hypercube(field);
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars as u64);
transcript.append(claimed_sum);
let total_start = Instant::now();
let mut rounds = Vec::with_capacity(num_vars);
let mut challenges = Vec::with_capacity(num_vars);
let mut round_sums = Vec::with_capacity(num_vars);
let mut round_durations = Vec::with_capacity(num_vars);
let mut running_sum = claimed_sum;
for _ in 0..num_vars {
round_sums.push(running_sum);
let round_start = Instant::now();
let mut g0_sum = 0u64;
let mut g1_sum = 0u64;
for chunk in layer.chunks(2) {
let v0 = chunk[0];
let v1 = chunk[1];
g0_sum = field.add(g0_sum, v0);
g1_sum = field.add(g1_sum, v1);
}
let a = field.sub(g1_sum, g0_sum);
let b = g0_sum;
rounds.push((a, b));
transcript.append(a);
transcript.append(b);
let r = transcript.challenge(field);
challenges.push(r);
let mut next_layer = Vec::with_capacity(layer.len() / 2);
let mut next_sum = 0u64;
for chunk in layer.chunks(2) {
let v0 = chunk[0];
let v1 = chunk[1];
let diff = field.sub(v1, v0);
let eval = field.add(field.mul(diff, r), v0);
next_sum = field.add(next_sum, eval);
next_layer.push(eval);
}
layer = next_layer;
running_sum = next_sum;
round_durations.push(round_start.elapsed());
}
assert_eq!(
layer.len(),
1,
"folding a multilinear polynomial must end with a single value"
);
let final_evaluation = layer[0];
let claim = GeneralSumClaim {
p,
num_vars,
claimed_sum,
rounds,
};
let proof = GeneralSumProof {
claim,
challenges: challenges.clone(),
round_sums: round_sums.clone(),
final_evaluation,
};
let stats = ProofStats {
total_duration: total_start.elapsed(),
round_durations,
};
(proof, stats)
}
pub fn prove_streaming_with_stats_poly(
poly: &StreamingPolynomial,
field: &Field,
) -> (Self, ProofStats) {
assert_eq!(poly.modulus(), field.modulus(), "field mismatch");
prove_streaming_with_stats_inner(poly.num_vars(), field, poly.evaluator())
}
pub fn prove_streaming_with_stats<F>(
num_vars: usize,
field: &Field,
evaluator: F,
) -> (Self, ProofStats)
where
F: Fn(usize) -> u64 + Send + Sync + 'static,
{
let eval: Arc<dyn Fn(usize) -> u64 + Send + Sync> = Arc::new(evaluator);
prove_streaming_with_stats_inner(num_vars, field, eval)
}
pub fn verify(&self, poly: &MultilinearPolynomial, field: &Field) -> bool {
self.verify_with_trace(poly, field).is_some()
}
pub fn verify_with_trace(
&self,
poly: &MultilinearPolynomial,
field: &Field,
) -> Option<GeneralSumTrace> {
let trace = self.claim.verify_with_trace(poly, field)?;
if trace.challenges != self.challenges
|| trace.round_sums != self.round_sums
|| trace.final_evaluation != self.final_evaluation
{
return None;
}
Some(trace)
}
pub fn verify_streaming(&self, poly: &StreamingPolynomial, field: &Field) -> bool {
self.verify_streaming_with_trace(poly, field).is_some()
}
pub fn verify_streaming_with_trace(
&self,
poly: &StreamingPolynomial,
field: &Field,
) -> Option<GeneralSumTrace> {
let trace = self.claim.verify_streaming_with_trace(poly, field)?;
if trace.challenges != self.challenges
|| trace.round_sums != self.round_sums
|| trace.final_evaluation != self.final_evaluation
{
return None;
}
Some(trace)
}
pub fn verify_constant_with_trace(&self, field: &Field, value: u64) -> Option<GeneralSumTrace> {
let trace = self.claim.verify_constant_with_trace(field, value)?;
if trace.challenges != self.challenges
|| trace.round_sums != self.round_sums
|| trace.final_evaluation != self.final_evaluation
{
return None;
}
Some(trace)
}
pub fn verify_constant(&self, field: &Field, value: u64) -> bool {
self.verify_constant_with_trace(field, value).is_some()
}
pub fn verify_seeded_affine_with_trace(
&self,
field: &Field,
seed: &[u8],
) -> Option<GeneralSumTrace> {
let trace = self.claim.verify_seeded_affine_with_trace(field, seed)?;
if trace.challenges != self.challenges
|| trace.round_sums != self.round_sums
|| trace.final_evaluation != self.final_evaluation
{
return None;
}
Some(trace)
}
pub fn verify_seeded_affine(&self, field: &Field, seed: &[u8]) -> bool {
self.verify_seeded_affine_with_trace(field, seed).is_some()
}
}
fn prove_constant_inner(num_vars: usize, field: &Field, value: u64) -> GeneralSumProof {
assert!(num_vars >= 1, "num_vars must be at least 1");
let num_vars_word = u64::try_from(num_vars).expect("num_vars must fit in transcript word");
let p = field.modulus();
let constant = value % p;
let claimed_sum = field.mul(constant, field.pow(2, num_vars_word));
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars_word);
transcript.append(claimed_sum);
let mut rounds = Vec::with_capacity(num_vars);
let mut challenges = Vec::with_capacity(num_vars);
let mut round_sums = Vec::with_capacity(num_vars);
let mut running_sum = claimed_sum;
for remaining in (1..=num_vars).rev() {
round_sums.push(running_sum);
let b = field.mul(constant, field.pow(2, (remaining - 1) as u64));
let a = 0;
rounds.push((a, b));
transcript.append(a);
transcript.append(b);
challenges.push(transcript.challenge(field));
running_sum = b;
}
debug_assert_eq!(running_sum, constant);
let claim = GeneralSumClaim {
p,
num_vars,
claimed_sum,
rounds,
};
GeneralSumProof {
claim,
challenges,
round_sums,
final_evaluation: constant,
}
}
fn prove_seeded_affine_inner(num_vars: usize, field: &Field, seed: &[u8]) -> GeneralSumProof {
assert!(num_vars >= 1, "num_vars must be at least 1");
let num_vars_word = u64::try_from(num_vars).expect("num_vars must fit in transcript word");
let p = field.modulus();
let parameters = derive_seeded_affine_parameters(num_vars, field, seed);
let constant = parameters[0];
let coefficients = ¶meters[1..];
let claimed_sum = seeded_affine_claimed_sum(num_vars, field, constant, coefficients);
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars_word);
transcript.append(claimed_sum);
let mut rounds = Vec::with_capacity(num_vars);
let mut challenges = Vec::with_capacity(num_vars);
let mut round_sums = Vec::with_capacity(num_vars);
let mut running_sum = claimed_sum;
let mut prefix_evaluation = constant;
let mut suffix_sum = coefficients
.iter()
.fold(0u64, |acc, &coefficient| field.add(acc, coefficient));
for (round_idx, &coefficient) in coefficients.iter().enumerate() {
round_sums.push(running_sum);
suffix_sum = field.sub(suffix_sum, coefficient);
let remaining_after = num_vars - round_idx - 1;
let (a, b) = seeded_affine_round(
field,
prefix_evaluation,
coefficient,
suffix_sum,
remaining_after,
);
debug_assert_eq!(field.add(b, field.add(a, b)), running_sum);
rounds.push((a, b));
transcript.append(a);
transcript.append(b);
let challenge = transcript.challenge(field);
challenges.push(challenge);
prefix_evaluation = field.add(prefix_evaluation, field.mul(coefficient, challenge));
running_sum = field.add(field.mul(a, challenge), b);
}
debug_assert_eq!(running_sum, prefix_evaluation);
let claim = GeneralSumClaim {
p,
num_vars,
claimed_sum,
rounds,
};
GeneralSumProof {
claim,
challenges,
round_sums,
final_evaluation: prefix_evaluation,
}
}
fn seed_to_transcript_words(seed: &[u8]) -> Vec<u64> {
let mut words = Vec::with_capacity(1 + seed.len().div_ceil(8));
let seed_len = u64::try_from(seed.len()).expect("seed length must fit in transcript word");
words.push(seed_len);
for chunk in seed.chunks(8) {
let mut word = [0u8; 8];
word[..chunk.len()].copy_from_slice(chunk);
words.push(u64::from_be_bytes(word));
}
words
}
fn derive_seeded_affine_parameters(num_vars: usize, field: &Field, seed: &[u8]) -> Vec<u64> {
let mut words = seed_to_transcript_words(seed);
words.push(u64::try_from(num_vars).expect("num_vars must fit in transcript word"));
derive_many_mod_p(field.modulus(), SEEDED_AFFINE_DOMAIN, &words, num_vars + 1)
}
fn seeded_affine_claimed_sum(
num_vars: usize,
field: &Field,
constant: u64,
coefficients: &[u64],
) -> u64 {
debug_assert_eq!(coefficients.len(), num_vars);
let num_vars_word = u64::try_from(num_vars).expect("num_vars must fit in transcript word");
let coefficient_sum = coefficients
.iter()
.fold(0u64, |acc, &coefficient| field.add(acc, coefficient));
let constant_term = field.mul(constant, field.pow(2, num_vars_word));
let linear_term = field.mul(coefficient_sum, field.pow(2, num_vars_word - 1));
field.add(constant_term, linear_term)
}
fn seeded_affine_round(
field: &Field,
prefix_evaluation: u64,
coefficient: u64,
suffix_sum: u64,
remaining_after: usize,
) -> (u64, u64) {
let scale = field.pow(2, remaining_after as u64);
let a = field.mul(coefficient, scale);
let base_term = field.mul(prefix_evaluation, scale);
let later_term = if remaining_after == 0 {
0
} else {
field.mul(suffix_sum, field.pow(2, (remaining_after - 1) as u64))
};
(a, field.add(base_term, later_term))
}
fn prove_streaming_with_stats_inner(
num_vars: usize,
field: &Field,
evaluator: Arc<dyn Fn(usize) -> u64 + Send + Sync>,
) -> (GeneralSumProof, ProofStats) {
assert!(num_vars >= 1, "num_vars must be at least 1");
let p = field.modulus();
let size = 1usize << num_vars;
let field = *field;
let use_parallel = {
#[cfg(not(target_arch = "wasm32"))]
{
const PARALLEL_THRESHOLD: usize = 1 << 16;
size >= PARALLEL_THRESHOLD && rayon::current_num_threads() > 1
}
#[cfg(target_arch = "wasm32")]
{
false
}
};
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars as u64);
let mut round_sums = Vec::with_capacity(num_vars);
let mut rounds = Vec::with_capacity(num_vars);
let mut challenges = Vec::with_capacity(num_vars);
let mut round_durations = Vec::with_capacity(num_vars);
let total_start = Instant::now();
let (claimed_sum, g0_sum, g1_sum) = if use_parallel {
#[cfg(not(target_arch = "wasm32"))]
{
(0..size / 2)
.into_par_iter()
.map(|pair| {
let idx = pair * 2;
let v0 = evaluator(idx) % p;
let v1 = evaluator(idx + 1) % p;
(v0, v1, field.add(v0, v1))
})
.reduce(
|| (0u64, 0u64, 0u64),
|acc, (v0, v1, sum)| {
(
field.add(acc.0, v0),
field.add(acc.1, v1),
field.add(acc.2, sum),
)
},
)
}
#[cfg(target_arch = "wasm32")]
{
(0u64, 0u64, 0u64)
}
} else {
let mut claimed_sum = 0u64;
let mut g0_sum = 0u64;
let mut g1_sum = 0u64;
for idx in (0..size).step_by(2) {
let v0 = evaluator(idx) % p;
let v1 = evaluator(idx + 1) % p;
g0_sum = field.add(g0_sum, v0);
g1_sum = field.add(g1_sum, v1);
claimed_sum = field.add(claimed_sum, field.add(v0, v1));
}
(claimed_sum, g0_sum, g1_sum)
};
transcript.append(claimed_sum);
round_sums.push(claimed_sum);
let round_start = Instant::now();
let first_a = field.sub(g1_sum, g0_sum);
let first_b = g0_sum;
rounds.push((first_a, first_b));
transcript.append(first_a);
transcript.append(first_b);
let mut r = transcript.challenge(&field);
challenges.push(r);
let (mut layer, mut current_sum) = if use_parallel {
#[cfg(not(target_arch = "wasm32"))]
{
let layer: Vec<u64> = (0..size / 2)
.into_par_iter()
.map(|pair| {
let idx = pair * 2;
let v0 = evaluator(idx) % p;
let v1 = evaluator(idx + 1) % p;
let diff = field.sub(v1, v0);
field.add(field.mul(diff, r), v0)
})
.collect();
let current_sum = layer
.par_iter()
.cloned()
.reduce(|| 0u64, |acc, v| field.add(acc, v));
(layer, current_sum)
}
#[cfg(target_arch = "wasm32")]
{
(Vec::new(), 0u64)
}
} else {
let mut layer = Vec::with_capacity(size / 2);
let mut current_sum = 0u64;
for idx in (0..size).step_by(2) {
let v0 = evaluator(idx) % p;
let v1 = evaluator(idx + 1) % p;
let diff = field.sub(v1, v0);
let val = field.add(field.mul(diff, r), v0);
current_sum = field.add(current_sum, val);
layer.push(val);
}
(layer, current_sum)
};
round_durations.push(round_start.elapsed());
for _round in 1..num_vars {
round_sums.push(current_sum);
let round_start = Instant::now();
let use_parallel_layer = {
#[cfg(not(target_arch = "wasm32"))]
{
const PARALLEL_LAYER_THRESHOLD: usize = 1 << 14;
use_parallel && layer.len() >= PARALLEL_LAYER_THRESHOLD
}
#[cfg(target_arch = "wasm32")]
{
false
}
};
let (g0_sum, g1_sum) = if use_parallel_layer {
#[cfg(not(target_arch = "wasm32"))]
{
layer
.par_chunks(2)
.map(|chunk| (chunk[0], chunk[1]))
.reduce(
|| (0u64, 0u64),
|acc, (v0, v1)| (field.add(acc.0, v0), field.add(acc.1, v1)),
)
}
#[cfg(target_arch = "wasm32")]
{
(0u64, 0u64)
}
} else {
let mut g0_sum = 0u64;
let mut g1_sum = 0u64;
for chunk in layer.chunks(2) {
g0_sum = field.add(g0_sum, chunk[0]);
g1_sum = field.add(g1_sum, chunk[1]);
}
(g0_sum, g1_sum)
};
let a = field.sub(g1_sum, g0_sum);
let b = g0_sum;
rounds.push((a, b));
transcript.append(a);
transcript.append(b);
r = transcript.challenge(&field);
challenges.push(r);
let (next_layer, next_sum) = if use_parallel_layer {
#[cfg(not(target_arch = "wasm32"))]
{
let next_layer: Vec<u64> = layer
.par_chunks(2)
.map(|chunk| {
let v0 = chunk[0];
let v1 = chunk[1];
let diff = field.sub(v1, v0);
field.add(field.mul(diff, r), v0)
})
.collect();
let next_sum = next_layer
.par_iter()
.cloned()
.reduce(|| 0u64, |acc, v| field.add(acc, v));
(next_layer, next_sum)
}
#[cfg(target_arch = "wasm32")]
{
(Vec::new(), 0u64)
}
} else {
let mut next_layer = Vec::with_capacity(layer.len() / 2);
let mut next_sum = 0u64;
for chunk in layer.chunks(2) {
let v0 = chunk[0];
let v1 = chunk[1];
let diff = field.sub(v1, v0);
let val = field.add(field.mul(diff, r), v0);
next_sum = field.add(next_sum, val);
next_layer.push(val);
}
(next_layer, next_sum)
};
layer = next_layer;
current_sum = next_sum;
round_durations.push(round_start.elapsed());
}
let final_evaluation = layer[0];
let proof = GeneralSumProof {
claim: GeneralSumClaim {
p,
num_vars,
claimed_sum,
rounds,
},
challenges,
round_sums,
final_evaluation,
};
let stats = ProofStats {
total_duration: total_start.elapsed(),
round_durations,
};
(proof, stats)
}
impl ChainedSumProof {
pub fn prove(polynomials: &[MultilinearPolynomial], field: &Field) -> Self {
let mut links = Vec::with_capacity(polynomials.len());
let mut previous_final: Option<u64> = None;
for poly in polynomials {
let parent_for_this = previous_final;
let proof = GeneralSumProof::prove(poly, field);
if let Some(expected_sum) = parent_for_this {
if field.sub(proof.claim.claimed_sum, expected_sum) != 0 {
panic!(
"chained proof mismatch: expected sum {} but found {}",
expected_sum, proof.claim.claimed_sum
);
}
}
previous_final = Some(proof.final_evaluation);
links.push(ChainLink {
parent_final: parent_for_this,
proof,
});
}
Self { links }
}
pub fn prove_with_stats(
polynomials: &[MultilinearPolynomial],
field: &Field,
) -> (Self, Vec<ProofStats>) {
let mut stats = Vec::with_capacity(polynomials.len());
let mut links = Vec::with_capacity(polynomials.len());
let mut previous_final: Option<u64> = None;
for poly in polynomials {
let parent_for_this = previous_final;
let (proof, proof_stats) = GeneralSumProof::prove_with_stats(poly, field);
stats.push(proof_stats);
if let Some(expected_sum) = parent_for_this {
if field.sub(proof.claim.claimed_sum, expected_sum) != 0 {
panic!(
"chained proof mismatch: expected sum {} but found {}",
expected_sum, proof.claim.claimed_sum
);
}
}
previous_final = Some(proof.final_evaluation);
links.push(ChainLink {
parent_final: parent_for_this,
proof,
});
}
(Self { links }, stats)
}
pub fn links(&self) -> &[ChainLink] {
&self.links
}
pub fn links_mut(&mut self) -> &mut [ChainLink] {
&mut self.links
}
pub fn len(&self) -> usize {
self.links.len()
}
pub fn is_empty(&self) -> bool {
self.links.is_empty()
}
pub fn verify_with_traces(
&self,
polynomials: &[MultilinearPolynomial],
field: &Field,
) -> Option<Vec<GeneralSumTrace>> {
if self.links.len() != polynomials.len() {
return None;
}
let mut traces = Vec::with_capacity(self.links.len());
let mut previous_final: Option<u64> = None;
for (link, poly) in self.links.iter().zip(polynomials) {
if link.parent_final != previous_final {
return None;
}
let trace = link.proof.verify_with_trace(poly, field)?;
if let Some(expected_sum) = previous_final {
if field.sub(link.proof.claim.claimed_sum, expected_sum) != 0 {
return None;
}
}
previous_final = Some(trace.final_evaluation);
traces.push(trace);
}
Some(traces)
}
pub fn verify(&self, polynomials: &[MultilinearPolynomial], field: &Field) -> bool {
self.verify_with_traces(polynomials, field).is_some()
}
}
fn verify_general_sum(
claim: &GeneralSumClaim,
poly: &MultilinearPolynomial,
field: &Field,
) -> Option<GeneralSumTrace> {
if claim.p != field.modulus() {
return None;
}
if claim.num_vars != poly.num_vars() {
return None;
}
if claim.rounds.len() != claim.num_vars {
return None;
}
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(claim.p);
transcript.append(claim.num_vars as u64);
transcript.append(claim.claimed_sum);
let mut layer = poly.evaluations_mod_p(field);
let mut running_claim = claim.claimed_sum;
let mut challenges = Vec::with_capacity(claim.num_vars);
let mut round_sums = Vec::with_capacity(claim.num_vars);
for &(a, b) in &claim.rounds {
round_sums.push(running_claim);
let sum_check = field.add(b, field.add(a, b));
if sum_check != running_claim {
return None;
}
transcript.append(a);
transcript.append(b);
let r = transcript.challenge(field);
challenges.push(r);
let mut next_layer = Vec::with_capacity(layer.len() / 2);
let mut next_sum = 0u64;
for chunk in layer.chunks(2) {
let v0 = chunk[0];
let v1 = chunk[1];
let diff = field.sub(v1, v0);
let eval = field.add(field.mul(diff, r), v0);
next_sum = field.add(next_sum, eval);
next_layer.push(eval);
}
layer = next_layer;
running_claim = next_sum;
}
if layer.len() != 1 {
return None;
}
let final_evaluation = poly.evaluate(field, &challenges);
if final_evaluation != running_claim {
return None;
}
Some(GeneralSumTrace {
challenges,
round_sums,
final_evaluation,
})
}
fn verify_general_sum_streaming(
claim: &GeneralSumClaim,
poly: &StreamingPolynomial,
field: &Field,
) -> Option<GeneralSumTrace> {
if claim.p != field.modulus() || claim.p != poly.modulus() {
return None;
}
if claim.num_vars != poly.num_vars() || claim.rounds.len() != claim.num_vars {
return None;
}
let p = claim.p;
let num_vars = claim.num_vars;
let size = 1usize << num_vars;
let eval = poly.evaluator();
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars as u64);
transcript.append(claim.claimed_sum);
let mut round_sums = Vec::with_capacity(num_vars);
let mut challenges = Vec::with_capacity(num_vars);
let mut computed_sum = 0u64;
let mut g0_sum = 0u64;
let mut g1_sum = 0u64;
for idx in (0..size).step_by(2) {
let v0 = eval(idx) % p;
let v1 = eval(idx + 1) % p;
g0_sum = field.add(g0_sum, v0);
g1_sum = field.add(g1_sum, v1);
computed_sum = field.add(computed_sum, field.add(v0, v1));
}
if computed_sum != claim.claimed_sum {
return None;
}
round_sums.push(computed_sum);
let mut layer = Vec::with_capacity(size / 2);
let mut running_sum = computed_sum;
for (round_idx, &(a, b)) in claim.rounds.iter().enumerate() {
if b % p != g0_sum || field.sub(g1_sum, g0_sum) != a {
return None;
}
transcript.append(a);
transcript.append(b);
let r = transcript.challenge(field);
challenges.push(r);
let mut next_layer = Vec::with_capacity(if round_idx == 0 {
size / 2
} else {
layer.len() / 2
});
let mut next_sum = 0u64;
if round_idx == 0 {
for idx in (0..size).step_by(2) {
let v0 = eval(idx) % p;
let v1 = eval(idx + 1) % p;
let diff = field.sub(v1, v0);
let val = field.add(field.mul(diff, r), v0);
next_sum = field.add(next_sum, val);
next_layer.push(val);
}
} else {
for chunk in layer.chunks(2) {
let v0 = chunk[0];
let v1 = chunk[1];
let diff = field.sub(v1, v0);
let val = field.add(field.mul(diff, r), v0);
next_sum = field.add(next_sum, val);
next_layer.push(val);
}
}
layer = next_layer;
running_sum = next_sum;
if round_idx + 1 < num_vars {
round_sums.push(running_sum);
g0_sum = 0u64;
g1_sum = 0u64;
for chunk in layer.chunks(2) {
g0_sum = field.add(g0_sum, chunk[0]);
g1_sum = field.add(g1_sum, chunk[1]);
}
}
}
if layer.len() != 1 {
return None;
}
let final_evaluation = layer[0];
if final_evaluation != running_sum {
return None;
}
Some(GeneralSumTrace {
challenges,
round_sums,
final_evaluation,
})
}
fn verify_constant_sum(
claim: &GeneralSumClaim,
field: &Field,
value: u64,
) -> Option<GeneralSumTrace> {
if claim.p != field.modulus() || claim.rounds.len() != claim.num_vars || claim.num_vars == 0 {
return None;
}
let p = claim.p;
let num_vars_word = u64::try_from(claim.num_vars).ok()?;
let constant = value % p;
let mut running_claim = field.mul(constant, field.pow(2, num_vars_word));
if claim.claimed_sum != running_claim {
return None;
}
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars_word);
transcript.append(claim.claimed_sum);
let mut challenges = Vec::with_capacity(claim.num_vars);
let mut round_sums = Vec::with_capacity(claim.num_vars);
for (round_idx, &(a, b)) in claim.rounds.iter().enumerate() {
let remaining = claim.num_vars - round_idx;
let expected_b = field.mul(constant, field.pow(2, (remaining - 1) as u64));
round_sums.push(running_claim);
if a % p != 0 || b % p != expected_b {
return None;
}
if field.add(b, field.add(a, b)) != running_claim {
return None;
}
transcript.append(a);
transcript.append(b);
challenges.push(transcript.challenge(field));
running_claim = b % p;
}
if running_claim != constant {
return None;
}
Some(GeneralSumTrace {
challenges,
round_sums,
final_evaluation: constant,
})
}
fn verify_seeded_affine_sum(
claim: &GeneralSumClaim,
field: &Field,
seed: &[u8],
) -> Option<GeneralSumTrace> {
if claim.p != field.modulus() || claim.rounds.len() != claim.num_vars || claim.num_vars == 0 {
return None;
}
let p = claim.p;
let num_vars = claim.num_vars;
let num_vars_word = u64::try_from(num_vars).ok()?;
let parameters = derive_seeded_affine_parameters(num_vars, field, seed);
let constant = parameters[0];
let coefficients = ¶meters[1..];
let mut running_claim = seeded_affine_claimed_sum(num_vars, field, constant, coefficients);
if claim.claimed_sum != running_claim {
return None;
}
let mut transcript = Transcript::new(GENERAL_SUMCHECK_DOMAIN);
transcript.append(p);
transcript.append(num_vars_word);
transcript.append(claim.claimed_sum);
let mut challenges = Vec::with_capacity(num_vars);
let mut round_sums = Vec::with_capacity(num_vars);
let mut prefix_evaluation = constant;
let mut suffix_sum = coefficients
.iter()
.fold(0u64, |acc, &coefficient| field.add(acc, coefficient));
for (round_idx, (&(a, b), &coefficient)) in claim.rounds.iter().zip(coefficients).enumerate() {
round_sums.push(running_claim);
suffix_sum = field.sub(suffix_sum, coefficient);
let remaining_after = num_vars - round_idx - 1;
let (expected_a, expected_b) = seeded_affine_round(
field,
prefix_evaluation,
coefficient,
suffix_sum,
remaining_after,
);
if a % p != expected_a || b % p != expected_b {
return None;
}
if field.add(b, field.add(a, b)) != running_claim {
return None;
}
transcript.append(a);
transcript.append(b);
let challenge = transcript.challenge(field);
challenges.push(challenge);
prefix_evaluation = field.add(prefix_evaluation, field.mul(coefficient, challenge));
running_claim = field.add(field.mul(a, challenge), b);
}
if running_claim != prefix_evaluation {
return None;
}
Some(GeneralSumTrace {
challenges,
round_sums,
final_evaluation: prefix_evaluation,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Field;
#[test]
fn test_demo_true_sum() {
let field = Field::new(101);
let sum = true_sum_demo(&field);
assert_eq!(sum, 6);
}
#[test]
fn test_prove_and_verify() {
let field = Field::new(101);
let claim = SumClaim::prove_demo(&field, 8);
assert!(claim.verify_demo());
}
#[test]
fn test_cheating_prover_fails() {
let field = Field::new(101);
let honest = SumClaim::prove_demo(&field, 4);
let mut forged = honest.clone();
forged.g2_a = field.add(forged.g2_a, 1);
let base_transcript = [
forged.p,
forged.claimed_sum,
forged.g1_a,
forged.g1_b,
0u64,
0u64,
forged.k as u64,
];
let r1 = derive_many_mod_p(forged.p, b"power_house:v1:sumcheck:r1", &base_transcript, 1)[0];
let s1 = field.add(field.mul(forged.g1_a, r1), forged.g1_b);
let inv2 = field.inv(2);
forged.g2_b = field.mul(field.sub(s1, forged.g2_a), inv2);
assert!(!forged.verify_demo());
}
fn sample_poly(field: &Field) -> MultilinearPolynomial {
let mut evals = Vec::with_capacity(8);
for x2 in 0..=1u64 {
for x1 in 0..=1u64 {
for x0 in 0..=1u64 {
let mut val = 0;
val = field.add(val, x0);
val = field.add(val, field.mul(2, x1));
val = field.add(val, field.mul(3, x2));
let triple = field.mul(x0, field.mul(x1, x2));
val = field.add(val, field.mul(5, triple));
evals.push(val);
}
}
}
MultilinearPolynomial::from_evaluations(3, evals)
}
#[test]
fn test_general_sumcheck_prove_verify() {
let field = Field::new(101);
let poly = sample_poly(&field);
let claim = GeneralSumClaim::prove(&poly, &field);
assert!(claim.verify(&poly, &field));
}
#[test]
fn test_general_sumcheck_rejects_tampering() {
let field = Field::new(101);
let poly = sample_poly(&field);
let mut claim = GeneralSumClaim::prove(&poly, &field);
assert!(claim.verify(&poly, &field));
if let Some((a, b)) = claim.rounds.get_mut(0) {
*a = field.add(*a, 1);
*b = field.add(*b, 1);
}
assert!(!claim.verify(&poly, &field));
}
#[test]
fn test_general_sumproof_trace_matches() {
let field = Field::new(101);
let poly = sample_poly(&field);
let proof = GeneralSumProof::prove(&poly, &field);
let trace = proof
.verify_with_trace(&poly, &field)
.expect("proof should verify");
assert_eq!(trace.challenges, proof.challenges);
assert_eq!(trace.round_sums, proof.round_sums);
assert_eq!(trace.final_evaluation, proof.final_evaluation);
}
#[test]
fn test_general_sumproof_stats() {
let field = Field::new(101);
let poly = sample_poly(&field);
let (_proof, stats) = GeneralSumProof::prove_with_stats(&poly, &field);
assert_eq!(stats.round_durations.len(), poly.num_vars());
}
#[test]
fn test_streaming_matches_standard() {
let field = Field::new(101);
let poly = sample_poly(&field);
let evals = poly.evaluations().to_vec();
let num_vars = poly.num_vars();
let streaming_poly =
StreamingPolynomial::new(num_vars, field.modulus(), move |idx| evals[idx]);
let (streaming, _) =
GeneralSumProof::prove_streaming_with_stats_poly(&streaming_poly, &field);
let standard = GeneralSumProof::prove(&poly, &field);
assert_eq!(streaming.claim.rounds, standard.claim.rounds);
assert_eq!(streaming.final_evaluation, standard.final_evaluation);
assert!(streaming.verify_streaming(&streaming_poly, &field));
}
#[test]
fn test_constant_sumcheck_verifies_sextillion_domain() {
let field = Field::new(1_000_000_007);
let num_vars = 70;
let constant = 173;
let domain_size = 1u128 << num_vars;
assert!(domain_size > 1_000_000_000_000_000_000_000u128);
let proof = GeneralSumProof::prove_constant(num_vars, &field, constant);
assert_eq!(proof.claim.num_vars, num_vars);
assert_eq!(proof.claim.rounds.len(), 70);
assert!(proof.verify_constant(&field, constant));
let expected_sum = field.mul(constant, field.pow(2, num_vars as u64));
assert_eq!(proof.claim.claimed_sum, expected_sum);
}
#[test]
fn test_constant_sumcheck_rejects_tampering() {
let field = Field::new(1_000_000_007);
let mut proof = GeneralSumProof::prove_constant(70, &field, 173);
assert!(proof.verify_constant(&field, 173));
proof.claim.rounds[12].1 = field.add(proof.claim.rounds[12].1, 1);
assert!(!proof.verify_constant(&field, 173));
}
fn dense_seeded_affine(num_vars: usize, field: &Field, seed: &[u8]) -> MultilinearPolynomial {
let parameters = derive_seeded_affine_parameters(num_vars, field, seed);
let constant = parameters[0];
let coefficients = ¶meters[1..];
let mut evaluations = Vec::with_capacity(1usize << num_vars);
for idx in 0..(1usize << num_vars) {
let mut value = constant;
for (bit, &coefficient) in coefficients.iter().enumerate() {
if (idx >> bit) & 1 == 1 {
value = field.add(value, coefficient);
}
}
evaluations.push(value);
}
MultilinearPolynomial::from_evaluations(num_vars, evaluations)
}
#[test]
fn test_seeded_affine_sumcheck_matches_dense_prover() {
let field = Field::new(1_000_000_007);
let seed = b"power-house seeded affine equivalence test";
let dense = dense_seeded_affine(5, &field, seed);
let dense_proof = GeneralSumProof::prove(&dense, &field);
let affine_proof = GeneralSumProof::prove_seeded_affine(5, &field, seed);
assert_eq!(
affine_proof.claim.claimed_sum,
dense_proof.claim.claimed_sum
);
assert_eq!(affine_proof.claim.rounds, dense_proof.claim.rounds);
assert_eq!(affine_proof.challenges, dense_proof.challenges);
assert_eq!(affine_proof.final_evaluation, dense_proof.final_evaluation);
assert!(affine_proof.verify_seeded_affine(&field, seed));
}
#[test]
fn test_seeded_affine_sumcheck_verifies_astronomical_domain() {
let field = Field::new(1_000_000_007);
let seed = b"power-house 2^1024 seeded affine certificate";
let proof = GeneralSumProof::prove_seeded_affine(1024, &field, seed);
let trace = proof
.verify_seeded_affine_with_trace(&field, seed)
.expect("seeded affine proof must verify");
assert_eq!(proof.claim.num_vars, 1024);
assert_eq!(proof.claim.rounds.len(), 1024);
assert_eq!(trace.challenges.len(), 1024);
assert!(proof.claim.rounds.iter().any(|&(a, _)| a != 0));
}
#[test]
fn test_seeded_affine_sumcheck_rejects_wrong_seed_and_tampering() {
let field = Field::new(1_000_000_007);
let seed = b"power-house seeded affine tamper test";
let proof = GeneralSumProof::prove_seeded_affine(128, &field, seed);
assert!(proof.verify_seeded_affine(&field, seed));
assert!(!proof.verify_seeded_affine(&field, b"wrong public seed"));
let mut tampered = proof.clone();
tampered.claim.rounds[37].0 = field.add(tampered.claim.rounds[37].0, 1);
assert!(!tampered.verify_seeded_affine(&field, seed));
}
fn sample_poly_highdim(field: &Field) -> MultilinearPolynomial {
let mut evals = Vec::with_capacity(32);
for x4 in 0..=1u64 {
for x3 in 0..=1u64 {
for x2 in 0..=1u64 {
for x1 in 0..=1u64 {
for x0 in 0..=1u64 {
let vars = [x0, x1, x2, x3, x4];
let mut acc = 1u64;
let lin_coefs = [3u64, 5, 7, 11, 13];
for (coef, &var) in lin_coefs.iter().zip(vars.iter()) {
acc = field.add(acc, field.mul(*coef, var));
}
let pair_coefs = [(0usize, 1usize, 17u64), (1, 2, 19), (3, 4, 23)];
for &(i, j, coef) in &pair_coefs {
let pair = field.mul(vars[i], vars[j]);
acc = field.add(acc, field.mul(coef, pair));
}
let triple = field.mul(vars[0], field.mul(vars[2], vars[4]));
acc = field.add(acc, field.mul(29, triple));
evals.push(acc);
}
}
}
}
}
MultilinearPolynomial::from_evaluations(5, evals)
}
#[test]
fn test_general_sumcheck_highdimensional() {
let field = Field::new(149);
let poly = sample_poly_highdim(&field);
let claim = GeneralSumClaim::prove(&poly, &field);
assert!(claim.verify(&poly, &field));
}
#[test]
fn test_chained_sum_proof_roundtrip() {
let field = Field::new(197);
let poly_a = sample_poly(&field);
let first = GeneralSumProof::prove(&poly_a, &field);
let poly_b = constant_polynomial(first.final_evaluation, 4, &field);
let second = GeneralSumProof::prove(&poly_b, &field);
let poly_c = constant_polynomial(second.final_evaluation, 3, &field);
let polynomials = vec![poly_a.clone(), poly_b.clone(), poly_c.clone()];
let (chain, stats) = ChainedSumProof::prove_with_stats(&polynomials, &field);
assert_eq!(stats.len(), polynomials.len());
assert!(chain.verify(&polynomials, &field));
}
#[test]
fn test_chained_sum_proof_detects_tampering() {
let field = Field::new(211);
let poly_a = sample_poly(&field);
let first = GeneralSumProof::prove(&poly_a, &field);
let poly_b = constant_polynomial(first.final_evaluation, 4, &field);
let polynomials = vec![poly_a.clone(), poly_b.clone()];
let (mut chain, _stats) = ChainedSumProof::prove_with_stats(&polynomials, &field);
if let Some(link) = chain.links_mut().get_mut(1) {
if let Some(parent) = link.parent_final {
link.parent_final = Some(field.add(parent, 1));
}
}
assert!(!chain.verify(&polynomials, &field));
}
fn constant_polynomial(
target_sum: u64,
num_vars: usize,
field: &Field,
) -> MultilinearPolynomial {
let points = 1usize << num_vars;
let inv_points = field.inv(points as u64 % field.modulus());
let constant = field.mul(target_sum % field.modulus(), inv_points);
MultilinearPolynomial::from_evaluations(num_vars, vec![constant; points])
}
}