arcium-primitives 0.4.2

Arcium primitives
Documentation
use std::{
    hash::Hash,
    iter::Sum,
    mem::MaybeUninit,
    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
    sync::Arc,
};

use elliptic_curve::group::{Group, GroupEncoding};
use rand::{
    distributions::{Distribution, Standard},
    RngCore,
};
use serde::{Deserialize, Serialize};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use wincode::{ReadResult, WriteResult};

use crate::{
    algebra::elliptic_curve::{
        curve::{FromExtendedEdwards, PointAtInfinityError, ToExtendedEdwards},
        BaseFieldElement,
        Curve,
        Scalar,
        ScalarAsExtension,
    },
    errors::PrimitiveError,
    random::{CryptoRngCore, Random},
    sharing::unauthenticated::AdditiveShares,
};

/// A point on a given curve.
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Point<C: Curve>(pub(crate) C::Point);

impl<C: Curve> wincode::SchemaWrite for Point<C> {
    type Src = Self;

    fn size_of(_src: &Self::Src) -> WriteResult<usize> {
        let repr = <C::Point as GroupEncoding>::Repr::default();
        Ok(repr.as_ref().len())
    }

    fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> WriteResult<()> {
        let bytes = src.0.to_bytes();
        Ok(writer.write(bytes.as_ref())?)
    }
}

impl<'de, C: Curve> wincode::SchemaRead<'de> for Point<C> {
    type Dst = Self;

    fn read(
        reader: &mut impl wincode::io::Reader<'de>,
        dst: &mut MaybeUninit<Self::Dst>,
    ) -> ReadResult<()> {
        let mut repr = <C::Point as GroupEncoding>::Repr::default();
        let len = repr.as_ref().len();
        let bytes = reader.fill_exact(len)?;
        repr.as_mut().copy_from_slice(bytes);
        reader.consume(len)?;

        let point = Option::from(C::Point::from_bytes(&repr))
            .ok_or(wincode::ReadError::Custom("invalid curve point encoding"))?;

        dst.write(Point(point));
        Ok(())
    }
}

impl<C: Curve> Unpin for Point<C> {}

impl<C: Curve> Serialize for Point<C> {
    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
        let bytes = self.0.to_bytes();
        serde_bytes::serialize(bytes.as_ref(), serializer)
    }
}

impl<'de, C: Curve> Deserialize<'de> for Point<C> {
    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
        let endian_bytes = if C::POINT_BIG_ENDIAN {
            Point::from_be_bytes(bytes)
        } else {
            Point::from_le_bytes(bytes)
        };
        let point = endian_bytes.map_err(|err| {
            serde::de::Error::custom(format!("Failed to deserialize curve point: {err:?}"))
        })?;
        Ok(point)
    }
}

// ------------------------
// | Misc Implementations |
// ------------------------

impl<C: Curve> Point<C> {
    /// The additive identity in the curve group
    pub fn identity() -> Point<C> {
        Point(C::Point::identity())
    }

    pub fn new(point: C::Point) -> Point<C> {
        Point(point)
    }

    /// Check whether the given point is the identity point in the group
    pub fn is_identity(&self) -> Choice {
        self.ct_eq(&Point::identity())
    }

    /// Return the wrapped type
    pub fn inner(&self) -> C::Point {
        self.0
    }

    /// The group generator
    pub fn generator() -> Point<C> {
        Point(<C::Point as Group>::generator())
    }

    /// Deserialize a point from a byte buffer
    pub fn from_be_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
        if bytes.len() != encoding.as_ref().len() {
            return Err(PrimitiveError::DeserializationFailed(format!(
                "Invalid point encoding length: expected {}, got {}",
                encoding.as_ref().len(),
                bytes.len()
            )));
        }

        if C::POINT_BIG_ENDIAN {
            encoding.as_mut().copy_from_slice(bytes);
        } else {
            encoding.as_mut().copy_from_slice(bytes);
            encoding.as_mut().reverse();
        }

        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
        })?;
        Ok(Point(point))
    }

    /// Deserialize a point from a byte buffer
    /// TODO: Check this is constant-time
    pub fn from_le_bytes(bytes: &[u8]) -> Result<Point<C>, PrimitiveError> {
        let mut encoding = <C::Point as GroupEncoding>::Repr::default();
        if bytes.len() != encoding.as_ref().len() {
            return Err(PrimitiveError::DeserializationFailed(format!(
                "Invalid point encoding length: expected {}, got {}",
                encoding.as_ref().len(),
                bytes.len()
            )));
        }

        if C::POINT_BIG_ENDIAN {
            encoding.as_mut().copy_from_slice(bytes);
            encoding.as_mut().reverse();
        } else {
            encoding.as_mut().copy_from_slice(bytes);
        }

        let point = Option::from(C::Point::from_bytes(&encoding)).ok_or_else(|| {
            PrimitiveError::DeserializationFailed("Invalid point encoding".to_string())
        })?;
        Ok(Point(point))
    }

    /// Serialize the point to a byte buffer
    pub fn to_bytes(&self) -> Arc<[u8]> {
        self.0.to_bytes().as_ref().into()
    }

    pub fn from_extended_edwards(coordinates: [BaseFieldElement<C>; 4]) -> Option<Point<C>> {
        C::Point::from_extended_edwards(coordinates).map(Point)
    }

    pub fn to_extended_edwards(self) -> Result<[BaseFieldElement<C>; 4], PointAtInfinityError> {
        self.0.to_extended_edwards()
    }
}

impl<C: Curve> Random for Point<C> {
    #[inline]
    fn random(rng: impl CryptoRngCore) -> Self {
        Point(C::Point::random(rng))
    }
}

