use crate::{weierstrass::Curve, Error, FieldBytes};
use core::{
fmt::{self, Debug},
ops::Add,
};
use generic_array::{
typenum::{Unsigned, U1},
ArrayLength, GenericArray,
};
use subtle::{Choice, ConditionallySelectable};
#[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[cfg(feature = "arithmetic")]
use crate::{
ff::PrimeField, weierstrass::point::Decompress, AffinePoint, ProjectiveArithmetic, Scalar,
};
#[cfg(all(feature = "arithmetic", feature = "zeroize"))]
use crate::{
group::{Curve as _, Group},
secret_key::SecretKey,
};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
pub type CompressedPointSize<C> = <<C as crate::Curve>::FieldSize as Add<U1>>::Output;
pub type UncompressedPointSize<C> = <UntaggedPointSize<C> as Add<U1>>::Output;
pub type UntaggedPointSize<C> = <<C as crate::Curve>::FieldSize as Add>::Output;
#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
pub struct EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
{
bytes: GenericArray<u8, UncompressedPointSize<C>>,
}
#[allow(clippy::len_without_is_empty)]
impl<C> EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
{
pub fn from_bytes(input: impl AsRef<[u8]>) -> Result<Self, Error> {
let input = input.as_ref();
let tag = input.first().cloned().ok_or(Error).and_then(Tag::from_u8)?;
let expected_len = tag.message_len(C::FieldSize::to_usize());
if input.len() != expected_len {
return Err(Error);
}
let mut bytes = GenericArray::default();
bytes[..expected_len].copy_from_slice(input);
Ok(Self { bytes })
}
pub fn from_untagged_bytes(bytes: &GenericArray<u8, UntaggedPointSize<C>>) -> Self {
let (x, y) = bytes.split_at(C::FieldSize::to_usize());
Self::from_affine_coordinates(x.into(), y.into(), false)
}
pub fn from_affine_coordinates(x: &FieldBytes<C>, y: &FieldBytes<C>, compress: bool) -> Self {
let tag = if compress {
Tag::compress_y(y.as_slice())
} else {
Tag::Uncompressed
};
let mut bytes = GenericArray::default();
bytes[0] = tag.into();
let element_size = C::FieldSize::to_usize();
bytes[1..(element_size + 1)].copy_from_slice(x);
if !compress {
bytes[(element_size + 1)..].copy_from_slice(y);
}
Self { bytes }
}
#[cfg(all(feature = "arithmetic", feature = "zeroize"))]
#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
#[cfg_attr(docsrs, doc(cfg(feature = "zeroize")))]
pub fn from_secret_key(secret_key: &SecretKey<C>, compress: bool) -> Self
where
C: Curve + ProjectiveArithmetic,
FieldBytes<C>: From<Scalar<C>> + for<'r> From<&'r Scalar<C>>,
AffinePoint<C>: ToEncodedPoint<C>,
Scalar<C>: PrimeField<Repr = FieldBytes<C>> + Zeroize,
{
(C::ProjectivePoint::generator() * secret_key.secret_scalar())
.to_affine()
.to_encoded_point(compress)
}
pub fn identity() -> Self {
Self::from_bytes(&[0]).unwrap()
}
pub fn len(&self) -> usize {
self.tag().message_len(C::FieldSize::to_usize())
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len()]
}
#[cfg(feature = "alloc")]
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub fn to_bytes(&self) -> Box<[u8]> {
self.as_bytes().to_vec().into_boxed_slice()
}
#[cfg(feature = "arithmetic")]
#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
pub fn to_untagged_bytes(&self) -> Option<GenericArray<u8, UntaggedPointSize<C>>>
where
C: Curve + ProjectiveArithmetic,
FieldBytes<C>: From<Scalar<C>> + for<'r> From<&'r Scalar<C>>,
Scalar<C>: PrimeField<Repr = FieldBytes<C>>,
AffinePoint<C>: ConditionallySelectable + Default + Decompress<C> + ToEncodedPoint<C>,
{
self.decompress().map(|point| {
let mut bytes = GenericArray::<u8, UntaggedPointSize<C>>::default();
bytes.copy_from_slice(&point.as_bytes()[1..]);
bytes
})
}
pub fn is_identity(&self) -> bool {
self.tag().is_identity()
}
pub fn is_compressed(&self) -> bool {
self.tag().is_compressed()
}
pub fn compress(&self) -> Self {
match self.coordinates() {
Coordinates::Identity | Coordinates::Compressed { .. } => self.clone(),
Coordinates::Uncompressed { x, y } => Self::from_affine_coordinates(x, y, true),
}
}
#[cfg(feature = "arithmetic")]
#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
pub fn decompress(&self) -> Option<Self>
where
C: Curve + ProjectiveArithmetic,
FieldBytes<C>: From<Scalar<C>> + for<'r> From<&'r Scalar<C>>,
Scalar<C>: PrimeField<Repr = FieldBytes<C>>,
AffinePoint<C>: ConditionallySelectable + Default + Decompress<C> + ToEncodedPoint<C>,
{
match self.coordinates() {
Coordinates::Identity => None,
Coordinates::Compressed { x, y_is_odd } => {
AffinePoint::<C>::decompress(x, Choice::from(y_is_odd as u8))
.map(|s| s.to_encoded_point(false))
.into()
}
Coordinates::Uncompressed { .. } => Some(self.clone()),
}
}
pub fn encode<T>(encodable: T, compress: bool) -> Self
where
T: ToEncodedPoint<C>,
{
encodable.to_encoded_point(compress)
}
pub fn decode<T>(&self) -> Result<T, Error>
where
T: FromEncodedPoint<C>,
{
T::from_encoded_point(self).ok_or(Error)
}
pub fn tag(&self) -> Tag {
Tag::from_u8(self.bytes[0]).expect("invalid tag")
}
#[inline]
pub fn coordinates(&self) -> Coordinates<'_, C> {
if self.is_identity() {
return Coordinates::Identity;
}
let (x, y) = self.bytes[1..].split_at(C::FieldSize::to_usize());
if self.is_compressed() {
Coordinates::Compressed {
x: x.into(),
y_is_odd: self.tag() as u8 & 1 == 1,
}
} else {
Coordinates::Uncompressed {
x: x.into(),
y: y.into(),
}
}
}
pub fn x(&self) -> Option<&FieldBytes<C>> {
match self.coordinates() {
Coordinates::Identity => None,
Coordinates::Compressed { x, .. } => Some(x),
Coordinates::Uncompressed { x, .. } => Some(x),
}
}
pub fn y(&self) -> Option<&FieldBytes<C>> {
match self.coordinates() {
Coordinates::Compressed { .. } | Coordinates::Identity => None,
Coordinates::Uncompressed { y, .. } => Some(y),
}
}
}
impl<C> AsRef<[u8]> for EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
{
#[inline]
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl<C> ConditionallySelectable for EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
<UncompressedPointSize<C> as ArrayLength<u8>>::ArrayType: Copy,
{
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut bytes = GenericArray::default();
for (i, byte) in bytes.iter_mut().enumerate() {
*byte = u8::conditional_select(&a.bytes[i], &b.bytes[i], choice);
}
Self { bytes }
}
}
impl<C> Copy for EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
<UncompressedPointSize<C> as ArrayLength<u8>>::ArrayType: Copy,
{
}
impl<C> Debug for EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "EncodedPoint<{:?}>({:?})", C::default(), &self.bytes)
}
}
#[cfg(feature = "zeroize")]
impl<C> Zeroize for EncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
{
fn zeroize(&mut self) {
self.bytes.zeroize()
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Coordinates<'a, C: Curve> {
Identity,
Compressed {
x: &'a FieldBytes<C>,
y_is_odd: bool,
},
Uncompressed {
x: &'a FieldBytes<C>,
y: &'a FieldBytes<C>,
},
}
impl<'a, C: Curve> Coordinates<'a, C> {
pub fn tag(&self) -> Tag {
match self {
Coordinates::Identity => Tag::Identity,
Coordinates::Compressed { y_is_odd, .. } => {
if *y_is_odd {
Tag::CompressedOddY
} else {
Tag::CompressedEvenY
}
}
Coordinates::Uncompressed { .. } => Tag::Uncompressed,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(u8)]
pub enum Tag {
Identity = 0,
CompressedEvenY = 2,
CompressedOddY = 3,
Uncompressed = 4,
}
impl Tag {
pub fn from_u8(byte: u8) -> Result<Self, Error> {
match byte {
0 => Ok(Tag::Identity),
2 => Ok(Tag::CompressedEvenY),
3 => Ok(Tag::CompressedOddY),
4 => Ok(Tag::Uncompressed),
_ => Err(Error),
}
}
pub fn is_identity(self) -> bool {
self == Tag::Identity
}
pub fn is_compressed(self) -> bool {
matches!(self, Tag::CompressedEvenY | Tag::CompressedOddY)
}
pub fn message_len(self, field_element_size: usize) -> usize {
1 + match self {
Tag::Identity => 0,
Tag::CompressedEvenY | Tag::CompressedOddY => field_element_size,
Tag::Uncompressed => field_element_size * 2,
}
}
fn compress_y(y: &[u8]) -> Self {
debug_assert!(!y.is_empty());
if y.as_ref().last().unwrap() & 1 == 1 {
Tag::CompressedOddY
} else {
Tag::CompressedEvenY
}
}
}
impl From<Tag> for u8 {
fn from(tag: Tag) -> u8 {
tag as u8
}
}
pub trait FromEncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
Self: Sized,
{
fn from_encoded_point(public_key: &EncodedPoint<C>) -> Option<Self>;
}
pub trait ToEncodedPoint<C>
where
C: Curve,
UntaggedPointSize<C>: Add<U1> + ArrayLength<u8>,
UncompressedPointSize<C>: ArrayLength<u8>,
{
fn to_encoded_point(&self, compress: bool) -> EncodedPoint<C>;
}
#[cfg(test)]
mod tests {
use super::{Coordinates, Tag};
use crate::{weierstrass, Curve};
use generic_array::{typenum::U32, GenericArray};
use hex_literal::hex;
use subtle::ConditionallySelectable;
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
struct ExampleCurve;
impl Curve for ExampleCurve {
type FieldSize = U32;
}
impl weierstrass::Curve for ExampleCurve {}
type EncodedPoint = super::EncodedPoint<ExampleCurve>;
const IDENTITY_BYTES: [u8; 1] = [0];
const UNCOMPRESSED_BYTES: [u8; 65] = hex!("0411111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
const COMPRESSED_BYTES: [u8; 33] =
hex!("021111111111111111111111111111111111111111111111111111111111111111");
#[test]
fn decode_compressed_point() {
let compressed_even_y_bytes =
hex!("020100000000000000000000000000000000000000000000000000000000000000");
let compressed_even_y = EncodedPoint::from_bytes(&compressed_even_y_bytes[..]).unwrap();
assert!(compressed_even_y.is_compressed());
assert_eq!(compressed_even_y.tag(), Tag::CompressedEvenY);
assert_eq!(compressed_even_y.len(), 33);
assert_eq!(compressed_even_y.as_bytes(), &compressed_even_y_bytes[..]);
assert_eq!(
compressed_even_y.coordinates(),
Coordinates::Compressed {
x: &hex!("0100000000000000000000000000000000000000000000000000000000000000").into(),
y_is_odd: false
}
);
assert_eq!(
compressed_even_y.x().unwrap(),
&hex!("0100000000000000000000000000000000000000000000000000000000000000").into()
);
assert_eq!(compressed_even_y.y(), None);
let compressed_odd_y_bytes =
hex!("030200000000000000000000000000000000000000000000000000000000000000");
let compressed_odd_y = EncodedPoint::from_bytes(&compressed_odd_y_bytes[..]).unwrap();
assert!(compressed_odd_y.is_compressed());
assert_eq!(compressed_odd_y.tag(), Tag::CompressedOddY);
assert_eq!(compressed_odd_y.len(), 33);
assert_eq!(compressed_odd_y.as_bytes(), &compressed_odd_y_bytes[..]);
assert_eq!(
compressed_odd_y.coordinates(),
Coordinates::Compressed {
x: &hex!("0200000000000000000000000000000000000000000000000000000000000000").into(),
y_is_odd: true
}
);
assert_eq!(
compressed_odd_y.x().unwrap(),
&hex!("0200000000000000000000000000000000000000000000000000000000000000").into()
);
assert_eq!(compressed_odd_y.y(), None);
}
#[test]
fn decode_uncompressed_point() {
let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
assert!(!uncompressed_point.is_compressed());
assert_eq!(uncompressed_point.tag(), Tag::Uncompressed);
assert_eq!(uncompressed_point.len(), 65);
assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
assert_eq!(
uncompressed_point.coordinates(),
Coordinates::Uncompressed {
x: &hex!("1111111111111111111111111111111111111111111111111111111111111111").into(),
y: &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
}
);
assert_eq!(
uncompressed_point.x().unwrap(),
&hex!("1111111111111111111111111111111111111111111111111111111111111111").into()
);
assert_eq!(
uncompressed_point.y().unwrap(),
&hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
);
}
#[test]
fn decode_identity() {
let identity_point = EncodedPoint::from_bytes(&IDENTITY_BYTES[..]).unwrap();
assert!(identity_point.is_identity());
assert_eq!(identity_point.tag(), Tag::Identity);
assert_eq!(identity_point.len(), 1);
assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
assert_eq!(identity_point.coordinates(), Coordinates::Identity);
assert_eq!(identity_point.x(), None);
assert_eq!(identity_point.y(), None);
}
#[test]
fn decode_invalid_tag() {
let mut compressed_bytes = COMPRESSED_BYTES.clone();
let mut uncompressed_bytes = UNCOMPRESSED_BYTES.clone();
for bytes in &mut [&mut compressed_bytes[..], &mut uncompressed_bytes[..]] {
for tag in 0..=0xFF {
if tag == 2 || tag == 3 || tag == 4 {
continue;
}
(*bytes)[0] = tag;
let decode_result = EncodedPoint::from_bytes(&*bytes);
assert!(decode_result.is_err());
}
}
}
#[test]
fn decode_truncated_point() {
for bytes in &[&COMPRESSED_BYTES[..], &UNCOMPRESSED_BYTES[..]] {
for len in 0..bytes.len() {
let decode_result = EncodedPoint::from_bytes(&bytes[..len]);
assert!(decode_result.is_err());
}
}
}
#[test]
fn from_untagged_point() {
let untagged_bytes = hex!("11111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
let uncompressed_point =
EncodedPoint::from_untagged_bytes(GenericArray::from_slice(&untagged_bytes[..]));
assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
}
#[test]
fn from_affine_coordinates() {
let x = hex!("1111111111111111111111111111111111111111111111111111111111111111");
let y = hex!("2222222222222222222222222222222222222222222222222222222222222222");
let uncompressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false);
assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
let compressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), true);
assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
}
#[test]
fn compress() {
let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
let compressed_point = uncompressed_point.compress();
assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
}
#[test]
fn conditional_select() {
let a = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
let b = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
let a_selected = EncodedPoint::conditional_select(&a, &b, 0.into());
assert_eq!(a, a_selected);
let b_selected = EncodedPoint::conditional_select(&a, &b, 1.into());
assert_eq!(b, b_selected);
}
#[test]
fn identity() {
let identity_point = EncodedPoint::identity();
assert_eq!(identity_point.tag(), Tag::Identity);
assert_eq!(identity_point.len(), 1);
assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
}
#[cfg(feature = "alloc")]
#[test]
fn to_bytes() {
let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
assert_eq!(&*uncompressed_point.to_bytes(), &UNCOMPRESSED_BYTES[..]);
}
}