use crate::asn1::oid::SM2P256V1;
use crate::asn1::{reader, writer};
use crate::sm2::curve::{Fn, Fp};
use crate::sm2::encrypt::{point_on_curve, projective_from_affine};
use crate::sm2::point::ProjectivePoint;
use alloc::vec::Vec;
use crypto_bigint::U256;
use subtle::ConstantTimeLess;
use zeroize::Zeroize;
pub(crate) const SEC1_TAG_UNCOMPRESSED: u8 = 0x04;
pub(crate) const SEC1_UNCOMPRESSED_LEN: usize = 65;
const ECPRIVKEY_VER1: u8 = 1;
#[must_use]
pub(crate) fn encode_uncompressed_point(x: &Fp, y: &Fp) -> [u8; SEC1_UNCOMPRESSED_LEN] {
let mut out = [0u8; SEC1_UNCOMPRESSED_LEN];
out[0] = SEC1_TAG_UNCOMPRESSED;
out[1..33].copy_from_slice(&x.retrieve().to_be_bytes());
out[33..65].copy_from_slice(&y.retrieve().to_be_bytes());
out
}
#[must_use]
pub(crate) fn decode_uncompressed_point(input: &[u8]) -> Option<ProjectivePoint> {
if input.len() != SEC1_UNCOMPRESSED_LEN {
return None;
}
if input[0] != SEC1_TAG_UNCOMPRESSED {
return None;
}
let x_be = &input[1..33];
let y_be = &input[33..65];
let x_u = U256::from_be_slice(x_be);
let y_u = U256::from_be_slice(y_be);
let p = *Fp::MODULUS.as_ref();
if !bool::from(x_u.ct_lt(&p)) || !bool::from(y_u.ct_lt(&p)) {
return None;
}
let x = Fp::new(&x_u);
let y = Fp::new(&y_u);
if !point_on_curve(&x, &y) {
return None;
}
Some(projective_from_affine(x, y))
}
#[must_use]
pub fn encode(
scalar_be: &[u8; 32],
public_uncompressed: Option<&[u8; SEC1_UNCOMPRESSED_LEN]>,
) -> Vec<u8> {
let mut body = Vec::with_capacity(120);
writer::write_integer(&mut body, &[ECPRIVKEY_VER1]);
writer::write_octet_string(&mut body, scalar_be);
let mut params_inner = Vec::with_capacity(SM2P256V1.len() + 2);
writer::write_oid(&mut params_inner, SM2P256V1);
writer::write_context_tagged_explicit(&mut body, 0, ¶ms_inner);
if let Some(pk) = public_uncompressed {
let mut pk_inner = Vec::with_capacity(SEC1_UNCOMPRESSED_LEN + 4);
writer::write_bit_string(&mut pk_inner, 0, pk);
writer::write_context_tagged_explicit(&mut body, 1, &pk_inner);
}
let mut out = Vec::with_capacity(body.len() + 4);
writer::write_sequence(&mut out, &body);
body.zeroize();
out
}
#[derive(Debug, Clone)]
pub struct EcPrivateKey {
pub scalar_be: [u8; 32],
pub public: Option<ProjectivePoint>,
}
impl Drop for EcPrivateKey {
fn drop(&mut self) {
self.scalar_be.zeroize();
}
}
#[must_use]
pub fn decode(input: &[u8]) -> Option<EcPrivateKey> {
let (body, rest) = reader::read_sequence(input)?;
if !rest.is_empty() {
return None;
}
let (version, body) = reader::read_integer(body)?;
if version != [ECPRIVKEY_VER1] {
return None;
}
let (scalar_bytes, mut body) = reader::read_octet_string(body)?;
if scalar_bytes.len() != 32 {
return None;
}
let mut scalar_be = [0u8; 32];
scalar_be.copy_from_slice(scalar_bytes);
let mut public: Option<ProjectivePoint> = None;
if let Some((params_inner, after)) = reader::read_context_tagged_explicit(body, 0) {
let (oid, params_rest) = reader::read_oid(params_inner)?;
if !params_rest.is_empty() || oid != SM2P256V1 {
scalar_be.zeroize();
return None;
}
body = after;
}
if let Some((pk_inner, after)) = reader::read_context_tagged_explicit(body, 1) {
let (unused, pk_bytes, pk_rest) = reader::read_bit_string(pk_inner)?;
if unused != 0 || !pk_rest.is_empty() {
scalar_be.zeroize();
return None;
}
if let Some(p) = decode_uncompressed_point(pk_bytes) {
public = Some(p);
} else {
scalar_be.zeroize();
return None;
}
body = after;
}
if !body.is_empty() {
scalar_be.zeroize();
return None;
}
Some(EcPrivateKey { scalar_be, public })
}
#[must_use]
#[allow(dead_code)]
pub(crate) fn validate_scalar(scalar_be: &[u8; 32]) -> Option<Fn> {
let d = U256::from_be_slice(scalar_be);
let n = *Fn::MODULUS.as_ref();
if d == U256::ZERO {
return None;
}
if !bool::from(d.ct_lt(&n)) {
return None;
}
Some(Fn::new(&d))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::point::ProjectivePoint;
#[test]
fn uncompressed_point_round_trip_generator() {
let g = ProjectivePoint::generator();
let (x, y) = g.to_affine().expect("G finite");
let bytes = encode_uncompressed_point(&x, &y);
assert_eq!(bytes[0], 0x04);
let recovered = decode_uncompressed_point(&bytes).expect("decode");
let (rx, ry) = recovered.to_affine().expect("recovered finite");
assert_eq!(rx.retrieve(), x.retrieve());
assert_eq!(ry.retrieve(), y.retrieve());
}
#[test]
fn uncompressed_point_rejects_wrong_length() {
assert!(decode_uncompressed_point(&[0x04]).is_none());
assert!(decode_uncompressed_point(&[0x04; 64]).is_none());
assert!(decode_uncompressed_point(&[0x04; 66]).is_none());
}
#[test]
fn uncompressed_point_rejects_compressed_tag() {
let mut bytes = [0u8; 65];
bytes[0] = 0x02;
assert!(decode_uncompressed_point(&bytes).is_none());
bytes[0] = 0x03;
assert!(decode_uncompressed_point(&bytes).is_none());
}
#[test]
fn uncompressed_point_rejects_off_curve() {
let mut bytes = [0u8; 65];
bytes[0] = 0x04;
bytes[1] = 1;
bytes[33] = 1;
assert!(decode_uncompressed_point(&bytes).is_none());
}
#[test]
fn uncompressed_point_rejects_x_at_or_above_p() {
let g = ProjectivePoint::generator();
let (_x, y) = g.to_affine().expect("G finite");
let p = *Fp::MODULUS.as_ref();
let mut bytes = [0u8; SEC1_UNCOMPRESSED_LEN];
bytes[0] = 0x04;
bytes[1..33].copy_from_slice(&p.to_be_bytes());
bytes[33..65].copy_from_slice(&y.retrieve().to_be_bytes());
assert!(
decode_uncompressed_point(&bytes).is_none(),
"X = p must be rejected"
);
}
#[test]
fn ecprivatekey_round_trip_with_public() {
let scalar_be: [u8; 32] = [
0x39, 0x45, 0x20, 0x8F, 0x7B, 0x21, 0x44, 0xB1, 0x3F, 0x36, 0xE3, 0x8A, 0xC6, 0xD3,
0x9F, 0x95, 0x88, 0x93, 0x93, 0x69, 0x28, 0x60, 0xB5, 0x1A, 0x42, 0xFB, 0x81, 0xEF,
0x4D, 0xF7, 0xC5, 0xB8,
];
let d = U256::from_be_slice(&scalar_be);
let key = crate::sm2::Sm2PrivateKey::from_scalar_inner(d).expect("valid d");
let (x, y) = key.public_key().to_affine().expect("finite");
let pk = encode_uncompressed_point(&x, &y);
let der = encode(&scalar_be, Some(&pk));
let recovered = decode(&der).expect("decode");
assert_eq!(recovered.scalar_be, scalar_be);
assert!(recovered.public.is_some());
let (rx, ry) = recovered.public.unwrap().to_affine().expect("finite");
assert_eq!(rx.retrieve(), x.retrieve());
assert_eq!(ry.retrieve(), y.retrieve());
}
#[test]
fn ecprivatekey_round_trip_minimal() {
let scalar_be: [u8; 32] = [0x42; 32];
let der = encode(&scalar_be, None);
let recovered = decode(&der).expect("decode");
assert_eq!(recovered.scalar_be, scalar_be);
assert!(recovered.public.is_none());
}
#[test]
fn ecprivatekey_round_trip_params_only() {
let scalar_be: [u8; 32] = [0x11; 32];
let der = encode(&scalar_be, None);
let recovered = decode(&der).expect("decode");
let der2 = encode(
&recovered.scalar_be,
recovered
.public
.as_ref()
.map(|p| {
let (x, y) = p.to_affine().expect("finite");
encode_uncompressed_point(&x, &y)
})
.as_ref(),
);
assert_eq!(der, der2);
}
#[test]
fn ecprivatekey_rejects_wrong_version() {
let bad = [
0x30, 0x05, 0x02, 0x01, 0x02, 0x04, 0x00, ];
assert!(decode(&bad).is_none());
}
#[test]
fn ecprivatekey_rejects_short_scalar() {
let bad = [
0x30, 0x06, 0x02, 0x01, 0x01, 0x04, 0x01, 0xAB, ];
assert!(decode(&bad).is_none());
}
#[test]
fn ecprivatekey_rejects_wrong_curve_oid() {
let mut body = Vec::new();
writer::write_integer(&mut body, &[1]);
let scalar = [0u8; 32];
writer::write_octet_string(&mut body, &scalar);
let p256_oid = &[0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07];
let mut params = Vec::new();
writer::write_oid(&mut params, p256_oid);
writer::write_context_tagged_explicit(&mut body, 0, ¶ms);
let mut der = Vec::new();
writer::write_sequence(&mut der, &body);
assert!(
decode(&der).is_none(),
"non-SM2 namedCurve must be rejected"
);
}
#[test]
fn ecprivatekey_rejects_trailing_bytes() {
let scalar_be: [u8; 32] = [0x42; 32];
let mut der = encode(&scalar_be, None);
der.push(0x00);
assert!(decode(&der).is_none(), "trailing byte must be rejected");
}
#[test]
fn validate_scalar_rejects_zero() {
let zero = [0u8; 32];
assert!(validate_scalar(&zero).is_none());
}
#[test]
fn validate_scalar_rejects_n() {
let n = *Fn::MODULUS.as_ref();
let n_bytes = n.to_be_bytes();
let mut buf = [0u8; 32];
buf.copy_from_slice(&n_bytes);
assert!(validate_scalar(&buf).is_none());
}
}