use num_bigint::{BigInt, BigUint};
use num_traits::{One, Zero};
use super::bn::BigNumberHelper;
#[derive(Debug, Clone)]
pub struct EcGroup {
pub p: BigUint, pub a: BigUint, pub b: BigUint, pub g_x: BigUint, pub g_y: BigUint, pub n: BigUint, pub h: BigUint, }
#[derive(Debug)]
#[allow(non_camel_case_types)]
pub enum Nid {
X9_62_PRIME256V1,
}
impl EcGroup {
pub fn from_curve_name(nid: Nid) -> Result<Self, EcError> {
match nid {
Nid::X9_62_PRIME256V1 => {
let p = BigUint::parse_bytes(
b"FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
16,
)
.unwrap();
let three = BigUint::from(3u32);
let a = p.clone() - three;
let b = BigUint::parse_bytes(
b"5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
16,
)
.unwrap();
let g_x = BigUint::parse_bytes(
b"6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296",
16,
)
.unwrap();
let g_y = BigUint::parse_bytes(
b"4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
16,
)
.unwrap();
let n = BigUint::parse_bytes(
b"FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
16,
)
.unwrap();
Ok(EcGroup {
p,
a,
b,
g_x,
g_y,
n,
h: BigUint::from(1u32),
})
}
}
}
pub fn check_point(&self, x: &BigUint, y: &BigUint) -> bool {
let p = &self.p;
let x = x % p;
let y = y % p;
trace!("Point checking:");
trace!("x: {}", x);
trace!("y: {}", y);
let y_squared = (&y * &y) % p;
let x_cubed = (x.modpow(&BigUint::from(3u32), p)) % p;
let ax = (&self.a * &x) % p;
let right_side = ((x_cubed + ax) % p + &self.b) % p;
trace!("y² mod p: {}", y_squared);
trace!("right side: {}", right_side);
y_squared == right_side
}
pub fn generator(&'static self) -> Result<EcPoint, EcError> {
EcPoint::new(self, self.g_x.clone(), self.g_y.clone())
}
pub fn order(&self) -> &BigUint {
&self.n
}
pub fn prime(&self) -> &BigUint {
&self.p
}
}
#[derive(Debug)]
pub enum EcError {
InvalidCurve,
InvalidPoint,
PointAtInfinity,
InvalidPrivateKey,
KeyPairMismatch,
}
pub const P256_POINT_BYTES: usize = 65; pub const P256_FIELD_SIZE: usize = 32;
#[derive(Debug)]
pub enum EcPointError {
InvalidLength,
InvalidFormat,
InvalidPoint,
PointAtInfinity,
}
#[derive(Clone, Debug)]
pub struct EcPoint {
pub x: BigUint,
pub y: BigUint,
pub z: BigUint,
pub group: &'static EcGroup,
}
#[derive(Debug, Clone, Copy)]
pub enum PointConversionForm {
Compressed, Uncompressed, Hybrid, }
impl EcPoint {
pub fn new(group: &'static EcGroup, x: BigUint, y: BigUint) -> Result<Self, EcError> {
let p = &group.p;
let x = &x % p;
let y = &y % p;
let y_squared = (y.clone() * y.clone()) % p;
let x_cubed = {
let x_squared = (x.clone() * x.clone()) % p;
(x_squared * x.clone()) % p
};
let neg_3x = {
let three_x = (x.clone() * 3u32) % p;
if three_x == BigUint::zero() {
BigUint::zero()
} else {
p - three_x
}
};
let mut right = (x_cubed.clone() + neg_3x.clone()) % p;
right = (right + &group.b) % p;
trace!("Detailed point validation:");
trace!("x = {}", x);
trace!("y = {}", y);
trace!("x³ mod p = {}", x_cubed);
trace!("-3x mod p = {}", neg_3x);
trace!("b = {}", group.b);
trace!("y² mod p = {}", y_squared);
trace!("right side (x³ - 3x + b mod p) = {}", right);
if y_squared != right {
error!("Invalid point: equation does not hold!");
error!(
"Difference: {}",
if y_squared > right {
y_squared - right
} else {
right - y_squared
}
);
return Err(EcError::InvalidPoint);
}
if x >= *p || y >= *p {
error!("Point coordinates out of range!");
return Err(EcError::InvalidPoint);
}
Ok(Self {
x: x.clone(),
y: y.clone(),
z: BigUint::one(),
group,
})
}
pub fn is_on_curve(&self) -> bool {
if self.z.is_zero() {
return true;
}
if let Err(err) = self.get_affine() {
error!("Err(err) {:?}", err);
return false;
};
let p = self.group.prime();
let x = &self.x % p;
let y = &self.y % p;
let x_squared = (&x * &x) % p;
let x_cubed = (&x_squared * &x) % p;
let a = &self.group.a % p;
let ax = (a.clone() * &x) % p;
let sum1 = (&x_cubed + &ax) % p;
let b = &self.group.b % p;
let right = (&sum1 + b) % p;
let y_squared = (&y * &y) % p;
trace!("Detailed check values:");
trace!("x mod p: {}", x);
trace!("y mod p: {}", y);
trace!("a mod p: {}", a);
trace!("x^2 mod p: {}", x_squared);
trace!("x^3 mod p: {}", x_cubed);
trace!("ax mod p: {}", ax);
trace!("sum1 (x^3 + ax) mod p: {}", sum1);
trace!("right side (x^3 + ax + b) mod p: {}", right);
trace!("y^2 mod p: {}", y_squared);
right == y_squared
}
pub fn from_bytes(group: &'static EcGroup, data: &[u8]) -> Result<Self, EcPointError> {
if data.len() != P256_POINT_BYTES {
return Err(EcPointError::InvalidLength);
}
if data[0] != 0x04 {
return Err(EcPointError::InvalidFormat);
}
let x_bytes = &data[1..P256_FIELD_SIZE + 1];
let y_bytes = &data[P256_FIELD_SIZE + 1..P256_POINT_BYTES];
let x = BigUint::from_bytes_be(x_bytes);
let y = BigUint::from_bytes_be(y_bytes);
Ok(Self::new(group, x, y).map_err(|_| EcPointError::InvalidPoint)?)
}
pub fn get_affine(&self) -> Result<(BigUint, BigUint), EcError> {
if self.z.is_zero() {
return Err(EcError::PointAtInfinity);
}
let p = &self.group.p;
let z_inv = mod_inverse(&self.z, p).ok_or(EcError::InvalidPoint)?;
let x_affine = (&self.x * &z_inv) % p;
let y_affine = (&self.y * &z_inv) % p;
Ok((x_affine, y_affine))
}
pub fn to_bytes(&self, form: PointConversionForm) -> Result<Vec<u8>, EcError> {
let (x, y) = self.get_affine()?;
match form {
PointConversionForm::Uncompressed => {
let mut result = vec![0u8; 65]; result[0] = 0x04;
let x_bytes = pad_to_32_bytes(&x)?;
result[1..33].copy_from_slice(&x_bytes);
let y_bytes = pad_to_32_bytes(&y)?;
result[33..65].copy_from_slice(&y_bytes);
Ok(result)
}
PointConversionForm::Compressed => {
let mut result = vec![if &y % 2u32 == BigUint::from(0u32) {
0x02
} else {
0x03
}];
result.extend(pad_to_32_bytes(&x)?);
Ok(result)
}
PointConversionForm::Hybrid => {
let mut result = vec![if &y % 2u32 == BigUint::from(0u32) {
0x06
} else {
0x07
}];
result.extend(pad_to_32_bytes(&x)?);
result.extend(pad_to_32_bytes(&y)?);
Ok(result)
}
}
}
pub fn affine_coordinates_gfp(
&self,
x: &mut BigNumberHelper,
y: &mut BigNumberHelper,
) -> Result<(), EcError> {
let (affine_x, affine_y) = self.get_affine()?;
*x = BigNumberHelper::from_bytes(&affine_x.to_bytes_be());
*y = BigNumberHelper::from_bytes(&affine_y.to_bytes_be());
Ok(())
}
pub fn scalar_mul(&self, scalar: &BigUint) -> Result<Self, EcError> {
let order = self.group.order();
let scalar = scalar % order;
if scalar.is_zero() {
return Err(EcError::PointAtInfinity);
}
let mut result = self.clone();
for i in (0..scalar.bits()).rev() {
if i != scalar.bits() - 1 {
result = result.double()?;
}
if scalar.bit(i) {
result = result.add(self)?;
}
}
if result.z.is_zero() {
return Err(EcError::PointAtInfinity);
}
Ok(result)
}
pub fn double(&self) -> Result<Self, EcError> {
if self.z.is_zero() {
return Ok(self.clone());
}
let p = self.group.prime();
let (x, y) = self.get_affine()?;
let two = BigUint::from(2u32);
let three = BigUint::from(3u32);
let x_squared = (&x * &x) % p;
let three_x_squared = (&three * &x_squared) % p;
let three_z_squared = BigUint::from(3u32);
let numerator = if three_x_squared >= three_z_squared {
(three_x_squared - three_z_squared) % p
} else {
(p + three_x_squared - three_z_squared) % p
};
let two_y = (&two * &y) % p;
let two_y_inv = mod_inverse(&two_y, p).ok_or(EcError::InvalidPoint)?;
let lambda = (&numerator * &two_y_inv) % p;
let lambda_squared = (&lambda * &lambda) % p;
let two_x = (&two * &x) % p;
let x3 = if lambda_squared >= two_x {
(lambda_squared - two_x) % p
} else {
(p + lambda_squared - two_x) % p
};
let x_diff = if x >= x3 {
(&x - &x3) % p
} else {
(p + &x - &x3) % p
};
let lambda_times_diff = (&lambda * &x_diff) % p;
let y3 = if lambda_times_diff >= y {
(lambda_times_diff - &y) % p
} else {
(p + lambda_times_diff - &y) % p
};
trace!("Double operation details:");
trace!("Input (x,y): ({}, {})", x, y);
trace!("λ calculation:");
trace!(" numerator = 3x² - 3 = {}", numerator);
trace!(" denominator = 2y = {}", two_y);
trace!(" λ = {}", lambda);
trace!("New point calculation:");
trace!(" x₃ = λ² - 2x = {}", x3);
trace!(" y₃ = λ(x - x₃) - y = {}", y3);
let result = Self {
x: x3.clone(),
y: y3.clone(),
z: BigUint::one(),
group: self.group,
};
if !self.group.check_point(&x3, &y3) {
error!("Double result validation failed!");
error!("Point ({}, {}) is not on curve!", x3, y3);
return Err(EcError::InvalidPoint);
}
Ok(result)
}
pub fn add(&self, other: &Self) -> Result<Self, EcError> {
if self.z.is_zero() {
return Ok(other.clone());
}
if other.z.is_zero() {
return Ok(self.clone());
}
let p = self.group.prime();
let (x1, y1) = self.get_affine()?;
let (x2, y2) = other.get_affine()?;
if x1 == x2 {
if y1 == y2 {
return self.double();
}
return Ok(Self {
x: BigUint::zero(),
y: BigUint::zero(),
z: BigUint::zero(),
group: self.group,
});
}
let y_diff = if y2 >= y1 {
(&y2 - &y1) % p
} else {
(p + &y2 - &y1) % p
};
let x_diff = if x2 >= x1 {
(&x2 - &x1) % p
} else {
(p + &x2 - &x1) % p
};
let x_diff_inv = mod_inverse(&x_diff, p).ok_or(EcError::InvalidPoint)?;
let lambda = (&y_diff * &x_diff_inv) % p;
let lambda_squared = (&lambda * &lambda) % p;
let x3 = ((&lambda_squared + p - &x1) % p + p - &x2) % p;
let x_diff_new = if x1 >= x3 {
(&x1 - &x3) % p
} else {
(p + &x1 - &x3) % p
};
let y3 = ((&lambda * &x_diff_new) % p + p - &y1) % p;
trace!("Add operation details:");
trace!(" Input points: ({}, {}), ({}, {})", x1, y1, x2, y2);
trace!(" λ = {}", lambda);
trace!(" Result: ({}, {})", x3, y3);
Ok(Self {
x: x3,
y: y3,
z: BigUint::one(),
group: self.group,
})
}
}
impl PartialEq for EcPoint {
fn eq(&self, other: &Self) -> bool {
if self.z.is_zero() && other.z.is_zero() {
return true;
}
if self.z.is_zero() || other.z.is_zero() {
return false;
}
match (self.get_affine(), other.get_affine()) {
(Ok((x1, y1)), Ok((x2, y2))) => x1 == x2 && y1 == y2,
_ => false,
}
}
}
fn mod_inverse(a: &BigUint, m: &BigUint) -> Option<BigUint> {
let a = a % m;
if a.is_zero() {
return None;
}
let mut t = BigInt::zero();
let mut newt = BigInt::one();
let mut r = BigInt::from(m.clone());
let mut newr = BigInt::from(a.clone());
while !newr.is_zero() {
let quotient = &r / &newr;
(t, newt) = (newt.clone(), t - "ient * &newt);
(r, newr) = (newr.clone(), r - quotient * newr);
}
if r > BigInt::one() {
return None;
}
while t < BigInt::zero() {
t = t + BigInt::from(m.clone());
}
let result = t.to_biguint().unwrap() % m;
if (&a * &result) % m != BigUint::one() {
return None;
}
Some(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mod_inverse() {
let a = BigUint::from(3u32);
let m = BigUint::from(11u32);
let result = mod_inverse(&a, &m).unwrap();
assert_eq!(result, BigUint::from(4u32));
let product = (BigInt::from(a) * BigInt::from(result)) % BigInt::from(m.clone());
assert_eq!(product, BigInt::one());
let a = BigUint::zero();
let m = BigUint::from(5u32);
assert_eq!(mod_inverse(&a, &m), None);
let a = BigUint::from(2u32);
let m = BigUint::from(4u32);
assert_eq!(mod_inverse(&a, &m), None);
let a = BigUint::from(17u32);
let m = BigUint::from(23u32);
let result = mod_inverse(&a, &m).unwrap();
let product = (BigInt::from(a) * BigInt::from(result)) % BigInt::from(m);
assert_eq!(product, BigInt::one());
}
#[test]
fn test_edge_cases() {
let a = BigUint::from(1u32);
let m = BigUint::from(7u32);
assert_eq!(mod_inverse(&a, &m).unwrap(), BigUint::from(1u32));
let a = BigUint::from(6u32); let m = BigUint::from(7u32);
assert_eq!(mod_inverse(&a, &m).unwrap(), BigUint::from(6u32));
let a = BigUint::from(5u32);
let m = BigUint::zero();
assert_eq!(mod_inverse(&a, &m), None);
}
}
pub trait KeyType {}
#[derive(Debug, Clone)]
pub struct Private;
#[derive(Debug, Clone)]
pub struct Public;
impl KeyType for Private {}
impl KeyType for Public {}
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct EcKey<T: KeyType> {
group: &'static EcGroup,
public_key: EcPoint,
pub private_key: Option<BigUint>, _phantom: PhantomData<T>, }
impl<T: KeyType> EcKey<T> {
pub fn group(&self) -> &'static EcGroup {
self.group
}
pub fn public_key(&self) -> &EcPoint {
&self.public_key
}
}
impl EcKey<Private> {
pub fn generate(group: &'static EcGroup) -> Result<Self, EcError> {
let private_key = generate_random_below(&group.order())?;
let public_key = group.generator()?.scalar_mul(&private_key)?;
Ok(Self {
group,
public_key,
private_key: Some(private_key),
_phantom: PhantomData,
})
}
pub fn private_key(&self) -> Option<&BigUint> {
self.private_key.as_ref()
}
pub fn to_public_key(&self) -> EcKey<Public> {
EcKey {
group: self.group,
public_key: self.public_key.clone(),
private_key: None,
_phantom: PhantomData,
}
}
pub fn from_private_components(
group: &'static EcGroup,
private_key: &BigNumberHelper,
public_point: &EcPoint,
) -> Result<Self, EcError> {
let private_biguint = BigUint::from_bytes_be(&private_key.to_bytes());
if private_biguint >= *group.order() {
return Err(EcError::InvalidPrivateKey);
}
let computed_public = group.generator()?.scalar_mul(&private_biguint)?;
if !computed_public.eq(public_point) {
return Err(EcError::KeyPairMismatch);
}
Ok(Self {
group,
public_key: public_point.clone(),
private_key: Some(private_biguint),
_phantom: PhantomData,
})
}
}
impl EcKey<Public> {
pub fn from_public_key(group: &'static EcGroup, public_key: EcPoint) -> Result<Self, EcError> {
if !public_key.is_on_curve() {
return Err(EcError::InvalidCurve);
}
Ok(Self {
group,
public_key,
private_key: None,
_phantom: PhantomData,
})
}
pub fn from_public_key_affine_coordinates(
group: &'static EcGroup,
x: &BigNumberHelper,
y: &BigNumberHelper,
) -> Result<Self, EcError> {
let x_biguint = BigUint::from_bytes_be(&x.to_bytes());
let y_biguint = BigUint::from_bytes_be(&y.to_bytes());
let point = EcPoint::new(group, x_biguint, y_biguint)?;
EcKey::from_public_key(group, point)
}
}
pub fn pad_to_32_bytes(num: &BigUint) -> Result<Vec<u8>, EcError> {
let bytes = num.to_bytes_be();
if bytes.len() > 32 {
return Err(EcError::InvalidPoint);
}
let mut result = vec![0; 32];
result[32 - bytes.len()..].copy_from_slice(&bytes);
Ok(result)
}
fn generate_random_below(max: &BigUint) -> Result<BigUint, EcError> {
use rand::{thread_rng, RngCore};
let mut rng = thread_rng();
let byte_length = (max.bits() + 7) / 8;
let mut bytes = vec![0u8; byte_length as usize];
loop {
rng.fill_bytes(&mut bytes);
let value = BigUint::from_bytes_be(&bytes);
if value < *max {
return Ok(value);
}
}
}
#[derive(Debug)]
pub enum PKeyType {
Public,
Private,
}
#[derive(Debug)]
pub struct PKey<T> {
pub key_data: Vec<u8>,
key_type: PKeyType,
_marker: PhantomData<T>,
}
pub trait KeyTypeMarker {}
impl KeyTypeMarker for Public {}
impl KeyTypeMarker for Private {}
impl<T: KeyTypeMarker> PKey<T> {
pub fn new(key_data: Vec<u8>, key_type: PKeyType) -> Self {
PKey {
key_data,
key_type,
_marker: PhantomData,
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.key_data
}
}
impl<T: KeyType + KeyTypeMarker> PKey<T> {
pub fn from_ec_key(ec_key: EcKey<T>) -> Result<Self, ece::Error> {
let mut key_data = Vec::new();
if let Some(priv_key) = ec_key.private_key.as_ref() {
if priv_key >= ec_key.group().order() {
return Err(ece::Error::CryptoError);
}
let priv_bytes = pad_to_32_bytes(priv_key).map_err(|_| ece::Error::CryptoError)?;
key_data.extend_from_slice(&priv_bytes);
}
let public_key = ec_key.public_key();
let (x, y) = public_key.get_affine().map_err(|e| {
error!("Affine transformation error: {:?}", e);
ece::Error::CryptoError
})?;
let group = ec_key.group();
if !group.check_point(&x, &y) {
error!("Public key is not on curve!");
return Err(ece::Error::CryptoError);
}
let x_bytes = pad_to_32_bytes(&x).map_err(|_| ece::Error::CryptoError)?;
let y_bytes = pad_to_32_bytes(&y).map_err(|_| ece::Error::CryptoError)?;
key_data.extend_from_slice(&x_bytes);
key_data.extend_from_slice(&y_bytes);
Ok(Self {
key_data,
key_type: if ec_key.private_key.is_some() {
PKeyType::Private
} else {
PKeyType::Public
},
_marker: PhantomData,
})
}
}