Skip to main content

primitives/algebra/elliptic_curve/
point.rs

1use std::{
2    hash::Hash,
3    iter::Sum,
4    mem::MaybeUninit,
5    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6    sync::Arc,
7};
8
9use elliptic_curve::group::{Group, GroupEncoding};
10use rand::{
11    distributions::{Distribution, Standard},
12    RngCore,
13};
14use serde::{Deserialize, Serialize};
15use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
16use wincode::{ReadResult, WriteResult};
17
18use crate::{
19    algebra::elliptic_curve::{
20        curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
21        BaseFieldElement,
22        Curve,
23        Scalar,
24        ScalarAsExtension,
25    },
26    errors::PrimitiveError,
27    random::{CryptoRngCore, Random},
28    sharing::unauthenticated::AdditiveShares,
29};
30
31/// A point on a given curve.
32#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
33#[repr(transparent)]
34pub struct Point<C: Curve>(pub(crate) C::Point);
35
36// SAFETY: Point<C> is #[repr(transparent)] over C::Point.
37unsafe impl<C: Curve> bytemuck::TransparentWrapper<C::Point> for Point<C> {}
38
39impl<C: Curve> wincode::SchemaWrite for Point<C> {
40    type Src = Self;
41
42    fn size_of(_src: &Self::Src) -> WriteResult<usize> {
43        let repr = <C::Point as GroupEncoding>::Repr::default();
44        Ok(repr.as_ref().len())
45    }
46
47    fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> WriteResult<()> {
48        let bytes = src.0.to_bytes();
49        Ok(writer.write(bytes.as_ref())?)
50    }
51}
52
53impl<'de, C: Curve> wincode::SchemaRead<'de> for Point<C> {
54    type Dst = Self;
55
56    fn read(
57        reader: &mut impl wincode::io::Reader<'de>,
58        dst: &mut MaybeUninit<Self::Dst>,
59    ) -> ReadResult<()> {
60        let mut repr = <C::Point as GroupEncoding>::Repr::default();
61        let len = repr.as_ref().len();
62        let bytes = reader.fill_exact(len)?;
63        repr.as_mut().copy_from_slice(bytes);
64        reader.consume(len)?;
65
66        let point = Option::from(C::Point::from_bytes(&repr))
67            .ok_or(wincode::ReadError::Custom("invalid curve point encoding"))?;
68
69        dst.write(Point(point));
70        Ok(())
71    }
72}
73
74impl<C: Curve> Unpin for Point<C> {}
75
76impl<C: Curve> Serialize for Point<C> {
77    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
78        let bytes = self.0.to_bytes();
79        serde_bytes::serialize(bytes.as_ref(), serializer)
80    }
81}
82
83impl<'de, C: Curve> Deserialize<'de> for Point<C> {
84    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
85        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
86        let endian_bytes = if C::POINT_BIG_ENDIAN {
87            Point::from_be_bytes(bytes)
88        } else {
89            Point::from_le_bytes(bytes)
90        };
91        let point = endian_bytes.map_err(|err| {
92            serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
93        })?;
94        Ok(point)
95    }
96}
97
98// ------------------------
99// | Misc Implementations |
100// ------------------------
101
102impl<C: Curve> Point<C> {
103    /// The additive identity in the curve group
104    pub fn identity() -> Point<C> {
105        Point(C::Point::identity())
106    }
107
108    pub fn new(point: C::Point) -> Point<C> {
109        Point(point)
110    }
111
112    /// Check whether the given point is the identity point in the group
113    pub fn is_identity(&self) -> Choice {
114        self.ct_eq(&Point::identity())
115    }
116
117    /// Return the wrapped type
118    pub fn inner(&self) -> C::Point {
119        self.0
120    }
121
122    /// The group generator
123    pub fn generator() -> Point<C> {
124        Point(<C::Point as Group>::generator())
125    }
126
127    /// Deserialize a point from a byte buffer
128    pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
129        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
130        if bytes.len() != encoding.as_ref().len() {
131            return Err(PrimitiveError::DeserializationFailed(format!(
132                "Invalid point encoding length: expected {}, got {}",
133                encoding.as_ref().len(),
134                bytes.len()
135            )));
136        }
137
138        if C::POINT_BIG_ENDIAN {
139            encoding.as_mut().copy_from_slice(bytes);
140        } else {
141            encoding.as_mut().copy_from_slice(bytes);
142            encoding.as_mut().reverse();
143        }
144
145        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
146            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
147        })?;
148        Ok(Point(point))
149    }
150
151    /// Deserialize a point from a byte buffer
152    /// TODO: Check this is constant-time
153    pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
154        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
155        if bytes.len() != encoding.as_ref().len() {
156            return Err(PrimitiveError::DeserializationFailed(format!(
157                "Invalid point encoding length: expected {}, got {}",
158                encoding.as_ref().len(),
159                bytes.len()
160            )));
161        }
162
163        if C::POINT_BIG_ENDIAN {
164            encoding.as_mut().copy_from_slice(bytes);
165            encoding.as_mut().reverse();
166        } else {
167            encoding.as_mut().copy_from_slice(bytes);
168        }
169
170        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
171            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
172        })?;
173        Ok(Point(point))
174    }
175
176    /// Serialize the point to a byte buffer
177    pub fn to_bytes(&self) -> Arc<[u8]> {
178        self.0.to_bytes().as_ref().into()
179    }
180
181    pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
182        C::Point::from_extended_edwards(coordinates).map(Point)
183    }
184
185    pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
186        self.0.to_extended_edwards()
187    }
188}
189
190impl<C: Curve> Random for Point<C> {
191    #[inline]
192    fn random(rng: impl CryptoRngCore) -> Self {
193        Point(C::Point::random(rng))
194    }
195}
196
197impl<C: Curve> Distribution<Point<C>> for Standard {
198    #[inline]
199    fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
200        Point(C::Point::random(rng))
201    }
202}
203
204// ------------------------------------
205// | Curve Arithmetic Implementations |
206// ------------------------------------
207
208// === Addition === //
209
210#[macros::op_variants(owned, borrowed, flipped_commutative)]
211impl<C: Curve> Add<&Point<C>> for Point<C> {
212    type Output = Point<C>;
213
214    #[inline]
215    fn add(mut self, rhs: &Point<C>) -> Self::Output {
216        self.0 += rhs.0;
217        self
218    }
219}
220
221#[macros::op_variants(owned)]
222impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
223    #[inline]
224    fn add_assign(&mut self, rhs: &Point<C>) {
225        self.0 += rhs.0;
226    }
227}
228
229// === Subtraction === //
230
231#[macros::op_variants(owned, borrowed, flipped)]
232impl<C: Curve> Sub<&Point<C>> for Point<C> {
233    type Output = Point<C>;
234
235    #[inline]
236    fn sub(mut self, rhs: &Point<C>) -> Self::Output {
237        self.0 -= rhs.0;
238        self
239    }
240}
241
242#[macros::op_variants(owned)]
243impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
244    #[inline]
245    fn sub_assign(&mut self, rhs: &Point<C>) {
246        self.0 -= rhs.0;
247    }
248}
249
250// === Negation === //
251
252#[macros::op_variants(borrowed)]
253impl<C: Curve> Neg for Point<C> {
254    type Output = Point<C>;
255
256    #[inline]
257    fn neg(self) -> Self::Output {
258        Point(-self.0)
259    }
260}
261
262// === Scalar Multiplication === //
263
264#[macros::op_variants(owned, borrowed, flipped)]
265impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
266    type Output = Point<C>;
267
268    #[inline]
269    fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
270        self.0 *= rhs.0;
271        self
272    }
273}
274
275#[macros::op_variants(owned, borrowed, flipped_commutative)]
276impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
277    type Output = Point<C>;
278
279    #[inline]
280    fn mul(self, rhs: &Point<C>) -> Self::Output {
281        Point(rhs.0 * self.0)
282    }
283}
284
285#[macros::op_variants(owned, borrowed, flipped)]
286impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
287    type Output = Point<C>;
288
289    #[inline]
290    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
291        Point(self.0 * rhs.0)
292    }
293}
294
295#[macros::op_variants(owned, borrowed, flipped_commutative)]
296impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
297    type Output = Point<C>;
298
299    #[inline]
300    fn mul(self, rhs: &Point<C>) -> Self::Output {
301        Point(rhs.0 * self.0)
302    }
303}
304
305// === MulAssign === //
306
307#[macros::op_variants(owned)]
308impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
309    #[inline]
310    fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
311        self.0 *= rhs.0;
312    }
313}
314
315#[macros::op_variants(owned)]
316impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
317    #[inline]
318    fn mul_assign(&mut self, rhs: &Scalar<C>) {
319        self.0 *= rhs.0;
320    }
321}
322
323// === Equality === //
324
325impl<C: Curve> ConstantTimeEq for Point<C> {
326    #[inline]
327    fn ct_eq(&self, other: &Self) -> Choice {
328        self.0.ct_eq(&other.0)
329    }
330}
331
332impl<C: Curve> ConditionallySelectable for Point<C> {
333    #[inline]
334    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
335        let selected = C::Point::conditional_select(&a.0, &b.0, choice);
336        Point(selected)
337    }
338}
339
340// === Other === //
341
342impl<C: Curve> AdditiveShares for Point<C> {}
343
344// === Iterator traits === //
345
346impl<C: Curve> Sum for Point<C> {
347    #[inline]
348    fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
349        iter.fold(Point::identity(), |acc, x| acc + x)
350    }
351}
352
353impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
354    #[inline]
355    fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
356        iter.fold(Point::identity(), |acc, x| acc + x)
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use crate::algebra::elliptic_curve::Curve25519Ristretto;
364
365    #[test]
366    fn test_point_serialization() {
367        let point = Point::<Curve25519Ristretto>::generator();
368        let bytes = point.to_bytes();
369        let deserialized_point = Point::<Curve25519Ristretto>::from_le_bytes(&bytes).unwrap();
370        assert_eq!(point, deserialized_point);
371
372        let bytes = bytes.as_ref()[1..].to_vec(); // Invalid length
373        let result = Point::<Curve25519Ristretto>::from_le_bytes(&bytes);
374        assert!(result.is_err());
375    }
376
377    /// Wincode should reject bytes that don't encode a valid curve
378    /// point, but currently `SchemaRead` is derived on the newtype
379    /// wrapper and blindly reads the inner `C::Point` without
380    /// validation.
381    #[test]
382    fn test_wincode_rejects_invalid_point() {
383        let valid = Point::<Curve25519Ristretto>::generator();
384        let mut buf = wincode::serialize(&valid).unwrap();
385
386        // Corrupt the serialized point bytes to produce an invalid
387        // Ristretto encoding (all 0xFF bytes is not on the curve).
388        let len = buf.len();
389        buf[len - 32..].fill(0xFF);
390
391        let result = wincode::deserialize::<Point<Curve25519Ristretto>>(&buf);
392        assert!(
393            result.is_err(),
394            "wincode deserialized an invalid curve point \
395             without returning an error"
396        );
397    }
398}