use std::{
hash::Hash,
iter::Sum,
mem::MaybeUninit,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
sync::Arc,
};
use elliptic_curve::group::{Group, GroupEncoding};
use rand::{
distributions::{Distribution, Standard},
RngCore,
};
use serde::{Deserialize, Serialize};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use wincode::{ReadResult, WriteResult};
use crate::{
algebra::elliptic_curve::{
curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
BaseFieldElement,
Curve,
Scalar,
ScalarAsExtension,
},
errors::PrimitiveError,
random::{CryptoRngCore, Random},
sharing::unauthenticated::AdditiveShares,
};
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Point<C: Curve>(pub(crate) C::Point);
impl<C: Curve> wincode::SchemaWrite for Point<C> {
type Src = Self;
fn size_of(_src: &Self::Src) -> WriteResult<usize> {
let repr = <C::Point as GroupEncoding>::Repr::default();
Ok(repr.as_ref().len())
}
fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> WriteResult<()> {
let bytes = src.0.to_bytes();
Ok(writer.write(bytes.as_ref())?)
}
}
impl<'de, C: Curve> wincode::SchemaRead<'de> for Point<C> {
type Dst = Self;
fn read(
reader: &mut impl wincode::io::Reader<'de>,
dst: &mut MaybeUninit<Self::Dst>,
) -> ReadResult<()> {
let mut repr = <C::Point as GroupEncoding>::Repr::default();
let len = repr.as_ref().len();
let bytes = reader.fill_exact(len)?;
repr.as_mut().copy_from_slice(bytes);
reader.consume(len)?;
let point = Option::from(C::Point::from_bytes(&repr))
.ok_or(wincode::ReadError::Custom("invalid curve point encoding"))?;
dst.write(Point(point));
Ok(())
}
}
impl<C: Curve> Unpin for Point<C> {}
impl<C: Curve> Serialize for Point<C> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.0.to_bytes();
serde_bytes::serialize(bytes.as_ref(), serializer)
}
}
impl<'de, C: Curve> Deserialize<'de> for Point<C> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
let endian_bytes = if C::POINT_BIG_ENDIAN {
Point::from_be_bytes(bytes)
} else {
Point::from_le_bytes(bytes)
};
let point = endian_bytes.map_err(|err| {
serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
})?;
Ok(point)
}
}
impl<C: Curve> Point<C> {
pub fn identity() -> Point<C> {
Point(C::Point::identity())
}
pub fn new(point: C::Point) -> Point<C> {
Point(point)
}
pub fn is_identity(&self) -> Choice {
self.ct_eq(&Point::identity())
}
pub fn inner(&self) -> C::Point {
self.0
}
pub fn generator() -> Point<C> {
Point(<C::Point as Group>::generator())
}
pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
let mut encoding = <C::Point as GroupEncoding>::Repr::default();
if bytes.len() != encoding.as_ref().len() {
return Err(PrimitiveError::DeserializationFailed(format!(
"Invalid point encoding length: expected {}, got {}",
encoding.as_ref().len(),
bytes.len()
)));
}
if C::POINT_BIG_ENDIAN {
encoding.as_mut().copy_from_slice(bytes);
} else {
encoding.as_mut().copy_from_slice(bytes);
encoding.as_mut().reverse();
}
let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
})?;
Ok(Point(point))
}
pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
let mut encoding = <C::Point as GroupEncoding>::Repr::default();
if bytes.len() != encoding.as_ref().len() {
return Err(PrimitiveError::DeserializationFailed(format!(
"Invalid point encoding length: expected {}, got {}",
encoding.as_ref().len(),
bytes.len()
)));
}
if C::POINT_BIG_ENDIAN {
encoding.as_mut().copy_from_slice(bytes);
encoding.as_mut().reverse();
} else {
encoding.as_mut().copy_from_slice(bytes);
}
let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
})?;
Ok(Point(point))
}
pub fn to_bytes(&self) -> Arc<[u8]> {
self.0.to_bytes().as_ref().into()
}
pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
C::Point::from_extended_edwards(coordinates).map(Point)
}
pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
self.0.to_extended_edwards()
}
}
impl<C: Curve> Random for Point<C> {
#[inline]
fn random(rng: impl CryptoRngCore) -> Self {
Point(C::Point::random(rng))
}
}
impl<C: Curve> Distribution<Point<C>> for Standard {
#[inline]
fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
Point(C::Point::random(rng))
}
}
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<C: Curve> Add<&Point<C>> for Point<C> {
type Output = Point<C>;
#[inline]
fn add(mut self, rhs: &Point<C>) -> Self::Output {
self.0 += rhs.0;
self
}
}
#[macros::op_variants(owned)]
impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
#[inline]
fn add_assign(&mut self, rhs: &Point<C>) {
self.0 += rhs.0;
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<C: Curve> Sub<&Point<C>> for Point<C> {
type Output = Point<C>;
#[inline]
fn sub(mut self, rhs: &Point<C>) -> Self::Output {
self.0 -= rhs.0;
self
}
}
#[macros::op_variants(owned)]
impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
#[inline]
fn sub_assign(&mut self, rhs: &Point<C>) {
self.0 -= rhs.0;
}
}
#[macros::op_variants(borrowed)]
impl<C: Curve> Neg for Point<C> {
type Output = Point<C>;
#[inline]
fn neg(self) -> Self::Output {
Point(-self.0)
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
type Output = Point<C>;
#[inline]
fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
self.0 *= rhs.0;
self
}
}
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
type Output = Point<C>;
#[inline]
fn mul(self, rhs: &Point<C>) -> Self::Output {
Point(rhs.0 * self.0)
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
type Output = Point<C>;
#[inline]
fn mul(self, rhs: &Scalar<C>) -> Self::Output {
Point(self.0 * rhs.0)
}
}
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
type Output = Point<C>;
#[inline]
fn mul(self, rhs: &Point<C>) -> Self::Output {
Point(rhs.0 * self.0)
}
}
#[macros::op_variants(owned)]
impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
#[inline]
fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
self.0 *= rhs.0;
}
}
#[macros::op_variants(owned)]
impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
#[inline]
fn mul_assign(&mut self, rhs: &Scalar<C>) {
self.0 *= rhs.0;
}
}
impl<C: Curve> ConstantTimeEq for Point<C> {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
impl<C: Curve> ConditionallySelectable for Point<C> {
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let selected = C::Point::conditional_select(&a.0, &b.0, choice);
Point(selected)
}
}
impl<C: Curve> AdditiveShares for Point<C> {}
impl<C: Curve> Sum for Point<C> {
#[inline]
fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
iter.fold(Point::identity(), |acc, x| acc + x)
}
}
impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
#[inline]
fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
iter.fold(Point::identity(), |acc, x| acc + x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::elliptic_curve::Curve25519Ristretto;
#[test]
fn test_point_serialization() {
let point = Point::<Curve25519Ristretto>::generator();
let bytes = point.to_bytes();
let deserialized_point = Point::<Curve25519Ristretto>::from_le_bytes(&bytes).unwrap();
assert_eq!(point, deserialized_point);
let bytes = bytes.as_ref()[1..].to_vec(); let result = Point::<Curve25519Ristretto>::from_le_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_wincode_rejects_invalid_point() {
let valid = Point::<Curve25519Ristretto>::generator();
let mut buf = wincode::serialize(&valid).unwrap();
let len = buf.len();
buf[len - 32..].fill(0xFF);
let result = wincode::deserialize::<Point<Curve25519Ristretto>>(&buf);
assert!(
result.is_err(),
"wincode deserialized an invalid curve point \
without returning an error"
);
}
}