elastic_elgamal/proofs/
mul.rs

1//! Proofs related to multiplication.
2
3use merlin::Transcript;
4use rand_core::{CryptoRng, RngCore};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use zeroize::Zeroizing;
8
9use core::iter;
10
11#[cfg(feature = "serde")]
12use crate::serde::{ScalarHelper, VecHelper};
13use crate::{
14    alloc::Vec, group::Group, proofs::TranscriptForGroup, Ciphertext, CiphertextWithValue,
15    PublicKey, SecretKey, VerificationError,
16};
17
18/// Zero-knowledge proof that an ElGamal-encrypted value is equal to a sum of squares
19/// of one or more other ElGamal-encrypted values.
20///
21/// # Construction
22///
23/// Consider the case with a single sum element (i.e., proving that an encrypted value is
24/// a square of another encrypted value). The prover wants to prove the knowledge of scalars
25///
26/// ```text
27/// r_x, x, r_z:
28///   R_x = [r_x]G, X = [x]G + [r_x]K;
29///   R_z = [r_z]G, Z = [x^2]G + [r_z]K,
30/// ```
31///
32/// where
33///
34/// - `G` is the conventional generator of the considered prime-order group
35/// - `K` is a group element equivalent to the receiver's public key
36/// - `(R_x, X)` and `(R_z, Z)` are ElGamal ciphertexts of values `x` and `x^2`, respectively.
37///
38/// Observe that
39///
40/// ```text
41/// r'_z := r_z - x * r_x =>
42///   R_z = [r'_z]G + [x]R_x; Z = [x]X + [r'_z]K.
43/// ```
44///
45/// and that proving the knowledge of `(r_x, x, r'_z)` is equivalent to the initial problem.
46/// The new problem can be solved using a conventional sigma protocol:
47///
48/// 1. **Commitment.** The prover generates random scalars `e_r`, `e_x` and `e_z` and commits
49///    to them via `E_r = [e_r]G`, `E_x = [e_x]G + [e_r]K`, `E_rz = [e_x]R_x + [e_z]G` and
50///    `E_z = [e_x]X + [e_z]K`.
51/// 2. **Challenge.** The verifier sends to the prover random scalar `c`.
52/// 3. **Response.** The prover computes the following scalars and sends them to the verifier.
53///
54/// ```text
55/// s_r = e_r + c * r_x;
56/// s_x = e_x + c * x;
57/// s_z = e_z + c * (r_z - x * r_x);
58/// ```
59///
60/// The verification equations are
61///
62/// ```text
63/// [s_r]G ?= E_r + [c]R_x;
64/// [s_x]G + [s_r]K ?= E_x + [c]X;
65/// [s_x]R_x + [s_z]G ?= E_rz + [c]R_z;
66/// [s_x]X + [s_z]K ?= E_z + [c]Z.
67/// ```
68///
69/// The case with multiple squares is a straightforward generalization:
70///
71/// - `e_r`, `E_r`, `e_x`, `E_x`, `s_r` and `s_x` are independently defined for each
72///   partial ciphertext in the same way as above.
73/// - Commitments `E_rz` and `E_z` sum over `[e_x]R_x` and `[e_x]X` for all ciphertexts,
74///   respectively.
75/// - Response `s_z` similarly substitutes `x * r_x` with the corresponding sum.
76///
77/// A non-interactive version of the proof is obtained by applying [Fiat–Shamir transform][fst].
78/// As with [`LogEqualityProof`], it is more efficient to represent a proof as the challenge
79/// and responses; in this case, the proof size is `2n + 2` scalars, where `n` is the number of
80/// partial ciphertexts.
81///
82/// [fst]: https://en.wikipedia.org/wiki/Fiat%E2%80%93Shamir_heuristic
83/// [`LogEqualityProof`]: crate::LogEqualityProof
84#[derive(Debug, Clone)]
85#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
86#[cfg_attr(feature = "serde", serde(bound = ""))]
87pub struct SumOfSquaresProof<G: Group> {
88    #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
89    challenge: G::Scalar,
90    #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
91    ciphertext_responses: Vec<G::Scalar>,
92    #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
93    sum_response: G::Scalar,
94}
95
96impl<G: Group> SumOfSquaresProof<G> {
97    fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
98        transcript.start_proof(b"sum_of_squares");
99        transcript.append_element_bytes(b"K", receiver.as_bytes());
100    }
101
102    /// Creates a new proof that squares of values encrypted in `ciphertexts` for `receiver` sum up
103    /// to a value encrypted in `sum_of_squares_ciphertext`.
104    ///
105    /// All provided ciphertexts must be encrypted for `receiver`; otherwise, the created proof
106    /// will not verify.
107    #[allow(clippy::needless_collect)] // false positive
108    pub fn new<'a, R: RngCore + CryptoRng>(
109        ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
110        sum_of_squares_ciphertext: &CiphertextWithValue<G>,
111        receiver: &PublicKey<G>,
112        transcript: &mut Transcript,
113        rng: &mut R,
114    ) -> Self {
115        Self::initialize_transcript(transcript, receiver);
116
117        let sum_scalar = SecretKey::<G>::generate(rng);
118        let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
119
120        let partial_scalars: Vec<_> = ciphertexts
121            .map(|ciphertext| {
122                transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
123                transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
124
125                let random_scalar = SecretKey::<G>::generate(rng);
126                let random_commitment = G::mul_generator(random_scalar.expose_scalar());
127                transcript.append_element::<G>(b"[e_r]G", &random_commitment);
128                let value_scalar = SecretKey::<G>::generate(rng);
129                let value_commitment = G::mul_generator(value_scalar.expose_scalar())
130                    + receiver.as_element() * random_scalar.expose_scalar();
131                transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
132
133                let neg_value = Zeroizing::new(-*ciphertext.value());
134                sum_random_scalar += ciphertext.randomness() * &neg_value;
135                (ciphertext, random_scalar, value_scalar)
136            })
137            .collect();
138
139        let scalars = partial_scalars
140            .iter()
141            .map(|(_, _, value_scalar)| value_scalar.expose_scalar())
142            .chain(iter::once(sum_scalar.expose_scalar()));
143        let random_sum_commitment = {
144            let elements = partial_scalars
145                .iter()
146                .map(|(ciphertext, ..)| ciphertext.inner().random_element)
147                .chain(iter::once(G::generator()));
148            G::multi_mul(scalars.clone(), elements)
149        };
150        let value_sum_commitment = {
151            let elements = partial_scalars
152                .iter()
153                .map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
154                .chain(iter::once(receiver.as_element()));
155            G::multi_mul(scalars, elements)
156        };
157
158        transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
159        transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
160        transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
161        transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
162        let challenge = transcript.challenge_scalar::<G>(b"c");
163
164        let ciphertext_responses = partial_scalars
165            .into_iter()
166            .flat_map(|(ciphertext, random_scalar, value_scalar)| {
167                [
168                    challenge * ciphertext.randomness().expose_scalar()
169                        + random_scalar.expose_scalar(),
170                    challenge * ciphertext.value() + value_scalar.expose_scalar(),
171                ]
172            })
173            .collect();
174        let sum_response =
175            challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
176
177        Self {
178            challenge,
179            ciphertext_responses,
180            sum_response,
181        }
182    }
183
184    /// Verifies this proof against the provided partial ciphertexts and the ciphertext of the
185    /// sum of their squares. The order of partial ciphertexts must correspond to their order
186    /// when creating the proof.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if this proof does not verify.
191    pub fn verify<'a>(
192        &self,
193        ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
194        sum_of_squares_ciphertext: &Ciphertext<G>,
195        receiver: &PublicKey<G>,
196        transcript: &mut Transcript,
197    ) -> Result<(), VerificationError> {
198        let ciphertexts_count = ciphertexts.clone().count();
199        VerificationError::check_lengths(
200            "ciphertext responses",
201            self.ciphertext_responses.len(),
202            ciphertexts_count * 2,
203        )?;
204
205        Self::initialize_transcript(transcript, receiver);
206        let neg_challenge = -self.challenge;
207
208        for (response_chunk, ciphertext) in
209            self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
210        {
211            transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
212            transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
213
214            let r_response = &response_chunk[0];
215            let v_response = &response_chunk[1];
216            let random_commitment = G::vartime_double_mul_generator(
217                &-self.challenge,
218                ciphertext.random_element,
219                r_response,
220            );
221            transcript.append_element::<G>(b"[e_r]G", &random_commitment);
222            let value_commitment = G::vartime_multi_mul(
223                [v_response, r_response, &neg_challenge],
224                [
225                    G::generator(),
226                    receiver.as_element(),
227                    ciphertext.blinded_element,
228                ],
229            );
230            transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
231        }
232
233        let scalars = OddItems::new(self.ciphertext_responses.iter())
234            .chain([&self.sum_response, &neg_challenge]);
235        let random_sum_commitment = {
236            let elements = ciphertexts
237                .clone()
238                .map(|c| c.random_element)
239                .chain([G::generator(), sum_of_squares_ciphertext.random_element]);
240            G::vartime_multi_mul(scalars.clone(), elements)
241        };
242        let value_sum_commitment = {
243            let elements = ciphertexts.map(|c| c.blinded_element).chain([
244                receiver.as_element(),
245                sum_of_squares_ciphertext.blinded_element,
246            ]);
247            G::vartime_multi_mul(scalars, elements)
248        };
249
250        transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
251        transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
252        transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
253        transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
254        let expected_challenge = transcript.challenge_scalar::<G>(b"c");
255
256        if expected_challenge == self.challenge {
257            Ok(())
258        } else {
259            Err(VerificationError::ChallengeMismatch)
260        }
261    }
262}
263
264/// Thin wrapper around an iterator that drops its even-indexed elements. This is necessary
265/// because `Ristretto::vartime_multi_mul()` panics otherwise, which is caused by an imprecise
266/// `Iterator::size_hint()` value.
267#[derive(Debug, Clone)]
268struct OddItems<I> {
269    iter: I,
270    ended: bool,
271}
272
273impl<I: Iterator> OddItems<I> {
274    fn new(iter: I) -> Self {
275        Self { iter, ended: false }
276    }
277}
278
279impl<I: Iterator> Iterator for OddItems<I> {
280    type Item = I::Item;
281
282    fn next(&mut self) -> Option<Self::Item> {
283        if self.ended {
284            return None;
285        }
286        self.ended = self.iter.next().is_none();
287        if self.ended {
288            return None;
289        }
290
291        let item = self.iter.next();
292        self.ended = item.is_none();
293        item
294    }
295
296    fn size_hint(&self) -> (usize, Option<usize>) {
297        let (min, max) = self.iter.size_hint();
298        (min / 2, max.map(|max| max / 2))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::{group::Ristretto, Keypair};
306
307    use rand::thread_rng;
308
309    #[test]
310    fn sum_of_squares_proof_basics() {
311        let mut rng = thread_rng();
312        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
313        let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
314        let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
315
316        let proof = SumOfSquaresProof::new(
317            [&ciphertext].into_iter(),
318            &sq_ciphertext,
319            &receiver,
320            &mut Transcript::new(b"test"),
321            &mut rng,
322        );
323
324        let ciphertext = ciphertext.into();
325        let sq_ciphertext = sq_ciphertext.into();
326        proof
327            .verify(
328                [&ciphertext].into_iter(),
329                &sq_ciphertext,
330                &receiver,
331                &mut Transcript::new(b"test"),
332            )
333            .unwrap();
334
335        let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
336        let err = proof
337            .verify(
338                [&ciphertext].into_iter(),
339                &other_ciphertext,
340                &receiver,
341                &mut Transcript::new(b"test"),
342            )
343            .unwrap_err();
344        assert!(matches!(err, VerificationError::ChallengeMismatch));
345
346        let err = proof
347            .verify(
348                [&other_ciphertext].into_iter(),
349                &sq_ciphertext,
350                &receiver,
351                &mut Transcript::new(b"test"),
352            )
353            .unwrap_err();
354        assert!(matches!(err, VerificationError::ChallengeMismatch));
355
356        let err = proof
357            .verify(
358                [&ciphertext].into_iter(),
359                &sq_ciphertext,
360                &receiver,
361                &mut Transcript::new(b"other_transcript"),
362            )
363            .unwrap_err();
364        assert!(matches!(err, VerificationError::ChallengeMismatch));
365    }
366
367    #[test]
368    fn sum_of_squares_proof_with_bogus_inputs() {
369        let mut rng = thread_rng();
370        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
371        let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
372        let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
373
374        let proof = SumOfSquaresProof::new(
375            [&ciphertext].into_iter(),
376            &sq_ciphertext,
377            &receiver,
378            &mut Transcript::new(b"test"),
379            &mut rng,
380        );
381
382        let ciphertext = ciphertext.into();
383        let sq_ciphertext = sq_ciphertext.into();
384        let err = proof
385            .verify(
386                [&ciphertext].into_iter(),
387                &sq_ciphertext,
388                &receiver,
389                &mut Transcript::new(b"test"),
390            )
391            .unwrap_err();
392        assert!(matches!(err, VerificationError::ChallengeMismatch));
393    }
394
395    #[test]
396    fn sum_of_squares_proof_with_several_squares() {
397        let mut rng = thread_rng();
398        let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
399        let ciphertexts =
400            [3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
401        let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
402
403        let proof = SumOfSquaresProof::new(
404            ciphertexts.iter(),
405            &sq_ciphertext,
406            &receiver,
407            &mut Transcript::new(b"test"),
408            &mut rng,
409        );
410
411        let sq_ciphertext = sq_ciphertext.into();
412        proof
413            .verify(
414                ciphertexts.iter().map(CiphertextWithValue::inner),
415                &sq_ciphertext,
416                &receiver,
417                &mut Transcript::new(b"test"),
418            )
419            .unwrap();
420
421        // The proof will not verify if ciphertexts are rearranged.
422        let err = proof
423            .verify(
424                ciphertexts.iter().rev().map(CiphertextWithValue::inner),
425                &sq_ciphertext,
426                &receiver,
427                &mut Transcript::new(b"test"),
428            )
429            .unwrap_err();
430        assert!(matches!(err, VerificationError::ChallengeMismatch));
431
432        let err = proof
433            .verify(
434                ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
435                &sq_ciphertext,
436                &receiver,
437                &mut Transcript::new(b"test"),
438            )
439            .unwrap_err();
440        assert!(matches!(err, VerificationError::LenMismatch { .. }));
441    }
442
443    #[test]
444    fn odd_items() {
445        let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
446        assert_eq!(odd_items.size_hint(), (2, Some(2)));
447        assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
448
449        let other_items = OddItems::new(0..7);
450        assert_eq!(other_items.size_hint(), (3, Some(3)));
451        assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
452    }
453}