elastic_elgamal/group/
generic.rs

1use elliptic_curve::{
2    ff::PrimeField,
3    generic_array::{typenum::Unsigned, GenericArray},
4    sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
5    CurveArithmetic, Field, FieldBytesSize, Group as _, ProjectivePoint, Scalar,
6};
7use rand_core::{CryptoRng, RngCore};
8use zeroize::Zeroize;
9
10use core::marker::PhantomData;
11
12use super::{ElementOps, Group, ScalarOps};
13
14/// Generic [`Group`] implementation for elliptic curves defined in terms of the traits
15/// from the [`elliptic-curve`] crate.
16///
17/// # Assumptions
18///
19/// - Arithmetic operations required to be constant-time as per [`ScalarOps`] and [`ElementOps`]
20///   contracts are indeed constant-time.
21///
22/// [`elliptic-curve`]: https://docs.rs/elliptic-curve/
23#[derive(Debug)]
24pub struct Generic<C>(PhantomData<C>);
25
26impl<C> Clone for Generic<C> {
27    fn clone(&self) -> Self {
28        *self
29    }
30}
31
32impl<C> Copy for Generic<C> {}
33
34impl<C> ScalarOps for Generic<C>
35where
36    C: CurveArithmetic,
37    Scalar<C>: Zeroize,
38{
39    type Scalar = Scalar<C>;
40
41    const SCALAR_SIZE: usize = <FieldBytesSize<C> as Unsigned>::USIZE;
42
43    fn generate_scalar<R: CryptoRng + RngCore>(rng: &mut R) -> Self::Scalar {
44        Scalar::<C>::random(rng)
45    }
46
47    fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
48        scalar.invert().unwrap()
49    }
50
51    fn serialize_scalar(scalar: &Self::Scalar, buffer: &mut [u8]) {
52        buffer.copy_from_slice(scalar.to_repr().as_ref());
53    }
54
55    fn deserialize_scalar(buffer: &[u8]) -> Option<Self::Scalar> {
56        // For most curves, cloning will be resolved as a copy.
57        Scalar::<C>::from_repr(GenericArray::from_slice(buffer).clone()).into()
58    }
59}
60
61impl<C> ElementOps for Generic<C>
62where
63    C: CurveArithmetic,
64    Scalar<C>: Zeroize,
65    FieldBytesSize<C>: ModulusSize,
66    ProjectivePoint<C>: ToEncodedPoint<C> + FromEncodedPoint<C>,
67{
68    type Element = ProjectivePoint<C>;
69
70    const ELEMENT_SIZE: usize = <FieldBytesSize<C> as Unsigned>::USIZE + 1;
71
72    #[inline]
73    fn identity() -> Self::Element {
74        C::ProjectivePoint::identity()
75    }
76
77    #[inline]
78    fn is_identity(element: &Self::Element) -> bool {
79        element.is_identity().into()
80    }
81
82    #[inline]
83    fn generator() -> Self::Element {
84        C::ProjectivePoint::generator()
85    }
86
87    fn serialize_element(element: &Self::Element, buffer: &mut [u8]) {
88        let encoded_point = element.to_encoded_point(true);
89        buffer.copy_from_slice(encoded_point.as_bytes());
90    }
91
92    fn deserialize_element(input: &[u8]) -> Option<Self::Element> {
93        let encoded_point = EncodedPoint::<C>::from_bytes(input).ok()?;
94        ProjectivePoint::<C>::from_encoded_point(&encoded_point).into()
95    }
96}
97
98impl<C> Group for Generic<C>
99where
100    C: CurveArithmetic + 'static,
101    Scalar<C>: Zeroize,
102    FieldBytesSize<C>: ModulusSize,
103    ProjectivePoint<C>: ToEncodedPoint<C> + FromEncodedPoint<C>,
104{
105    // Default implementations are fine.
106}
107
108#[cfg(test)]
109mod tests {
110    use rand::thread_rng;
111
112    use super::*;
113
114    type K256 = Generic<k256::Secp256k1>;
115
116    #[test]
117    fn scalar_roundtrip() {
118        let mut rng = thread_rng();
119        let mut buffer = [0_u8; K256::SCALAR_SIZE];
120        for _ in 0..100 {
121            let scalar = K256::generate_scalar(&mut rng);
122            K256::serialize_scalar(&scalar, &mut buffer);
123            assert_eq!(K256::deserialize_scalar(&buffer).unwrap(), scalar);
124        }
125    }
126
127    #[test]
128    fn point_roundtrip() {
129        let mut rng = thread_rng();
130        let mut buffer = [0_u8; K256::ELEMENT_SIZE];
131        for _ in 0..100 {
132            let point = K256::mul_generator(&K256::generate_scalar(&mut rng));
133            K256::serialize_element(&point, &mut buffer);
134            assert_eq!(K256::deserialize_element(&buffer).unwrap(), point);
135        }
136    }
137}