curv/elliptic/curves/
curve_ristretto.rs

1#![allow(non_snake_case)]
2/*
3    This file is part of Curv library
4    Copyright 2018 by Kzen Networks
5    (https://github.com/KZen-networks/curv)
6    License MIT: <https://github.com/KZen-networks/curv/blob/master/LICENSE>
7*/
8
9use std::convert::TryInto;
10use std::ptr;
11use std::sync::atomic;
12
13use curve25519_dalek::constants::{BASEPOINT_ORDER, RISTRETTO_BASEPOINT_POINT};
14use curve25519_dalek::ristretto::CompressedRistretto;
15use curve25519_dalek::traits::{Identity, IsIdentity};
16use generic_array::GenericArray;
17use rand::thread_rng;
18use sha2::{Digest, Sha256};
19use zeroize::{Zeroize, Zeroizing};
20
21use crate::arithmetic::*;
22use crate::elliptic::curves::traits::*;
23
24use super::traits::{ECPoint, ECScalar};
25
26lazy_static::lazy_static! {
27    static ref GROUP_ORDER: BigInt = RistrettoScalar {
28        purpose: "intermediate GROUP_ORDER",
29        fe: BASEPOINT_ORDER.into(),
30    }.to_bigint();
31
32    static ref GENERATOR: RistrettoPoint = RistrettoPoint {
33        purpose: "generator",
34        ge: RISTRETTO_BASEPOINT_POINT,
35    };
36
37    static ref BASE_POINT2: RistrettoPoint = {
38        let g = RistrettoPoint::generator();
39        let hash = Sha256::digest(g.serialize_compressed().as_ref());
40        RistrettoPoint {
41            purpose: "base_point2",
42            ge: RistrettoPoint::deserialize(&hash).unwrap().ge,
43        }
44    };
45}
46
47pub const SECRET_KEY_SIZE: usize = 32;
48pub const COOR_BYTE_SIZE: usize = 32;
49pub const NUM_OF_COORDINATES: usize = 4;
50
51pub type SK = curve25519_dalek::scalar::Scalar;
52pub type PK = curve25519_dalek::ristretto::RistrettoPoint;
53
54/// Ristretto curve implementation based on [curve25519_dalek] library
55///
56/// ## Implementation notes
57/// * x coordinate
58///
59///   Underlying library intentionally doesn't expose x coordinate of curve point, therefore
60///   `.x_coord()`, `.coords()` methods always return `None`, `from_coords()` constructor always
61///   returns `Err(NotOnCurve)`
62#[derive(Debug, PartialEq, Eq, Clone)]
63pub enum Ristretto {}
64#[derive(Clone, Debug)]
65pub struct RistrettoScalar {
66    #[allow(dead_code)]
67    purpose: &'static str,
68    fe: Zeroizing<SK>,
69}
70#[derive(Clone, Debug, Copy)]
71pub struct RistrettoPoint {
72    #[allow(dead_code)]
73    purpose: &'static str,
74    ge: PK,
75}
76pub type GE = RistrettoPoint;
77pub type FE = RistrettoScalar;
78
79impl Curve for Ristretto {
80    type Point = GE;
81    type Scalar = FE;
82
83    const CURVE_NAME: &'static str = "ristretto";
84}
85
86impl ECScalar for RistrettoScalar {
87    type Underlying = SK;
88
89    type ScalarLength = typenum::U32;
90
91    fn random() -> RistrettoScalar {
92        RistrettoScalar {
93            purpose: "random",
94            fe: SK::random(&mut thread_rng()).into(),
95        }
96    }
97
98    fn zero() -> RistrettoScalar {
99        RistrettoScalar {
100            purpose: "zero",
101            fe: SK::zero().into(),
102        }
103    }
104
105    fn from_bigint(n: &BigInt) -> RistrettoScalar {
106        let curve_order = RistrettoScalar::group_order();
107        let mut bytes = n
108            .modulus(curve_order)
109            .to_bytes_array::<32>()
110            .expect("n mod curve_order must be equal or less than 32 bytes");
111        bytes.reverse();
112        RistrettoScalar {
113            purpose: "from_bigint",
114            fe: SK::from_bytes_mod_order(bytes).into(),
115        }
116    }
117
118    fn to_bigint(&self) -> BigInt {
119        let mut t = self.fe.to_bytes();
120        t.reverse();
121        BigInt::from_bytes(&t)
122    }
123
124    fn serialize(&self) -> GenericArray<u8, Self::ScalarLength> {
125        GenericArray::from(self.fe.to_bytes())
126    }
127
128    fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError> {
129        let bytes: [u8; 32] = bytes.try_into().or(Err(DeserializationError))?;
130        Ok(RistrettoScalar {
131            purpose: "from_bigint",
132            fe: SK::from_canonical_bytes(bytes)
133                .ok_or(DeserializationError)?
134                .into(),
135        })
136    }
137
138    fn add(&self, other: &Self) -> RistrettoScalar {
139        RistrettoScalar {
140            purpose: "add",
141            fe: (*self.fe + *other.fe).into(),
142        }
143    }
144
145    fn mul(&self, other: &Self) -> RistrettoScalar {
146        RistrettoScalar {
147            purpose: "mul",
148            fe: (*self.fe * *other.fe).into(),
149        }
150    }
151
152    fn sub(&self, other: &Self) -> RistrettoScalar {
153        RistrettoScalar {
154            purpose: "sub",
155            fe: (*self.fe - *other.fe).into(),
156        }
157    }
158
159    fn neg(&self) -> Self {
160        RistrettoScalar {
161            purpose: "neg",
162            fe: (-&*self.fe).into(),
163        }
164    }
165
166    fn invert(&self) -> Option<RistrettoScalar> {
167        if self.is_zero() {
168            None
169        } else {
170            Some(RistrettoScalar {
171                purpose: "invert",
172                fe: self.fe.invert().into(),
173            })
174        }
175    }
176
177    fn add_assign(&mut self, other: &Self) {
178        *self.fe += &*other.fe;
179    }
180    fn mul_assign(&mut self, other: &Self) {
181        *self.fe *= &*other.fe;
182    }
183    fn sub_assign(&mut self, other: &Self) {
184        *self.fe -= &*other.fe;
185    }
186
187    fn group_order() -> &'static BigInt {
188        &GROUP_ORDER
189    }
190
191    fn underlying_ref(&self) -> &Self::Underlying {
192        &self.fe
193    }
194    fn underlying_mut(&mut self) -> &mut Self::Underlying {
195        &mut self.fe
196    }
197    fn from_underlying(fe: Self::Underlying) -> RistrettoScalar {
198        RistrettoScalar {
199            purpose: "from_underlying",
200            fe: fe.into(),
201        }
202    }
203}
204
205impl PartialEq for RistrettoScalar {
206    fn eq(&self, other: &RistrettoScalar) -> bool {
207        self.fe == other.fe
208    }
209}
210
211impl ECPoint for RistrettoPoint {
212    type Scalar = RistrettoScalar;
213    type Underlying = PK;
214
215    type CompressedPointLength = typenum::U32;
216    type UncompressedPointLength = typenum::U32;
217
218    fn zero() -> RistrettoPoint {
219        RistrettoPoint {
220            purpose: "zero",
221            ge: PK::identity(),
222        }
223    }
224
225    fn is_zero(&self) -> bool {
226        self.ge.is_identity()
227    }
228
229    fn generator() -> &'static RistrettoPoint {
230        &GENERATOR
231    }
232
233    fn base_point2() -> &'static RistrettoPoint {
234        &BASE_POINT2
235    }
236
237    fn from_coords(_x: &BigInt, _y: &BigInt) -> Result<RistrettoPoint, NotOnCurve> {
238        // Underlying library intentionally hides x coordinate. There's no way to match if `x`
239        // correspond to given `y`.
240        Err(NotOnCurve)
241    }
242
243    fn x_coord(&self) -> Option<BigInt> {
244        // Underlying library intentionally hides x coordinate. There's no way we can know x
245        // coordinate
246        None
247    }
248
249    fn y_coord(&self) -> Option<BigInt> {
250        let mut y = self.ge.compress().to_bytes();
251        y.reverse();
252        Some(BigInt::from_bytes(&y[..]))
253    }
254
255    fn coords(&self) -> Option<PointCoords> {
256        None
257    }
258
259    fn serialize_compressed(&self) -> GenericArray<u8, Self::CompressedPointLength> {
260        GenericArray::from(self.ge.compress().to_bytes())
261    }
262
263    fn serialize_uncompressed(&self) -> GenericArray<u8, Self::UncompressedPointLength> {
264        GenericArray::from(self.ge.compress().to_bytes())
265    }
266
267    fn deserialize(bytes: &[u8]) -> Result<RistrettoPoint, DeserializationError> {
268        let mut buffer = [0u8; 32];
269        let n = bytes.len();
270
271        if n == 0 || n > 32 {
272            return Err(DeserializationError);
273        }
274        buffer[32 - n..].copy_from_slice(bytes);
275
276        CompressedRistretto::from_slice(&buffer)
277            .decompress()
278            .ok_or(DeserializationError)
279            .map(|ge| RistrettoPoint {
280                purpose: "deserialize",
281                ge,
282            })
283    }
284
285    fn check_point_order_equals_group_order(&self) -> bool {
286        !self.is_zero()
287    }
288
289    fn scalar_mul(&self, fe: &Self::Scalar) -> RistrettoPoint {
290        RistrettoPoint {
291            purpose: "scalar_mul",
292            ge: self.ge * *fe.fe,
293        }
294    }
295
296    fn add_point(&self, other: &Self) -> RistrettoPoint {
297        RistrettoPoint {
298            purpose: "add_point",
299            ge: self.ge + other.ge,
300        }
301    }
302
303    fn sub_point(&self, other: &Self) -> RistrettoPoint {
304        RistrettoPoint {
305            purpose: "sub_point",
306            ge: self.ge - other.ge,
307        }
308    }
309
310    fn neg_point(&self) -> RistrettoPoint {
311        RistrettoPoint {
312            purpose: "sub_point",
313            ge: -self.ge,
314        }
315    }
316
317    fn scalar_mul_assign(&mut self, scalar: &Self::Scalar) {
318        self.ge *= &*scalar.fe
319    }
320    fn add_point_assign(&mut self, other: &Self) {
321        self.ge += &other.ge
322    }
323    fn sub_point_assign(&mut self, other: &Self) {
324        self.ge -= &other.ge
325    }
326    fn underlying_ref(&self) -> &Self::Underlying {
327        &self.ge
328    }
329    fn underlying_mut(&mut self) -> &mut Self::Underlying {
330        &mut self.ge
331    }
332    fn from_underlying(ge: Self::Underlying) -> RistrettoPoint {
333        RistrettoPoint {
334            purpose: "from_underlying",
335            ge,
336        }
337    }
338}
339
340impl PartialEq for RistrettoPoint {
341    fn eq(&self, other: &RistrettoPoint) -> bool {
342        self.ge == other.ge
343    }
344}
345
346impl Zeroize for RistrettoPoint {
347    fn zeroize(&mut self) {
348        unsafe { ptr::write_volatile(&mut self.ge, PK::default()) };
349        atomic::compiler_fence(atomic::Ordering::SeqCst);
350    }
351}