use crypto_bigint::U256;
use subtle::{Choice, ConditionallySelectable};
use crate::error::Error;
use crate::sm2::field::{
fp_add, fp_from_bytes, fp_inv, fp_mul, fp_neg, fp_square, fp_sub, fp_to_bytes, Fp, CURVE_A,
CURVE_B, FIELD_MODULUS, GX, GY,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AffinePoint {
pub x: Fp,
pub y: Fp,
}
#[derive(Clone, Copy, Debug)]
pub struct JacobianPoint {
pub(crate) x: Fp,
pub(crate) y: Fp,
pub(crate) z: Fp,
}
impl ConditionallySelectable for JacobianPoint {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
JacobianPoint {
x: Fp::conditional_select(&a.x, &b.x, choice),
y: Fp::conditional_select(&a.y, &b.y, choice),
z: Fp::conditional_select(&a.z, &b.z, choice),
}
}
}
impl JacobianPoint {
pub const INFINITY: Self = JacobianPoint {
x: Fp::ONE,
y: Fp::ONE,
z: Fp::ZERO,
};
pub fn from_affine(p: &AffinePoint) -> Self {
JacobianPoint {
x: p.x,
y: p.y,
z: Fp::ONE,
}
}
pub fn to_affine(&self) -> Result<AffinePoint, Error> {
if self.is_infinity() {
return Err(Error::PointAtInfinity);
}
let z_inv = fp_inv(&self.z).ok_or(Error::PointAtInfinity)?;
let z_inv2 = fp_square(&z_inv);
let z_inv3 = fp_mul(&z_inv2, &z_inv);
Ok(AffinePoint {
x: fp_mul(&self.x, &z_inv2),
y: fp_mul(&self.y, &z_inv3),
})
}
pub fn is_infinity(&self) -> bool {
bool::from(self.ct_is_infinity())
}
fn ct_is_infinity(&self) -> Choice {
use subtle::ConstantTimeEq;
fp_to_bytes(&self.z).ct_eq(&[0u8; 32])
}
pub fn double(&self) -> Self {
let (x1, y1, z1) = (&self.x, &self.y, &self.z);
let delta = fp_square(z1); let gamma = fp_square(y1); let beta = fp_mul(x1, &gamma);
let alpha = fp_mul(&fp_sub(x1, &delta), &fp_add(x1, &delta));
let alpha = fp_add(&fp_add(&alpha, &alpha), &alpha);
let x3 = fp_sub(&fp_square(&alpha), &double2(&double1(&beta)));
let z3 = fp_sub(&fp_sub(&fp_square(&fp_add(y1, z1)), &gamma), &delta);
let gamma2 = fp_square(&gamma);
let y3 = fp_sub(
&fp_mul(&alpha, &fp_sub(&double2(&beta), &x3)),
&double2(&double1(&gamma2)),
);
let d = JacobianPoint {
x: x3,
y: y3,
z: z3,
};
JacobianPoint::conditional_select(&d, self, self.ct_is_infinity())
}
pub fn add(p: &JacobianPoint, q: &JacobianPoint) -> JacobianPoint {
use subtle::ConstantTimeEq;
let z1sq = fp_square(&p.z);
let z2sq = fp_square(&q.z);
let u1 = fp_mul(&p.x, &z2sq); let u2 = fp_mul(&q.x, &z1sq); let s1 = fp_mul(&p.y, &fp_mul(&q.z, &z2sq)); let s2 = fp_mul(&q.y, &fp_mul(&p.z, &z1sq));
let h = fp_sub(&u2, &u1);
let r = fp_sub(&s2, &s1);
let h_is_zero = fp_to_bytes(&h).ct_eq(&[0u8; 32]);
let r_is_zero = fp_to_bytes(&r).ct_eq(&[0u8; 32]);
let h2 = fp_square(&h);
let h3 = fp_mul(&h, &h2);
let u1h2 = fp_mul(&u1, &h2);
let x3 = fp_sub(&fp_sub(&fp_square(&r), &h3), &double1(&u1h2));
let y3 = fp_sub(&fp_mul(&r, &fp_sub(&u1h2, &x3)), &fp_mul(&s1, &h3));
let z3 = fp_mul(&fp_mul(&h, &p.z), &q.z);
let normal = JacobianPoint {
x: x3,
y: y3,
z: z3,
};
let double_p = p.double();
let result = normal;
let result = JacobianPoint::conditional_select(
&result,
&JacobianPoint::INFINITY,
h_is_zero & !r_is_zero,
);
let result = JacobianPoint::conditional_select(&result, &double_p, h_is_zero & r_is_zero);
let result = JacobianPoint::conditional_select(&result, p, q.ct_is_infinity());
JacobianPoint::conditional_select(&result, q, p.ct_is_infinity())
}
pub fn scalar_mul(k: &U256, p: &JacobianPoint) -> JacobianPoint {
let mut result = JacobianPoint::INFINITY;
for byte in &k.to_be_bytes() {
for b in (0..8).rev() {
result = result.double();
let sum = JacobianPoint::add(&result, p);
let bit = Choice::from((byte >> b) & 1);
result = JacobianPoint::conditional_select(&result, &sum, bit);
}
}
result
}
pub fn scalar_mul_g(k: &U256) -> JacobianPoint {
scalar_mul_g_window(k)
}
}
#[inline]
fn double1(a: &Fp) -> Fp {
fp_add(a, a)
}
#[inline]
fn double2(a: &Fp) -> Fp {
let t = double1(a);
double1(&t)
}
fn add_mixed(p: &JacobianPoint, q: &AffinePoint) -> JacobianPoint {
use subtle::ConstantTimeEq;
let z1sq = fp_square(&p.z); let z1cu = fp_mul(&p.z, &z1sq); let u2 = fp_mul(&q.x, &z1sq); let s2 = fp_mul(&q.y, &z1cu);
let h = fp_sub(&u2, &p.x);
let r = fp_sub(&s2, &p.y);
let h_is_zero = fp_to_bytes(&h).ct_eq(&[0u8; 32]);
let r_is_zero = fp_to_bytes(&r).ct_eq(&[0u8; 32]);
let h2 = fp_square(&h);
let h3 = fp_mul(&h, &h2);
let u1h2 = fp_mul(&p.x, &h2);
let x3 = fp_sub(&fp_sub(&fp_square(&r), &h3), &double1(&u1h2));
let y3 = fp_sub(&fp_mul(&r, &fp_sub(&u1h2, &x3)), &fp_mul(&p.y, &h3));
let z3 = fp_mul(&h, &p.z);
let normal = JacobianPoint {
x: x3,
y: y3,
z: z3,
};
let double_p = p.double();
let result = normal;
let result = JacobianPoint::conditional_select(
&result,
&JacobianPoint::INFINITY,
h_is_zero & !r_is_zero,
);
let result = JacobianPoint::conditional_select(&result, &double_p, h_is_zero & r_is_zero);
let q_jac = JacobianPoint::from_affine(q);
JacobianPoint::conditional_select(&result, &q_jac, p.ct_is_infinity())
}
fn scalar_mul_g_window(k: &U256) -> JacobianPoint {
use subtle::ConstantTimeEq;
let g_aff = AffinePoint { x: GX, y: GY };
let g_jac = JacobianPoint::from_affine(&g_aff);
let mut table = [JacobianPoint::INFINITY; 16];
table[1] = g_jac;
for i in 2..=15usize {
table[i] = add_mixed(&table[i - 1], &g_aff);
}
let mut result = JacobianPoint::INFINITY;
for byte in &k.to_be_bytes() {
for _ in 0..4 {
result = result.double();
}
let window = byte >> 4;
let mut sel = JacobianPoint::INFINITY;
for j in 1u8..=15 {
let eq = window.ct_eq(&j);
sel = JacobianPoint::conditional_select(&sel, &table[j as usize], eq);
}
result = JacobianPoint::add(&result, &sel);
for _ in 0..4 {
result = result.double();
}
let window = byte & 0xF;
let mut sel = JacobianPoint::INFINITY;
for j in 1u8..=15 {
let eq = window.ct_eq(&j);
sel = JacobianPoint::conditional_select(&sel, &table[j as usize], eq);
}
result = JacobianPoint::add(&result, &sel);
}
result
}
impl AffinePoint {
pub fn generator() -> Self {
AffinePoint { x: GX, y: GY }
}
pub fn is_on_curve(&self) -> bool {
let x2 = fp_square(&self.x);
let x3 = fp_mul(&x2, &self.x);
let ax = fp_mul(&CURVE_A, &self.x);
let rhs = fp_add(&fp_add(&x3, &ax), &CURVE_B);
fp_square(&self.y) == rhs
}
pub fn from_bytes(bytes: &[u8; 65]) -> Result<Self, Error> {
if bytes[0] != 0x04 {
return Err(Error::InvalidPublicKey);
}
let x_bytes: [u8; 32] = bytes[1..33].try_into().unwrap();
let y_bytes: [u8; 32] = bytes[33..65].try_into().unwrap();
use crypto_bigint::subtle::ConstantTimeGreater;
let x_val = U256::from_be_slice(&x_bytes);
let y_val = U256::from_be_slice(&y_bytes);
if bool::from(x_val.ct_gt(&FIELD_MODULUS))
|| x_val == FIELD_MODULUS
|| bool::from(y_val.ct_gt(&FIELD_MODULUS))
|| y_val == FIELD_MODULUS
{
return Err(Error::InvalidPublicKey);
}
let p = AffinePoint {
x: fp_from_bytes(&x_bytes),
y: fp_from_bytes(&y_bytes),
};
if !p.is_on_curve() {
return Err(Error::InvalidPublicKey);
}
Ok(p)
}
pub fn to_bytes(&self) -> [u8; 65] {
let mut out = [0u8; 65];
out[0] = 0x04;
out[1..33].copy_from_slice(&fp_to_bytes(&self.x));
out[33..65].copy_from_slice(&fp_to_bytes(&self.y));
out
}
pub fn decompress(bytes: &[u8; 33]) -> Result<Self, Error> {
let prefix = bytes[0];
if prefix != 0x02 && prefix != 0x03 {
return Err(Error::InvalidPublicKey);
}
let x_bytes: [u8; 32] = bytes[1..33].try_into().unwrap();
use crypto_bigint::subtle::ConstantTimeGreater;
let x_val = U256::from_be_slice(&x_bytes);
if bool::from(x_val.ct_gt(&FIELD_MODULUS)) || x_val == FIELD_MODULUS {
return Err(Error::InvalidPublicKey);
}
let x = fp_from_bytes(&x_bytes);
let x2 = fp_square(&x);
let x3 = fp_mul(&x2, &x);
let ax = fp_mul(&CURVE_A, &x);
let y2 = fp_add(&fp_add(&x3, &ax), &CURVE_B);
let y = crate::sm2::field::fp_sqrt(&y2).ok_or(Error::InvalidPublicKey)?;
let y_lsb = fp_to_bytes(&y)[31] & 1;
let want_odd = prefix & 1;
let y_final = if y_lsb == want_odd { y } else { fp_neg(&y) };
Ok(AffinePoint { x, y: y_final })
}
}
pub fn multi_scalar_mul(u: &U256, v: &U256, q: &AffinePoint) -> Result<AffinePoint, Error> {
let g = AffinePoint::generator();
let q_jac = JacobianPoint::from_affine(q);
let g_jac = JacobianPoint::from_affine(&g);
let gq_jac = JacobianPoint::add(&g_jac, &q_jac);
let u_bytes = u.to_be_bytes();
let v_bytes = v.to_be_bytes();
let mut result = JacobianPoint::INFINITY;
for i in 0..32 {
let ub = u_bytes[i];
let vb = v_bytes[i];
for b in (0..8).rev() {
result = result.double();
let ui = (ub >> b) & 1;
let vi = (vb >> b) & 1;
match (ui, vi) {
(1, 0) => result = add_mixed(&result, &g),
(0, 1) => result = add_mixed(&result, q),
(1, 1) => result = JacobianPoint::add(&result, &gq_jac),
_ => {}
}
}
}
result.to_affine()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::field::{fp_to_bytes, GX, GY};
#[test]
fn test_generator_on_curve() {
assert!(AffinePoint::generator().is_on_curve());
}
#[test]
fn test_double_stays_on_curve() {
let g = JacobianPoint::from_affine(&AffinePoint::generator());
let g2 = g.double().to_affine().unwrap();
assert!(g2.is_on_curve());
}
#[test]
fn test_add_commutativity() {
let g = JacobianPoint::from_affine(&AffinePoint::generator());
let g2 = g.double();
let p1 = JacobianPoint::add(&g2, &g).to_affine().unwrap();
let p2 = JacobianPoint::add(&g, &g2).to_affine().unwrap();
assert_eq!(fp_to_bytes(&p1.x), fp_to_bytes(&p2.x));
assert_eq!(fp_to_bytes(&p1.y), fp_to_bytes(&p2.y));
assert!(p1.is_on_curve());
}
#[test]
fn test_scalar_mul_one_is_g() {
let g1 = JacobianPoint::scalar_mul_g(&U256::ONE).to_affine().unwrap();
assert_eq!(fp_to_bytes(&g1.x), fp_to_bytes(&GX));
assert_eq!(fp_to_bytes(&g1.y), fp_to_bytes(&GY));
}
#[test]
fn test_serialization_roundtrip() {
let g = AffinePoint::generator();
let bytes = g.to_bytes();
assert_eq!(bytes[0], 0x04);
let g2 = AffinePoint::from_bytes(&bytes).unwrap();
assert_eq!(fp_to_bytes(&g.x), fp_to_bytes(&g2.x));
assert_eq!(fp_to_bytes(&g.y), fp_to_bytes(&g2.y));
}
#[test]
fn test_keypair_on_curve() {
let k_hex = "f927525e176ae5607c628bc508ec0465ef285b74415bf876130a8a5d004c789e";
let k_bytes: [u8; 32] = {
let mut b = [0u8; 32];
for (i, chunk) in k_hex.as_bytes().chunks(2).enumerate() {
b[i] = u8::from_str_radix(core::str::from_utf8(chunk).unwrap(), 16).unwrap();
}
b
};
let k = U256::from_be_slice(&k_bytes);
let pub_aff = JacobianPoint::scalar_mul_g(&k).to_affine().unwrap();
assert!(pub_aff.is_on_curve());
let x2 = fp_square(&pub_aff.x);
let x3 = fp_mul(&x2, &pub_aff.x);
let ax = fp_mul(&CURVE_A, &pub_aff.x);
let rhs = fp_add(&fp_add(&x3, &ax), &CURVE_B);
assert_eq!(rhs, fp_square(&pub_aff.y));
}
#[test]
fn test_add_degenerate_cases() {
let g = JacobianPoint::from_affine(&AffinePoint::generator());
let inf = JacobianPoint::INFINITY;
let r = JacobianPoint::add(&inf, &g).to_affine().unwrap();
assert_eq!(fp_to_bytes(&r.x), fp_to_bytes(&GX), "∞ + G 的 x 坐标错误");
assert_eq!(fp_to_bytes(&r.y), fp_to_bytes(&GY), "∞ + G 的 y 坐标错误");
let r = JacobianPoint::add(&g, &inf).to_affine().unwrap();
assert_eq!(fp_to_bytes(&r.x), fp_to_bytes(&GX), "G + ∞ 的 x 坐标错误");
let add_gg = JacobianPoint::add(&g, &g).to_affine().unwrap();
let double_g = g.double().to_affine().unwrap();
assert_eq!(
fp_to_bytes(&add_gg.x),
fp_to_bytes(&double_g.x),
"add(G,G) != double(G) 的 x 坐标"
);
assert_eq!(
fp_to_bytes(&add_gg.y),
fp_to_bytes(&double_g.y),
"add(G,G) != double(G) 的 y 坐标"
);
let g_neg = JacobianPoint {
x: g.x,
y: fp_neg(&g.y),
z: g.z,
};
assert!(
JacobianPoint::add(&g, &g_neg).is_infinity(),
"G + (-G) 应为无穷远点"
);
}
}