Skip to main content

primitives/algebra/elliptic_curve/
point.rs

1use std::{
2    hash::Hash,
3    iter::Sum,
4    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5    sync::Arc,
6};
7
8use elliptic_curve::group::{Group, GroupEncoding};
9use rand::{distributions::Standard, prelude::Distribution, RngCore};
10use serde::{Deserialize, Serialize};
11use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
12
13use crate::{
14    algebra::elliptic_curve::{
15        curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
16        BaseFieldElement,
17        Curve,
18        Scalar,
19        ScalarAsExtension,
20    },
21    errors::PrimitiveError,
22    sharing::unauthenticated::AdditiveShares,
23};
24
25/// A point on a given curve.
26#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
27pub struct Point<C: Curve>(pub(crate) C::Point);
28impl<C: Curve> Unpin for Point<C> {}
29
30impl<C: Curve> Serialize for Point<C> {
31    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
32        let bytes = self.0.to_bytes();
33        serde_bytes::serialize(bytes.as_ref(), serializer)
34    }
35}
36
37impl<'de, C: Curve> Deserialize<'de> for Point<C> {
38    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
39        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
40        let endian_bytes = if C::POINT_BIG_ENDIAN {
41            Point::from_be_bytes(bytes)
42        } else {
43            Point::from_le_bytes(bytes)
44        };
45        let point = endian_bytes.map_err(|err| {
46            serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
47        })?;
48        Ok(point)
49    }
50}
51
52// ------------------------
53// | Misc Implementations |
54// ------------------------
55
56impl<C: Curve> Point<C> {
57    /// The additive identity in the curve group
58    pub fn identity() -> Point<C> {
59        Point(C::Point::identity())
60    }
61
62    pub fn new(point: C::Point) -> Point<C> {
63        Point(point)
64    }
65
66    /// Check whether the given point is the identity point in the group
67    pub fn is_identity(&self) -> Choice {
68        self.ct_eq(&Point::identity())
69    }
70
71    /// Return the wrapped type
72    pub fn inner(&self) -> C::Point {
73        self.0
74    }
75
76    /// The group generator
77    pub fn generator() -> Point<C> {
78        Point(<C::Point as Group>::generator())
79    }
80
81    /// Deserialize a point from a byte buffer
82    pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
83        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
84
85        if C::POINT_BIG_ENDIAN {
86            encoding.as_mut().copy_from_slice(bytes);
87        } else {
88            encoding.as_mut().copy_from_slice(bytes);
89            encoding.as_mut().reverse();
90        }
91
92        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
93            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
94        })?;
95        Ok(Point(point))
96    }
97
98    /// Deserialize a point from a byte buffer
99    /// TODO: Check this is constant-time
100    pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
101        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
102
103        if C::POINT_BIG_ENDIAN {
104            encoding.as_mut().copy_from_slice(bytes);
105            encoding.as_mut().reverse();
106        } else {
107            encoding.as_mut().copy_from_slice(bytes);
108        }
109
110        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
111            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
112        })?;
113        Ok(Point(point))
114    }
115
116    /// Serialize the point to a byte buffer
117    pub fn to_bytes(&self) -> Arc<[u8]> {
118        self.0.to_bytes().as_ref().into()
119    }
120
121    pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
122        C::Point::from_extended_edwards(coordinates).map(Point)
123    }
124
125    pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
126        self.0.to_extended_edwards()
127    }
128}
129
130impl<C: Curve> Distribution<Point<C>> for Standard {
131    #[inline]
132    fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
133        Point(C::Point::random(rng))
134    }
135}
136
137// ------------------------------------
138// | Curve Arithmetic Implementations |
139// ------------------------------------
140
141// === Addition === //
142
143#[macros::op_variants(owned, borrowed, flipped_commutative)]
144impl<C: Curve> Add<&Point<C>> for Point<C> {
145    type Output = Point<C>;
146
147    #[inline]
148    fn add(mut self, rhs: &Point<C>) -> Self::Output {
149        self.0 += rhs.0;
150        self
151    }
152}
153
154#[macros::op_variants(owned)]
155impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
156    #[inline]
157    fn add_assign(&mut self, rhs: &Point<C>) {
158        self.0 += rhs.0;
159    }
160}
161
162// === Subtraction === //
163
164#[macros::op_variants(owned, borrowed, flipped)]
165impl<C: Curve> Sub<&Point<C>> for Point<C> {
166    type Output = Point<C>;
167
168    #[inline]
169    fn sub(mut self, rhs: &Point<C>) -> Self::Output {
170        self.0 -= rhs.0;
171        self
172    }
173}
174
175#[macros::op_variants(owned)]
176impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
177    #[inline]
178    fn sub_assign(&mut self, rhs: &Point<C>) {
179        self.0 -= rhs.0;
180    }
181}
182
183// === Negation === //
184
185#[macros::op_variants(borrowed)]
186impl<C: Curve> Neg for Point<C> {
187    type Output = Point<C>;
188
189    #[inline]
190    fn neg(self) -> Self::Output {
191        Point(-self.0)
192    }
193}
194
195// === Scalar Multiplication === //
196
197#[macros::op_variants(owned, borrowed, flipped)]
198impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
199    type Output = Point<C>;
200
201    #[inline]
202    fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
203        self.0 *= rhs.0;
204        self
205    }
206}
207
208#[macros::op_variants(owned, borrowed, flipped_commutative)]
209impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
210    type Output = Point<C>;
211
212    #[inline]
213    fn mul(self, rhs: &Point<C>) -> Self::Output {
214        Point(rhs.0 * self.0)
215    }
216}
217
218#[macros::op_variants(owned, borrowed, flipped)]
219impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
220    type Output = Point<C>;
221
222    #[inline]
223    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
224        Point(self.0 * rhs.0)
225    }
226}
227
228#[macros::op_variants(owned, borrowed, flipped_commutative)]
229impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
230    type Output = Point<C>;
231
232    #[inline]
233    fn mul(self, rhs: &Point<C>) -> Self::Output {
234        Point(rhs.0 * self.0)
235    }
236}
237
238// === MulAssign === //
239
240#[macros::op_variants(owned)]
241impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
242    #[inline]
243    fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
244        self.0 *= rhs.0;
245    }
246}
247
248#[macros::op_variants(owned)]
249impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
250    #[inline]
251    fn mul_assign(&mut self, rhs: &Scalar<C>) {
252        self.0 *= rhs.0;
253    }
254}
255
256// === Equality === //
257
258impl<C: Curve> ConstantTimeEq for Point<C> {
259    #[inline]
260    fn ct_eq(&self, other: &Self) -> Choice {
261        self.0.ct_eq(&other.0)
262    }
263}
264
265impl<C: Curve> ConditionallySelectable for Point<C> {
266    #[inline]
267    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
268        let selected = C::Point::conditional_select(&a.0, &b.0, choice);
269        Point(selected)
270    }
271}
272
273// === Other === //
274
275impl<C: Curve> AdditiveShares for Point<C> {}
276
277// === Iterator traits === //
278
279impl<C: Curve> Sum for Point<C> {
280    #[inline]
281    fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
282        iter.fold(Point::identity(), |acc, x| acc + x)
283    }
284}
285
286impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
287    #[inline]
288    fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
289        iter.fold(Point::identity(), |acc, x| acc + x)
290    }
291}