impl<C: Curve> Distribution<Point<C>> for Standard {
    #[inline]
    fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> Point<C> {
        Point(C::Point::random(rng))
    }
}

// ------------------------------------
// | Curve Arithmetic Implementations |
// ------------------------------------

// === Addition === //

#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<C: Curve> Add<&Point<C>> for Point<C> {
    type Output = Point<C>;

    #[inline]
    fn add(mut self, rhs: &Point<C>) -> Self::Output {
        self.0 += rhs.0;
        self
    }
}

#[macros::op_variants(owned)]
impl<C: Curve> AddAssign<&Point<C>> for Point<C> {
    #[inline]
    fn add_assign(&mut self, rhs: &Point<C>) {
        self.0 += rhs.0;
    }
}

// === Subtraction === //

#[macros::op_variants(owned, borrowed, flipped)]
impl<C: Curve> Sub<&Point<C>> for Point<C> {
    type Output = Point<C>;

    #[inline]
    fn sub(mut self, rhs: &Point<C>) -> Self::Output {
        self.0 -= rhs.0;
        self
    }
}

#[macros::op_variants(owned)]
impl<C: Curve> SubAssign<&Point<C>> for Point<C> {
    #[inline]
    fn sub_assign(&mut self, rhs: &Point<C>) {
        self.0 -= rhs.0;
    }
}

// === Negation === //

#[macros::op_variants(borrowed)]
impl<C: Curve> Neg for Point<C> {
    type Output = Point<C>;

    #[inline]
    fn neg(self) -> Self::Output {
        Point(-self.0)
    }
}

// === Scalar Multiplication === //

#[macros::op_variants(owned, borrowed, flipped)]
impl<C: Curve> Mul<&ScalarAsExtension<C>> for Point<C> {
    type Output = Point<C>;

    #[inline]
    fn mul(mut self, rhs: &ScalarAsExtension<C>) -> Self::Output {
        self.0 *= rhs.0;
        self
    }
}

#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<C: Curve> Mul<&Point<C>> for ScalarAsExtension<C> {
    type Output = Point<C>;

    #[inline]
    fn mul(self, rhs: &Point<C>) -> Self::Output {
        Point(rhs.0 * self.0)
    }
}

#[macros::op_variants(owned, borrowed, flipped)]
impl<C: Curve> Mul<&Scalar<C>> for Point<C> {
    type Output = Point<C>;

    #[inline]
    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
        Point(self.0 * rhs.0)
    }
}

#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<C: Curve> Mul<&Point<C>> for Scalar<C> {
    type Output = Point<C>;

    #[inline]
    fn mul(self, rhs: &Point<C>) -> Self::Output {
        Point(rhs.0 * self.0)
    }
}

// === MulAssign === //

#[macros::op_variants(owned)]
impl<C: Curve> MulAssign<&ScalarAsExtension<C>> for Point<C> {
    #[inline]
    fn mul_assign(&mut self, rhs: &ScalarAsExtension<C>) {
        self.0 *= rhs.0;
    }
}

#[macros::op_variants(owned)]
impl<C: Curve> MulAssign<&Scalar<C>> for Point<C> {
    #[inline]
    fn mul_assign(&mut self, rhs: &Scalar<C>) {
        self.0 *= rhs.0;
    }
}

// === Equality === //

impl<C: Curve> ConstantTimeEq for Point<C> {
    #[inline]
    fn ct_eq(&self, other: &Self) -> Choice {
        self.0.ct_eq(&other.0)
    }
}

impl<C: Curve> ConditionallySelectable for Point<C> {
    #[inline]
    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
        let selected = C::Point::conditional_select(&a.0, &b.0, choice);
        Point(selected)
    }
}

// === Other === //

impl<C: Curve> AdditiveShares for Point<C> {}

// === Iterator traits === //

impl<C: Curve> Sum for Point<C> {
    #[inline]
    fn sum<I: Iterator<Item = Point<C>>>(iter: I) -> Self {
        iter.fold(Point::identity(), |acc, x| acc + x)
    }
}

impl<'a, C: Curve> Sum<&'a Point<C>> for Point<C> {
    #[inline]
    fn sum<I: Iterator<Item = &'a Point<C>>>(iter: I) -> Self {
        iter.fold(Point::identity(), |acc, x| acc + x)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::algebra::elliptic_curve::Curve25519Ristretto;

    #[test]
    fn test_point_serialization() {
        let point = Point::<Curve25519Ristretto>::generator();
        let bytes = point.to_bytes();
        let deserialized_point = Point::<Curve25519Ristretto>::from_le_bytes(&bytes).unwrap();
        assert_eq!(point, deserialized_point);

        let bytes = bytes.as_ref()[1..].to_vec(); // Invalid length
        let result = Point::<Curve25519Ristretto>::from_le_bytes(&bytes);
        assert!(result.is_err());
    }

    /// Wincode should reject bytes that don't encode a valid curve
    /// point, but currently `SchemaRead` is derived on the newtype
    /// wrapper and blindly reads the inner `C::Point` without
    /// validation.
    #[test]
    fn test_wincode_rejects_invalid_point() {
        let valid = Point::<Curve25519Ristretto>::generator();
        let mut buf = wincode::serialize(&valid).unwrap();

        // Corrupt the serialized point bytes to produce an invalid
        // Ristretto encoding (all 0xFF bytes is not on the curve).
        let len = buf.len();
        buf[len - 32..].fill(0xFF);

        let result = wincode::deserialize::<Point<Curve25519Ristretto>>(&buf);
        assert!(
            result.is_err(),
            "wincode deserialized an invalid curve point \
             without returning an error"
        );
    }
}