1use merlin::Transcript;
4use rand_core::{CryptoRng, RngCore};
5#[cfg(feature = "serde")]
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use zeroize::Zeroizing;
8
9use core::{fmt, iter, ops};
10
11use crate::{
12 alloc::{vec, Vec},
13 group::Group,
14 Ciphertext, CiphertextWithValue, LogEqualityProof, PublicKey, RingProof, RingProofBuilder,
15 VerificationError,
16};
17
18pub trait ProveSum<G: Group>: Clone + crate::sealed::Sealed {
23 #[cfg(not(feature = "serde"))]
25 type Proof: Sized;
26 #[cfg(feature = "serde")]
28 type Proof: Sized + Serialize + DeserializeOwned;
29
30 #[doc(hidden)]
31 fn prove<R: CryptoRng + RngCore>(
32 &self,
33 ciphertext: &CiphertextWithValue<G, u64>,
34 receiver: &PublicKey<G>,
35 rng: &mut R,
36 ) -> Self::Proof;
37
38 #[doc(hidden)]
39 fn verify(
40 &self,
41 ciphertext: &Ciphertext<G>,
42 proof: &Self::Proof,
43 receiver: &PublicKey<G>,
44 ) -> Result<(), ChoiceVerificationError>;
45}
46
47#[derive(Debug, Clone, Copy)]
53pub struct SingleChoice(());
54
55impl crate::sealed::Sealed for SingleChoice {}
56
57impl<G: Group> ProveSum<G> for SingleChoice {
58 type Proof = LogEqualityProof<G>;
59
60 fn prove<R: CryptoRng + RngCore>(
61 &self,
62 ciphertext: &CiphertextWithValue<G, u64>,
63 receiver: &PublicKey<G>,
64 rng: &mut R,
65 ) -> Self::Proof {
66 LogEqualityProof::new(
67 receiver,
68 ciphertext.randomness(),
69 (
70 ciphertext.inner().random_element,
71 ciphertext.inner().blinded_element - G::generator(),
72 ),
73 &mut Transcript::new(b"choice_encryption_sum"),
74 rng,
75 )
76 }
77
78 fn verify(
79 &self,
80 ciphertext: &Ciphertext<G>,
81 proof: &Self::Proof,
82 receiver: &PublicKey<G>,
83 ) -> Result<(), ChoiceVerificationError> {
84 let powers = (
85 ciphertext.random_element,
86 ciphertext.blinded_element - G::generator(),
87 );
88 proof
89 .verify(
90 receiver,
91 powers,
92 &mut Transcript::new(b"choice_encryption_sum"),
93 )
94 .map_err(ChoiceVerificationError::Sum)
95 }
96}
97
98#[derive(Debug, Clone, Copy)]
105pub struct MultiChoice(());
106
107impl crate::sealed::Sealed for MultiChoice {}
108
109impl<G: Group> ProveSum<G> for MultiChoice {
110 type Proof = ();
111
112 fn prove<R: CryptoRng + RngCore>(
113 &self,
114 _ciphertext: &CiphertextWithValue<G, u64>,
115 _receiver: &PublicKey<G>,
116 _rng: &mut R,
117 ) -> Self::Proof {
118 }
120
121 fn verify(
122 &self,
123 _ciphertext: &Ciphertext<G>,
124 _proof: &Self::Proof,
125 _receiver: &PublicKey<G>,
126 ) -> Result<(), ChoiceVerificationError> {
127 Ok(()) }
129}
130
131#[derive(Debug)]
133pub struct ChoiceParams<G: Group, S: ProveSum<G>> {
134 options_count: usize,
135 sum_prover: S,
136 receiver: PublicKey<G>,
137}
138
139impl<G: Group, S: ProveSum<G>> Clone for ChoiceParams<G, S> {
140 fn clone(&self) -> Self {
141 Self {
142 options_count: self.options_count,
143 sum_prover: self.sum_prover.clone(),
144 receiver: self.receiver.clone(),
145 }
146 }
147}
148
149impl<G: Group, S: ProveSum<G>> ChoiceParams<G, S> {
150 fn check_options_count(&self, actual_count: usize) -> Result<(), ChoiceVerificationError> {
151 if self.options_count == actual_count {
152 Ok(())
153 } else {
154 Err(ChoiceVerificationError::OptionsLenMismatch {
155 expected: self.options_count,
156 actual: actual_count,
157 })
158 }
159 }
160
161 pub fn receiver(&self) -> &PublicKey<G> {
163 &self.receiver
164 }
165
166 pub fn options_count(&self) -> usize {
168 self.options_count
169 }
170}
171
172impl<G: Group> ChoiceParams<G, SingleChoice> {
173 pub fn single(receiver: PublicKey<G>, options_count: usize) -> Self {
179 assert!(options_count > 0, "Number of options must be positive");
180 Self {
181 options_count,
182 sum_prover: SingleChoice(()),
183 receiver,
184 }
185 }
186}
187
188impl<G: Group> ChoiceParams<G, MultiChoice> {
189 pub fn multi(receiver: PublicKey<G>, options_count: usize) -> Self {
195 assert!(options_count > 0, "Number of options must be positive");
196 Self {
197 options_count,
198 sum_prover: MultiChoice(()),
199 receiver,
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
277#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
278#[cfg_attr(feature = "serde", serde(bound = ""))]
279pub struct EncryptedChoice<G: Group, S: ProveSum<G>> {
280 choices: Vec<Ciphertext<G>>,
281 range_proof: RingProof<G>,
282 sum_proof: S::Proof,
283}
284
285impl<G: Group> EncryptedChoice<G, SingleChoice> {
286 pub fn single<R: CryptoRng + RngCore>(
292 params: &ChoiceParams<G, SingleChoice>,
293 choice: usize,
294 rng: &mut R,
295 ) -> Self {
296 assert!(
297 choice < params.options_count,
298 "invalid choice {choice}; expected a value in 0..{}",
299 params.options_count
300 );
301 let choices: Vec<_> = (0..params.options_count).map(|i| choice == i).collect();
302 Self::new(params, &Zeroizing::new(choices), rng)
303 }
304}
305
306#[allow(clippy::len_without_is_empty)] impl<G: Group, S: ProveSum<G>> EncryptedChoice<G, S> {
308 pub fn new<R: CryptoRng + RngCore>(
317 params: &ChoiceParams<G, S>,
318 choices: &[bool],
319 rng: &mut R,
320 ) -> Self {
321 assert!(!choices.is_empty(), "No choices provided");
322 assert_eq!(
323 choices.len(),
324 params.options_count,
325 "Mismatch between expected and actual number of choices"
326 );
327
328 let admissible_values = [G::identity(), G::generator()];
329 let mut ring_responses = vec![G::Scalar::default(); 2 * params.options_count];
330 let mut transcript = Transcript::new(b"encrypted_choice_ranges");
331 let mut proof_builder = RingProofBuilder::new(
332 ¶ms.receiver,
333 params.options_count,
334 &mut ring_responses,
335 &mut transcript,
336 rng,
337 );
338
339 let sum = choices.iter().map(|&flag| u64::from(flag)).sum::<u64>();
340 let choices: Vec<_> = choices
341 .iter()
342 .map(|&flag| proof_builder.add_value(&admissible_values, usize::from(flag)))
343 .collect();
344 let range_proof = RingProof::new(proof_builder.build(), ring_responses);
345
346 let sum_ciphertext = choices.iter().cloned().reduce(ops::Add::add).unwrap();
347 let sum_ciphertext = sum_ciphertext.with_value(sum);
348 let sum_proof = params
349 .sum_prover
350 .prove(&sum_ciphertext, ¶ms.receiver, rng);
351 Self {
352 choices: choices.into_iter().map(|choice| choice.inner).collect(),
353 range_proof,
354 sum_proof,
355 }
356 }
357
358 #[allow(clippy::missing_panics_doc)]
365 pub fn verify(
366 &self,
367 params: &ChoiceParams<G, S>,
368 ) -> Result<&[Ciphertext<G>], ChoiceVerificationError> {
369 params.check_options_count(self.choices.len())?;
370 let sum_of_ciphertexts = self.choices.iter().copied().reduce(ops::Add::add);
371 let sum_of_ciphertexts = sum_of_ciphertexts.unwrap();
372 params
374 .sum_prover
375 .verify(&sum_of_ciphertexts, &self.sum_proof, ¶ms.receiver)?;
376
377 let admissible_values = [G::identity(), G::generator()];
378 self.range_proof
379 .verify(
380 ¶ms.receiver,
381 iter::repeat(&admissible_values as &[_]).take(self.choices.len()),
382 self.choices.iter().copied(),
383 &mut Transcript::new(b"encrypted_choice_ranges"),
384 )
385 .map(|()| self.choices.as_slice())
386 .map_err(ChoiceVerificationError::Range)
387 }
388
389 pub fn len(&self) -> usize {
392 self.choices.len()
393 }
394
395 pub fn choices_unchecked(&self) -> &[Ciphertext<G>] {
397 &self.choices
398 }
399
400 pub fn range_proof(&self) -> &RingProof<G> {
402 &self.range_proof
403 }
404
405 pub fn sum_proof(&self) -> &S::Proof {
407 &self.sum_proof
408 }
409}
410
411#[derive(Debug)]
413#[non_exhaustive]
414pub enum ChoiceVerificationError {
415 OptionsLenMismatch {
417 expected: usize,
419 actual: usize,
421 },
422 Sum(VerificationError),
424 Range(VerificationError),
426}
427
428impl fmt::Display for ChoiceVerificationError {
429 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
430 match self {
431 Self::OptionsLenMismatch { expected, actual } => write!(
432 formatter,
433 "number of options in the ballot ({actual}) differs from expected ({expected})",
434 ),
435 Self::Sum(err) => write!(formatter, "cannot verify sum proof: {err}"),
436 Self::Range(err) => write!(formatter, "cannot verify range proofs: {err}"),
437 }
438 }
439}
440
441#[cfg(feature = "std")]
442impl std::error::Error for ChoiceVerificationError {
443 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
444 match self {
445 Self::Sum(err) | Self::Range(err) => Some(err),
446 _ => None,
447 }
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use rand::thread_rng;
454
455 use super::*;
456 use crate::{
457 group::{Generic, Ristretto},
458 Keypair,
459 };
460
461 fn test_bogus_encrypted_choice_does_not_work<G: Group>() {
462 let mut rng = thread_rng();
463 let (receiver, _) = Keypair::<G>::generate(&mut rng).into_tuple();
464 let params = ChoiceParams::single(receiver.clone(), 5);
465
466 let mut choice = EncryptedChoice::single(¶ms, 2, &mut rng);
467 let (encrypted_one, _) = receiver.encrypt_bool(true, &mut rng);
468 choice.choices[0] = encrypted_one;
469 assert!(choice.verify(¶ms).is_err());
470
471 let mut choice = EncryptedChoice::single(¶ms, 4, &mut rng);
472 let (encrypted_zero, _) = receiver.encrypt_bool(false, &mut rng);
473 choice.choices[4] = encrypted_zero;
474 assert!(choice.verify(¶ms).is_err());
475
476 let mut choice = EncryptedChoice::single(¶ms, 4, &mut rng);
477 choice.choices[4].blinded_element =
478 choice.choices[4].blinded_element + G::mul_generator(&G::Scalar::from(10));
479 choice.choices[3].blinded_element =
480 choice.choices[3].blinded_element - G::mul_generator(&G::Scalar::from(10));
481 assert!(choice.verify(¶ms).is_err());
484 }
485
486 #[test]
487 fn bogus_encrypted_choice_does_not_work_for_edwards() {
488 test_bogus_encrypted_choice_does_not_work::<Ristretto>();
489 }
490
491 #[test]
492 fn bogus_encrypted_choice_does_not_work_for_k256() {
493 test_bogus_encrypted_choice_does_not_work::<Generic<k256::Secp256k1>>();
494 }
495}