use std::{
f64::{self, consts::LOG2_10},
fmt,
ops::Neg,
};
use ark_ff::{AdditiveGroup, FftField, Field};
use ark_std::rand::{CryptoRng, RngCore};
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::{
algebra::{
dot,
embedding::Embedding,
fields::FieldWithSize,
lift,
linear_form::UnivariateEvaluation,
mixed_univariate_evaluate,
ntt::{self, interleaved_rs_encode},
},
engines::EngineId,
hash::Hash,
protocols::{challenge_indices::challenge_indices, matrix_commit},
transcript::{
Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult,
VerifierMessage, VerifierState,
},
type_info::Typed,
utils::zip_strict,
verify,
};
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
#[serde(bound = "M: Embedding, M::Source: FftField")]
pub struct Config<M>
where
M: Embedding,
M::Source: FftField,
{
pub embedding: Typed<M>,
pub num_vectors: usize,
pub vector_size: usize,
pub codeword_length: usize,
pub interleaving_depth: usize,
pub matrix_commit: matrix_commit::Config<M::Source>,
pub johnson_slack: OrderedFloat<f64>,
pub in_domain_samples: usize,
pub out_domain_samples: usize,
pub deduplicate_in_domain: bool,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)]
#[must_use]
pub struct Witness<F: FftField, G: Field> {
pub matrix: Vec<F>,
pub matrix_witness: matrix_commit::Witness,
pub out_of_domain: Evaluations<G>,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)]
#[must_use]
pub struct Commitment<G: Field> {
matrix_commitment: matrix_commit::Commitment,
out_of_domain: Evaluations<G>,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize, Deserialize, Default)]
pub struct Evaluations<F> {
pub points: Vec<F>,
pub matrix: Vec<F>,
}
impl<M: Embedding> Config<M>
where
M::Source: FftField,
{
pub fn new(
security_target: f64,
unique_decoding: bool,
hash_id: EngineId,
num_vectors: usize,
vector_size: usize,
interleaving_depth: usize,
rate: f64,
) -> Self
where
M: Default,
{
assert!(vector_size.is_multiple_of(interleaving_depth));
assert!(rate > 0. && rate <= 1.);
let message_length = vector_size / interleaving_depth;
#[allow(clippy::cast_sign_loss)]
let codeword_length = (message_length as f64 / rate).ceil() as usize;
let rate = message_length as f64 / codeword_length as f64;
let johnson_slack = johnson_slack(unique_decoding, rate);
#[allow(clippy::cast_sign_loss)]
let out_domain_samples = if unique_decoding {
0
} else {
let field_size_bits = M::Target::field_size_bits();
let list_size = 1. / (2. * johnson_slack * rate.sqrt());
let l_choose_2 = list_size * (list_size - 1.) / 2.;
let log_per_sample = field_size_bits - ((vector_size - 1) as f64).log2();
assert!(log_per_sample > 0.);
((security_target + l_choose_2.log2()) / log_per_sample)
.ceil()
.max(1.) as usize
};
#[allow(clippy::cast_sign_loss)]
let in_domain_samples = {
let per_sample = if unique_decoding {
f64::midpoint(1., rate)
} else {
rate.sqrt() + johnson_slack
};
(security_target / (-per_sample.log2())).ceil() as usize
};
debug_assert_eq!(
in_domain_samples,
num_in_domain_queries(unique_decoding, security_target, rate)
);
Self {
embedding: Typed::<M>::default(),
num_vectors,
vector_size,
codeword_length,
interleaving_depth,
matrix_commit: matrix_commit::Config::with_hash(
hash_id,
codeword_length,
interleaving_depth * num_vectors,
),
johnson_slack: OrderedFloat(johnson_slack),
in_domain_samples,
out_domain_samples,
deduplicate_in_domain: false,
}
}
pub const fn num_cols(&self) -> usize {
self.matrix_commit.num_cols
}
pub const fn size(&self) -> usize {
self.matrix_commit.size()
}
pub fn embedding(&self) -> &M {
&self.embedding
}
pub fn generator(&self) -> M::Source {
ntt::generator(self.codeword_length).expect("Subgroup of requested size not found")
}
pub fn message_length(&self) -> usize {
assert!(self.vector_size.is_multiple_of(self.interleaving_depth));
self.vector_size / self.interleaving_depth
}
pub fn rate(&self) -> f64 {
self.message_length() as f64 / self.codeword_length as f64
}
pub fn unique_decoding(&self) -> bool {
self.out_domain_samples == 0 && self.johnson_slack == 0.0
}
pub fn list_size(&self) -> f64 {
if self.unique_decoding() {
1.
} else {
1. / (2. * self.johnson_slack.into_inner() * self.rate().sqrt())
}
}
pub fn rbr_ood_sample(&self) -> f64 {
let list_size = self.list_size();
let log_field_size = M::Target::field_size_bits();
let l_choose_2 = list_size * (list_size - 1.) / 2.;
let log_per_sample = ((self.vector_size - 1) as f64).log2() - log_field_size;
-l_choose_2.log2() - self.out_domain_samples as f64 * log_per_sample
}
pub fn rbr_queries(&self) -> f64 {
let per_sample = if self.unique_decoding() {
f64::midpoint(1., self.rate())
} else {
self.rate().sqrt() + self.johnson_slack.into_inner()
};
self.in_domain_samples as f64 * per_sample.log2().neg()
}
pub fn rbr_soundness_fold_prox_gaps(&self) -> f64 {
let log_field_size = M::Target::field_size_bits();
let log_inv_rate = self.rate().log2().neg();
let log_k = (self.message_length() as f64).log2();
let error = if self.unique_decoding() {
log_k + log_inv_rate
} else {
let log_eta = self.johnson_slack.into_inner().log2();
assert!(log_eta >= -(0.5 * log_inv_rate + LOG2_10 + 1.0) - 1e-6);
7. * LOG2_10 + 3.5 * log_inv_rate + 2. * log_k
};
log_field_size - error
}
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))]
pub fn commit<H, R>(
&self,
prover_state: &mut ProverState<H, R>,
vectors: &[&[M::Source]],
) -> Witness<M::Source, M::Target>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
M::Target: Codec<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
assert!((self.vector_size).is_multiple_of(self.interleaving_depth));
assert_eq!(self.matrix_commit.num_rows(), self.codeword_length);
assert_eq!(
self.matrix_commit.num_cols,
self.num_vectors * self.interleaving_depth
);
assert_eq!(vectors.len(), self.num_vectors);
assert!(vectors.iter().all(|p| p.len() == self.vector_size));
let matrix = interleaved_rs_encode(vectors, self.codeword_length, self.interleaving_depth);
let matrix_witness = self.matrix_commit.commit(prover_state, &matrix);
let oods_points: Vec<M::Target> =
prover_state.verifier_message_vec(self.out_domain_samples);
let mut oods_matrix = Vec::with_capacity(self.out_domain_samples * self.num_vectors);
for &point in &oods_points {
for &vector in vectors {
let value = mixed_univariate_evaluate(&*self.embedding, vector, point);
prover_state.prover_message(&value);
oods_matrix.push(value);
}
}
Witness {
matrix,
matrix_witness,
out_of_domain: Evaluations {
points: oods_points,
matrix: oods_matrix,
},
}
}
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))]
pub fn receive_commitment<H>(
&self,
verifier_state: &mut VerifierState<H>,
) -> VerificationResult<Commitment<M::Target>>
where
H: DuplexSpongeInterface,
Hash: ProverMessage<[H::U]>,
M::Target: Codec<[H::U]>,
{
let matrix_commitment = self.matrix_commit.receive_commitment(verifier_state)?;
let oods_points: Vec<M::Target> =
verifier_state.verifier_message_vec(self.out_domain_samples);
let oods_matrix =
verifier_state.prover_messages_vec(self.out_domain_samples * self.num_vectors)?;
Ok(Commitment {
matrix_commitment,
out_of_domain: Evaluations {
points: oods_points,
matrix: oods_matrix,
},
})
}
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))]
pub fn open<H, R>(
&self,
prover_state: &mut ProverState<H, R>,
witnesses: &[&Witness<M::Source, M::Target>],
) -> Evaluations<M::Source>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
u8: Decoding<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
for witness in witnesses {
assert_eq!(witness.matrix.len(), self.size());
assert_eq!(witness.out_of_domain.points.len(), self.out_domain_samples);
assert_eq!(
witness.out_of_domain.matrix.len(),
self.out_domain_samples * self.num_vectors
);
}
let (indices, points) = self.in_domain_challenges(prover_state);
self.open_inner(prover_state, witnesses, &indices, points)
}
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))]
pub fn verify<H>(
&self,
verifier_state: &mut VerifierState<H>,
commitments: &[&Commitment<M::Target>],
) -> VerificationResult<Evaluations<M::Source>>
where
H: DuplexSpongeInterface,
u8: Decoding<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
for commitment in commitments {
verify!(commitment.out_of_domain.points.len() == self.out_domain_samples);
verify!(
commitment.out_of_domain.matrix.len() == self.num_vectors * self.out_domain_samples
);
}
let (indices, points) = self.in_domain_challenges(verifier_state);
self.verify_inner(verifier_state, commitments, &indices, points)
}
pub fn open_at_indices<H, R>(
&self,
prover_state: &mut ProverState<H, R>,
witnesses: &[&Witness<M::Source, M::Target>],
indices: &[usize],
) -> Evaluations<M::Source>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
Hash: ProverMessage<[H::U]>,
{
assert!(
indices.iter().all(|&i| i < self.codeword_length),
"index out of bounds: all indices must be < codeword_length ({})",
self.codeword_length
);
let generator = self.generator();
let points: Vec<M::Source> = indices.iter().map(|&i| generator.pow([i as u64])).collect();
self.open_inner(prover_state, witnesses, indices, points)
}
pub fn verify_at_indices<H>(
&self,
verifier_state: &mut VerifierState<H>,
commitments: &[&Commitment<M::Target>],
indices: &[usize],
) -> VerificationResult<Evaluations<M::Source>>
where
H: DuplexSpongeInterface,
u8: Decoding<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
assert!(
indices.iter().all(|&i| i < self.codeword_length),
"index out of bounds: all indices must be < codeword_length ({})",
self.codeword_length
);
let generator = self.generator();
let points: Vec<M::Source> = indices.iter().map(|&i| generator.pow([i as u64])).collect();
self.verify_inner(verifier_state, commitments, indices, points)
}
fn open_inner<H, R>(
&self,
prover_state: &mut ProverState<H, R>,
witnesses: &[&Witness<M::Source, M::Target>],
indices: &[usize],
points: Vec<M::Source>,
) -> Evaluations<M::Source>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
Hash: ProverMessage<[H::U]>,
{
let stride = witnesses.len() * self.num_cols();
let mut matrix = vec![M::Source::ZERO; indices.len() * stride];
let mut submatrix = Vec::with_capacity(indices.len() * self.num_cols());
let mut matrix_col_offset = 0;
for witness in witnesses {
submatrix.clear();
for (point_index, &code_index) in indices.iter().enumerate() {
let row = &witness.matrix
[code_index * self.num_cols()..(code_index + 1) * self.num_cols()];
submatrix.extend_from_slice(row);
let matrix_row = &mut matrix[point_index * stride..(point_index + 1) * stride];
matrix_row[matrix_col_offset..matrix_col_offset + self.num_cols()]
.copy_from_slice(row);
}
prover_state.prover_hint_ark(&submatrix);
self.matrix_commit
.open(prover_state, &witness.matrix_witness, indices);
matrix_col_offset += self.num_cols();
}
Evaluations { points, matrix }
}
fn verify_inner<H>(
&self,
verifier_state: &mut VerifierState<H>,
commitments: &[&Commitment<M::Target>],
indices: &[usize],
points: Vec<M::Source>,
) -> VerificationResult<Evaluations<M::Source>>
where
H: DuplexSpongeInterface,
u8: Decoding<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
let stride = commitments.len() * self.num_cols();
let mut matrix = vec![M::Source::ZERO; indices.len() * stride];
let mut matrix_col_offset = 0;
for commitment in commitments {
let submatrix: Vec<M::Source> = verifier_state.prover_hint_ark()?;
self.matrix_commit.verify(
verifier_state,
&commitment.matrix_commitment,
indices,
&submatrix,
)?;
if stride != 0 && self.num_cols() != 0 {
for (dst, src) in zip_strict(
matrix.chunks_exact_mut(stride),
submatrix.chunks_exact(self.num_cols()),
) {
dst[matrix_col_offset..matrix_col_offset + self.num_cols()]
.copy_from_slice(src);
}
}
matrix_col_offset += self.num_cols();
}
Ok(Evaluations { points, matrix })
}
fn in_domain_challenges<T>(&self, transcript: &mut T) -> (Vec<usize>, Vec<M::Source>)
where
T: VerifierMessage,
u8: Decoding<[T::U]>,
{
let indices = challenge_indices(
transcript,
self.codeword_length,
self.in_domain_samples,
self.deduplicate_in_domain,
);
let generator = self.generator();
let points = indices
.iter()
.map(|index| generator.pow([*index as u64]))
.collect::<Vec<_>>();
(indices, points)
}
}
impl<G: Field> Commitment<G> {
pub const fn out_of_domain(&self) -> &Evaluations<G> {
&self.out_of_domain
}
pub fn num_vectors(&self) -> usize {
self.out_of_domain().num_columns()
}
}
impl<F: FftField, G: Field> Witness<F, G> {
pub const fn out_of_domain(&self) -> &Evaluations<G> {
&self.out_of_domain
}
pub fn num_vectors(&self) -> usize {
self.out_of_domain().num_columns()
}
}
impl<F: Field> Evaluations<F> {
pub const fn num_points(&self) -> usize {
self.points.len()
}
pub fn num_columns(&self) -> usize {
self.matrix
.len()
.checked_div(self.num_points())
.unwrap_or_default()
}
pub fn rows(&self) -> impl Iterator<Item = &[F]> {
let cols = self.num_columns();
(0..self.num_points()).map(move |i| &self.matrix[i * cols..(i + 1) * cols])
}
pub fn lift<M>(&self, embedding: &M) -> Evaluations<M::Target>
where
M: Embedding<Source = F>,
{
Evaluations {
points: lift(embedding, &self.points),
matrix: lift(embedding, &self.matrix),
}
}
pub fn evaluators(&self, size: usize) -> impl '_ + Iterator<Item = UnivariateEvaluation<F>> {
self.points
.iter()
.map(move |&point| UnivariateEvaluation::new(point, size))
}
pub fn values<'a>(&'a self, weights: &'a [F]) -> impl 'a + Iterator<Item = F> {
self.rows().map(|row| dot(weights, row))
}
}
impl<M: Embedding> fmt::Display for Config<M>
where
M::Source: FftField,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"size {}×{}/{}",
self.num_vectors, self.vector_size, self.interleaving_depth,
)?;
write!(f, " rate 2⁻{:.2}", -self.rate().log2())?;
write!(
f,
" samples {} in- {} out-domain",
self.in_domain_samples, self.out_domain_samples
)
}
}
pub(crate) fn johnson_slack(unique_decoding: bool, rate: f64) -> f64 {
if unique_decoding {
0.0
} else {
rate.sqrt() / 20.
}
}
#[allow(clippy::cast_sign_loss)]
pub(crate) fn num_in_domain_queries(
unique_decoding: bool,
security_target: f64,
rate: f64,
) -> usize {
let slack = johnson_slack(unique_decoding, rate);
let per_sample = if unique_decoding {
f64::midpoint(1., rate)
} else {
rate.sqrt() + slack
};
(security_target / (-per_sample.log2())).ceil() as usize
}
#[cfg(test)]
mod tests {
use ark_std::rand::{
distributions::Standard, prelude::Distribution, rngs::StdRng, Rng, SeedableRng,
};
use proptest::{bool, prelude::Strategy, proptest, sample::select, strategy::Just};
use super::*;
use crate::{
algebra::{
embedding::{Basefield, Compose, Frobenius, Identity},
fields, univariate_evaluate,
},
transcript::{codecs::U64, DomainSeparator},
};
pub fn config<M: Embedding + Clone>(
embedding: M,
num_vectors: usize,
vector_size: usize,
interleaving_depth: usize,
) -> impl Strategy<Value = Config<M>>
where
M::Source: FftField,
{
assert!(interleaving_depth != 0);
assert!(vector_size.is_multiple_of(interleaving_depth));
let message_length = vector_size / interleaving_depth;
let valid_codeword_lengths = (1..=30)
.map(|n| n * message_length)
.filter(|&n| ntt::generator::<M::Source>(n).is_some())
.collect::<Vec<_>>();
let codeword_length = select(valid_codeword_lengths);
let codeword_matrix = codeword_length.prop_flat_map(move |codeword_length| {
(
Just(codeword_length),
matrix_commit::tests::config::<M::Source>(
codeword_length,
interleaving_depth * num_vectors,
),
)
});
(codeword_matrix, 0_usize..=10, 0_usize..=10, bool::ANY).prop_map(
move |(
(codeword_length, matrix_commit),
in_domain_samples,
out_domain_samples,
deduplicate_in_domain,
)| Config {
embedding: Typed::new(embedding.clone()),
num_vectors,
vector_size,
codeword_length,
interleaving_depth,
matrix_commit,
johnson_slack: OrderedFloat::default(),
in_domain_samples,
out_domain_samples,
deduplicate_in_domain,
},
)
}
fn test<M: Embedding>(seed: u64, config: &Config<M>)
where
M::Source: FftField + ProverMessage,
M::Target: Codec,
Standard: Distribution<M::Source> + Distribution<M::Target>,
{
crate::tests::init();
let instance = U64(seed);
let ds = DomainSeparator::protocol(config)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&instance);
let mut rng = StdRng::seed_from_u64(seed);
let vectors = (0..config.num_vectors)
.map(|_| {
(0..config.vector_size)
.map(|_| rng.gen::<M::Source>())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let mut prover_state = ProverState::new_std(&ds);
let witness = config.commit(
&mut prover_state,
&vectors.iter().map(|p| p.as_slice()).collect::<Vec<_>>(),
);
assert_eq!(
witness.out_of_domain().points.len(),
config.out_domain_samples
);
assert_eq!(
witness.out_of_domain().matrix.len(),
config.out_domain_samples * config.num_vectors
);
if config.num_vectors > 0 {
for (point, evals) in zip_strict(
witness.out_of_domain().points.iter(),
witness
.out_of_domain()
.matrix
.chunks_exact(config.num_vectors),
) {
for (vector, expected) in zip_strict(vectors.iter(), evals.iter()) {
assert_eq!(
mixed_univariate_evaluate(config.embedding(), vector, *point),
*expected
);
}
}
}
let in_domain_evals = config.open(&mut prover_state, &[&witness]);
if config.deduplicate_in_domain {
assert!(in_domain_evals.points.len() <= config.in_domain_samples);
assert!({
let mut unique = in_domain_evals.points.clone();
unique.sort_unstable();
unique.dedup();
unique.len() == in_domain_evals.points.len()
});
} else {
assert_eq!(in_domain_evals.points.len(), config.in_domain_samples);
}
assert_eq!(
in_domain_evals.matrix.len(),
in_domain_evals.points.len() * config.num_vectors * config.interleaving_depth
);
if config.num_vectors > 0 {
let base = config.vector_size / config.interleaving_depth;
for (point, evals) in zip_strict(
&in_domain_evals.points,
in_domain_evals
.matrix
.chunks_exact(config.num_vectors * config.interleaving_depth),
) {
let expected_iter = vectors.iter().flat_map(|poly| {
(0..config.interleaving_depth).map(|j| {
let start = j * base;
let coeffs: Vec<_> = poly.iter().copied().skip(start).take(base).collect();
univariate_evaluate(&coeffs, *point)
})
});
for (expected, got) in zip_strict(expected_iter, evals.iter()) {
assert_eq!(expected, *got);
}
}
}
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitment = config.receive_commitment(&mut verifier_state).unwrap();
assert_eq!(commitment.out_of_domain(), witness.out_of_domain());
let verifier_in_domain_evals = config.verify(&mut verifier_state, &[&commitment]).unwrap();
assert_eq!(&verifier_in_domain_evals, &in_domain_evals);
verifier_state.check_eof().unwrap();
}
fn proptest<M: Embedding>(embedding: &M)
where
M::Source: FftField + ProverMessage,
M::Target: FftField + Codec,
Standard: Distribution<M::Source> + Distribution<M::Target>,
{
let valid_sizes = (1..=1024)
.filter(|&n| ntt::generator::<M::Source>(n).is_some())
.collect::<Vec<_>>();
let size = select(valid_sizes);
let config = (0_usize..=3, size, 1_usize..=10).prop_flat_map(
|(num_vectors, size, interleaving_depth)| {
config(
embedding.clone(),
num_vectors,
size * interleaving_depth,
interleaving_depth,
)
},
);
proptest!(|(
seed: u64,
config in config,
)| {
test(seed, &config);
});
}
#[test]
fn test_field64_1() {
proptest(&Identity::<fields::Field64>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_field64_2() {
proptest(&Identity::<fields::Field64_2>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_field64_3() {
proptest(&Identity::<fields::Field64_3>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_field128() {
proptest(&Identity::<fields::Field128>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_field192() {
proptest(&Identity::<fields::Field192>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_field256() {
proptest(&Identity::<fields::Field256>::new());
}
#[test]
fn test_basefield_field64_2() {
proptest(&Basefield::<fields::Field64_2>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_basefield_field64_3() {
proptest(&Basefield::<fields::Field64_3>::new());
}
#[test]
#[ignore = "Somewhat expensive and redundant"]
fn test_base_frob_field64_3() {
let embedding = Compose::new(Basefield::<fields::Field64_3>::new(), Frobenius::new(2));
proptest(&embedding);
}
}