primitives/algebra/elliptic_curve/
point.rs1use std::{
2 hash::Hash,
3 iter::Sum,
4 mem::MaybeUninit,
5 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6 sync::Arc,
7};
8
9use elliptic_curve::group::{Group, GroupEncoding};
10use rand::{
11 distributions::{Distribution, Standard},
12 RngCore,
13};
14use serde::{Deserialize, Serialize};
15use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
16use wincode::{ReadResult, WriteResult};
17
18use crate::{
19 algebra::elliptic_curve::{
20 curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
21 BaseFieldElement,
22 Curve,
23 Scalar,
24 ScalarAsExtension,
25 },
26 errors::PrimitiveError,
27 random::{CryptoRngCore, Random},
28 sharing::unauthenticated::AdditiveShares,
29};
30
31#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
33#[repr(transparent)]
34pub struct Point<C: Curve>(pub(crate) C::Point);
35
36unsafe impl<C: Curve> bytemuck::TransparentWrapper<C::Point> for Point<C> {}
38
39impl<C: Curve> wincode::SchemaWrite for Point<C> {
40 type Src = Self;
41
42 fn size_of(_src: &Self::Src) -> WriteResult<usize> {
43 let repr = <C::Point as GroupEncoding>::Repr::default();
44 Ok(repr.as_ref().len())
45 }
46
47 fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> WriteResult<()> {
48 let bytes = src.0.to_bytes();
49 Ok(writer.write(bytes.as_ref())?)
50 }
51}
52
53impl<'de, C: Curve> wincode::SchemaRead<'de> for Point<C> {
54 type Dst = Self;
55
56 fn read(
57 reader: &mut impl wincode::io::Reader<'de>,
58 dst: &mut MaybeUninit<Self::Dst>,
59 ) -> ReadResult<()> {
60 let mut repr = <C::Point as GroupEncoding>::Repr::default();
61 let len = repr.as_ref().len();
62 let bytes = reader.fill_exact(len)?;
63 repr.as_mut().copy_from_slice(bytes);
64 reader.consume(len)?;
65
66 let point = Option::from(C::Point::from_bytes(&repr))
67 .ok_or(wincode::ReadError::Custom("invalid curve point encoding"))?;
68
69 dst.write(Point(point));
70 Ok(())
71 }
72}
73
74impl<C: Curve> Unpin for Point<C> {}
75
76impl<C: Curve> Serialize for Point<C> {
77 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
78 let bytes = self.0.to_bytes();
79 serde_bytes::serialize(bytes.as_ref(), serializer)
80 }
81}
82
83impl<'de, C: Curve> Deserialize<'de> for Point<C> {
84 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
85 let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
86 let endian_bytes = if C::POINT_BIG_ENDIAN {
87 Point::from_be_bytes(bytes)
88 } else {
89 Point::from_le_bytes(bytes)
90 };
91 let point = endian_bytes.map_err(|err| {
92 serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
93 })?;
94 Ok(point)
95 }
96}
97
98impl<C: Curve> Point<C> {
103 pub fn identity() -> Point<C> {
105 Point(C::Point::identity())
106 }
107
108 pub fn new(point: C::Point) -> Point<C> {
109 Point(point)
110 }
111
112 pub fn is_identity(&self) -> Choice {
114 self.ct_eq(&Point::identity())
115 }
116
117 pub fn inner(&self) -> C::Point {
119 self.0
120 }
121
122 pub fn generator() -> Point<C> {
124 Point(<C::Point as Group>::generator())
125 }
126
127 pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
129 let mut encoding = <C::Point as GroupEncoding>::Repr::default();
130 if bytes.len() != encoding.as_ref().len() {
131 return Err(PrimitiveError::DeserializationFailed(format!(
132 "Invalid point encoding length: expected {}, got {}",
133 encoding.as_ref().len(),
134 bytes.len()
135 )));
136 }
137
138 if C::POINT_BIG_ENDIAN {
139 encoding.as_mut().copy_from_slice(bytes);
140 } else {
141 encoding.as_mut().copy_from_slice(bytes);
142 encoding.as_mut().reverse();
143 }
144
145 let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
146 PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
147 })?;
148 Ok(Point(point))
149 }
150
151 pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
154 let mut encoding = <C::Point as GroupEncoding>::Repr::default();
155 if bytes.len() != encoding.as_ref().len() {
156 return Err(PrimitiveError::DeserializationFailed(format!(
157 "Invalid point encoding length: expected {}, got {}",
158 encoding.as_ref().len(),
159 bytes.len()
160 )));
161 }
162
163 if C::POINT_BIG_ENDIAN {
164 encoding.as_mut().copy_from_slice(bytes);
165 encoding.as_mut().reverse();
166 } else {
167 encoding.as_mut().copy_from_slice(bytes);
168 }
169
170 let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
171 PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
172 })?;
173 Ok(Point(point))
174 }
175
176 pub fn to_bytes(&self) -> Arc<[u8]> {
178 self.0.to_bytes().as_ref().into()
179 }
180
181 pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
182 C::Point::from_extended_edwards(coordinates).map(Point)
183 }
184
185 pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
186 self.0.to_extended_edwards()
187 }
188}
189
190impl<C: Curve> Random for Point<C> {
191 #[inline]
192 fn random(rng: impl CryptoRngCore) -> Self {
193 Point(C::Point::random(rng))
194 }
195}
196
197impl<C: Curve> Distribution<Point<C>> for Standard {
198 #[inline]
199 fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
200 Point(C::Point::random(rng))
201 }
202}
203
204#[macros::op_variants(owned, borrowed, flipped_commutative)]
211impl<C: Curve> Add<&Point<C>> for Point<C> {
212 type Output = Point<C>;
213
214 #[inline]
215 fn add(mut self, rhs: &Point<C>) -> Self::Output {
216 self.0 += rhs.0;
217 self
218 }
219}
220
221#[macros::op_variants(owned)]
222impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
223 #[inline]
224 fn add_assign(&mut self, rhs: &Point<C>) {
225 self.0 += rhs.0;
226 }
227}
228
229#[macros::op_variants(owned, borrowed, flipped)]
232impl<C: Curve> Sub<&Point<C>> for Point<C> {
233 type Output = Point<C>;
234
235 #[inline]
236 fn sub(mut self, rhs: &Point<C>) -> Self::Output {
237 self.0 -= rhs.0;
238 self
239 }
240}
241
242#[macros::op_variants(owned)]
243impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
244 #[inline]
245 fn sub_assign(&mut self, rhs: &Point<C>) {
246 self.0 -= rhs.0;
247 }
248}
249
250#[macros::op_variants(borrowed)]
253impl<C: Curve> Neg for Point<C> {
254 type Output = Point<C>;
255
256 #[inline]
257 fn neg(self) -> Self::Output {
258 Point(-self.0)
259 }
260}
261
262#[macros::op_variants(owned, borrowed, flipped)]
265impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
266 type Output = Point<C>;
267
268 #[inline]
269 fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
270 self.0 *= rhs.0;
271 self
272 }
273}
274
275#[macros::op_variants(owned, borrowed, flipped_commutative)]
276impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
277 type Output = Point<C>;
278
279 #[inline]
280 fn mul(self, rhs: &Point<C>) -> Self::Output {
281 Point(rhs.0 * self.0)
282 }
283}
284
285#[macros::op_variants(owned, borrowed, flipped)]
286impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
287 type Output = Point<C>;
288
289 #[inline]
290 fn mul(self, rhs: &Scalar<C>) -> Self::Output {
291 Point(self.0 * rhs.0)
292 }
293}
294
295#[macros::op_variants(owned, borrowed, flipped_commutative)]
296impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
297 type Output = Point<C>;
298
299 #[inline]
300 fn mul(self, rhs: &Point<C>) -> Self::Output {
301 Point(rhs.0 * self.0)
302 }
303}
304
305#[macros::op_variants(owned)]
308impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
309 #[inline]
310 fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
311 self.0 *= rhs.0;
312 }
313}
314
315#[macros::op_variants(owned)]
316impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
317 #[inline]
318 fn mul_assign(&mut self, rhs: &Scalar<C>) {
319 self.0 *= rhs.0;
320 }
321}
322
323impl<C: Curve> ConstantTimeEq for Point<C> {
326 #[inline]
327 fn ct_eq(&self, other: &Self) -> Choice {
328 self.0.ct_eq(&other.0)
329 }
330}
331
332impl<C: Curve> ConditionallySelectable for Point<C> {
333 #[inline]
334 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
335 let selected = C::Point::conditional_select(&a.0, &b.0, choice);
336 Point(selected)
337 }
338}
339
340impl<C: Curve> AdditiveShares for Point<C> {}
343
344impl<C: Curve> Sum for Point<C> {
347 #[inline]
348 fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
349 iter.fold(Point::identity(), |acc, x| acc + x)
350 }
351}
352
353impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
354 #[inline]
355 fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
356 iter.fold(Point::identity(), |acc, x| acc + x)
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::algebra::elliptic_curve::Curve25519Ristretto;
364
365 #[test]
366 fn test_point_serialization() {
367 let point = Point::<Curve25519Ristretto>::generator();
368 let bytes = point.to_bytes();
369 let deserialized_point = Point::<Curve25519Ristretto>::from_le_bytes(&bytes).unwrap();
370 assert_eq!(point, deserialized_point);
371
372 let bytes = bytes.as_ref()[1..].to_vec(); let result = Point::<Curve25519Ristretto>::from_le_bytes(&bytes);
374 assert!(result.is_err());
375 }
376
377 #[test]
382 fn test_wincode_rejects_invalid_point() {
383 let valid = Point::<Curve25519Ristretto>::generator();
384 let mut buf = wincode::serialize(&valid).unwrap();
385
386 let len = buf.len();
389 buf[len - 32..].fill(0xFF);
390
391 let result = wincode::deserialize::<Point<Curve25519Ristretto>>(&buf);
392 assert!(
393 result.is_err(),
394 "wincode deserialized an invalid curve point \
395 without returning an error"
396 );
397 }
398}