lcpc_2d/
lib.rs

1// Copyright 2021 Riad S. Wahby <rsw@cs.stanford.edu>
2//
3// This file is part of lcpc-2d, which is part of lcpc.
4//
5// Licensed under the Apache License, Version 2.0 (see
6// LICENSE or https://www.apache.org/licenses/LICENSE-2.0).
7// This file may not be copied, modified, or distributed
8// except according to those terms.
9#![deny(missing_docs)]
10
11/*!
12lcpc2d is a polynomial commitment scheme based on linear codes
13*/
14
15use digest::{Digest, Output};
16use err_derive::Error;
17use ff::{Field, PrimeField};
18use merlin::Transcript;
19use rand::{
20    distributions::{Distribution, Uniform},
21    SeedableRng,
22};
23use rand_chacha::ChaCha20Rng;
24use rayon::prelude::*;
25use serde::{Deserialize, Deserializer, Serialize, Serializer};
26use std::iter::repeat_with;
27
28mod macros;
29
30#[cfg(test)]
31mod tests;
32
33/// Trait for a field element that can be hashed via [digest::Digest]
34pub trait FieldHash {
35    /// A representation of `Self` that can be converted to a slice of `u8`.
36    type HashRepr: AsRef<[u8]>;
37
38    /// Convert `Self` into a `HashRepr` for hashing
39    fn to_hash_repr(&self) -> Self::HashRepr;
40
41    /// Update the digest `d` with the `HashRepr` of `Self`
42    fn digest_update<D: Digest>(&self, d: &mut D) {
43        d.update(self.to_hash_repr())
44    }
45
46    /// Update the [merlin::Transcript] `t` with the `HashRepr` of `Self` with label `l`
47    fn transcript_update(&self, t: &mut Transcript, l: &'static [u8]) {
48        t.append_message(l, self.to_hash_repr().as_ref())
49    }
50}
51
52impl<T: PrimeField> FieldHash for T {
53    type HashRepr = T::Repr;
54
55    fn to_hash_repr(&self) -> Self::HashRepr {
56        PrimeField::to_repr(self)
57    }
58}
59
60/// Trait representing bit size information for a field
61pub trait SizedField {
62    /// Ceil of log2(cardinality)
63    const CLOG2: u32;
64    /// Floor of log2(cardinality)
65    const FLOG2: u32;
66}
67
68impl<T: PrimeField> SizedField for T {
69    const CLOG2: u32 = <T as PrimeField>::NUM_BITS;
70    const FLOG2: u32 = <T as PrimeField>::NUM_BITS - 1;
71}
72
73/// Trait for a linear encoding used by the polycommit
74pub trait LcEncoding: Clone + std::fmt::Debug + Sync {
75    /// Field over which coefficients are defined
76    type F: Field + FieldHash + std::fmt::Debug + Clone;
77
78    /// Domain separation label - degree test (see def_labels!())
79    const LABEL_DT: &'static [u8];
80    /// Domain separation label - random lin combs (see def_labels!())
81    const LABEL_PR: &'static [u8];
82    /// Domain separation label - eval comb (see def_labels!())
83    const LABEL_PE: &'static [u8];
84    /// Domain separation label - column openings (see def_labels!())
85    const LABEL_CO: &'static [u8];
86
87    /// Error type for encoding
88    type Err: std::fmt::Debug + std::error::Error + Send;
89
90    /// Encoding function
91    fn encode<T: AsMut<[Self::F]>>(&self, inp: T) -> Result<(), Self::Err>;
92
93    /// Get dimensions for this encoding instance on an input vector of length `len`
94    fn get_dims(&self, len: usize) -> (usize, usize, usize);
95
96    /// Check that supplied dimensions are compatible with this encoding
97    fn dims_ok(&self, n_per_row: usize, n_cols: usize) -> bool;
98
99    /// Get the number of column openings required for this encoding
100    fn get_n_col_opens(&self) -> usize;
101
102    /// Get the number of degree tests required for this encoding
103    fn get_n_degree_tests(&self) -> usize;
104}
105
106// local accessors for enclosed types
107type FldT<E> = <E as LcEncoding>::F;
108type ErrT<E> = <E as LcEncoding>::Err;
109
110/// Err variant for prover operations
111#[derive(Debug, Error)]
112pub enum ProverError<ErrT>
113where
114    ErrT: std::fmt::Debug + std::error::Error + 'static,
115{
116    /// size too big
117    #[error(display = "n_cols is too large for this encoding")]
118    TooBig,
119    /// error encoding a vector
120    #[error(display = "encoding error: {:?}", _0)]
121    Encode(#[source] ErrT),
122    /// inconsistent LcCommit fields
123    #[error(display = "inconsistent commitment fields")]
124    Commit,
125    /// bad column number
126    #[error(display = "bad column number")]
127    ColumnNumber,
128    /// bad outer tensor
129    #[error(display = "outer tensor: wrong size")]
130    OuterTensor,
131}
132
133/// result of a prover operation
134pub type ProverResult<T, ErrT> = Result<T, ProverError<ErrT>>;
135
136/// Err variant for verifier operations
137#[derive(Debug, Error)]
138pub enum VerifierError<ErrT>
139where
140    ErrT: std::fmt::Debug + std::error::Error + 'static,
141{
142    /// wrong number of column openings in proof
143    #[error(display = "wrong number of column openings in proof")]
144    NumColOpens,
145    /// failed to verify column merkle path
146    #[error(display = "column verification: merkle path failed")]
147    ColumnPath,
148    /// failed to verify column dot product for poly eval
149    #[error(display = "column verification: eval dot product failed")]
150    ColumnEval,
151    /// failed to verify column dot product for degree test
152    #[error(display = "column verification: degree test dot product failed")]
153    ColumnDegree,
154    /// bad outer tensor
155    #[error(display = "outer tensor: wrong size")]
156    OuterTensor,
157    /// bad inner tensor
158    #[error(display = "inner tensor: wrong size")]
159    InnerTensor,
160    /// encoding dimensions do not match proof
161    #[error(display = "encoding dimension mismatch")]
162    EncodingDims,
163    /// error encoding a vector
164    #[error(display = "encoding error: {:?}", _0)]
165    Encode(#[source] ErrT),
166}
167
168/// result of a verifier operation
169pub type VerifierResult<T, ErrT> = Result<T, VerifierError<ErrT>>;
170
171/// a commitment
172#[derive(Debug, Clone)]
173pub struct LcCommit<D, E>
174where
175    D: Digest,
176    E: LcEncoding,
177{
178    comm: Vec<FldT<E>>,
179    coeffs: Vec<FldT<E>>,
180    n_rows: usize,
181    n_cols: usize,
182    n_per_row: usize,
183    hashes: Vec<Output<D>>,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187struct WrappedLcCommit<F>
188where
189    F: Serialize,
190{
191    comm: Vec<F>,
192    coeffs: Vec<F>,
193    n_rows: usize,
194    n_cols: usize,
195    n_per_row: usize,
196    hashes: Vec<WrappedOutput>,
197}
198
199impl<F> WrappedLcCommit<F>
200where
201    F: Serialize,
202{
203    /// turn a WrappedLcCommit into an LcCommit
204    fn unwrap<D, E>(self) -> LcCommit<D, E>
205    where
206        D: Digest,
207        E: LcEncoding<F = F>,
208    {
209        let hashes = self.hashes.into_iter().map(|c| c.unwrap::<D, E>().root).collect();
210
211        LcCommit {
212            comm: self.comm,
213            coeffs: self.coeffs,
214            n_rows: self.n_rows,
215            n_cols: self.n_cols,
216            n_per_row: self.n_per_row,
217            hashes,
218        }
219    }
220}
221
222impl<D, E> LcCommit<D, E>
223where
224    D: Digest,
225    E: LcEncoding,
226    E::F: Serialize,
227{
228    fn wrapped(&self) -> WrappedLcCommit<FldT<E>> {
229        let hashes_wrapped = self.hashes.iter().map(|h| WrappedOutput { bytes: h.to_vec() }).collect();
230
231        WrappedLcCommit {
232            comm: self.comm.clone(),
233            coeffs: self.coeffs.clone(),
234            n_rows: self.n_rows,
235            n_cols: self.n_cols,
236            n_per_row: self.n_per_row,
237            hashes: hashes_wrapped,
238        }
239    }
240}
241
242impl<D, E> Serialize for LcCommit<D, E>
243where
244    D: Digest,
245    E: LcEncoding,
246    E::F: Serialize,
247{
248    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
249    where
250        S: Serializer,
251    {
252        self.wrapped().serialize(serializer)
253    }
254}
255
256impl<'de, D, E> Deserialize<'de> for LcCommit<D, E>
257where
258    D: Digest,
259    E: LcEncoding,
260    E::F: Serialize + Deserialize<'de>,
261{
262    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
263    where
264        De: Deserializer<'de>,
265    {
266        Ok(WrappedLcCommit::<FldT<E>>::deserialize(deserializer)?.unwrap())
267    }
268}
269
270impl<D, E> LcCommit<D, E>
271where
272    D: Digest,
273    E: LcEncoding,
274{
275    /// returns the Merkle root of this polynomial commitment (which is the commitment itself)
276    pub fn get_root(&self) -> LcRoot<D, E> {
277        LcRoot {
278            root: self.hashes.last().cloned().unwrap(),
279            _p: Default::default(),
280        }
281    }
282
283    /// return the number of coefficients encoded in each matrix row
284    pub fn get_n_per_row(&self) -> usize {
285        self.n_per_row
286    }
287
288    /// return the number of columns in the encoded matrix
289    pub fn get_n_cols(&self) -> usize {
290        self.n_cols
291    }
292
293    /// return the number of rows in the encoded matrix
294    pub fn get_n_rows(&self) -> usize {
295        self.n_rows
296    }
297
298    /// generate a commitment to a polynomial
299    pub fn commit(coeffs: &[FldT<E>], enc: &E) -> ProverResult<Self, ErrT<E>> {
300        commit(coeffs, enc)
301    }
302
303    /// Generate an evaluation of a committed polynomial
304    pub fn prove(
305        &self,
306        outer_tensor: &[FldT<E>],
307        enc: &E,
308        tr: &mut Transcript,
309    ) -> ProverResult<LcEvalProof<D, E>, ErrT<E>> {
310        prove(self, outer_tensor, enc, tr)
311    }
312}
313
314/// A Merkle root corresponding to a committed polynomial
315#[derive(Debug, Clone)]
316pub struct LcRoot<D, E>
317where
318    D: Digest,
319    E: LcEncoding,
320{
321    root: Output<D>,
322    _p: std::marker::PhantomData<E>,
323}
324
325impl<D, E> LcRoot<D, E>
326where
327    D: Digest,
328    E: LcEncoding,
329{
330    fn wrapped(&self) -> WrappedOutput {
331        WrappedOutput {
332            bytes: self.root.to_vec(),
333        }
334    }
335
336    /// Convert this value into a raw Output<D>
337    pub fn into_raw(self) -> Output<D> {
338        self.root
339    }
340}
341
342impl<D, E> AsRef<Output<D>> for LcRoot<D, E>
343where
344    D: Digest,
345    E: LcEncoding,
346{
347    fn as_ref(&self) -> &Output<D> {
348        &self.root
349    }
350}
351
352// support impl for serializing and deserializing proofs
353#[derive(Debug, Clone, Deserialize, Serialize)]
354struct WrappedOutput {
355    /// wrapped output
356    #[serde(with = "serde_bytes")]
357    pub bytes: Vec<u8>,
358}
359
360impl WrappedOutput {
361    fn unwrap<D, E>(self) -> LcRoot<D, E>
362    where
363        D: Digest,
364        E: LcEncoding,
365    {
366        LcRoot {
367            root: self.bytes.into_iter().collect::<Output<D>>(),
368            _p: Default::default(),
369        }
370    }
371}
372
373impl<D, E> Serialize for LcRoot<D, E>
374where
375    D: Digest,
376    E: LcEncoding,
377{
378    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
379    where
380        S: Serializer,
381    {
382        self.wrapped().serialize(serializer)
383    }
384}
385
386impl<'de, D, E> Deserialize<'de> for LcRoot<D, E>
387where
388    D: Digest,
389    E: LcEncoding,
390{
391    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
392    where
393        De: Deserializer<'de>,
394    {
395        Ok(WrappedOutput::deserialize(deserializer)?.unwrap())
396    }
397}
398
399/// A column opening and the corresponding Merkle path.
400#[derive(Debug, Clone)]
401pub struct LcColumn<D, E>
402where
403    D: Digest,
404    E: LcEncoding,
405{
406    col: Vec<FldT<E>>,
407    path: Vec<Output<D>>,
408}
409
410impl<D, E> LcColumn<D, E>
411where
412    D: Digest,
413    E: LcEncoding,
414    E::F: Serialize,
415{
416    fn wrapped(&self) -> WrappedLcColumn<FldT<E>> {
417        let path_wrapped = (0..self.path.len())
418            .map(|i| WrappedOutput {
419                bytes: self.path[i].to_vec(),
420            })
421            .collect();
422
423        WrappedLcColumn {
424            col: self.col.clone(),
425            path: path_wrapped,
426        }
427    }
428}
429
430// A column opening and the corresponding Merkle path.
431#[derive(Debug, Clone, Deserialize, Serialize)]
432struct WrappedLcColumn<F>
433where
434    F: Serialize,
435{
436    col: Vec<F>,
437    path: Vec<WrappedOutput>,
438}
439
440impl<F> WrappedLcColumn<F>
441where
442    F: Serialize,
443{
444    /// turn WrappedLcColumn into LcColumn
445    fn unwrap<D, E>(self) -> LcColumn<D, E>
446    where
447        D: Digest,
448        E: LcEncoding<F = F>,
449    {
450        let col = self.col;
451        let path = self
452            .path
453            .into_iter()
454            .map(|v| v.bytes.into_iter().collect::<Output<D>>())
455            .collect();
456
457        LcColumn { col, path }
458    }
459}
460
461impl<D, E> Serialize for LcColumn<D, E>
462where
463    D: Digest,
464    E: LcEncoding,
465    E::F: Serialize,
466{
467    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
468    where
469        S: Serializer,
470    {
471        self.wrapped().serialize(serializer)
472    }
473}
474
475impl<'de, D, E> Deserialize<'de> for LcColumn<D, E>
476where
477    D: Digest,
478    E: LcEncoding,
479    E::F: Serialize + Deserialize<'de>,
480{
481    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
482    where
483        De: Deserializer<'de>,
484    {
485        Ok(WrappedLcColumn::<FldT<E>>::deserialize(deserializer)?.unwrap())
486    }
487}
488
489/// An evaluation and proof of its correctness and of the low-degreeness of the commitment.
490#[derive(Debug, Clone)]
491pub struct LcEvalProof<D, E>
492where
493    D: Digest,
494    E: LcEncoding,
495{
496    n_cols: usize,
497    p_eval: Vec<FldT<E>>,
498    p_random_vec: Vec<Vec<FldT<E>>>,
499    columns: Vec<LcColumn<D, E>>,
500}
501
502impl<D, E> LcEvalProof<D, E>
503where
504    D: Digest,
505    E: LcEncoding,
506{
507    /// Get the number of elements in an encoded vector
508    pub fn get_n_cols(&self) -> usize {
509        self.n_cols
510    }
511
512    /// Get the number of elements in an unencoded vector
513    pub fn get_n_per_row(&self) -> usize {
514        self.p_eval.len()
515    }
516
517    /// Verify an evaluation proof and return the resulting evaluation
518    pub fn verify(
519        &self,
520        root: &Output<D>,
521        outer_tensor: &[FldT<E>],
522        inner_tensor: &[FldT<E>],
523        enc: &E,
524        tr: &mut Transcript,
525    ) -> VerifierResult<FldT<E>, ErrT<E>> {
526        verify(root, outer_tensor, inner_tensor, self, enc, tr)
527    }
528}
529
530impl<D, E> LcEvalProof<D, E>
531where
532    D: Digest,
533    E: LcEncoding,
534    E::F: Serialize,
535{
536    fn wrapped(&self) -> WrappedLcEvalProof<FldT<E>> {
537        let columns_wrapped = (0..self.columns.len())
538            .map(|i| self.columns[i].wrapped())
539            .collect();
540
541        WrappedLcEvalProof {
542            n_cols: self.n_cols,
543            p_eval: self.p_eval.clone(),
544            p_random_vec: self.p_random_vec.clone(),
545            columns: columns_wrapped,
546        }
547    }
548}
549
550/// An evaluation and proof of its correctness and of the low-degreeness of the commitment.
551#[derive(Debug, Clone, Deserialize, Serialize)]
552pub struct WrappedLcEvalProof<F>
553where
554    F: Serialize,
555{
556    n_cols: usize,
557    p_eval: Vec<F>,
558    p_random_vec: Vec<Vec<F>>,
559    columns: Vec<WrappedLcColumn<F>>,
560}
561
562impl<F> WrappedLcEvalProof<F>
563where
564    F: Serialize,
565{
566    /// turn a WrappedLcEvalProof into an LcEvalProof
567    fn unwrap<D, E>(self) -> LcEvalProof<D, E>
568    where
569        D: Digest,
570        E: LcEncoding<F = F>,
571    {
572        let columns = self.columns.into_iter().map(|c| c.unwrap()).collect();
573
574        LcEvalProof {
575            n_cols: self.n_cols,
576            p_eval: self.p_eval,
577            p_random_vec: self.p_random_vec,
578            columns,
579        }
580    }
581}
582
583impl<D, E> Serialize for LcEvalProof<D, E>
584where
585    D: Digest,
586    E: LcEncoding,
587    E::F: Serialize,
588{
589    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
590    where
591        S: Serializer,
592    {
593        self.wrapped().serialize(serializer)
594    }
595}
596
597impl<'de, D, E> Deserialize<'de> for LcEvalProof<D, E>
598where
599    D: Digest,
600    E: LcEncoding,
601    E::F: Serialize + Deserialize<'de>,
602{
603    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
604    where
605        De: Deserializer<'de>,
606    {
607        Ok(WrappedLcEvalProof::<FldT<E>>::deserialize(deserializer)?.unwrap())
608    }
609}
610
611/// Compute number of degree tests required for `lambda`-bit security
612/// for a code with `len`-length codewords over `flog2`-bit field
613pub fn n_degree_tests(lambda: usize, len: usize, flog2: usize) -> usize {
614    let den = flog2 - log2(len);
615    (lambda + den - 1) / den
616}
617
618// parallelization limit when working on columns
619const LOG_MIN_NCOLS: usize = 5;
620
621/// Commit to a univariate polynomial whose coefficients are `coeffs` using encoding `enc`
622fn commit<D, E>(coeffs_in: &[FldT<E>], enc: &E) -> ProverResult<LcCommit<D, E>, ErrT<E>>
623where
624    D: Digest,
625    E: LcEncoding,
626{
627    let (n_rows, n_per_row, n_cols) = enc.get_dims(coeffs_in.len());
628
629    // check that parameters are ok
630    assert!(n_rows * n_per_row >= coeffs_in.len());
631    assert!((n_rows - 1) * n_per_row < coeffs_in.len());
632    assert!(enc.dims_ok(n_per_row, n_cols));
633
634    // matrix (encoded as a vector)
635    // XXX(zk) pad coeffs
636    let mut coeffs = vec![FldT::<E>::zero(); n_rows * n_per_row];
637    let mut comm = vec![FldT::<E>::zero(); n_rows * n_cols];
638
639    // local copy of coeffs with padding
640    coeffs
641        .par_chunks_mut(n_per_row)
642        .zip(coeffs_in.par_chunks(n_per_row))
643        .for_each(|(c, c_in)| {
644            c[..c_in.len()].copy_from_slice(c_in);
645        });
646
647    // now compute FFTs
648    comm.par_chunks_mut(n_cols)
649        .zip(coeffs.par_chunks(n_per_row))
650        .try_for_each(|(r, c)| {
651            r[..c.len()].copy_from_slice(c);
652            enc.encode(r)
653        })?;
654
655    // compute Merkle tree
656    let n_cols_np2 = n_cols
657        .checked_next_power_of_two()
658        .ok_or(ProverError::TooBig)?;
659    let mut ret = LcCommit {
660        comm,
661        coeffs,
662        n_rows,
663        n_cols,
664        n_per_row,
665        hashes: vec![<Output<D> as Default>::default(); 2 * n_cols_np2 - 1],
666    };
667    check_comm(&ret, enc)?;
668    merkleize(&mut ret);
669
670    Ok(ret)
671}
672
673fn check_comm<D, E>(comm: &LcCommit<D, E>, enc: &E) -> ProverResult<(), ErrT<E>>
674where
675    D: Digest,
676    E: LcEncoding,
677{
678    let comm_sz = comm.comm.len() != comm.n_rows * comm.n_cols;
679    let coeff_sz = comm.coeffs.len() != comm.n_rows * comm.n_per_row;
680    let hashlen = comm.hashes.len() != 2 * comm.n_cols.next_power_of_two() - 1;
681    let dims = !enc.dims_ok(comm.n_per_row, comm.n_cols);
682
683    if comm_sz || coeff_sz || hashlen || dims {
684        Err(ProverError::Commit)
685    } else {
686        Ok(())
687    }
688}
689
690fn merkleize<D, E>(comm: &mut LcCommit<D, E>)
691where
692    D: Digest,
693    E: LcEncoding,
694{
695    // step 1: hash each column of the commitment (we always reveal a full column)
696    let hashes = &mut comm.hashes[..comm.n_cols];
697    hash_columns::<D, E>(&comm.comm, hashes, comm.n_rows, comm.n_cols, 0);
698
699    // step 2: compute rest of Merkle tree
700    let len_plus_one = comm.hashes.len() + 1;
701    assert!(len_plus_one.is_power_of_two());
702    let (hin, hout) = comm.hashes.split_at_mut(len_plus_one / 2);
703    merkle_tree::<D>(hin, hout);
704}
705
706fn hash_columns<D, E>(
707    comm: &[FldT<E>],
708    hashes: &mut [Output<D>],
709    n_rows: usize,
710    n_cols: usize,
711    offset: usize,
712) where
713    D: Digest,
714    E: LcEncoding,
715{
716    if hashes.len() <= (1 << LOG_MIN_NCOLS) {
717        // base case: run the computation
718        // 1. prepare the digests for each column
719        let mut digests = Vec::with_capacity(hashes.len());
720        for _ in 0..hashes.len() {
721            // column hashes start with a block of 0's
722            let mut dig = D::new();
723            dig.update(<Output<D> as Default>::default());
724            digests.push(dig);
725        }
726        // 2. for each row, update the digests for each column
727        for row in 0..n_rows {
728            for (col, digest) in digests.iter_mut().enumerate() {
729                comm[row * n_cols + offset + col].digest_update(digest);
730            }
731        }
732        // 3. finalize each digest and write the results back
733        for (col, digest) in digests.into_iter().enumerate() {
734            hashes[col] = digest.finalize();
735        }
736    } else {
737        // recursive case: split and execute in parallel
738        let half_cols = hashes.len() / 2;
739        let (lo, hi) = hashes.split_at_mut(half_cols);
740        rayon::join(
741            || hash_columns::<D, E>(comm, lo, n_rows, n_cols, offset),
742            || hash_columns::<D, E>(comm, hi, n_rows, n_cols, offset + half_cols),
743        );
744    }
745}
746
747fn merkle_tree<D>(ins: &[Output<D>], outs: &mut [Output<D>])
748where
749    D: Digest,
750{
751    // array should always be of length 2^k - 1
752    assert_eq!(ins.len(), outs.len() + 1);
753
754    let (outs, rems) = outs.split_at_mut((outs.len() + 1) / 2);
755    merkle_layer::<D>(ins, outs);
756
757    if !rems.is_empty() {
758        merkle_tree::<D>(outs, rems)
759    }
760}
761
762fn merkle_layer<D>(ins: &[Output<D>], outs: &mut [Output<D>])
763where
764    D: Digest,
765{
766    assert_eq!(ins.len(), 2 * outs.len());
767
768    if ins.len() <= (1 << LOG_MIN_NCOLS) {
769        // base case: just compute all of the hashes
770        let mut digest = D::new();
771        for idx in 0..outs.len() {
772            digest.update(ins[2 * idx].as_ref());
773            digest.update(ins[2 * idx + 1].as_ref());
774            outs[idx] = digest.finalize_reset();
775        }
776    } else {
777        // recursive case: split and compute
778        let (inl, inr) = ins.split_at(ins.len() / 2);
779        let (outl, outr) = outs.split_at_mut(outs.len() / 2);
780        rayon::join(
781            || merkle_layer::<D>(inl, outl),
782            || merkle_layer::<D>(inr, outr),
783        );
784    }
785}
786
787// Open the commitment to one column
788fn open_column<D, E>(
789    comm: &LcCommit<D, E>,
790    mut column: usize,
791) -> ProverResult<LcColumn<D, E>, ErrT<E>>
792where
793    D: Digest,
794    E: LcEncoding,
795{
796    // make sure arguments are well formed
797    if column >= comm.n_cols {
798        return Err(ProverError::ColumnNumber);
799    }
800
801    // column of values
802    let col = comm
803        .comm
804        .iter()
805        .skip(column)
806        .step_by(comm.n_cols)
807        .cloned()
808        .collect();
809
810    // Merkle path
811    let mut hashes = &comm.hashes[..];
812    let path_len = log2(comm.n_cols);
813    let mut path = Vec::with_capacity(path_len);
814    for _ in 0..path_len {
815        let other = (column & !1) | (!column & 1);
816        assert_eq!(other ^ column, 1);
817        path.push(hashes[other].clone());
818        let (_, hashes_new) = hashes.split_at((hashes.len() + 1) / 2);
819        hashes = hashes_new;
820        column >>= 1;
821    }
822    assert_eq!(column, 0);
823
824    Ok(LcColumn { col, path })
825}
826
827const fn log2(v: usize) -> usize {
828    (63 - (v.next_power_of_two() as u64).leading_zeros()) as usize
829}
830
831/// Verify the evaluation of a committed polynomial and return the result
832fn verify<D, E>(
833    root: &Output<D>,
834    outer_tensor: &[FldT<E>],
835    inner_tensor: &[FldT<E>],
836    proof: &LcEvalProof<D, E>,
837    enc: &E,
838    tr: &mut Transcript,
839) -> VerifierResult<FldT<E>, ErrT<E>>
840where
841    D: Digest,
842    E: LcEncoding,
843{
844    // make sure arguments are well formed
845    let n_col_opens = enc.get_n_col_opens();
846    if n_col_opens != proof.columns.len() || n_col_opens == 0 {
847        return Err(VerifierError::NumColOpens);
848    }
849    let n_rows = proof.columns[0].col.len();
850    let n_cols = proof.get_n_cols();
851    let n_per_row = proof.get_n_per_row();
852    if inner_tensor.len() != n_per_row {
853        return Err(VerifierError::InnerTensor);
854    }
855    if outer_tensor.len() != n_rows {
856        return Err(VerifierError::OuterTensor);
857    }
858    if !enc.dims_ok(n_per_row, n_cols) {
859        return Err(VerifierError::EncodingDims);
860    }
861
862    // step 1: random tensor for degree test and random columns to test
863    // step 1a: extract random tensor from transcript
864    // we run multiple instances of this to boost soundness
865    let mut rand_tensor_vec: Vec<Vec<FldT<E>>> = Vec::new();
866    let mut p_random_fft: Vec<Vec<FldT<E>>> = Vec::new();
867    let n_degree_tests = enc.get_n_degree_tests();
868    for i in 0..n_degree_tests {
869        let rand_tensor: Vec<FldT<E>> = {
870            let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
871            tr.challenge_bytes(E::LABEL_DT, &mut key);
872            let mut deg_test_rng = ChaCha20Rng::from_seed(key);
873            // XXX(optimization) could expand seed in parallel instead of in series
874            repeat_with(|| FldT::<E>::random(&mut deg_test_rng))
875                .take(n_rows)
876                .collect()
877        };
878
879        rand_tensor_vec.push(rand_tensor);
880
881        // step 1b: eval encoding of p_random
882        {
883            let mut tmp = Vec::with_capacity(n_cols);
884            tmp.extend_from_slice(&proof.p_random_vec[i][..]);
885            tmp.resize(n_cols, FldT::<E>::zero());
886            enc.encode(&mut tmp)?;
887            p_random_fft.push(tmp);
888        };
889
890        // step 1c: push p_random and p_eval into transcript
891        proof.p_random_vec[i]
892            .iter()
893            .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PR));
894    }
895
896    proof
897        .p_eval
898        .iter()
899        .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PE));
900
901    // step 1d: extract columns to open
902    let cols_to_open: Vec<usize> = {
903        let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
904        tr.challenge_bytes(E::LABEL_CO, &mut key);
905        let mut cols_rng = ChaCha20Rng::from_seed(key);
906        // XXX(optimization) could expand seed in parallel instead of in series
907        let col_range = Uniform::new(0usize, n_cols);
908        repeat_with(|| col_range.sample(&mut cols_rng))
909            .take(n_col_opens)
910            .collect()
911    };
912
913    // step 2: p_eval fft for column checks
914    let p_eval_fft = {
915        let mut tmp = Vec::with_capacity(n_cols);
916        tmp.extend_from_slice(&proof.p_eval[..]);
917        tmp.resize(n_cols, FldT::<E>::zero());
918        enc.encode(&mut tmp)?;
919        tmp
920    };
921
922    // step 3: check p_random, p_eval, and col paths
923    cols_to_open
924        .par_iter()
925        .zip(&proof.columns[..])
926        .try_for_each(|(&col_num, column)| {
927            let rand = {
928                let mut rand = true;
929                for i in 0..n_degree_tests {
930                    rand &=
931                        verify_column_value(column, &rand_tensor_vec[i], &p_random_fft[i][col_num]);
932                }
933                rand
934            };
935
936            let eval = verify_column_value(column, outer_tensor, &p_eval_fft[col_num]);
937            let path = verify_column_path(column, col_num, root);
938            match (rand, eval, path) {
939                (false, _, _) => Err(VerifierError::ColumnDegree),
940                (_, false, _) => Err(VerifierError::ColumnEval),
941                (_, _, false) => Err(VerifierError::ColumnPath),
942                _ => Ok(()),
943            }
944        })?;
945
946    // step 4: evaluate and return
947    Ok(inner_tensor
948        .par_iter()
949        .zip(&proof.p_eval[..])
950        .fold(FldT::<E>::zero, |a, (t, e)| a + *t * e)
951        .reduce(FldT::<E>::zero, |a, v| a + v))
952}
953
954// Check a column opening
955fn verify_column_path<D, E>(column: &LcColumn<D, E>, col_num: usize, root: &Output<D>) -> bool
956where
957    D: Digest,
958    E: LcEncoding,
959{
960    let mut digest = D::new();
961    digest.update(<Output<D> as Default>::default());
962    for e in &column.col[..] {
963        e.digest_update(&mut digest);
964    }
965
966    // check Merkle path
967    let mut hash = digest.finalize_reset();
968    let mut col = col_num;
969    for p in &column.path[..] {
970        if col % 2 == 0 {
971            digest.update(&hash);
972            digest.update(p);
973        } else {
974            digest.update(p);
975            digest.update(&hash);
976        }
977        hash = digest.finalize_reset();
978        col >>= 1;
979    }
980
981    &hash == root
982}
983
984// check column value
985fn verify_column_value<D, E>(
986    column: &LcColumn<D, E>,
987    tensor: &[FldT<E>],
988    poly_eval: &FldT<E>,
989) -> bool
990where
991    D: Digest,
992    E: LcEncoding,
993{
994    let tensor_eval = tensor
995        .iter()
996        .zip(&column.col[..])
997        .fold(FldT::<E>::zero(), |a, (t, e)| a + *t * e);
998
999    poly_eval == &tensor_eval
1000}
1001
1002/// Evaluate the committed polynomial using the supplied "outer" tensor
1003/// and generate a proof of (1) low-degreeness and (2) correct evaluation.
1004fn prove<D, E>(
1005    comm: &LcCommit<D, E>,
1006    outer_tensor: &[FldT<E>],
1007    enc: &E,
1008    tr: &mut Transcript,
1009) -> ProverResult<LcEvalProof<D, E>, ErrT<E>>
1010where
1011    D: Digest,
1012    E: LcEncoding,
1013{
1014    // make sure arguments are well formed
1015    check_comm(comm, enc)?;
1016    if outer_tensor.len() != comm.n_rows {
1017        return Err(ProverError::OuterTensor);
1018    }
1019
1020    // first, evaluate the polynomial on a random tensor (low-degree test)
1021    // we repeat this to boost soundness
1022    let mut p_random_vec: Vec<Vec<FldT<E>>> = Vec::new();
1023    let n_degree_tests = enc.get_n_degree_tests();
1024    for _i in 0..n_degree_tests {
1025        let p_random = {
1026            let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
1027            tr.challenge_bytes(E::LABEL_DT, &mut key);
1028            let mut deg_test_rng = ChaCha20Rng::from_seed(key);
1029            // XXX(optimization) could expand seed in parallel instead of in series
1030            let rand_tensor: Vec<FldT<E>> = repeat_with(|| FldT::<E>::random(&mut deg_test_rng))
1031                .take(comm.n_rows)
1032                .collect();
1033            let mut tmp = vec![FldT::<E>::zero(); comm.n_per_row];
1034            collapse_columns::<E>(
1035                &comm.coeffs,
1036                &rand_tensor,
1037                &mut tmp,
1038                comm.n_rows,
1039                comm.n_per_row,
1040                0,
1041            );
1042            tmp
1043        };
1044        // add p_random to the transcript
1045        p_random
1046            .iter()
1047            .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PR));
1048
1049        p_random_vec.push(p_random);
1050    }
1051
1052    // next, evaluate the polynomial using the supplied tensor
1053    let p_eval = {
1054        let mut tmp = vec![FldT::<E>::zero(); comm.n_per_row];
1055        collapse_columns::<E>(
1056            &comm.coeffs,
1057            outer_tensor,
1058            &mut tmp,
1059            comm.n_rows,
1060            comm.n_per_row,
1061            0,
1062        );
1063        tmp
1064    };
1065    // add p_eval to the transcript
1066    p_eval
1067        .iter()
1068        .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PE));
1069
1070    // now extract the column numbers to open
1071    let n_col_opens = enc.get_n_col_opens();
1072    let columns: Vec<LcColumn<D, E>> = {
1073        let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
1074        tr.challenge_bytes(E::LABEL_CO, &mut key);
1075        let mut cols_rng = ChaCha20Rng::from_seed(key);
1076        // XXX(optimization) could expand seed in parallel instead of in series
1077        let col_range = Uniform::new(0usize, comm.n_cols);
1078        let cols_to_open: Vec<usize> = repeat_with(|| col_range.sample(&mut cols_rng))
1079            .take(n_col_opens)
1080            .collect();
1081        cols_to_open
1082            .par_iter()
1083            .map(|&col| open_column(comm, col))
1084            .collect::<ProverResult<Vec<LcColumn<D, E>>, ErrT<E>>>()?
1085    };
1086
1087    Ok(LcEvalProof {
1088        n_cols: comm.n_cols,
1089        p_eval,
1090        p_random_vec,
1091        columns,
1092    })
1093}
1094
1095fn collapse_columns<E>(
1096    coeffs: &[FldT<E>],
1097    tensor: &[FldT<E>],
1098    poly: &mut [FldT<E>],
1099    n_rows: usize,
1100    n_per_row: usize,
1101    offset: usize,
1102) where
1103    E: LcEncoding,
1104{
1105    if poly.len() <= (1 << LOG_MIN_NCOLS) {
1106        // base case: run the computation
1107        // row-by-row, compute elements of dot product
1108        for (row, tensor_val) in tensor.iter().enumerate() {
1109            for (col, val) in poly.iter_mut().enumerate() {
1110                let entry = row * n_per_row + offset + col;
1111                *val += coeffs[entry] * tensor_val;
1112            }
1113        }
1114    } else {
1115        // recursive case: split and execute in parallel
1116        let half_cols = poly.len() / 2;
1117        let (lo, hi) = poly.split_at_mut(half_cols);
1118        rayon::join(
1119            || collapse_columns::<E>(coeffs, tensor, lo, n_rows, n_per_row, offset),
1120            || collapse_columns::<E>(coeffs, tensor, hi, n_rows, n_per_row, offset + half_cols),
1121        );
1122    }
1123}
1124
1125// TESTING ONLY //
1126
1127#[cfg(test)]
1128fn merkleize_ser<D, E>(comm: &mut LcCommit<D, E>)
1129where
1130    D: Digest,
1131    E: LcEncoding,
1132{
1133    let hashes = &mut comm.hashes;
1134
1135    // hash each column
1136    for (col, hash) in hashes.iter_mut().enumerate().take(comm.n_cols) {
1137        let mut digest = D::new();
1138        digest.update(<Output<D> as Default>::default());
1139        for row in 0..comm.n_rows {
1140            comm.comm[row * comm.n_cols + col].digest_update(&mut digest);
1141        }
1142        *hash = digest.finalize();
1143    }
1144
1145    // compute rest of Merkle tree
1146    let (mut ins, mut outs) = hashes.split_at_mut(comm.n_cols);
1147    while !outs.is_empty() {
1148        for idx in 0..ins.len() / 2 {
1149            let mut digest = D::new();
1150            digest.update(ins[2 * idx].as_ref());
1151            digest.update(ins[2 * idx + 1].as_ref());
1152            outs[idx] = digest.finalize();
1153        }
1154        let (new_ins, new_outs) = outs.split_at_mut((outs.len() + 1) / 2);
1155        ins = new_ins;
1156        outs = new_outs;
1157    }
1158}
1159
1160#[cfg(test)]
1161// Check a column opening
1162fn verify_column<D, E>(
1163    column: &LcColumn<D, E>,
1164    col_num: usize,
1165    root: &Output<D>,
1166    tensor: &[FldT<E>],
1167    poly_eval: &FldT<E>,
1168) -> bool
1169where
1170    D: Digest,
1171    E: LcEncoding,
1172{
1173    verify_column_path(column, col_num, root) && verify_column_value(column, tensor, poly_eval)
1174}
1175
1176// Evaluate the committed polynomial using the "outer" tensor
1177#[cfg(test)]
1178fn eval_outer<D, E>(
1179    comm: &LcCommit<D, E>,
1180    tensor: &[FldT<E>],
1181) -> ProverResult<Vec<FldT<E>>, ErrT<E>>
1182where
1183    D: Digest,
1184    E: LcEncoding,
1185{
1186    if tensor.len() != comm.n_rows {
1187        return Err(ProverError::OuterTensor);
1188    }
1189
1190    // allocate result and compute
1191    let mut poly = vec![FldT::<E>::zero(); comm.n_per_row];
1192    collapse_columns::<E>(
1193        &comm.coeffs,
1194        tensor,
1195        &mut poly,
1196        comm.n_rows,
1197        comm.n_per_row,
1198        0,
1199    );
1200
1201    Ok(poly)
1202}
1203
1204#[cfg(test)]
1205fn eval_outer_ser<D, E>(
1206    comm: &LcCommit<D, E>,
1207    tensor: &[FldT<E>],
1208) -> ProverResult<Vec<FldT<E>>, ErrT<E>>
1209where
1210    D: Digest,
1211    E: LcEncoding,
1212{
1213    if tensor.len() != comm.n_rows {
1214        return Err(ProverError::OuterTensor);
1215    }
1216
1217    let mut poly = vec![FldT::<E>::zero(); comm.n_per_row];
1218    for (row, tensor_val) in tensor.iter().enumerate() {
1219        for (col, val) in poly.iter_mut().enumerate() {
1220            let entry = row * comm.n_per_row + col;
1221            *val += comm.coeffs[entry] * tensor_val;
1222        }
1223    }
1224
1225    Ok(poly)
1226}
1227
1228#[cfg(test)]
1229fn eval_outer_fft<D, E>(
1230    comm: &LcCommit<D, E>,
1231    tensor: &[FldT<E>],
1232) -> ProverResult<Vec<FldT<E>>, ErrT<E>>
1233where
1234    D: Digest,
1235    E: LcEncoding,
1236{
1237    if tensor.len() != comm.n_rows {
1238        return Err(ProverError::OuterTensor);
1239    }
1240
1241    let mut poly_fft = vec![FldT::<E>::zero(); comm.n_cols];
1242    for (coeffs, tensorval) in comm.comm.chunks(comm.n_cols).zip(tensor.iter()) {
1243        for (coeff, polyval) in coeffs.iter().zip(poly_fft.iter_mut()) {
1244            *polyval += *coeff * tensorval;
1245        }
1246    }
1247
1248    Ok(poly_fft)
1249}