elastic_elgamal/group/
ristretto.rs1use rand_core::{CryptoRng, RngCore};
2
3use core::convert::TryInto;
4
5use crate::curve25519::{
6 constants::{RISTRETTO_BASEPOINT_POINT, RISTRETTO_BASEPOINT_TABLE},
7 ristretto::{CompressedRistretto, RistrettoPoint},
8 scalar::Scalar,
9 traits::{Identity, IsIdentity, MultiscalarMul, VartimeMultiscalarMul},
10};
11use crate::group::{ElementOps, Group, RandomBytesProvider, ScalarOps};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15#[cfg_attr(
16 docsrs,
17 doc(cfg(any(feature = "curve25519-dalek", feature = "curve25519-dalek-ng")))
18)]
19pub struct Ristretto(());
20
21impl ScalarOps for Ristretto {
22 type Scalar = Scalar;
23
24 const SCALAR_SIZE: usize = 32;
25
26 fn generate_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Self::Scalar {
27 let mut scalar_bytes = [0_u8; 64];
28 rng.fill_bytes(&mut scalar_bytes[..]);
29 Scalar::from_bytes_mod_order_wide(&scalar_bytes)
30 }
31
32 fn scalar_from_random_bytes(source: RandomBytesProvider<'_>) -> Self::Scalar {
33 let mut scalar_bytes = [0_u8; 64];
34 source.fill_bytes(&mut scalar_bytes);
35 Scalar::from_bytes_mod_order_wide(&scalar_bytes)
36 }
37
38 fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
39 scalar.invert()
40 }
41
42 fn invert_scalars(scalars: &mut [Self::Scalar]) {
43 Scalar::batch_invert(scalars);
44 }
45
46 fn serialize_scalar(scalar: &Self::Scalar, buffer: &mut [u8]) {
47 buffer.copy_from_slice(&scalar.to_bytes());
48 }
49
50 #[cfg(feature = "curve25519-dalek")]
51 fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
52 let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
53 Scalar::from_canonical_bytes(*bytes).into()
54 }
55
56 #[cfg(feature = "curve25519-dalek-ng")]
57 fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
58 let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
59 Scalar::from_canonical_bytes(*bytes)
60 }
61}
62
63impl ElementOps for Ristretto {
64 type Element = RistrettoPoint;
65
66 const ELEMENT_SIZE: usize = 32;
67
68 fn identity() -> Self::Element {
69 RistrettoPoint::identity()
70 }
71
72 fn is_identity(element: &Self::Element) -> bool {
73 element.is_identity()
74 }
75
76 fn generator() -> Self::Element {
77 RISTRETTO_BASEPOINT_POINT
78 }
79
80 fn serialize_element(element: &Self::Element, buffer: &mut [u8]) {
81 buffer.copy_from_slice(&element.compress().to_bytes());
82 }
83
84 #[cfg(feature = "curve25519-dalek")]
85 fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
86 CompressedRistretto::from_slice(buffer).ok()?.decompress()
87 }
88
89 #[cfg(feature = "curve25519-dalek-ng")]
90 fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
91 CompressedRistretto::from_slice(buffer).decompress()
92 }
93}
94
95impl Group for Ristretto {
96 #[cfg(feature = "curve25519-dalek")]
97 fn mul_generator(k: &Scalar) -> Self::Element {
98 k * RISTRETTO_BASEPOINT_TABLE
99 }
100
101 #[cfg(feature = "curve25519-dalek-ng")]
102 fn mul_generator(k: &Scalar) -> Self::Element {
103 k * &RISTRETTO_BASEPOINT_TABLE
104 }
105
106 fn vartime_mul_generator(k: &Scalar) -> Self::Element {
107 #[cfg(feature = "curve25519-dalek")]
108 let zero = Scalar::ZERO;
109 #[cfg(feature = "curve25519-dalek-ng")]
110 let zero = Scalar::zero();
111
112 RistrettoPoint::vartime_double_scalar_mul_basepoint(&zero, &RistrettoPoint::identity(), k)
113 }
114
115 fn multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
116 where
117 I: IntoIterator<Item = &'a Self::Scalar>,
118 J: IntoIterator<Item = Self::Element>,
119 {
120 RistrettoPoint::multiscalar_mul(scalars, elements)
121 }
122
123 fn vartime_double_mul_generator(
124 k: &Scalar,
125 k_element: Self::Element,
126 r: &Scalar,
127 ) -> Self::Element {
128 RistrettoPoint::vartime_double_scalar_mul_basepoint(k, &k_element, r)
129 }
130
131 fn vartime_multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
132 where
133 I: IntoIterator<Item = &'a Self::Scalar>,
134 J: IntoIterator<Item = Self::Element>,
135 {
136 RistrettoPoint::vartime_multiscalar_mul(scalars, elements)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use rand::thread_rng;
143
144 use super::*;
145 use crate::{
146 app::{ChoiceParams, EncryptedChoice},
147 group::Curve25519Subgroup,
148 DiscreteLogTable,
149 };
150
151 type SecretKey = crate::SecretKey<Ristretto>;
152 type Keypair = crate::Keypair<Ristretto>;
153
154 #[test]
155 fn encrypt_and_decrypt() {
156 let mut rng = thread_rng();
157 let keypair = Keypair::generate(&mut rng);
158 let value = Ristretto::generate_scalar(&mut rng);
159 let encrypted = keypair.public().encrypt(value, &mut rng);
160 let decryption = keypair.secret().decrypt_to_element(encrypted);
161 assert_eq!(decryption, Ristretto::vartime_mul_generator(&value));
162 }
163
164 #[test]
165 fn encrypt_choice() {
166 let mut rng = thread_rng();
167 let (pk, sk) = Keypair::generate(&mut rng).into_tuple();
168 let choice_params = ChoiceParams::single(pk, 5);
169 let encrypted = EncryptedChoice::single(&choice_params, 3, &mut rng);
170 let choices = encrypted.verify(&choice_params).unwrap();
171
172 let lookup_table = DiscreteLogTable::new(0..=1);
173 for (i, &choice) in choices.iter().enumerate() {
174 let decryption = sk.decrypt(choice, &lookup_table);
175 assert_eq!(decryption.unwrap(), u64::from(i == 3));
176 }
177 }
178
179 #[test]
180 fn edwards_and_ristretto_public_keys_differ() {
181 type SubgroupSecretKey = crate::SecretKey<Curve25519Subgroup>;
182 type SubgroupKeypair = crate::Keypair<Curve25519Subgroup>;
183
184 for _ in 0..1_000 {
185 let secret_key = SecretKey::generate(&mut thread_rng());
186 let keypair = Keypair::from(secret_key.clone());
187 let secret_key = SubgroupSecretKey::new(*secret_key.expose_scalar());
188 let ed_keypair = SubgroupKeypair::from(secret_key);
189 assert_ne!(keypair.public().as_bytes(), ed_keypair.public().as_bytes());
190 }
191 }
192}