1#![deny(missing_docs)]
10
11use digest::{Digest, Output};
16use err_derive::Error;
17use ff::{Field, PrimeField};
18use merlin::Transcript;
19use rand::{
20 distributions::{Distribution, Uniform},
21 SeedableRng,
22};
23use rand_chacha::ChaCha20Rng;
24use rayon::prelude::*;
25use serde::{Deserialize, Deserializer, Serialize, Serializer};
26use std::iter::repeat_with;
27
28mod macros;
29
30#[cfg(test)]
31mod tests;
32
33pub trait FieldHash {
35 type HashRepr: AsRef<[u8]>;
37
38 fn to_hash_repr(&self) -> Self::HashRepr;
40
41 fn digest_update<D: Digest>(&self, d: &mut D) {
43 d.update(self.to_hash_repr())
44 }
45
46 fn transcript_update(&self, t: &mut Transcript, l: &'static [u8]) {
48 t.append_message(l, self.to_hash_repr().as_ref())
49 }
50}
51
52impl<T: PrimeField> FieldHash for T {
53 type HashRepr = T::Repr;
54
55 fn to_hash_repr(&self) -> Self::HashRepr {
56 PrimeField::to_repr(self)
57 }
58}
59
60pub trait SizedField {
62 const CLOG2: u32;
64 const FLOG2: u32;
66}
67
68impl<T: PrimeField> SizedField for T {
69 const CLOG2: u32 = <T as PrimeField>::NUM_BITS;
70 const FLOG2: u32 = <T as PrimeField>::NUM_BITS - 1;
71}
72
73pub trait LcEncoding: Clone + std::fmt::Debug + Sync {
75 type F: Field + FieldHash + std::fmt::Debug + Clone;
77
78 const LABEL_DT: &'static [u8];
80 const LABEL_PR: &'static [u8];
82 const LABEL_PE: &'static [u8];
84 const LABEL_CO: &'static [u8];
86
87 type Err: std::fmt::Debug + std::error::Error + Send;
89
90 fn encode<T: AsMut<[Self::F]>>(&self, inp: T) -> Result<(), Self::Err>;
92
93 fn get_dims(&self, len: usize) -> (usize, usize, usize);
95
96 fn dims_ok(&self, n_per_row: usize, n_cols: usize) -> bool;
98
99 fn get_n_col_opens(&self) -> usize;
101
102 fn get_n_degree_tests(&self) -> usize;
104}
105
106type FldT<E> = <E as LcEncoding>::F;
108type ErrT<E> = <E as LcEncoding>::Err;
109
110#[derive(Debug, Error)]
112pub enum ProverError<ErrT>
113where
114 ErrT: std::fmt::Debug + std::error::Error + 'static,
115{
116 #[error(display = "n_cols is too large for this encoding")]
118 TooBig,
119 #[error(display = "encoding error: {:?}", _0)]
121 Encode(#[source] ErrT),
122 #[error(display = "inconsistent commitment fields")]
124 Commit,
125 #[error(display = "bad column number")]
127 ColumnNumber,
128 #[error(display = "outer tensor: wrong size")]
130 OuterTensor,
131}
132
133pub type ProverResult<T, ErrT> = Result<T, ProverError<ErrT>>;
135
136#[derive(Debug, Error)]
138pub enum VerifierError<ErrT>
139where
140 ErrT: std::fmt::Debug + std::error::Error + 'static,
141{
142 #[error(display = "wrong number of column openings in proof")]
144 NumColOpens,
145 #[error(display = "column verification: merkle path failed")]
147 ColumnPath,
148 #[error(display = "column verification: eval dot product failed")]
150 ColumnEval,
151 #[error(display = "column verification: degree test dot product failed")]
153 ColumnDegree,
154 #[error(display = "outer tensor: wrong size")]
156 OuterTensor,
157 #[error(display = "inner tensor: wrong size")]
159 InnerTensor,
160 #[error(display = "encoding dimension mismatch")]
162 EncodingDims,
163 #[error(display = "encoding error: {:?}", _0)]
165 Encode(#[source] ErrT),
166}
167
168pub type VerifierResult<T, ErrT> = Result<T, VerifierError<ErrT>>;
170
171#[derive(Debug, Clone)]
173pub struct LcCommit<D, E>
174where
175 D: Digest,
176 E: LcEncoding,
177{
178 comm: Vec<FldT<E>>,
179 coeffs: Vec<FldT<E>>,
180 n_rows: usize,
181 n_cols: usize,
182 n_per_row: usize,
183 hashes: Vec<Output<D>>,
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187struct WrappedLcCommit<F>
188where
189 F: Serialize,
190{
191 comm: Vec<F>,
192 coeffs: Vec<F>,
193 n_rows: usize,
194 n_cols: usize,
195 n_per_row: usize,
196 hashes: Vec<WrappedOutput>,
197}
198
199impl<F> WrappedLcCommit<F>
200where
201 F: Serialize,
202{
203 fn unwrap<D, E>(self) -> LcCommit<D, E>
205 where
206 D: Digest,
207 E: LcEncoding<F = F>,
208 {
209 let hashes = self.hashes.into_iter().map(|c| c.unwrap::<D, E>().root).collect();
210
211 LcCommit {
212 comm: self.comm,
213 coeffs: self.coeffs,
214 n_rows: self.n_rows,
215 n_cols: self.n_cols,
216 n_per_row: self.n_per_row,
217 hashes,
218 }
219 }
220}
221
222impl<D, E> LcCommit<D, E>
223where
224 D: Digest,
225 E: LcEncoding,
226 E::F: Serialize,
227{
228 fn wrapped(&self) -> WrappedLcCommit<FldT<E>> {
229 let hashes_wrapped = self.hashes.iter().map(|h| WrappedOutput { bytes: h.to_vec() }).collect();
230
231 WrappedLcCommit {
232 comm: self.comm.clone(),
233 coeffs: self.coeffs.clone(),
234 n_rows: self.n_rows,
235 n_cols: self.n_cols,
236 n_per_row: self.n_per_row,
237 hashes: hashes_wrapped,
238 }
239 }
240}
241
242impl<D, E> Serialize for LcCommit<D, E>
243where
244 D: Digest,
245 E: LcEncoding,
246 E::F: Serialize,
247{
248 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
249 where
250 S: Serializer,
251 {
252 self.wrapped().serialize(serializer)
253 }
254}
255
256impl<'de, D, E> Deserialize<'de> for LcCommit<D, E>
257where
258 D: Digest,
259 E: LcEncoding,
260 E::F: Serialize + Deserialize<'de>,
261{
262 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
263 where
264 De: Deserializer<'de>,
265 {
266 Ok(WrappedLcCommit::<FldT<E>>::deserialize(deserializer)?.unwrap())
267 }
268}
269
270impl<D, E> LcCommit<D, E>
271where
272 D: Digest,
273 E: LcEncoding,
274{
275 pub fn get_root(&self) -> LcRoot<D, E> {
277 LcRoot {
278 root: self.hashes.last().cloned().unwrap(),
279 _p: Default::default(),
280 }
281 }
282
283 pub fn get_n_per_row(&self) -> usize {
285 self.n_per_row
286 }
287
288 pub fn get_n_cols(&self) -> usize {
290 self.n_cols
291 }
292
293 pub fn get_n_rows(&self) -> usize {
295 self.n_rows
296 }
297
298 pub fn commit(coeffs: &[FldT<E>], enc: &E) -> ProverResult<Self, ErrT<E>> {
300 commit(coeffs, enc)
301 }
302
303 pub fn prove(
305 &self,
306 outer_tensor: &[FldT<E>],
307 enc: &E,
308 tr: &mut Transcript,
309 ) -> ProverResult<LcEvalProof<D, E>, ErrT<E>> {
310 prove(self, outer_tensor, enc, tr)
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct LcRoot<D, E>
317where
318 D: Digest,
319 E: LcEncoding,
320{
321 root: Output<D>,
322 _p: std::marker::PhantomData<E>,
323}
324
325impl<D, E> LcRoot<D, E>
326where
327 D: Digest,
328 E: LcEncoding,
329{
330 fn wrapped(&self) -> WrappedOutput {
331 WrappedOutput {
332 bytes: self.root.to_vec(),
333 }
334 }
335
336 pub fn into_raw(self) -> Output<D> {
338 self.root
339 }
340}
341
342impl<D, E> AsRef<Output<D>> for LcRoot<D, E>
343where
344 D: Digest,
345 E: LcEncoding,
346{
347 fn as_ref(&self) -> &Output<D> {
348 &self.root
349 }
350}
351
352#[derive(Debug, Clone, Deserialize, Serialize)]
354struct WrappedOutput {
355 #[serde(with = "serde_bytes")]
357 pub bytes: Vec<u8>,
358}
359
360impl WrappedOutput {
361 fn unwrap<D, E>(self) -> LcRoot<D, E>
362 where
363 D: Digest,
364 E: LcEncoding,
365 {
366 LcRoot {
367 root: self.bytes.into_iter().collect::<Output<D>>(),
368 _p: Default::default(),
369 }
370 }
371}
372
373impl<D, E> Serialize for LcRoot<D, E>
374where
375 D: Digest,
376 E: LcEncoding,
377{
378 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
379 where
380 S: Serializer,
381 {
382 self.wrapped().serialize(serializer)
383 }
384}
385
386impl<'de, D, E> Deserialize<'de> for LcRoot<D, E>
387where
388 D: Digest,
389 E: LcEncoding,
390{
391 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
392 where
393 De: Deserializer<'de>,
394 {
395 Ok(WrappedOutput::deserialize(deserializer)?.unwrap())
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct LcColumn<D, E>
402where
403 D: Digest,
404 E: LcEncoding,
405{
406 col: Vec<FldT<E>>,
407 path: Vec<Output<D>>,
408}
409
410impl<D, E> LcColumn<D, E>
411where
412 D: Digest,
413 E: LcEncoding,
414 E::F: Serialize,
415{
416 fn wrapped(&self) -> WrappedLcColumn<FldT<E>> {
417 let path_wrapped = (0..self.path.len())
418 .map(|i| WrappedOutput {
419 bytes: self.path[i].to_vec(),
420 })
421 .collect();
422
423 WrappedLcColumn {
424 col: self.col.clone(),
425 path: path_wrapped,
426 }
427 }
428}
429
430#[derive(Debug, Clone, Deserialize, Serialize)]
432struct WrappedLcColumn<F>
433where
434 F: Serialize,
435{
436 col: Vec<F>,
437 path: Vec<WrappedOutput>,
438}
439
440impl<F> WrappedLcColumn<F>
441where
442 F: Serialize,
443{
444 fn unwrap<D, E>(self) -> LcColumn<D, E>
446 where
447 D: Digest,
448 E: LcEncoding<F = F>,
449 {
450 let col = self.col;
451 let path = self
452 .path
453 .into_iter()
454 .map(|v| v.bytes.into_iter().collect::<Output<D>>())
455 .collect();
456
457 LcColumn { col, path }
458 }
459}
460
461impl<D, E> Serialize for LcColumn<D, E>
462where
463 D: Digest,
464 E: LcEncoding,
465 E::F: Serialize,
466{
467 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
468 where
469 S: Serializer,
470 {
471 self.wrapped().serialize(serializer)
472 }
473}
474
475impl<'de, D, E> Deserialize<'de> for LcColumn<D, E>
476where
477 D: Digest,
478 E: LcEncoding,
479 E::F: Serialize + Deserialize<'de>,
480{
481 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
482 where
483 De: Deserializer<'de>,
484 {
485 Ok(WrappedLcColumn::<FldT<E>>::deserialize(deserializer)?.unwrap())
486 }
487}
488
489#[derive(Debug, Clone)]
491pub struct LcEvalProof<D, E>
492where
493 D: Digest,
494 E: LcEncoding,
495{
496 n_cols: usize,
497 p_eval: Vec<FldT<E>>,
498 p_random_vec: Vec<Vec<FldT<E>>>,
499 columns: Vec<LcColumn<D, E>>,
500}
501
502impl<D, E> LcEvalProof<D, E>
503where
504 D: Digest,
505 E: LcEncoding,
506{
507 pub fn get_n_cols(&self) -> usize {
509 self.n_cols
510 }
511
512 pub fn get_n_per_row(&self) -> usize {
514 self.p_eval.len()
515 }
516
517 pub fn verify(
519 &self,
520 root: &Output<D>,
521 outer_tensor: &[FldT<E>],
522 inner_tensor: &[FldT<E>],
523 enc: &E,
524 tr: &mut Transcript,
525 ) -> VerifierResult<FldT<E>, ErrT<E>> {
526 verify(root, outer_tensor, inner_tensor, self, enc, tr)
527 }
528}
529
530impl<D, E> LcEvalProof<D, E>
531where
532 D: Digest,
533 E: LcEncoding,
534 E::F: Serialize,
535{
536 fn wrapped(&self) -> WrappedLcEvalProof<FldT<E>> {
537 let columns_wrapped = (0..self.columns.len())
538 .map(|i| self.columns[i].wrapped())
539 .collect();
540
541 WrappedLcEvalProof {
542 n_cols: self.n_cols,
543 p_eval: self.p_eval.clone(),
544 p_random_vec: self.p_random_vec.clone(),
545 columns: columns_wrapped,
546 }
547 }
548}
549
550#[derive(Debug, Clone, Deserialize, Serialize)]
552pub struct WrappedLcEvalProof<F>
553where
554 F: Serialize,
555{
556 n_cols: usize,
557 p_eval: Vec<F>,
558 p_random_vec: Vec<Vec<F>>,
559 columns: Vec<WrappedLcColumn<F>>,
560}
561
562impl<F> WrappedLcEvalProof<F>
563where
564 F: Serialize,
565{
566 fn unwrap<D, E>(self) -> LcEvalProof<D, E>
568 where
569 D: Digest,
570 E: LcEncoding<F = F>,
571 {
572 let columns = self.columns.into_iter().map(|c| c.unwrap()).collect();
573
574 LcEvalProof {
575 n_cols: self.n_cols,
576 p_eval: self.p_eval,
577 p_random_vec: self.p_random_vec,
578 columns,
579 }
580 }
581}
582
583impl<D, E> Serialize for LcEvalProof<D, E>
584where
585 D: Digest,
586 E: LcEncoding,
587 E::F: Serialize,
588{
589 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
590 where
591 S: Serializer,
592 {
593 self.wrapped().serialize(serializer)
594 }
595}
596
597impl<'de, D, E> Deserialize<'de> for LcEvalProof<D, E>
598where
599 D: Digest,
600 E: LcEncoding,
601 E::F: Serialize + Deserialize<'de>,
602{
603 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
604 where
605 De: Deserializer<'de>,
606 {
607 Ok(WrappedLcEvalProof::<FldT<E>>::deserialize(deserializer)?.unwrap())
608 }
609}
610
611pub fn n_degree_tests(lambda: usize, len: usize, flog2: usize) -> usize {
614 let den = flog2 - log2(len);
615 (lambda + den - 1) / den
616}
617
618const LOG_MIN_NCOLS: usize = 5;
620
621fn commit<D, E>(coeffs_in: &[FldT<E>], enc: &E) -> ProverResult<LcCommit<D, E>, ErrT<E>>
623where
624 D: Digest,
625 E: LcEncoding,
626{
627 let (n_rows, n_per_row, n_cols) = enc.get_dims(coeffs_in.len());
628
629 assert!(n_rows * n_per_row >= coeffs_in.len());
631 assert!((n_rows - 1) * n_per_row < coeffs_in.len());
632 assert!(enc.dims_ok(n_per_row, n_cols));
633
634 let mut coeffs = vec![FldT::<E>::zero(); n_rows * n_per_row];
637 let mut comm = vec![FldT::<E>::zero(); n_rows * n_cols];
638
639 coeffs
641 .par_chunks_mut(n_per_row)
642 .zip(coeffs_in.par_chunks(n_per_row))
643 .for_each(|(c, c_in)| {
644 c[..c_in.len()].copy_from_slice(c_in);
645 });
646
647 comm.par_chunks_mut(n_cols)
649 .zip(coeffs.par_chunks(n_per_row))
650 .try_for_each(|(r, c)| {
651 r[..c.len()].copy_from_slice(c);
652 enc.encode(r)
653 })?;
654
655 let n_cols_np2 = n_cols
657 .checked_next_power_of_two()
658 .ok_or(ProverError::TooBig)?;
659 let mut ret = LcCommit {
660 comm,
661 coeffs,
662 n_rows,
663 n_cols,
664 n_per_row,
665 hashes: vec![<Output<D> as Default>::default(); 2 * n_cols_np2 - 1],
666 };
667 check_comm(&ret, enc)?;
668 merkleize(&mut ret);
669
670 Ok(ret)
671}
672
673fn check_comm<D, E>(comm: &LcCommit<D, E>, enc: &E) -> ProverResult<(), ErrT<E>>
674where
675 D: Digest,
676 E: LcEncoding,
677{
678 let comm_sz = comm.comm.len() != comm.n_rows * comm.n_cols;
679 let coeff_sz = comm.coeffs.len() != comm.n_rows * comm.n_per_row;
680 let hashlen = comm.hashes.len() != 2 * comm.n_cols.next_power_of_two() - 1;
681 let dims = !enc.dims_ok(comm.n_per_row, comm.n_cols);
682
683 if comm_sz || coeff_sz || hashlen || dims {
684 Err(ProverError::Commit)
685 } else {
686 Ok(())
687 }
688}
689
690fn merkleize<D, E>(comm: &mut LcCommit<D, E>)
691where
692 D: Digest,
693 E: LcEncoding,
694{
695 let hashes = &mut comm.hashes[..comm.n_cols];
697 hash_columns::<D, E>(&comm.comm, hashes, comm.n_rows, comm.n_cols, 0);
698
699 let len_plus_one = comm.hashes.len() + 1;
701 assert!(len_plus_one.is_power_of_two());
702 let (hin, hout) = comm.hashes.split_at_mut(len_plus_one / 2);
703 merkle_tree::<D>(hin, hout);
704}
705
706fn hash_columns<D, E>(
707 comm: &[FldT<E>],
708 hashes: &mut [Output<D>],
709 n_rows: usize,
710 n_cols: usize,
711 offset: usize,
712) where
713 D: Digest,
714 E: LcEncoding,
715{
716 if hashes.len() <= (1 << LOG_MIN_NCOLS) {
717 let mut digests = Vec::with_capacity(hashes.len());
720 for _ in 0..hashes.len() {
721 let mut dig = D::new();
723 dig.update(<Output<D> as Default>::default());
724 digests.push(dig);
725 }
726 for row in 0..n_rows {
728 for (col, digest) in digests.iter_mut().enumerate() {
729 comm[row * n_cols + offset + col].digest_update(digest);
730 }
731 }
732 for (col, digest) in digests.into_iter().enumerate() {
734 hashes[col] = digest.finalize();
735 }
736 } else {
737 let half_cols = hashes.len() / 2;
739 let (lo, hi) = hashes.split_at_mut(half_cols);
740 rayon::join(
741 || hash_columns::<D, E>(comm, lo, n_rows, n_cols, offset),
742 || hash_columns::<D, E>(comm, hi, n_rows, n_cols, offset + half_cols),
743 );
744 }
745}
746
747fn merkle_tree<D>(ins: &[Output<D>], outs: &mut [Output<D>])
748where
749 D: Digest,
750{
751 assert_eq!(ins.len(), outs.len() + 1);
753
754 let (outs, rems) = outs.split_at_mut((outs.len() + 1) / 2);
755 merkle_layer::<D>(ins, outs);
756
757 if !rems.is_empty() {
758 merkle_tree::<D>(outs, rems)
759 }
760}
761
762fn merkle_layer<D>(ins: &[Output<D>], outs: &mut [Output<D>])
763where
764 D: Digest,
765{
766 assert_eq!(ins.len(), 2 * outs.len());
767
768 if ins.len() <= (1 << LOG_MIN_NCOLS) {
769 let mut digest = D::new();
771 for idx in 0..outs.len() {
772 digest.update(ins[2 * idx].as_ref());
773 digest.update(ins[2 * idx + 1].as_ref());
774 outs[idx] = digest.finalize_reset();
775 }
776 } else {
777 let (inl, inr) = ins.split_at(ins.len() / 2);
779 let (outl, outr) = outs.split_at_mut(outs.len() / 2);
780 rayon::join(
781 || merkle_layer::<D>(inl, outl),
782 || merkle_layer::<D>(inr, outr),
783 );
784 }
785}
786
787fn open_column<D, E>(
789 comm: &LcCommit<D, E>,
790 mut column: usize,
791) -> ProverResult<LcColumn<D, E>, ErrT<E>>
792where
793 D: Digest,
794 E: LcEncoding,
795{
796 if column >= comm.n_cols {
798 return Err(ProverError::ColumnNumber);
799 }
800
801 let col = comm
803 .comm
804 .iter()
805 .skip(column)
806 .step_by(comm.n_cols)
807 .cloned()
808 .collect();
809
810 let mut hashes = &comm.hashes[..];
812 let path_len = log2(comm.n_cols);
813 let mut path = Vec::with_capacity(path_len);
814 for _ in 0..path_len {
815 let other = (column & !1) | (!column & 1);
816 assert_eq!(other ^ column, 1);
817 path.push(hashes[other].clone());
818 let (_, hashes_new) = hashes.split_at((hashes.len() + 1) / 2);
819 hashes = hashes_new;
820 column >>= 1;
821 }
822 assert_eq!(column, 0);
823
824 Ok(LcColumn { col, path })
825}
826
827const fn log2(v: usize) -> usize {
828 (63 - (v.next_power_of_two() as u64).leading_zeros()) as usize
829}
830
831fn verify<D, E>(
833 root: &Output<D>,
834 outer_tensor: &[FldT<E>],
835 inner_tensor: &[FldT<E>],
836 proof: &LcEvalProof<D, E>,
837 enc: &E,
838 tr: &mut Transcript,
839) -> VerifierResult<FldT<E>, ErrT<E>>
840where
841 D: Digest,
842 E: LcEncoding,
843{
844 let n_col_opens = enc.get_n_col_opens();
846 if n_col_opens != proof.columns.len() || n_col_opens == 0 {
847 return Err(VerifierError::NumColOpens);
848 }
849 let n_rows = proof.columns[0].col.len();
850 let n_cols = proof.get_n_cols();
851 let n_per_row = proof.get_n_per_row();
852 if inner_tensor.len() != n_per_row {
853 return Err(VerifierError::InnerTensor);
854 }
855 if outer_tensor.len() != n_rows {
856 return Err(VerifierError::OuterTensor);
857 }
858 if !enc.dims_ok(n_per_row, n_cols) {
859 return Err(VerifierError::EncodingDims);
860 }
861
862 let mut rand_tensor_vec: Vec<Vec<FldT<E>>> = Vec::new();
866 let mut p_random_fft: Vec<Vec<FldT<E>>> = Vec::new();
867 let n_degree_tests = enc.get_n_degree_tests();
868 for i in 0..n_degree_tests {
869 let rand_tensor: Vec<FldT<E>> = {
870 let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
871 tr.challenge_bytes(E::LABEL_DT, &mut key);
872 let mut deg_test_rng = ChaCha20Rng::from_seed(key);
873 repeat_with(|| FldT::<E>::random(&mut deg_test_rng))
875 .take(n_rows)
876 .collect()
877 };
878
879 rand_tensor_vec.push(rand_tensor);
880
881 {
883 let mut tmp = Vec::with_capacity(n_cols);
884 tmp.extend_from_slice(&proof.p_random_vec[i][..]);
885 tmp.resize(n_cols, FldT::<E>::zero());
886 enc.encode(&mut tmp)?;
887 p_random_fft.push(tmp);
888 };
889
890 proof.p_random_vec[i]
892 .iter()
893 .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PR));
894 }
895
896 proof
897 .p_eval
898 .iter()
899 .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PE));
900
901 let cols_to_open: Vec<usize> = {
903 let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
904 tr.challenge_bytes(E::LABEL_CO, &mut key);
905 let mut cols_rng = ChaCha20Rng::from_seed(key);
906 let col_range = Uniform::new(0usize, n_cols);
908 repeat_with(|| col_range.sample(&mut cols_rng))
909 .take(n_col_opens)
910 .collect()
911 };
912
913 let p_eval_fft = {
915 let mut tmp = Vec::with_capacity(n_cols);
916 tmp.extend_from_slice(&proof.p_eval[..]);
917 tmp.resize(n_cols, FldT::<E>::zero());
918 enc.encode(&mut tmp)?;
919 tmp
920 };
921
922 cols_to_open
924 .par_iter()
925 .zip(&proof.columns[..])
926 .try_for_each(|(&col_num, column)| {
927 let rand = {
928 let mut rand = true;
929 for i in 0..n_degree_tests {
930 rand &=
931 verify_column_value(column, &rand_tensor_vec[i], &p_random_fft[i][col_num]);
932 }
933 rand
934 };
935
936 let eval = verify_column_value(column, outer_tensor, &p_eval_fft[col_num]);
937 let path = verify_column_path(column, col_num, root);
938 match (rand, eval, path) {
939 (false, _, _) => Err(VerifierError::ColumnDegree),
940 (_, false, _) => Err(VerifierError::ColumnEval),
941 (_, _, false) => Err(VerifierError::ColumnPath),
942 _ => Ok(()),
943 }
944 })?;
945
946 Ok(inner_tensor
948 .par_iter()
949 .zip(&proof.p_eval[..])
950 .fold(FldT::<E>::zero, |a, (t, e)| a + *t * e)
951 .reduce(FldT::<E>::zero, |a, v| a + v))
952}
953
954fn verify_column_path<D, E>(column: &LcColumn<D, E>, col_num: usize, root: &Output<D>) -> bool
956where
957 D: Digest,
958 E: LcEncoding,
959{
960 let mut digest = D::new();
961 digest.update(<Output<D> as Default>::default());
962 for e in &column.col[..] {
963 e.digest_update(&mut digest);
964 }
965
966 let mut hash = digest.finalize_reset();
968 let mut col = col_num;
969 for p in &column.path[..] {
970 if col % 2 == 0 {
971 digest.update(&hash);
972 digest.update(p);
973 } else {
974 digest.update(p);
975 digest.update(&hash);
976 }
977 hash = digest.finalize_reset();
978 col >>= 1;
979 }
980
981 &hash == root
982}
983
984fn verify_column_value<D, E>(
986 column: &LcColumn<D, E>,
987 tensor: &[FldT<E>],
988 poly_eval: &FldT<E>,
989) -> bool
990where
991 D: Digest,
992 E: LcEncoding,
993{
994 let tensor_eval = tensor
995 .iter()
996 .zip(&column.col[..])
997 .fold(FldT::<E>::zero(), |a, (t, e)| a + *t * e);
998
999 poly_eval == &tensor_eval
1000}
1001
1002fn prove<D, E>(
1005 comm: &LcCommit<D, E>,
1006 outer_tensor: &[FldT<E>],
1007 enc: &E,
1008 tr: &mut Transcript,
1009) -> ProverResult<LcEvalProof<D, E>, ErrT<E>>
1010where
1011 D: Digest,
1012 E: LcEncoding,
1013{
1014 check_comm(comm, enc)?;
1016 if outer_tensor.len() != comm.n_rows {
1017 return Err(ProverError::OuterTensor);
1018 }
1019
1020 let mut p_random_vec: Vec<Vec<FldT<E>>> = Vec::new();
1023 let n_degree_tests = enc.get_n_degree_tests();
1024 for _i in 0..n_degree_tests {
1025 let p_random = {
1026 let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
1027 tr.challenge_bytes(E::LABEL_DT, &mut key);
1028 let mut deg_test_rng = ChaCha20Rng::from_seed(key);
1029 let rand_tensor: Vec<FldT<E>> = repeat_with(|| FldT::<E>::random(&mut deg_test_rng))
1031 .take(comm.n_rows)
1032 .collect();
1033 let mut tmp = vec![FldT::<E>::zero(); comm.n_per_row];
1034 collapse_columns::<E>(
1035 &comm.coeffs,
1036 &rand_tensor,
1037 &mut tmp,
1038 comm.n_rows,
1039 comm.n_per_row,
1040 0,
1041 );
1042 tmp
1043 };
1044 p_random
1046 .iter()
1047 .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PR));
1048
1049 p_random_vec.push(p_random);
1050 }
1051
1052 let p_eval = {
1054 let mut tmp = vec![FldT::<E>::zero(); comm.n_per_row];
1055 collapse_columns::<E>(
1056 &comm.coeffs,
1057 outer_tensor,
1058 &mut tmp,
1059 comm.n_rows,
1060 comm.n_per_row,
1061 0,
1062 );
1063 tmp
1064 };
1065 p_eval
1067 .iter()
1068 .for_each(|coeff| coeff.transcript_update(tr, E::LABEL_PE));
1069
1070 let n_col_opens = enc.get_n_col_opens();
1072 let columns: Vec<LcColumn<D, E>> = {
1073 let mut key: <ChaCha20Rng as SeedableRng>::Seed = Default::default();
1074 tr.challenge_bytes(E::LABEL_CO, &mut key);
1075 let mut cols_rng = ChaCha20Rng::from_seed(key);
1076 let col_range = Uniform::new(0usize, comm.n_cols);
1078 let cols_to_open: Vec<usize> = repeat_with(|| col_range.sample(&mut cols_rng))
1079 .take(n_col_opens)
1080 .collect();
1081 cols_to_open
1082 .par_iter()
1083 .map(|&col| open_column(comm, col))
1084 .collect::<ProverResult<Vec<LcColumn<D, E>>, ErrT<E>>>()?
1085 };
1086
1087 Ok(LcEvalProof {
1088 n_cols: comm.n_cols,
1089 p_eval,
1090 p_random_vec,
1091 columns,
1092 })
1093}
1094
1095fn collapse_columns<E>(
1096 coeffs: &[FldT<E>],
1097 tensor: &[FldT<E>],
1098 poly: &mut [FldT<E>],
1099 n_rows: usize,
1100 n_per_row: usize,
1101 offset: usize,
1102) where
1103 E: LcEncoding,
1104{
1105 if poly.len() <= (1 << LOG_MIN_NCOLS) {
1106 for (row, tensor_val) in tensor.iter().enumerate() {
1109 for (col, val) in poly.iter_mut().enumerate() {
1110 let entry = row * n_per_row + offset + col;
1111 *val += coeffs[entry] * tensor_val;
1112 }
1113 }
1114 } else {
1115 let half_cols = poly.len() / 2;
1117 let (lo, hi) = poly.split_at_mut(half_cols);
1118 rayon::join(
1119 || collapse_columns::<E>(coeffs, tensor, lo, n_rows, n_per_row, offset),
1120 || collapse_columns::<E>(coeffs, tensor, hi, n_rows, n_per_row, offset + half_cols),
1121 );
1122 }
1123}
1124
1125#[cfg(test)]
1128fn merkleize_ser<D, E>(comm: &mut LcCommit<D, E>)
1129where
1130 D: Digest,
1131 E: LcEncoding,
1132{
1133 let hashes = &mut comm.hashes;
1134
1135 for (col, hash) in hashes.iter_mut().enumerate().take(comm.n_cols) {
1137 let mut digest = D::new();
1138 digest.update(<Output<D> as Default>::default());
1139 for row in 0..comm.n_rows {
1140 comm.comm[row * comm.n_cols + col].digest_update(&mut digest);
1141 }
1142 *hash = digest.finalize();
1143 }
1144
1145 let (mut ins, mut outs) = hashes.split_at_mut(comm.n_cols);
1147 while !outs.is_empty() {
1148 for idx in 0..ins.len() / 2 {
1149 let mut digest = D::new();
1150 digest.update(ins[2 * idx].as_ref());
1151 digest.update(ins[2 * idx + 1].as_ref());
1152 outs[idx] = digest.finalize();
1153 }
1154 let (new_ins, new_outs) = outs.split_at_mut((outs.len() + 1) / 2);
1155 ins = new_ins;
1156 outs = new_outs;
1157 }
1158}
1159
1160#[cfg(test)]
1161fn verify_column<D, E>(
1163 column: &LcColumn<D, E>,
1164 col_num: usize,
1165 root: &Output<D>,
1166 tensor: &[FldT<E>],
1167 poly_eval: &FldT<E>,
1168) -> bool
1169where
1170 D: Digest,
1171 E: LcEncoding,
1172{
1173 verify_column_path(column, col_num, root) && verify_column_value(column, tensor, poly_eval)
1174}
1175
1176#[cfg(test)]
1178fn eval_outer<D, E>(
1179 comm: &LcCommit<D, E>,
1180 tensor: &[FldT<E>],
1181) -> ProverResult<Vec<FldT<E>>, ErrT<E>>
1182where
1183 D: Digest,
1184 E: LcEncoding,
1185{
1186 if tensor.len() != comm.n_rows {
1187 return Err(ProverError::OuterTensor);
1188 }
1189
1190 let mut poly = vec![FldT::<E>::zero(); comm.n_per_row];
1192 collapse_columns::<E>(
1193 &comm.coeffs,
1194 tensor,
1195 &mut poly,
1196 comm.n_rows,
1197 comm.n_per_row,
1198 0,
1199 );
1200
1201 Ok(poly)
1202}
1203
1204#[cfg(test)]
1205fn eval_outer_ser<D, E>(
1206 comm: &LcCommit<D, E>,
1207 tensor: &[FldT<E>],
1208) -> ProverResult<Vec<FldT<E>>, ErrT<E>>
1209where
1210 D: Digest,
1211 E: LcEncoding,
1212{
1213 if tensor.len() != comm.n_rows {
1214 return Err(ProverError::OuterTensor);
1215 }
1216
1217 let mut poly = vec![FldT::<E>::zero(); comm.n_per_row];
1218 for (row, tensor_val) in tensor.iter().enumerate() {
1219 for (col, val) in poly.iter_mut().enumerate() {
1220 let entry = row * comm.n_per_row + col;
1221 *val += comm.coeffs[entry] * tensor_val;
1222 }
1223 }
1224
1225 Ok(poly)
1226}
1227
1228#[cfg(test)]
1229fn eval_outer_fft<D, E>(
1230 comm: &LcCommit<D, E>,
1231 tensor: &[FldT<E>],
1232) -> ProverResult<Vec<FldT<E>>, ErrT<E>>
1233where
1234 D: Digest,
1235 E: LcEncoding,
1236{
1237 if tensor.len() != comm.n_rows {
1238 return Err(ProverError::OuterTensor);
1239 }
1240
1241 let mut poly_fft = vec![FldT::<E>::zero(); comm.n_cols];
1242 for (coeffs, tensorval) in comm.comm.chunks(comm.n_cols).zip(tensor.iter()) {
1243 for (coeff, polyval) in coeffs.iter().zip(poly_fft.iter_mut()) {
1244 *polyval += *coeff * tensorval;
1245 }
1246 }
1247
1248 Ok(poly_fft)
1249}