1use merlin::Transcript;
4use rand_core::{CryptoRng, RngCore};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use zeroize::Zeroizing;
8
9use core::iter;
10
11#[cfg(feature = "serde")]
12use crate::serde::{ScalarHelper, VecHelper};
13use crate::{
14 alloc::Vec, group::Group, proofs::TranscriptForGroup, Ciphertext, CiphertextWithValue,
15 PublicKey, SecretKey, VerificationError,
16};
17
18#[derive(Debug, Clone)]
85#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
86#[cfg_attr(feature = "serde", serde(bound = ""))]
87pub struct SumOfSquaresProof<G: Group> {
88 #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
89 challenge: G::Scalar,
90 #[cfg_attr(feature = "serde", serde(with = "VecHelper::<ScalarHelper<G>, 2>"))]
91 ciphertext_responses: Vec<G::Scalar>,
92 #[cfg_attr(feature = "serde", serde(with = "ScalarHelper::<G>"))]
93 sum_response: G::Scalar,
94}
95
96impl<G: Group> SumOfSquaresProof<G> {
97 fn initialize_transcript(transcript: &mut Transcript, receiver: &PublicKey<G>) {
98 transcript.start_proof(b"sum_of_squares");
99 transcript.append_element_bytes(b"K", receiver.as_bytes());
100 }
101
102 #[allow(clippy::needless_collect)] pub fn new<'a, R: RngCore + CryptoRng>(
109 ciphertexts: impl Iterator<Item = &'a CiphertextWithValue<G>>,
110 sum_of_squares_ciphertext: &CiphertextWithValue<G>,
111 receiver: &PublicKey<G>,
112 transcript: &mut Transcript,
113 rng: &mut R,
114 ) -> Self {
115 Self::initialize_transcript(transcript, receiver);
116
117 let sum_scalar = SecretKey::<G>::generate(rng);
118 let mut sum_random_scalar = sum_of_squares_ciphertext.randomness().clone();
119
120 let partial_scalars: Vec<_> = ciphertexts
121 .map(|ciphertext| {
122 transcript.append_element::<G>(b"R_x", &ciphertext.inner().random_element);
123 transcript.append_element::<G>(b"X", &ciphertext.inner().blinded_element);
124
125 let random_scalar = SecretKey::<G>::generate(rng);
126 let random_commitment = G::mul_generator(random_scalar.expose_scalar());
127 transcript.append_element::<G>(b"[e_r]G", &random_commitment);
128 let value_scalar = SecretKey::<G>::generate(rng);
129 let value_commitment = G::mul_generator(value_scalar.expose_scalar())
130 + receiver.as_element() * random_scalar.expose_scalar();
131 transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
132
133 let neg_value = Zeroizing::new(-*ciphertext.value());
134 sum_random_scalar += ciphertext.randomness() * &neg_value;
135 (ciphertext, random_scalar, value_scalar)
136 })
137 .collect();
138
139 let scalars = partial_scalars
140 .iter()
141 .map(|(_, _, value_scalar)| value_scalar.expose_scalar())
142 .chain(iter::once(sum_scalar.expose_scalar()));
143 let random_sum_commitment = {
144 let elements = partial_scalars
145 .iter()
146 .map(|(ciphertext, ..)| ciphertext.inner().random_element)
147 .chain(iter::once(G::generator()));
148 G::multi_mul(scalars.clone(), elements)
149 };
150 let value_sum_commitment = {
151 let elements = partial_scalars
152 .iter()
153 .map(|(ciphertext, ..)| ciphertext.inner().blinded_element)
154 .chain(iter::once(receiver.as_element()));
155 G::multi_mul(scalars, elements)
156 };
157
158 transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.inner().random_element);
159 transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.inner().blinded_element);
160 transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
161 transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
162 let challenge = transcript.challenge_scalar::<G>(b"c");
163
164 let ciphertext_responses = partial_scalars
165 .into_iter()
166 .flat_map(|(ciphertext, random_scalar, value_scalar)| {
167 [
168 challenge * ciphertext.randomness().expose_scalar()
169 + random_scalar.expose_scalar(),
170 challenge * ciphertext.value() + value_scalar.expose_scalar(),
171 ]
172 })
173 .collect();
174 let sum_response =
175 challenge * sum_random_scalar.expose_scalar() + sum_scalar.expose_scalar();
176
177 Self {
178 challenge,
179 ciphertext_responses,
180 sum_response,
181 }
182 }
183
184 pub fn verify<'a>(
192 &self,
193 ciphertexts: impl Iterator<Item = &'a Ciphertext<G>> + Clone,
194 sum_of_squares_ciphertext: &Ciphertext<G>,
195 receiver: &PublicKey<G>,
196 transcript: &mut Transcript,
197 ) -> Result<(), VerificationError> {
198 let ciphertexts_count = ciphertexts.clone().count();
199 VerificationError::check_lengths(
200 "ciphertext responses",
201 self.ciphertext_responses.len(),
202 ciphertexts_count * 2,
203 )?;
204
205 Self::initialize_transcript(transcript, receiver);
206 let neg_challenge = -self.challenge;
207
208 for (response_chunk, ciphertext) in
209 self.ciphertext_responses.chunks(2).zip(ciphertexts.clone())
210 {
211 transcript.append_element::<G>(b"R_x", &ciphertext.random_element);
212 transcript.append_element::<G>(b"X", &ciphertext.blinded_element);
213
214 let r_response = &response_chunk[0];
215 let v_response = &response_chunk[1];
216 let random_commitment = G::vartime_double_mul_generator(
217 &-self.challenge,
218 ciphertext.random_element,
219 r_response,
220 );
221 transcript.append_element::<G>(b"[e_r]G", &random_commitment);
222 let value_commitment = G::vartime_multi_mul(
223 [v_response, r_response, &neg_challenge],
224 [
225 G::generator(),
226 receiver.as_element(),
227 ciphertext.blinded_element,
228 ],
229 );
230 transcript.append_element::<G>(b"[e_x]G + [e_r]K", &value_commitment);
231 }
232
233 let scalars = OddItems::new(self.ciphertext_responses.iter())
234 .chain([&self.sum_response, &neg_challenge]);
235 let random_sum_commitment = {
236 let elements = ciphertexts
237 .clone()
238 .map(|c| c.random_element)
239 .chain([G::generator(), sum_of_squares_ciphertext.random_element]);
240 G::vartime_multi_mul(scalars.clone(), elements)
241 };
242 let value_sum_commitment = {
243 let elements = ciphertexts.map(|c| c.blinded_element).chain([
244 receiver.as_element(),
245 sum_of_squares_ciphertext.blinded_element,
246 ]);
247 G::vartime_multi_mul(scalars, elements)
248 };
249
250 transcript.append_element::<G>(b"R_z", &sum_of_squares_ciphertext.random_element);
251 transcript.append_element::<G>(b"Z", &sum_of_squares_ciphertext.blinded_element);
252 transcript.append_element::<G>(b"[e_x]R_x + [e_z]G", &random_sum_commitment);
253 transcript.append_element::<G>(b"[e_x]X + [e_z]K", &value_sum_commitment);
254 let expected_challenge = transcript.challenge_scalar::<G>(b"c");
255
256 if expected_challenge == self.challenge {
257 Ok(())
258 } else {
259 Err(VerificationError::ChallengeMismatch)
260 }
261 }
262}
263
264#[derive(Debug, Clone)]
268struct OddItems<I> {
269 iter: I,
270 ended: bool,
271}
272
273impl<I: Iterator> OddItems<I> {
274 fn new(iter: I) -> Self {
275 Self { iter, ended: false }
276 }
277}
278
279impl<I: Iterator> Iterator for OddItems<I> {
280 type Item = I::Item;
281
282 fn next(&mut self) -> Option<Self::Item> {
283 if self.ended {
284 return None;
285 }
286 self.ended = self.iter.next().is_none();
287 if self.ended {
288 return None;
289 }
290
291 let item = self.iter.next();
292 self.ended = item.is_none();
293 item
294 }
295
296 fn size_hint(&self) -> (usize, Option<usize>) {
297 let (min, max) = self.iter.size_hint();
298 (min / 2, max.map(|max| max / 2))
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::{group::Ristretto, Keypair};
306
307 use rand::thread_rng;
308
309 #[test]
310 fn sum_of_squares_proof_basics() {
311 let mut rng = thread_rng();
312 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
313 let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
314 let sq_ciphertext = CiphertextWithValue::new(9_u64, &receiver, &mut rng).generalize();
315
316 let proof = SumOfSquaresProof::new(
317 [&ciphertext].into_iter(),
318 &sq_ciphertext,
319 &receiver,
320 &mut Transcript::new(b"test"),
321 &mut rng,
322 );
323
324 let ciphertext = ciphertext.into();
325 let sq_ciphertext = sq_ciphertext.into();
326 proof
327 .verify(
328 [&ciphertext].into_iter(),
329 &sq_ciphertext,
330 &receiver,
331 &mut Transcript::new(b"test"),
332 )
333 .unwrap();
334
335 let other_ciphertext = receiver.encrypt(8_u64, &mut rng);
336 let err = proof
337 .verify(
338 [&ciphertext].into_iter(),
339 &other_ciphertext,
340 &receiver,
341 &mut Transcript::new(b"test"),
342 )
343 .unwrap_err();
344 assert!(matches!(err, VerificationError::ChallengeMismatch));
345
346 let err = proof
347 .verify(
348 [&other_ciphertext].into_iter(),
349 &sq_ciphertext,
350 &receiver,
351 &mut Transcript::new(b"test"),
352 )
353 .unwrap_err();
354 assert!(matches!(err, VerificationError::ChallengeMismatch));
355
356 let err = proof
357 .verify(
358 [&ciphertext].into_iter(),
359 &sq_ciphertext,
360 &receiver,
361 &mut Transcript::new(b"other_transcript"),
362 )
363 .unwrap_err();
364 assert!(matches!(err, VerificationError::ChallengeMismatch));
365 }
366
367 #[test]
368 fn sum_of_squares_proof_with_bogus_inputs() {
369 let mut rng = thread_rng();
370 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
371 let ciphertext = CiphertextWithValue::new(3_u64, &receiver, &mut rng).generalize();
372 let sq_ciphertext = CiphertextWithValue::new(10_u64, &receiver, &mut rng).generalize();
373
374 let proof = SumOfSquaresProof::new(
375 [&ciphertext].into_iter(),
376 &sq_ciphertext,
377 &receiver,
378 &mut Transcript::new(b"test"),
379 &mut rng,
380 );
381
382 let ciphertext = ciphertext.into();
383 let sq_ciphertext = sq_ciphertext.into();
384 let err = proof
385 .verify(
386 [&ciphertext].into_iter(),
387 &sq_ciphertext,
388 &receiver,
389 &mut Transcript::new(b"test"),
390 )
391 .unwrap_err();
392 assert!(matches!(err, VerificationError::ChallengeMismatch));
393 }
394
395 #[test]
396 fn sum_of_squares_proof_with_several_squares() {
397 let mut rng = thread_rng();
398 let (receiver, _) = Keypair::<Ristretto>::generate(&mut rng).into_tuple();
399 let ciphertexts =
400 [3_u64, 1, 4, 1].map(|x| CiphertextWithValue::new(x, &receiver, &mut rng).generalize());
401 let sq_ciphertext = CiphertextWithValue::new(27_u64, &receiver, &mut rng).generalize();
402
403 let proof = SumOfSquaresProof::new(
404 ciphertexts.iter(),
405 &sq_ciphertext,
406 &receiver,
407 &mut Transcript::new(b"test"),
408 &mut rng,
409 );
410
411 let sq_ciphertext = sq_ciphertext.into();
412 proof
413 .verify(
414 ciphertexts.iter().map(CiphertextWithValue::inner),
415 &sq_ciphertext,
416 &receiver,
417 &mut Transcript::new(b"test"),
418 )
419 .unwrap();
420
421 let err = proof
423 .verify(
424 ciphertexts.iter().rev().map(CiphertextWithValue::inner),
425 &sq_ciphertext,
426 &receiver,
427 &mut Transcript::new(b"test"),
428 )
429 .unwrap_err();
430 assert!(matches!(err, VerificationError::ChallengeMismatch));
431
432 let err = proof
433 .verify(
434 ciphertexts.iter().take(2).map(CiphertextWithValue::inner),
435 &sq_ciphertext,
436 &receiver,
437 &mut Transcript::new(b"test"),
438 )
439 .unwrap_err();
440 assert!(matches!(err, VerificationError::LenMismatch { .. }));
441 }
442
443 #[test]
444 fn odd_items() {
445 let odd_items = OddItems::new(iter::once(1).chain([2, 3, 4]));
446 assert_eq!(odd_items.size_hint(), (2, Some(2)));
447 assert_eq!(odd_items.collect::<Vec<_>>(), [2, 4]);
448
449 let other_items = OddItems::new(0..7);
450 assert_eq!(other_items.size_hint(), (3, Some(3)));
451 assert_eq!(other_items.collect::<Vec<_>>(), [1, 3, 5]);
452 }
453}