use alloc::format;
use alloc::string::String;
use core::default::Default;
use core::ops::{Add, Mul, Sub};
use k256::{
elliptic_curve::{
bigint::U256, generic_array::GenericArray,
hash2curve::{ExpandMsgXmd, GroupDigest},
ops::Reduce,
sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
Field,
FieldSize,
NonZeroScalar,
ProjectiveArithmetic,
Scalar,
},
Secp256k1,
};
use rand_core::{CryptoRng, RngCore};
use sha2::{digest::Digest, Sha256};
use subtle::CtOption;
use zeroize::{DefaultIsZeroes, Zeroize};
#[cfg(feature = "serde-support")]
use k256::elliptic_curve::group::ff::PrimeField;
#[cfg(feature = "serde-support")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde-support")]
use crate::serde_bytes::{
deserialize_with_encoding, serialize_with_encoding, Encoding, TryFromBytes,
};
pub(crate) type CurveType = Secp256k1;
pub(crate) type CompressedPointSize = <FieldSize<CurveType> as ModulusSize>::CompressedPointSize;
type BackendScalar = Scalar<CurveType>;
pub(crate) type ScalarSize = FieldSize<CurveType>;
pub(crate) type BackendNonZeroScalar = NonZeroScalar<CurveType>;
#[derive(Clone, Copy, Debug, PartialEq, Default)]
pub struct CurveScalar(BackendScalar);
impl CurveScalar {
pub(crate) fn invert(&self) -> CtOption<Self> {
self.0.invert().map(Self)
}
pub(crate) fn one() -> Self {
Self(BackendScalar::one())
}
pub(crate) fn to_array(self) -> k256::FieldBytes {
self.0.to_bytes()
}
}
#[cfg(feature = "serde-support")]
impl Serialize for CurveScalar {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_with_encoding(&self.0.to_bytes(), serializer, Encoding::Hex)
}
}
#[cfg(feature = "serde-support")]
impl<'de> Deserialize<'de> for CurveScalar {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_with_encoding(deserializer, Encoding::Hex)
}
}
#[cfg(feature = "serde-support")]
impl TryFromBytes for CurveScalar {
type Error = String;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
let arr = GenericArray::<u8, ScalarSize>::from_exact_iter(bytes.iter().cloned())
.ok_or("Invalid length of a curve scalar")?;
let maybe_scalar: Option<BackendScalar> = BackendScalar::from_repr(arr).into();
maybe_scalar
.map(Self)
.ok_or_else(|| "Invalid curve scalar representation".into())
}
}
impl DefaultIsZeroes for CurveScalar {}
#[derive(Clone, Zeroize)]
pub(crate) struct NonZeroCurveScalar(BackendNonZeroScalar);
impl NonZeroCurveScalar {
pub(crate) fn random(rng: &mut (impl CryptoRng + RngCore)) -> Self {
Self(BackendNonZeroScalar::random(rng))
}
pub(crate) fn from_backend_scalar(source: BackendNonZeroScalar) -> Self {
Self(source)
}
pub(crate) fn as_backend_scalar(&self) -> &BackendNonZeroScalar {
&self.0
}
pub(crate) fn invert(&self) -> Self {
let inv = self.0.invert().unwrap();
Self(BackendNonZeroScalar::new(inv).unwrap())
}
pub(crate) fn from_digest(d: impl Digest<OutputSize = ScalarSize>) -> Self {
Self(<BackendNonZeroScalar as Reduce<U256>>::from_be_bytes_reduced(d.finalize()))
}
}
impl From<NonZeroCurveScalar> for CurveScalar {
fn from(source: NonZeroCurveScalar) -> Self {
CurveScalar(*source.0)
}
}
impl From<&NonZeroCurveScalar> for CurveScalar {
fn from(source: &NonZeroCurveScalar) -> Self {
CurveScalar(*source.0)
}
}
type BackendPoint = <CurveType as ProjectiveArithmetic>::ProjectivePoint;
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) struct CurvePoint(BackendPoint);
impl CurvePoint {
pub(crate) fn from_backend_point(point: &BackendPoint) -> Self {
Self(*point)
}
pub(crate) fn as_backend_point(&self) -> &BackendPoint {
&self.0
}
pub(crate) fn generator() -> Self {
Self(BackendPoint::GENERATOR)
}
pub(crate) fn identity() -> Self {
Self(BackendPoint::IDENTITY)
}
pub(crate) fn try_from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
let ep = EncodedPoint::<CurveType>::from_bytes(bytes).map_err(|err| format!("{}", err))?;
let cp_opt: Option<BackendPoint> = BackendPoint::from_encoded_point(&ep).into();
cp_opt
.map(Self)
.ok_or_else(|| "Invalid curve point representation".into())
}
pub(crate) fn to_compressed_array(self) -> GenericArray<u8, CompressedPointSize> {
*GenericArray::<u8, CompressedPointSize>::from_slice(
self.0.to_affine().to_encoded_point(true).as_bytes(),
)
}
pub(crate) fn from_data(dst: &[u8], data: &[u8]) -> Option<Self> {
Some(Self(
CurveType::hash_from_bytes::<ExpandMsgXmd<Sha256>>(&[data], dst).ok()?,
))
}
}
impl Default for CurvePoint {
fn default() -> Self {
CurvePoint::identity()
}
}
#[cfg(feature = "serde-support")]
impl Serialize for CurvePoint {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_with_encoding(&self.to_compressed_array(), serializer, Encoding::Hex)
}
}
#[cfg(feature = "serde-support")]
impl<'de> Deserialize<'de> for CurvePoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_with_encoding(deserializer, Encoding::Hex)
}
}
#[cfg(feature = "serde-support")]
impl TryFromBytes for CurvePoint {
type Error = String;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
Self::try_from_compressed_bytes(bytes)
}
}
impl DefaultIsZeroes for CurvePoint {}
impl Add<&CurveScalar> for &CurveScalar {
type Output = CurveScalar;
fn add(self, other: &CurveScalar) -> CurveScalar {
CurveScalar(self.0.add(&(other.0)))
}
}
impl Add<&NonZeroCurveScalar> for &CurveScalar {
type Output = CurveScalar;
fn add(self, other: &NonZeroCurveScalar) -> CurveScalar {
CurveScalar(self.0.add(&(*other.0)))
}
}
impl Add<&NonZeroCurveScalar> for &NonZeroCurveScalar {
type Output = CurveScalar;
fn add(self, other: &NonZeroCurveScalar) -> CurveScalar {
CurveScalar(self.0.add(&(*other.0)))
}
}
impl Add<&CurvePoint> for &CurvePoint {
type Output = CurvePoint;
fn add(self, other: &CurvePoint) -> CurvePoint {
CurvePoint(self.0.add(&(other.0)))
}
}
impl Sub<&CurveScalar> for &CurveScalar {
type Output = CurveScalar;
fn sub(self, other: &CurveScalar) -> CurveScalar {
CurveScalar(self.0.sub(&(other.0)))
}
}
impl Sub<&NonZeroCurveScalar> for &NonZeroCurveScalar {
type Output = CurveScalar;
fn sub(self, other: &NonZeroCurveScalar) -> CurveScalar {
CurveScalar(self.0.sub(&(*other.0)))
}
}
impl Mul<&CurveScalar> for &CurvePoint {
type Output = CurvePoint;
fn mul(self, other: &CurveScalar) -> CurvePoint {
CurvePoint(self.0.mul(&(other.0)))
}
}
impl Mul<&NonZeroCurveScalar> for &CurvePoint {
type Output = CurvePoint;
fn mul(self, other: &NonZeroCurveScalar) -> CurvePoint {
CurvePoint(self.0.mul(&(*other.0)))
}
}
impl Mul<&CurveScalar> for &CurveScalar {
type Output = CurveScalar;
fn mul(self, other: &CurveScalar) -> CurveScalar {
CurveScalar(self.0.mul(&(other.0)))
}
}
impl Mul<&NonZeroCurveScalar> for &CurveScalar {
type Output = CurveScalar;
fn mul(self, other: &NonZeroCurveScalar) -> CurveScalar {
CurveScalar(self.0.mul(&(*other.0)))
}
}
impl Mul<&NonZeroCurveScalar> for &NonZeroCurveScalar {
type Output = NonZeroCurveScalar;
fn mul(self, other: &NonZeroCurveScalar) -> NonZeroCurveScalar {
NonZeroCurveScalar(self.0.mul(other.0))
}
}