1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![cfg_attr(not(feature = "std"), no_std)]
3#![doc = include_str!("../README.md")]
4
5use core::ops::Deref;
6
7use rand_core::{RngCore, CryptoRng};
8
9use zeroize::{Zeroize, Zeroizing};
10
11use transcript::Transcript;
12
13use ff::{Field, PrimeField};
14use group::prime::PrimeGroup;
15
16#[cfg(feature = "serialize")]
17use std::io::{self, Error, Read, Write};
18
19#[cfg(feature = "experimental")]
22pub mod cross_group;
23
24#[cfg(test)]
25mod tests;
26
27pub(crate) fn challenge<T: Transcript, F: PrimeField>(transcript: &mut T) -> F {
29 let mut challenge = F::ZERO;
36
37 let target_bytes = ((usize::try_from(F::NUM_BITS).unwrap() + 7) / 8) * 2;
41 let mut challenge_bytes = transcript.challenge(b"challenge");
42 let challenge_bytes_len = challenge_bytes.as_ref().len();
43 let needed_challenges = (target_bytes + (challenge_bytes_len - 1)) / challenge_bytes_len;
45
46 let mut handled_bytes = 0;
49 'outer: for _ in 0 ..= needed_challenges {
50 let mut b = 0;
52 while b < challenge_bytes_len {
53 let chunk_bytes = (target_bytes - handled_bytes).min(8).min(challenge_bytes_len - b);
57
58 let mut chunk = 0;
59 for _ in 0 .. chunk_bytes {
60 chunk <<= 8;
61 chunk |= u64::from(challenge_bytes.as_ref()[b]);
62 b += 1;
63 }
64 challenge += F::from(chunk);
66
67 handled_bytes += chunk_bytes;
68 if handled_bytes == target_bytes {
70 break 'outer;
71 }
72
73 let next_chunk_bytes = (target_bytes - handled_bytes).min(8).min(challenge_bytes_len);
75 for _ in 0 .. (next_chunk_bytes * 8) {
76 challenge = challenge.double();
77 }
78 }
79
80 challenge_bytes = transcript.challenge(b"challenge_extension");
82 }
83
84 challenge
85}
86
87#[cfg(feature = "serialize")]
89fn read_scalar<R: Read, F: PrimeField>(r: &mut R) -> io::Result<F> {
90 let mut repr = F::Repr::default();
91 r.read_exact(repr.as_mut())?;
92 let scalar = F::from_repr(repr);
93 if scalar.is_none().into() {
94 Err(Error::other("invalid scalar"))?;
95 }
96 Ok(scalar.unwrap())
97}
98
99#[derive(Clone, Copy, PartialEq, Eq, Debug)]
101pub enum DLEqError {
102 InvalidProof,
104}
105
106#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroize)]
108pub struct DLEqProof<G: PrimeGroup<Scalar: Zeroize>> {
109 c: G::Scalar,
110 s: G::Scalar,
111}
112
113#[allow(non_snake_case)]
114impl<G: PrimeGroup<Scalar: Zeroize>> DLEqProof<G> {
115 fn transcript<T: Transcript>(transcript: &mut T, generator: G, nonce: G, point: G) {
116 transcript.append_message(b"generator", generator.to_bytes());
117 transcript.append_message(b"nonce", nonce.to_bytes());
118 transcript.append_message(b"point", point.to_bytes());
119 }
120
121 pub fn prove<R: RngCore + CryptoRng, T: Transcript>(
124 rng: &mut R,
125 transcript: &mut T,
126 generators: &[G],
127 scalar: &Zeroizing<G::Scalar>,
128 ) -> DLEqProof<G> {
129 let r = Zeroizing::new(G::Scalar::random(rng));
130
131 transcript.domain_separate(b"dleq");
132 for generator in generators {
133 Self::transcript(transcript, *generator, *generator * r.deref(), *generator * scalar.deref());
135 }
136
137 let c = challenge(transcript);
138 let s = (c * scalar.deref()) + r.deref();
140
141 DLEqProof { c, s }
142 }
143
144 fn verify_statement<T: Transcript>(
147 transcript: &mut T,
148 generator: G,
149 point: G,
150 c: G::Scalar,
151 s: G::Scalar,
152 ) {
153 Self::transcript(transcript, generator, (generator * s) - (point * c), point);
157 }
158
159 pub fn verify<T: Transcript>(
161 &self,
162 transcript: &mut T,
163 generators: &[G],
164 points: &[G],
165 ) -> Result<(), DLEqError> {
166 if generators.len() != points.len() {
167 Err(DLEqError::InvalidProof)?;
168 }
169
170 transcript.domain_separate(b"dleq");
171 for (generator, point) in generators.iter().zip(points) {
172 Self::verify_statement(transcript, *generator, *point, self.c, self.s);
173 }
174
175 if self.c != challenge(transcript) {
176 Err(DLEqError::InvalidProof)?;
177 }
178
179 Ok(())
180 }
181
182 #[cfg(feature = "serialize")]
184 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
185 w.write_all(self.c.to_repr().as_ref())?;
186 w.write_all(self.s.to_repr().as_ref())
187 }
188
189 #[cfg(feature = "serialize")]
191 pub fn read<R: Read>(r: &mut R) -> io::Result<DLEqProof<G>> {
192 Ok(DLEqProof { c: read_scalar(r)?, s: read_scalar(r)? })
193 }
194
195 #[cfg(feature = "serialize")]
197 pub fn serialize(&self) -> Vec<u8> {
198 let mut res = vec![];
199 self.write(&mut res).unwrap();
200 res
201 }
202}
203
204#[cfg(feature = "std")]
209#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
210pub struct MultiDLEqProof<G: PrimeGroup<Scalar: Zeroize>> {
211 c: G::Scalar,
212 s: Vec<G::Scalar>,
213}
214
215#[cfg(feature = "std")]
216#[allow(non_snake_case)]
217impl<G: PrimeGroup<Scalar: Zeroize>> MultiDLEqProof<G> {
218 pub fn prove<R: RngCore + CryptoRng, T: Transcript>(
222 rng: &mut R,
223 transcript: &mut T,
224 generators: &[Vec<G>],
225 scalars: &[Zeroizing<G::Scalar>],
226 ) -> MultiDLEqProof<G> {
227 assert_eq!(
228 generators.len(),
229 scalars.len(),
230 "amount of series of generators doesn't match the amount of scalars"
231 );
232
233 transcript.domain_separate(b"multi_dleq");
234
235 let mut nonces = vec![];
236 for (i, (scalar, generators)) in scalars.iter().zip(generators).enumerate() {
237 transcript.append_message(b"discrete_logarithm", i.to_le_bytes());
239
240 let nonce = Zeroizing::new(G::Scalar::random(&mut *rng));
241 for generator in generators {
242 DLEqProof::transcript(
243 transcript,
244 *generator,
245 *generator * nonce.deref(),
246 *generator * scalar.deref(),
247 );
248 }
249 nonces.push(nonce);
250 }
251
252 let c = challenge(transcript);
253
254 let mut s = vec![];
255 for (scalar, nonce) in scalars.iter().zip(nonces) {
256 s.push((c * scalar.deref()) + nonce.deref());
257 }
258
259 MultiDLEqProof { c, s }
260 }
261
262 pub fn verify<T: Transcript>(
265 &self,
266 transcript: &mut T,
267 generators: &[Vec<G>],
268 points: &[Vec<G>],
269 ) -> Result<(), DLEqError> {
270 if points.len() != generators.len() {
271 Err(DLEqError::InvalidProof)?;
272 }
273 if self.s.len() != generators.len() {
274 Err(DLEqError::InvalidProof)?;
275 }
276
277 transcript.domain_separate(b"multi_dleq");
278 for (i, (generators, points)) in generators.iter().zip(points).enumerate() {
279 if points.len() != generators.len() {
280 Err(DLEqError::InvalidProof)?;
281 }
282
283 transcript.append_message(b"discrete_logarithm", i.to_le_bytes());
284 for (generator, point) in generators.iter().zip(points) {
285 DLEqProof::verify_statement(transcript, *generator, *point, self.c, self.s[i]);
286 }
287 }
288
289 if self.c != challenge(transcript) {
290 Err(DLEqError::InvalidProof)?;
291 }
292
293 Ok(())
294 }
295
296 #[cfg(feature = "serialize")]
298 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
299 w.write_all(self.c.to_repr().as_ref())?;
300 for s in &self.s {
301 w.write_all(s.to_repr().as_ref())?;
302 }
303 Ok(())
304 }
305
306 #[cfg(feature = "serialize")]
308 pub fn read<R: Read>(r: &mut R, discrete_logs: usize) -> io::Result<MultiDLEqProof<G>> {
309 let c = read_scalar(r)?;
310 let mut s = vec![];
311 for _ in 0 .. discrete_logs {
312 s.push(read_scalar(r)?);
313 }
314 Ok(MultiDLEqProof { c, s })
315 }
316
317 #[cfg(feature = "serialize")]
319 pub fn serialize(&self) -> Vec<u8> {
320 let mut res = vec![];
321 self.write(&mut res).unwrap();
322 res
323 }
324}