use std::fmt;
use generic_array::{ArrayLength, GenericArray};
use serde::{Deserialize, Serialize};
use typenum::Unsigned;
use zeroize::Zeroize;
use crate::BigInt;
pub trait Curve: PartialEq + Clone + fmt::Debug + Sync + Send + 'static {
type Point: ECPoint<Scalar = Self::Scalar>;
type Scalar: ECScalar;
const CURVE_NAME: &'static str;
}
pub trait ECScalar: Clone + PartialEq + fmt::Debug + Send + Sync + 'static {
type Underlying;
type ScalarLength: ArrayLength<u8> + Unsigned;
fn random() -> Self;
fn zero() -> Self;
fn is_zero(&self) -> bool {
self == &Self::zero()
}
fn from_bigint(n: &BigInt) -> Self;
fn to_bigint(&self) -> BigInt;
fn serialize(&self) -> GenericArray<u8, Self::ScalarLength>;
fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError>;
fn add(&self, other: &Self) -> Self;
fn mul(&self, other: &Self) -> Self;
fn sub(&self, other: &Self) -> Self;
fn neg(&self) -> Self;
fn invert(&self) -> Option<Self>;
fn add_assign(&mut self, other: &Self) {
*self = self.add(other)
}
fn mul_assign(&mut self, other: &Self) {
*self = self.mul(other)
}
fn sub_assign(&mut self, other: &Self) {
*self = self.sub(other)
}
fn neg_assign(&mut self) {
*self = self.neg()
}
fn group_order() -> &'static BigInt;
fn underlying_ref(&self) -> &Self::Underlying;
fn underlying_mut(&mut self) -> &mut Self::Underlying;
fn from_underlying(u: Self::Underlying) -> Self;
}
pub trait ECPoint: Zeroize + Clone + PartialEq + fmt::Debug + Sync + Send + 'static {
type Scalar: ECScalar;
type Underlying;
type CompressedPointLength: ArrayLength<u8> + Unsigned;
type UncompressedPointLength: ArrayLength<u8> + Unsigned;
fn zero() -> Self;
fn is_zero(&self) -> bool {
self == &Self::zero()
}
fn generator() -> &'static Self;
fn base_point2() -> &'static Self;
fn from_coords(x: &BigInt, y: &BigInt) -> Result<Self, NotOnCurve>;
fn x_coord(&self) -> Option<BigInt>;
fn y_coord(&self) -> Option<BigInt>;
fn coords(&self) -> Option<PointCoords>;
fn serialize_compressed(&self) -> GenericArray<u8, Self::CompressedPointLength>;
fn serialize_uncompressed(&self) -> GenericArray<u8, Self::UncompressedPointLength>;
fn deserialize(bytes: &[u8]) -> Result<Self, DeserializationError>;
fn check_point_order_equals_group_order(&self) -> bool {
let mut self_at_q = self.scalar_mul(&Self::Scalar::from_bigint(
&(Self::Scalar::group_order() - 1),
));
self_at_q.add_point_assign(self);
!self.is_zero() && self_at_q.is_zero()
}
fn scalar_mul(&self, scalar: &Self::Scalar) -> Self;
fn generator_mul(scalar: &Self::Scalar) -> Self {
Self::generator().scalar_mul(scalar)
}
fn add_point(&self, other: &Self) -> Self;
fn sub_point(&self, other: &Self) -> Self;
fn neg_point(&self) -> Self;
fn scalar_mul_assign(&mut self, scalar: &Self::Scalar) {
*self = self.scalar_mul(scalar)
}
fn add_point_assign(&mut self, other: &Self) {
*self = self.add_point(other)
}
fn sub_point_assign(&mut self, other: &Self) {
*self = self.sub_point(other)
}
fn neg_point_assign(&mut self) {
*self = self.neg_point()
}
fn underlying_ref(&self) -> &Self::Underlying;
fn underlying_mut(&mut self) -> &mut Self::Underlying;
fn from_underlying(u: Self::Underlying) -> Self;
}
#[derive(Serialize, Deserialize)]
pub struct PointCoords {
pub x: BigInt,
pub y: BigInt,
}
#[derive(Debug)]
pub struct DeserializationError;
impl fmt::Display for DeserializationError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "failed to deserialize the point/scalar")
}
}
impl std::error::Error for DeserializationError {}
#[derive(Debug)]
pub struct NotOnCurve;
impl fmt::Display for NotOnCurve {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "point not on the curve")
}
}
impl std::error::Error for NotOnCurve {}