primitives/algebra/elliptic_curve/
point.rs1use std::{
2 hash::Hash,
3 iter::Sum,
4 mem::MaybeUninit,
5 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6 sync::Arc,
7};
8
9use elliptic_curve::group::{Group, GroupEncoding};
10use rand::{
11 distributions::{Distribution, Standard},
12 RngCore,
13};
14use serde::{Deserialize, Serialize};
15use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
16use wincode::{ReadResult, WriteResult};
17
18use crate::{
19 algebra::elliptic_curve::{
20 curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
21 BaseFieldElement,
22 Curve,
23 Scalar,
24 ScalarAsExtension,
25 },
26 errors::PrimitiveError,
27 random::{CryptoRngCore, Random},
28 sharing::unauthenticated::AdditiveShares,
29};
30
31#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
33#[repr(transparent)]
34pub struct Point<C: Curve>(pub(crate) C::Point);
35
36impl<C: Curve> wincode::SchemaWrite for Point<C> {
37 type Src = Self;
38
39 fn size_of(_src: &Self::Src) -> WriteResult<usize> {
40 let repr = <C::Point as GroupEncoding>::Repr::default();
41 Ok(repr.as_ref().len())
42 }
43
44 fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> WriteResult<()> {
45 let bytes = src.0.to_bytes();
46 Ok(writer.write(bytes.as_ref())?)
47 }
48}
49
50impl<'de, C: Curve> wincode::SchemaRead<'de> for Point<C> {
51 type Dst = Self;
52
53 fn read(
54 reader: &mut impl wincode::io::Reader<'de>,
55 dst: &mut MaybeUninit<Self::Dst>,
56 ) -> ReadResult<()> {
57 let mut repr = <C::Point as GroupEncoding>::Repr::default();
58 let len = repr.as_ref().len();
59 let bytes = reader.fill_exact(len)?;
60 repr.as_mut().copy_from_slice(bytes);
61 reader.consume(len)?;
62
63 let point = Option::from(C::Point::from_bytes(&repr))
64 .ok_or(wincode::ReadError::Custom("invalid curve point encoding"))?;
65
66 dst.write(Point(point));
67 Ok(())
68 }
69}
70
71impl<C: Curve> Unpin for Point<C> {}
72
73impl<C: Curve> Serialize for Point<C> {
74 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
75 let bytes = self.0.to_bytes();
76 serde_bytes::serialize(bytes.as_ref(), serializer)
77 }
78}
79
80impl<'de, C: Curve> Deserialize<'de> for Point<C> {
81 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
82 let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
83 let endian_bytes = if C::POINT_BIG_ENDIAN {
84 Point::from_be_bytes(bytes)
85 } else {
86 Point::from_le_bytes(bytes)
87 };
88 let point = endian_bytes.map_err(|err| {
89 serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
90 })?;
91 Ok(point)
92 }
93}
94
95impl<C: Curve> Point<C> {
100 pub fn identity() -> Point<C> {
102 Point(C::Point::identity())
103 }
104
105 pub fn new(point: C::Point) -> Point<C> {
106 Point(point)
107 }
108
109 pub fn is_identity(&self) -> Choice {
111 self.ct_eq(&Point::identity())
112 }
113
114 pub fn inner(&self) -> C::Point {
116 self.0
117 }
118
119 pub fn generator() -> Point<C> {
121 Point(<C::Point as Group>::generator())
122 }
123
124 pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
126 let mut encoding = <C::Point as GroupEncoding>::Repr::default();
127 if bytes.len() != encoding.as_ref().len() {
128 return Err(PrimitiveError::DeserializationFailed(format!(
129 "Invalid point encoding length: expected {}, got {}",
130 encoding.as_ref().len(),
131 bytes.len()
132 )));
133 }
134
135 if C::POINT_BIG_ENDIAN {
136 encoding.as_mut().copy_from_slice(bytes);
137 } else {
138 encoding.as_mut().copy_from_slice(bytes);
139 encoding.as_mut().reverse();
140 }
141
142 let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
143 PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
144 })?;
145 Ok(Point(point))
146 }
147
148 pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
151 let mut encoding = <C::Point as GroupEncoding>::Repr::default();
152 if bytes.len() != encoding.as_ref().len() {
153 return Err(PrimitiveError::DeserializationFailed(format!(
154 "Invalid point encoding length: expected {}, got {}",
155 encoding.as_ref().len(),
156 bytes.len()
157 )));
158 }
159
160 if C::POINT_BIG_ENDIAN {
161 encoding.as_mut().copy_from_slice(bytes);
162 encoding.as_mut().reverse();
163 } else {
164 encoding.as_mut().copy_from_slice(bytes);
165 }
166
167 let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
168 PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
169 })?;
170 Ok(Point(point))
171 }
172
173 pub fn to_bytes(&self) -> Arc<[u8]> {
175 self.0.to_bytes().as_ref().into()
176 }
177
178 pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
179 C::Point::from_extended_edwards(coordinates).map(Point)
180 }
181
182 pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
183 self.0.to_extended_edwards()
184 }
185}
186
187impl<C: Curve> Random for Point<C> {
188 #[inline]
189 fn random(rng: impl CryptoRngCore) -> Self {
190 Point(C::Point::random(rng))
191 }
192}
193
194impl<C: Curve> Distribution<Point<C>> for Standard {
195 #[inline]
196 fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
197 Point(C::Point::random(rng))
198 }
199}
200
201#[macros::op_variants(owned, borrowed, flipped_commutative)]
208impl<C: Curve> Add<&Point<C>> for Point<C> {
209 type Output = Point<C>;
210
211 #[inline]
212 fn add(mut self, rhs: &Point<C>) -> Self::Output {
213 self.0 += rhs.0;
214 self
215 }
216}
217
218#[macros::op_variants(owned)]
219impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
220 #[inline]
221 fn add_assign(&mut self, rhs: &Point<C>) {
222 self.0 += rhs.0;
223 }
224}
225
226#[macros::op_variants(owned, borrowed, flipped)]
229impl<C: Curve> Sub<&Point<C>> for Point<C> {
230 type Output = Point<C>;
231
232 #[inline]
233 fn sub(mut self, rhs: &Point<C>) -> Self::Output {
234 self.0 -= rhs.0;
235 self
236 }
237}
238
239#[macros::op_variants(owned)]
240impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
241 #[inline]
242 fn sub_assign(&mut self, rhs: &Point<C>) {
243 self.0 -= rhs.0;
244 }
245}
246
247#[macros::op_variants(borrowed)]
250impl<C: Curve> Neg for Point<C> {
251 type Output = Point<C>;
252
253 #[inline]
254 fn neg(self) -> Self::Output {
255 Point(-self.0)
256 }
257}
258
259#[macros::op_variants(owned, borrowed, flipped)]
262impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
263 type Output = Point<C>;
264
265 #[inline]
266 fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
267 self.0 *= rhs.0;
268 self
269 }
270}
271
272#[macros::op_variants(owned, borrowed, flipped_commutative)]
273impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
274 type Output = Point<C>;
275
276 #[inline]
277 fn mul(self, rhs: &Point<C>) -> Self::Output {
278 Point(rhs.0 * self.0)
279 }
280}
281
282#[macros::op_variants(owned, borrowed, flipped)]
283impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
284 type Output = Point<C>;
285
286 #[inline]
287 fn mul(self, rhs: &Scalar<C>) -> Self::Output {
288 Point(self.0 * rhs.0)
289 }
290}
291
292#[macros::op_variants(owned, borrowed, flipped_commutative)]
293impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
294 type Output = Point<C>;
295
296 #[inline]
297 fn mul(self, rhs: &Point<C>) -> Self::Output {
298 Point(rhs.0 * self.0)
299 }
300}
301
302#[macros::op_variants(owned)]
305impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
306 #[inline]
307 fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
308 self.0 *= rhs.0;
309 }
310}
311
312#[macros::op_variants(owned)]
313impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
314 #[inline]
315 fn mul_assign(&mut self, rhs: &Scalar<C>) {
316 self.0 *= rhs.0;
317 }
318}
319
320impl<C: Curve> ConstantTimeEq for Point<C> {
323 #[inline]
324 fn ct_eq(&self, other: &Self) -> Choice {
325 self.0.ct_eq(&other.0)
326 }
327}
328
329impl<C: Curve> ConditionallySelectable for Point<C> {
330 #[inline]
331 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
332 let selected = C::Point::conditional_select(&a.0, &b.0, choice);
333 Point(selected)
334 }
335}
336
337impl<C: Curve> AdditiveShares for Point<C> {}
340
341impl<C: Curve> Sum for Point<C> {
344 #[inline]
345 fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
346 iter.fold(Point::identity(), |acc, x| acc + x)
347 }
348}
349
350impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
351 #[inline]
352 fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
353 iter.fold(Point::identity(), |acc, x| acc + x)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::algebra::elliptic_curve::Curve25519Ristretto;
361
362 #[test]
363 fn test_point_serialization() {
364 let point = Point::<Curve25519Ristretto>::generator();
365 let bytes = point.to_bytes();
366 let deserialized_point = Point::<Curve25519Ristretto>::from_le_bytes(&bytes).unwrap();
367 assert_eq!(point, deserialized_point);
368
369 let bytes = bytes.as_ref()[1..].to_vec(); let result = Point::<Curve25519Ristretto>::from_le_bytes(&bytes);
371 assert!(result.is_err());
372 }
373
374 #[test]
379 fn test_wincode_rejects_invalid_point() {
380 let valid = Point::<Curve25519Ristretto>::generator();
381 let mut buf = wincode::serialize(&valid).unwrap();
382
383 let len = buf.len();
386 buf[len - 32..].fill(0xFF);
387
388 let result = wincode::deserialize::<Point<Curve25519Ristretto>>(&buf);
389 assert!(
390 result.is_err(),
391 "wincode deserialized an invalid curve point \
392 without returning an error"
393 );
394 }
395}