use std::sync::Arc;
use num_bigint::BigInt;
use num_bigint::RandBigInt;
use num_traits::One;
use num_traits::Zero;
use serde::Deserialize;
use serde::Serialize;
use crate::symbolic::finite_field::PrimeField;
use crate::symbolic::finite_field::PrimeFieldElement;
#[derive(Clone, Serialize, Deserialize)]
pub struct EllipticCurve {
pub a: PrimeFieldElement,
pub b: PrimeFieldElement,
pub field: Arc<PrimeField>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum CurvePoint {
Infinity,
Affine {
x: PrimeFieldElement,
y: PrimeFieldElement,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EcdhKeyPair {
pub private_key: BigInt,
pub public_key: CurvePoint,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EcdsaSignature {
pub r: BigInt,
pub s: BigInt,
}
impl CurvePoint {
#[must_use]
pub const fn is_infinity(&self) -> bool {
matches!(self, Self::Infinity)
}
#[must_use]
pub const fn x(&self) -> Option<&PrimeFieldElement> {
match self {
| Self::Affine { x, .. } => Some(x),
| Self::Infinity => None,
}
}
#[must_use]
pub const fn y(&self) -> Option<&PrimeFieldElement> {
match self {
| Self::Affine { y, .. } => Some(y),
| Self::Infinity => None,
}
}
}
impl EllipticCurve {
#[must_use]
pub fn new(
a: BigInt,
b: BigInt,
modulus: BigInt,
) -> Self {
let field = PrimeField::new(modulus);
Self {
a: PrimeFieldElement::new(a, field.clone()),
b: PrimeFieldElement::new(b, field.clone()),
field,
}
}
#[must_use]
pub fn is_on_curve(
&self,
point: &CurvePoint,
) -> bool {
match point {
| CurvePoint::Infinity => true,
| CurvePoint::Affine { x, y } => {
let lhs = y.clone() * y.clone();
let rhs =
x.clone() * x.clone() * x.clone() + self.a.clone() * x.clone() + self.b.clone();
lhs == rhs
},
}
}
#[must_use]
pub fn negate(
&self,
point: &CurvePoint,
) -> CurvePoint {
match point {
| CurvePoint::Infinity => CurvePoint::Infinity,
| CurvePoint::Affine { x, y } => {
CurvePoint::Affine {
x: x.clone(),
y: -y.clone(),
}
},
}
}
#[must_use]
pub fn double(
&self,
point: &CurvePoint,
) -> CurvePoint {
match point {
| CurvePoint::Infinity => CurvePoint::Infinity,
| CurvePoint::Affine { x, y } => {
if y.value.is_zero() {
return CurvePoint::Infinity;
}
let three = PrimeFieldElement::new(BigInt::from(3), self.field.clone());
let two = PrimeFieldElement::new(BigInt::from(2), self.field.clone());
let m = (three * x.clone() * x.clone() + self.a.clone()) / (two * y.clone());
let x3 = m.clone() * m.clone() - x.clone() - x.clone();
let y3 = m * (x.clone() - x3.clone()) - y.clone();
CurvePoint::Affine { x: x3, y: y3 }
},
}
}
#[must_use]
pub fn add(
&self,
p1: &CurvePoint,
p2: &CurvePoint,
) -> CurvePoint {
match (p1, p2) {
| (CurvePoint::Infinity, p) | (p, CurvePoint::Infinity) => p.clone(),
| (CurvePoint::Affine { x: x1, y: y1 }, CurvePoint::Affine { x: x2, y: y2 }) => {
if x1 == x2 && *y1 != *y2 {
return CurvePoint::Infinity;
}
if x1 == x2 && y1 == y2 {
return self.double(p1);
}
let m = (y2.clone() - y1.clone()) / (x2.clone() - x1.clone());
let x3 = m.clone() * m.clone() - x1.clone() - x2.clone();
let y3 = m * (x1.clone() - x3.clone()) - y1.clone();
CurvePoint::Affine { x: x3, y: y3 }
},
}
}
#[must_use]
pub fn scalar_mult(
&self,
k: &BigInt,
p: &CurvePoint,
) -> CurvePoint {
let mut res = CurvePoint::Infinity;
let mut app = p.clone();
let mut k_clone = k.clone();
while k_clone > Zero::zero() {
if &k_clone % 2 != Zero::zero() {
res = self.add(&res, &app);
}
app = self.double(&app);
k_clone >>= 1;
}
res
}
}
#[must_use]
pub fn generate_keypair(
curve: &EllipticCurve,
generator: &CurvePoint,
) -> EcdhKeyPair {
let mut rng = rand::thread_rng();
let private_key = rng.gen_bigint_range(&BigInt::one(), &curve.field.modulus);
let public_key = curve.scalar_mult(&private_key, generator);
EcdhKeyPair {
private_key,
public_key,
}
}
#[must_use]
pub fn generate_shared_secret(
curve: &EllipticCurve,
own_private_key: &BigInt,
other_public_key: &CurvePoint,
) -> CurvePoint {
curve.scalar_mult(own_private_key, other_public_key)
}
#[must_use]
pub fn point_compress(point: &CurvePoint) -> Option<(BigInt, bool)> {
match point {
| CurvePoint::Infinity => None,
| CurvePoint::Affine { x, y } => {
let is_y_odd = &y.value % 2 != BigInt::zero();
Some((x.value.clone(), is_y_odd))
},
}
}
#[must_use]
pub fn point_decompress(
x: BigInt,
is_y_odd: bool,
curve: &EllipticCurve,
) -> Option<CurvePoint> {
let x_elem = PrimeFieldElement::new(x, curve.field.clone());
let y_squared = x_elem.clone() * x_elem.clone() * x_elem.clone()
+ curve.a.clone() * x_elem.clone()
+ curve.b.clone();
let modulus = &curve.field.modulus;
let y_squared_val = &y_squared.value;
let exp = (modulus + 1) / 4;
let y_val = y_squared_val.modpow(&exp, modulus);
if (&y_val * &y_val) % modulus != y_squared_val % modulus {
return None; }
let y_is_odd = &y_val % 2 != BigInt::zero();
let y_final = if y_is_odd == is_y_odd {
y_val
} else {
modulus - &y_val
};
Some(CurvePoint::Affine {
x: x_elem,
y: PrimeFieldElement::new(y_final, curve.field.clone()),
})
}
#[must_use]
pub fn ecdsa_sign(
message_hash: &BigInt,
private_key: &BigInt,
curve: &EllipticCurve,
generator: &CurvePoint,
order: &BigInt,
) -> Option<EcdsaSignature> {
let mut rng = rand::thread_rng();
let k = rng.gen_bigint_range(&BigInt::one(), order);
let r_point = curve.scalar_mult(&k, generator);
let r = match &r_point {
| CurvePoint::Affine { x, .. } => x.value.clone() % order,
| CurvePoint::Infinity => {
return None;
},
};
if r.is_zero() {
return None;
}
let k_inv = mod_inverse(&k, order)?;
let s = (&k_inv * (message_hash + &r * private_key)) % order;
if s.is_zero() {
return None;
}
Some(EcdsaSignature { r, s })
}
#[must_use]
pub fn ecdsa_verify(
message_hash: &BigInt,
signature: &EcdsaSignature,
public_key: &CurvePoint,
curve: &EllipticCurve,
generator: &CurvePoint,
order: &BigInt,
) -> bool {
if signature.r <= BigInt::zero() || signature.r >= *order {
return false;
}
if signature.s <= BigInt::zero() || signature.s >= *order {
return false;
}
let w = match mod_inverse(&signature.s, order) {
| Some(w) => w,
| None => return false,
};
let u1 = (message_hash * &w) % order;
let u2 = (&signature.r * &w) % order;
let point1 = curve.scalar_mult(&u1, generator);
let point2 = curve.scalar_mult(&u2, public_key);
let r_prime = curve.add(&point1, &point2);
match r_prime {
| CurvePoint::Infinity => false,
| CurvePoint::Affine { x, .. } => {
let v = x.value % order;
v == signature.r
},
}
}
pub(crate) fn mod_inverse(
a: &BigInt,
m: &BigInt,
) -> Option<BigInt> {
let (g, x, _) = extended_gcd(a, m);
if g != BigInt::one() {
return None;
}
Some(((x % m) + m) % m)
}
fn extended_gcd(
a: &BigInt,
b: &BigInt,
) -> (BigInt, BigInt, BigInt) {
if b.is_zero() {
(a.clone(), BigInt::one(), BigInt::zero())
} else {
let (g, x, y) = extended_gcd(b, &(a % b));
(g, y.clone(), x - (a / b) * y)
}
}