multiexp/batch.rs
1use std_shims::vec::Vec;
2
3use rand_core::{RngCore, CryptoRng};
4
5use zeroize::{Zeroize, Zeroizing};
6
7use ff::{Field, PrimeFieldBits};
8use group::Group;
9
10use crate::{multiexp, multiexp_vartime};
11
12// Flatten the contained statements to a single Vec.
13// Wrapped in Zeroizing in case any of the included statements contain private values.
14#[allow(clippy::type_complexity)]
15fn flat<Id: Copy + Zeroize, G: Group + Zeroize>(
16 slice: &[(Id, Vec<(G::Scalar, G)>)],
17) -> Zeroizing<Vec<(G::Scalar, G)>>
18where
19 <G as Group>::Scalar: PrimeFieldBits + Zeroize,
20{
21 Zeroizing::new(slice.iter().flat_map(|pairs| pairs.1.iter()).copied().collect::<Vec<_>>())
22}
23
24/// A batch verifier intended to verify a series of statements are each equivalent to zero.
25#[allow(clippy::type_complexity)]
26#[derive(Clone, Zeroize)]
27pub struct BatchVerifier<Id: Copy + Zeroize, G: Group + Zeroize>(
28 Zeroizing<Vec<(Id, Vec<(G::Scalar, G)>)>>,
29)
30where
31 <G as Group>::Scalar: PrimeFieldBits + Zeroize;
32
33impl<Id: Copy + Zeroize, G: Group + Zeroize> BatchVerifier<Id, G>
34where
35 <G as Group>::Scalar: PrimeFieldBits + Zeroize,
36{
37 /// Create a new batch verifier, expected to verify the following amount of statements.
38 ///
39 /// `capacity` is a size hint and is not required to be accurate.
40 pub fn new(capacity: usize) -> BatchVerifier<Id, G> {
41 BatchVerifier(Zeroizing::new(Vec::with_capacity(capacity)))
42 }
43
44 /// Queue a statement for batch verification.
45 pub fn queue<R: RngCore + CryptoRng, I: IntoIterator<Item = (G::Scalar, G)>>(
46 &mut self,
47 rng: &mut R,
48 id: Id,
49 pairs: I,
50 ) {
51 // Define a unique scalar factor for this set of variables so individual items can't overlap
52 let u = if self.0.is_empty() {
53 G::Scalar::ONE
54 } else {
55 let mut weight;
56 while {
57 // Generate a random scalar
58 weight = G::Scalar::random(&mut *rng);
59
60 // Clears half the bits, maintaining security, to minimize scalar additions
61 // Is not practically faster for whatever reason
62 /*
63 // Generate a random scalar
64 let mut repr = G::Scalar::random(&mut *rng).to_repr();
65
66 // Calculate the amount of bytes to clear. We want to clear less than half
67 let repr_len = repr.as_ref().len();
68 let unused_bits = (repr_len * 8) - usize::try_from(G::Scalar::CAPACITY).unwrap();
69 // Don't clear any partial bytes
70 let to_clear = (repr_len / 2) - ((unused_bits + 7) / 8);
71
72 // Clear a safe amount of bytes
73 for b in &mut repr.as_mut()[.. to_clear] {
74 *b = 0;
75 }
76
77 // Ensure these bits are used as the low bits so low scalars multiplied by this don't
78 // become large scalars
79 weight = G::Scalar::from_repr(repr).unwrap();
80 // Tests if any bit we supposedly just cleared is set, and if so, reverses it
81 // Not a security issue if this fails, just a minor performance hit at ~2^-120 odds
82 if weight.to_le_bits().iter().take(to_clear * 8).any(|bit| *bit) {
83 repr.as_mut().reverse();
84 weight = G::Scalar::from_repr(repr).unwrap();
85 }
86 */
87
88 // Ensure it's non-zero, as a zero scalar would cause this item to pass no matter what
89 weight.is_zero().into()
90 } {}
91 weight
92 };
93
94 self.0.push((id, pairs.into_iter().map(|(scalar, point)| (scalar * u, point)).collect()));
95 }
96
97 /// Perform batch verification, returning a boolean of if the statements equaled zero.
98 #[must_use]
99 pub fn verify(&self) -> bool {
100 multiexp(&flat(&self.0)).is_identity().into()
101 }
102
103 /// Perform batch verification in variable time.
104 #[must_use]
105 pub fn verify_vartime(&self) -> bool {
106 multiexp_vartime(&flat(&self.0)).is_identity().into()
107 }
108
109 /// Perform a binary search to identify which statement does not equal 0, returning None if all
110 /// statements do.
111 ///
112 /// This function will only return the ID of one invalid statement, even if multiple are invalid.
113 // A constant time variant may be beneficial for robust protocols
114 pub fn blame_vartime(&self) -> Option<Id> {
115 let mut slice = self.0.as_slice();
116 while slice.len() > 1 {
117 let split = slice.len() / 2;
118 if multiexp_vartime(&flat(&slice[.. split])).is_identity().into() {
119 slice = &slice[split ..];
120 } else {
121 slice = &slice[.. split];
122 }
123 }
124
125 slice
126 .get(0)
127 .filter(|(_, value)| !bool::from(multiexp_vartime(value).is_identity()))
128 .map(|(id, _)| *id)
129 }
130
131 /// Perform constant time batch verification, and if verification fails, identify one faulty
132 /// statement in variable time.
133 pub fn verify_with_vartime_blame(&self) -> Result<(), Id> {
134 if self.verify() {
135 Ok(())
136 } else {
137 Err(self.blame_vartime().unwrap())
138 }
139 }
140
141 /// Perform variable time batch verification, and if verification fails, identify one faulty
142 /// statement in variable time.
143 pub fn verify_vartime_with_vartime_blame(&self) -> Result<(), Id> {
144 if self.verify_vartime() {
145 Ok(())
146 } else {
147 Err(self.blame_vartime().unwrap())
148 }
149 }
150}