elastic_elgamal/group/
ristretto.rs

1use rand_core::{CryptoRng, RngCore};
2
3use core::convert::TryInto;
4
5use crate::curve25519::{
6    constants::{RISTRETTO_BASEPOINT_POINT, RISTRETTO_BASEPOINT_TABLE},
7    ristretto::{CompressedRistretto, RistrettoPoint},
8    scalar::Scalar,
9    traits::{Identity, IsIdentity, MultiscalarMul, VartimeMultiscalarMul},
10};
11use crate::group::{ElementOps, Group, RandomBytesProvider, ScalarOps};
12
13/// [Ristretto](https://ristretto.group/) transform of Curve25519, also known as ristretto255.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15#[cfg_attr(
16    docsrs,
17    doc(cfg(any(feature = "curve25519-dalek", feature = "curve25519-dalek-ng")))
18)]
19pub struct Ristretto(());
20
21impl ScalarOps for Ristretto {
22    type Scalar = Scalar;
23
24    const SCALAR_SIZE: usize = 32;
25
26    fn generate_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Self::Scalar {
27        let mut scalar_bytes = [0_u8; 64];
28        rng.fill_bytes(&mut scalar_bytes[..]);
29        Scalar::from_bytes_mod_order_wide(&scalar_bytes)
30    }
31
32    fn scalar_from_random_bytes(source: RandomBytesProvider<'_>) -> Self::Scalar {
33        let mut scalar_bytes = [0_u8; 64];
34        source.fill_bytes(&mut scalar_bytes);
35        Scalar::from_bytes_mod_order_wide(&scalar_bytes)
36    }
37
38    fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
39        scalar.invert()
40    }
41
42    fn invert_scalars(scalars: &mut [Self::Scalar]) {
43        Scalar::batch_invert(scalars);
44    }
45
46    fn serialize_scalar(scalar: &Self::Scalar, buffer: &mut [u8]) {
47        buffer.copy_from_slice(&scalar.to_bytes());
48    }
49
50    #[cfg(feature = "curve25519-dalek")]
51    fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
52        let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
53        Scalar::from_canonical_bytes(*bytes).into()
54    }
55
56    #[cfg(feature = "curve25519-dalek-ng")]
57    fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
58        let bytes: &[u8; 32] = buffer.try_into().expect("input has incorrect byte size");
59        Scalar::from_canonical_bytes(*bytes)
60    }
61}
62
63impl ElementOps for Ristretto {
64    type Element = RistrettoPoint;
65
66    const ELEMENT_SIZE: usize = 32;
67
68    fn identity() -> Self::Element {
69        RistrettoPoint::identity()
70    }
71
72    fn is_identity(element: &Self::Element) -> bool {
73        element.is_identity()
74    }
75
76    fn generator() -> Self::Element {
77        RISTRETTO_BASEPOINT_POINT
78    }
79
80    fn serialize_element(element: &Self::Element, buffer: &mut [u8]) {
81        buffer.copy_from_slice(&element.compress().to_bytes());
82    }
83
84    #[cfg(feature = "curve25519-dalek")]
85    fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
86        CompressedRistretto::from_slice(buffer).ok()?.decompress()
87    }
88
89    #[cfg(feature = "curve25519-dalek-ng")]
90    fn deserialize_element(buffer: &[u8]) -> Option<Self::Element> {
91        CompressedRistretto::from_slice(buffer).decompress()
92    }
93}
94
95impl Group for Ristretto {
96    #[cfg(feature = "curve25519-dalek")]
97    fn mul_generator(k: &Scalar) -> Self::Element {
98        k * RISTRETTO_BASEPOINT_TABLE
99    }
100
101    #[cfg(feature = "curve25519-dalek-ng")]
102    fn mul_generator(k: &Scalar) -> Self::Element {
103        k * &RISTRETTO_BASEPOINT_TABLE
104    }
105
106    fn vartime_mul_generator(k: &Scalar) -> Self::Element {
107        #[cfg(feature = "curve25519-dalek")]
108        let zero = Scalar::ZERO;
109        #[cfg(feature = "curve25519-dalek-ng")]
110        let zero = Scalar::zero();
111
112        RistrettoPoint::vartime_double_scalar_mul_basepoint(&zero, &RistrettoPoint::identity(), k)
113    }
114
115    fn multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
116    where
117        I: IntoIterator<Item = &'a Self::Scalar>,
118        J: IntoIterator<Item = Self::Element>,
119    {
120        RistrettoPoint::multiscalar_mul(scalars, elements)
121    }
122
123    fn vartime_double_mul_generator(
124        k: &Scalar,
125        k_element: Self::Element,
126        r: &Scalar,
127    ) -> Self::Element {
128        RistrettoPoint::vartime_double_scalar_mul_basepoint(k, &k_element, r)
129    }
130
131    fn vartime_multi_mul<'a, I, J>(scalars: I, elements: J) -> Self::Element
132    where
133        I: IntoIterator<Item = &'a Self::Scalar>,
134        J: IntoIterator<Item = Self::Element>,
135    {
136        RistrettoPoint::vartime_multiscalar_mul(scalars, elements)
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use rand::thread_rng;
143
144    use super::*;
145    use crate::{
146        app::{ChoiceParams, EncryptedChoice},
147        group::Curve25519Subgroup,
148        DiscreteLogTable,
149    };
150
151    type SecretKey = crate::SecretKey<Ristretto>;
152    type Keypair = crate::Keypair<Ristretto>;
153
154    #[test]
155    fn encrypt_and_decrypt() {
156        let mut rng = thread_rng();
157        let keypair = Keypair::generate(&mut rng);
158        let value = Ristretto::generate_scalar(&mut rng);
159        let encrypted = keypair.public().encrypt(value, &mut rng);
160        let decryption = keypair.secret().decrypt_to_element(encrypted);
161        assert_eq!(decryption, Ristretto::vartime_mul_generator(&value));
162    }
163
164    #[test]
165    fn encrypt_choice() {
166        let mut rng = thread_rng();
167        let (pk, sk) = Keypair::generate(&mut rng).into_tuple();
168        let choice_params = ChoiceParams::single(pk, 5);
169        let encrypted = EncryptedChoice::single(&choice_params, 3, &mut rng);
170        let choices = encrypted.verify(&choice_params).unwrap();
171
172        let lookup_table = DiscreteLogTable::new(0..=1);
173        for (i, &choice) in choices.iter().enumerate() {
174            let decryption = sk.decrypt(choice, &lookup_table);
175            assert_eq!(decryption.unwrap(), u64::from(i == 3));
176        }
177    }
178
179    #[test]
180    fn edwards_and_ristretto_public_keys_differ() {
181        type SubgroupSecretKey = crate::SecretKey<Curve25519Subgroup>;
182        type SubgroupKeypair = crate::Keypair<Curve25519Subgroup>;
183
184        for _ in 0..1_000 {
185            let secret_key = SecretKey::generate(&mut thread_rng());
186            let keypair = Keypair::from(secret_key.clone());
187            let secret_key = SubgroupSecretKey::new(*secret_key.expose_scalar());
188            let ed_keypair = SubgroupKeypair::from(secret_key);
189            assert_ne!(keypair.public().as_bytes(), ed_keypair.public().as_bytes());
190        }
191    }
192}