Skip to main content

primitives/algebra/elliptic_curve/
point.rs

1use std::{
2    hash::Hash,
3    iter::Sum,
4    mem::MaybeUninit,
5    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6    sync::Arc,
7};
8
9use elliptic_curve::group::{Group, GroupEncoding};
10use rand::{
11    distributions::{Distribution, Standard},
12    RngCore,
13};
14use serde::{Deserialize, Serialize};
15use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
16use wincode::{ReadResult, WriteResult};
17
18use crate::{
19    algebra::elliptic_curve::{
20        curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
21        BaseFieldElement,
22        Curve,
23        Scalar,
24        ScalarAsExtension,
25    },
26    errors::PrimitiveError,
27    random::{CryptoRngCore, Random},
28    sharing::unauthenticated::AdditiveShares,
29};
30
31/// A point on a given curve.
32#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
33#[repr(transparent)]
34pub struct Point<C: Curve>(pub(crate) C::Point);
35
36impl<C: Curve> wincode::SchemaWrite for Point<C> {
37    type Src = Self;
38
39    fn size_of(_src: &Self::Src) -> WriteResult<usize> {
40        let repr = <C::Point as GroupEncoding>::Repr::default();
41        Ok(repr.as_ref().len())
42    }
43
44    fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> WriteResult<()> {
45        let bytes = src.0.to_bytes();
46        Ok(writer.write(bytes.as_ref())?)
47    }
48}
49
50impl<'de, C: Curve> wincode::SchemaRead<'de> for Point<C> {
51    type Dst = Self;
52
53    fn read(
54        reader: &mut impl wincode::io::Reader<'de>,
55        dst: &mut MaybeUninit<Self::Dst>,
56    ) -> ReadResult<()> {
57        let mut repr = <C::Point as GroupEncoding>::Repr::default();
58        let len = repr.as_ref().len();
59        let bytes = reader.fill_exact(len)?;
60        repr.as_mut().copy_from_slice(bytes);
61        reader.consume(len)?;
62
63        let point = Option::from(C::Point::from_bytes(&repr))
64            .ok_or(wincode::ReadError::Custom("invalid curve point encoding"))?;
65
66        dst.write(Point(point));
67        Ok(())
68    }
69}
70
71impl<C: Curve> Unpin for Point<C> {}
72
73impl<C: Curve> Serialize for Point<C> {
74    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
75        let bytes = self.0.to_bytes();
76        serde_bytes::serialize(bytes.as_ref(), serializer)
77    }
78}
79
80impl<'de, C: Curve> Deserialize<'de> for Point<C> {
81    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
82        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
83        let endian_bytes = if C::POINT_BIG_ENDIAN {
84            Point::from_be_bytes(bytes)
85        } else {
86            Point::from_le_bytes(bytes)
87        };
88        let point = endian_bytes.map_err(|err| {
89            serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
90        })?;
91        Ok(point)
92    }
93}
94
95// ------------------------
96// | Misc Implementations |
97// ------------------------
98
99impl<C: Curve> Point<C> {
100    /// The additive identity in the curve group
101    pub fn identity() -> Point<C> {
102        Point(C::Point::identity())
103    }
104
105    pub fn new(point: C::Point) -> Point<C> {
106        Point(point)
107    }
108
109    /// Check whether the given point is the identity point in the group
110    pub fn is_identity(&self) -> Choice {
111        self.ct_eq(&Point::identity())
112    }
113
114    /// Return the wrapped type
115    pub fn inner(&self) -> C::Point {
116        self.0
117    }
118
119    /// The group generator
120    pub fn generator() -> Point<C> {
121        Point(<C::Point as Group>::generator())
122    }
123
124    /// Deserialize a point from a byte buffer
125    pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
126        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
127        if bytes.len() != encoding.as_ref().len() {
128            return Err(PrimitiveError::DeserializationFailed(format!(
129                "Invalid point encoding length: expected {}, got {}",
130                encoding.as_ref().len(),
131                bytes.len()
132            )));
133        }
134
135        if C::POINT_BIG_ENDIAN {
136            encoding.as_mut().copy_from_slice(bytes);
137        } else {
138            encoding.as_mut().copy_from_slice(bytes);
139            encoding.as_mut().reverse();
140        }
141
142        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
143            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
144        })?;
145        Ok(Point(point))
146    }
147
148    /// Deserialize a point from a byte buffer
149    /// TODO: Check this is constant-time
150    pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
151        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
152        if bytes.len() != encoding.as_ref().len() {
153            return Err(PrimitiveError::DeserializationFailed(format!(
154                "Invalid point encoding length: expected {}, got {}",
155                encoding.as_ref().len(),
156                bytes.len()
157            )));
158        }
159
160        if C::POINT_BIG_ENDIAN {
161            encoding.as_mut().copy_from_slice(bytes);
162            encoding.as_mut().reverse();
163        } else {
164            encoding.as_mut().copy_from_slice(bytes);
165        }
166
167        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
168            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
169        })?;
170        Ok(Point(point))
171    }
172
173    /// Serialize the point to a byte buffer
174    pub fn to_bytes(&self) -> Arc<[u8]> {
175        self.0.to_bytes().as_ref().into()
176    }
177
178    pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
179        C::Point::from_extended_edwards(coordinates).map(Point)
180    }
181
182    pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
183        self.0.to_extended_edwards()
184    }
185}
186
187impl<C: Curve> Random for Point<C> {
188    #[inline]
189    fn random(rng: impl CryptoRngCore) -> Self {
190        Point(C::Point::random(rng))
191    }
192}
193
194impl<C: Curve> Distribution<Point<C>> for Standard {
195    #[inline]
196    fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
197        Point(C::Point::random(rng))
198    }
199}
200
201// ------------------------------------
202// | Curve Arithmetic Implementations |
203// ------------------------------------
204
205// === Addition === //
206
207#[macros::op_variants(owned, borrowed, flipped_commutative)]
208impl<C: Curve> Add<&Point<C>> for Point<C> {
209    type Output = Point<C>;
210
211    #[inline]
212    fn add(mut self, rhs: &Point<C>) -> Self::Output {
213        self.0 += rhs.0;
214        self
215    }
216}
217
218#[macros::op_variants(owned)]
219impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
220    #[inline]
221    fn add_assign(&mut self, rhs: &Point<C>) {
222        self.0 += rhs.0;
223    }
224}
225
226// === Subtraction === //
227
228#[macros::op_variants(owned, borrowed, flipped)]
229impl<C: Curve> Sub<&Point<C>> for Point<C> {
230    type Output = Point<C>;
231
232    #[inline]
233    fn sub(mut self, rhs: &Point<C>) -> Self::Output {
234        self.0 -= rhs.0;
235        self
236    }
237}
238
239#[macros::op_variants(owned)]
240impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
241    #[inline]
242    fn sub_assign(&mut self, rhs: &Point<C>) {
243        self.0 -= rhs.0;
244    }
245}
246
247// === Negation === //
248
249#[macros::op_variants(borrowed)]
250impl<C: Curve> Neg for Point<C> {
251    type Output = Point<C>;
252
253    #[inline]
254    fn neg(self) -> Self::Output {
255        Point(-self.0)
256    }
257}
258
259// === Scalar Multiplication === //
260
261#[macros::op_variants(owned, borrowed, flipped)]
262impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
263    type Output = Point<C>;
264
265    #[inline]
266    fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
267        self.0 *= rhs.0;
268        self
269    }
270}
271
272#[macros::op_variants(owned, borrowed, flipped_commutative)]
273impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
274    type Output = Point<C>;
275
276    #[inline]
277    fn mul(self, rhs: &Point<C>) -> Self::Output {
278        Point(rhs.0 * self.0)
279    }
280}
281
282#[macros::op_variants(owned, borrowed, flipped)]
283impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
284    type Output = Point<C>;
285
286    #[inline]
287    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
288        Point(self.0 * rhs.0)
289    }
290}
291
292#[macros::op_variants(owned, borrowed, flipped_commutative)]
293impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
294    type Output = Point<C>;
295
296    #[inline]
297    fn mul(self, rhs: &Point<C>) -> Self::Output {
298        Point(rhs.0 * self.0)
299    }
300}
301
302// === MulAssign === //
303
304#[macros::op_variants(owned)]
305impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
306    #[inline]
307    fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
308        self.0 *= rhs.0;
309    }
310}
311
312#[macros::op_variants(owned)]
313impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
314    #[inline]
315    fn mul_assign(&mut self, rhs: &Scalar<C>) {
316        self.0 *= rhs.0;
317    }
318}
319
320// === Equality === //
321
322impl<C: Curve> ConstantTimeEq for Point<C> {
323    #[inline]
324    fn ct_eq(&self, other: &Self) -> Choice {
325        self.0.ct_eq(&other.0)
326    }
327}
328
329impl<C: Curve> ConditionallySelectable for Point<C> {
330    #[inline]
331    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
332        let selected = C::Point::conditional_select(&a.0, &b.0, choice);
333        Point(selected)
334    }
335}
336
337// === Other === //
338
339impl<C: Curve> AdditiveShares for Point<C> {}
340
341// === Iterator traits === //
342
343impl<C: Curve> Sum for Point<C> {
344    #[inline]
345    fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
346        iter.fold(Point::identity(), |acc, x| acc + x)
347    }
348}
349
350impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
351    #[inline]
352    fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
353        iter.fold(Point::identity(), |acc, x| acc + x)
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::algebra::elliptic_curve::Curve25519Ristretto;
361
362    #[test]
363    fn test_point_serialization() {
364        let point = Point::<Curve25519Ristretto>::generator();
365        let bytes = point.to_bytes();
366        let deserialized_point = Point::<Curve25519Ristretto>::from_le_bytes(&bytes).unwrap();
367        assert_eq!(point, deserialized_point);
368
369        let bytes = bytes.as_ref()[1..].to_vec(); // Invalid length
370        let result = Point::<Curve25519Ristretto>::from_le_bytes(&bytes);
371        assert!(result.is_err());
372    }
373
374    /// Wincode should reject bytes that don't encode a valid curve
375    /// point, but currently `SchemaRead` is derived on the newtype
376    /// wrapper and blindly reads the inner `C::Point` without
377    /// validation.
378    #[test]
379    fn test_wincode_rejects_invalid_point() {
380        let valid = Point::<Curve25519Ristretto>::generator();
381        let mut buf = wincode::serialize(&valid).unwrap();
382
383        // Corrupt the serialized point bytes to produce an invalid
384        // Ristretto encoding (all 0xFF bytes is not on the curve).
385        let len = buf.len();
386        buf[len - 32..].fill(0xFF);
387
388        let result = wincode::deserialize::<Point<Curve25519Ristretto>>(&buf);
389        assert!(
390            result.is_err(),
391            "wincode deserialized an invalid curve point \
392             without returning an error"
393        );
394    }
395}