use std::{fmt, iter};
use crate::elliptic::curves::traits::*;
use crate::BigInt;
use super::{
error::{MismatchedPointOrder, PointFromBytesError, PointFromCoordsError, ZeroPointError},
EncodedPoint, Generator,
};
use crate::elliptic::curves::wrappers::encoded_point::EncodedPointChoice;
#[repr(transparent)]
pub struct Point<E: Curve> {
raw_point: E::Point,
}
impl<E: Curve> Point<E> {
pub fn ensure_nonzero(&self) -> Result<(), ZeroPointError> {
if self.is_zero() {
Err(ZeroPointError::new())
} else {
Ok(())
}
}
pub fn generator() -> Generator<E> {
Generator::default()
}
pub fn base_point2() -> &'static Self {
let p = E::Point::base_point2();
unsafe { Self::from_raw_ref_unchecked(p) }
}
pub fn zero() -> Self {
unsafe { Self::from_raw_unchecked(E::Point::zero()) }
}
pub fn is_zero(&self) -> bool {
self.as_raw().is_zero()
}
pub fn coords(&self) -> Option<PointCoords> {
self.as_raw().coords()
}
pub fn x_coord(&self) -> Option<BigInt> {
self.as_raw().x_coord()
}
pub fn y_coord(&self) -> Option<BigInt> {
self.as_raw().y_coord()
}
pub fn from_coords(x: &BigInt, y: &BigInt) -> Result<Self, PointFromCoordsError> {
let raw_point = E::Point::from_coords(x, y)
.map_err(|_: NotOnCurve| PointFromCoordsError::NotOnCurve)?;
Self::from_raw(raw_point).map_err(PointFromCoordsError::InvalidPoint)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, PointFromBytesError> {
let p = E::Point::deserialize(bytes)
.map_err(|_: DeserializationError| PointFromBytesError::DeserializationError)?;
Self::from_raw(p).map_err(PointFromBytesError::InvalidPoint)
}
pub fn to_bytes(&self, compressed: bool) -> EncodedPoint<E> {
if compressed {
EncodedPoint(EncodedPointChoice::Compressed(
self.as_raw().serialize_compressed(),
))
} else {
EncodedPoint(EncodedPointChoice::Uncompressed(
self.as_raw().serialize_uncompressed(),
))
}
}
pub fn from_raw(raw_point: E::Point) -> Result<Self, MismatchedPointOrder> {
if raw_point.is_zero() || raw_point.check_point_order_equals_group_order() {
Ok(Self { raw_point })
} else {
Err(MismatchedPointOrder::new())
}
}
pub fn from_raw_ref(raw_point: &E::Point) -> Result<&Self, MismatchedPointOrder> {
if raw_point.is_zero() || raw_point.check_point_order_equals_group_order() {
let reference = unsafe { Self::from_raw_ref_unchecked(raw_point) };
Ok(reference)
} else {
Err(MismatchedPointOrder::new())
}
}
pub unsafe fn from_raw_unchecked(raw_point: E::Point) -> Self {
debug_assert!(raw_point.is_zero() || raw_point.check_point_order_equals_group_order());
Self { raw_point }
}
pub unsafe fn from_raw_ref_unchecked(raw_point: &E::Point) -> &Self {
debug_assert!(raw_point.is_zero() || raw_point.check_point_order_equals_group_order());
&*(raw_point as *const E::Point as *const Self)
}
pub fn as_raw(&self) -> &E::Point {
&self.raw_point
}
pub fn into_raw(self) -> E::Point {
self.raw_point
}
}
impl<E: Curve> PartialEq for Point<E> {
fn eq(&self, other: &Self) -> bool {
self.raw_point.eq(&other.raw_point)
}
}
impl<E: Curve> Eq for Point<E> {}
impl<E: Curve> PartialEq<Generator<E>> for Point<E> {
fn eq(&self, other: &Generator<E>) -> bool {
self.as_raw().eq(other.as_raw())
}
}
impl<E: Curve> Clone for Point<E> {
fn clone(&self) -> Self {
unsafe { Point::from_raw_unchecked(self.as_raw().clone()) }
}
}
impl<E: Curve> fmt::Debug for Point<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.raw_point.fmt(f)
}
}
impl<E: Curve> From<Generator<E>> for Point<E> {
fn from(g: Generator<E>) -> Self {
unsafe { Point::from_raw_unchecked(g.as_raw().clone()) }
}
}
impl<E: Curve> iter::Sum for Point<E> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Point::zero(), |acc, p| acc + p)
}
}
impl<'p, E: Curve> iter::Sum<&'p Point<E>> for Point<E> {
fn sum<I: Iterator<Item = &'p Point<E>>>(iter: I) -> Self {
iter.fold(Point::zero(), |acc, p| acc + p)
}
}