use crate::asn1::oid::{ID_EC_PUBLIC_KEY, ID_HMAC_WITH_SM3, ID_PBKDF2, PBES2, SM2P256V1, SM4_CBC};
use crate::asn1::{reader, writer};
use crate::kdf::pbkdf2_hmac_sm3;
use crate::sec1;
use crate::sm2::Sm2PrivateKey;
use crate::sm4::mode_cbc;
use alloc::vec::Vec;
use crypto_bigint::U256;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
Failed,
}
const PKCS8_V1: u8 = 0;
const PKCS8_V2: u8 = 1;
const SM4_KEY_LEN: usize = 16;
const SM4_IV_LEN: usize = 16;
pub const PBKDF2_MAX_ITERATIONS: u32 = 10_000_000;
#[must_use]
pub fn encode(key: &Sm2PrivateKey) -> Vec<u8> {
let mut scalar_be = key.to_sec1_be();
let pub_uncompressed = {
let pub_key = crate::sm2::Sm2PublicKey::from_point(key.public_key());
pub_key.to_sec1_uncompressed()
};
let mut inner = sec1::encode(&scalar_be, Some(&pub_uncompressed));
scalar_be.zeroize();
let mut alg_inner = Vec::with_capacity(ID_EC_PUBLIC_KEY.len() + SM2P256V1.len() + 4);
writer::write_oid(&mut alg_inner, ID_EC_PUBLIC_KEY);
writer::write_oid(&mut alg_inner, SM2P256V1);
let mut alg_seq = Vec::with_capacity(alg_inner.len() + 4);
writer::write_sequence(&mut alg_seq, &alg_inner);
let mut body = Vec::with_capacity(inner.len() + alg_seq.len() + 8);
writer::write_integer(&mut body, &[PKCS8_V1]);
body.extend_from_slice(&alg_seq);
writer::write_octet_string(&mut body, &inner);
let mut out = Vec::with_capacity(body.len() + 4);
writer::write_sequence(&mut out, &body);
inner.zeroize();
body.zeroize();
out
}
pub fn decode(input: &[u8]) -> Result<Sm2PrivateKey, Error> {
let (body, rest) = reader::read_sequence(input).ok_or(Error::Failed)?;
if !rest.is_empty() {
return Err(Error::Failed);
}
let (version, body) = reader::read_integer(body).ok_or(Error::Failed)?;
if version != [PKCS8_V1] && version != [PKCS8_V2] {
return Err(Error::Failed);
}
let (alg_inner, body) = reader::read_sequence(body).ok_or(Error::Failed)?;
let (alg_oid, alg_inner) = reader::read_oid(alg_inner).ok_or(Error::Failed)?;
if alg_oid != ID_EC_PUBLIC_KEY {
return Err(Error::Failed);
}
let (curve_oid, alg_inner) = reader::read_oid(alg_inner).ok_or(Error::Failed)?;
if curve_oid != SM2P256V1 || !alg_inner.is_empty() {
return Err(Error::Failed);
}
let (inner_bytes, body) = reader::read_octet_string(body).ok_or(Error::Failed)?;
let mut inner = sec1::decode(inner_bytes).ok_or(Error::Failed)?;
if !body.is_empty() {
let mut tail = body;
while !tail.is_empty() {
if let Some((_, after)) = reader::read_context_tagged_explicit(tail, 0) {
tail = after;
continue;
}
if let Some((_, after)) = reader::read_context_tagged_explicit(tail, 1) {
tail = after;
continue;
}
inner.scalar_be.zeroize();
return Err(Error::Failed);
}
}
let d = U256::from_be_slice(&inner.scalar_be);
inner.scalar_be.zeroize();
let key = Sm2PrivateKey::new(d);
let key: Option<Sm2PrivateKey> = key.into();
let key = key.ok_or(Error::Failed)?;
if let Some(stored_pub) = inner.public {
let derived = key.public_key();
if !bool::from(stored_pub.ct_eq(&derived)) {
return Err(Error::Failed);
}
}
Ok(key)
}
pub fn encrypt(
key: &Sm2PrivateKey,
password: &[u8],
salt: &[u8],
iterations: u32,
iv: &[u8; SM4_IV_LEN],
) -> Result<Vec<u8>, Error> {
if iterations == 0 {
return Err(Error::Failed);
}
let mut inner = encode(key);
let mut sm4_key = [0u8; SM4_KEY_LEN];
pbkdf2_hmac_sm3(password, salt, iterations, &mut sm4_key).ok_or(Error::Failed)?;
let ciphertext = mode_cbc::encrypt(&sm4_key, iv, &inner);
inner.zeroize();
sm4_key.zeroize();
let pbes2_params = build_pbes2_params(salt, iterations, iv);
let mut alg_inner = Vec::with_capacity(PBES2.len() + pbes2_params.len() + 4);
writer::write_oid(&mut alg_inner, PBES2);
alg_inner.extend_from_slice(&pbes2_params);
let mut alg_seq = Vec::with_capacity(alg_inner.len() + 4);
writer::write_sequence(&mut alg_seq, &alg_inner);
let mut body = Vec::with_capacity(alg_seq.len() + ciphertext.len() + 8);
body.extend_from_slice(&alg_seq);
writer::write_octet_string(&mut body, &ciphertext);
let mut out = Vec::with_capacity(body.len() + 4);
writer::write_sequence(&mut out, &body);
Ok(out)
}
fn build_pbes2_params(salt: &[u8], iterations: u32, iv: &[u8; SM4_IV_LEN]) -> Vec<u8> {
let mut pbkdf2_inner = Vec::with_capacity(salt.len() + 32);
writer::write_octet_string(&mut pbkdf2_inner, salt);
writer::write_integer(&mut pbkdf2_inner, &iterations.to_be_bytes());
let mut prf_inner = Vec::with_capacity(ID_HMAC_WITH_SM3.len() + 4);
writer::write_oid(&mut prf_inner, ID_HMAC_WITH_SM3);
writer::write_null(&mut prf_inner);
let mut prf_seq = Vec::with_capacity(prf_inner.len() + 4);
writer::write_sequence(&mut prf_seq, &prf_inner);
pbkdf2_inner.extend_from_slice(&prf_seq);
let mut pbkdf2_seq = Vec::with_capacity(pbkdf2_inner.len() + 4);
writer::write_sequence(&mut pbkdf2_seq, &pbkdf2_inner);
let mut kdf_inner = Vec::with_capacity(ID_PBKDF2.len() + pbkdf2_seq.len() + 4);
writer::write_oid(&mut kdf_inner, ID_PBKDF2);
kdf_inner.extend_from_slice(&pbkdf2_seq);
let mut kdf_seq = Vec::with_capacity(kdf_inner.len() + 4);
writer::write_sequence(&mut kdf_seq, &kdf_inner);
let mut es_inner = Vec::with_capacity(SM4_CBC.len() + iv.len() + 4);
writer::write_oid(&mut es_inner, SM4_CBC);
writer::write_octet_string(&mut es_inner, iv);
let mut es_seq = Vec::with_capacity(es_inner.len() + 4);
writer::write_sequence(&mut es_seq, &es_inner);
let mut params_inner = Vec::with_capacity(kdf_seq.len() + es_seq.len());
params_inner.extend_from_slice(&kdf_seq);
params_inner.extend_from_slice(&es_seq);
let mut out = Vec::with_capacity(params_inner.len() + 4);
writer::write_sequence(&mut out, ¶ms_inner);
out
}
pub fn decrypt(input: &[u8], password: &[u8]) -> Result<Sm2PrivateKey, Error> {
let parsed = parse_encrypted_blob(input).ok_or(Error::Failed)?;
let mut sm4_key = [0u8; SM4_KEY_LEN];
let derive_ok =
pbkdf2_hmac_sm3(password, parsed.salt, parsed.iterations, &mut sm4_key).is_some();
let _ = derive_ok;
let plaintext = mode_cbc::decrypt(&sm4_key, &parsed.iv, parsed.ciphertext);
sm4_key.zeroize();
let mut plaintext = plaintext.ok_or(Error::Failed)?;
let result = decode(&plaintext);
plaintext.zeroize();
result
}
struct ParsedEncrypted<'a> {
salt: &'a [u8],
iterations: u32,
iv: [u8; SM4_IV_LEN],
ciphertext: &'a [u8],
}
fn parse_encrypted_blob(input: &[u8]) -> Option<ParsedEncrypted<'_>> {
let (body, rest) = reader::read_sequence(input)?;
if !rest.is_empty() {
return None;
}
let (alg_inner, body) = reader::read_sequence(body)?;
let (alg_oid, alg_inner) = reader::read_oid(alg_inner)?;
if alg_oid != PBES2 {
return None;
}
let (params_inner, alg_inner_rest) = reader::read_sequence(alg_inner)?;
if !alg_inner_rest.is_empty() {
return None;
}
let (kdf_seq, params_rest) = reader::read_sequence(params_inner)?;
let (kdf_oid, kdf_after) = reader::read_oid(kdf_seq)?;
if kdf_oid != ID_PBKDF2 {
return None;
}
let (pbkdf2_inner, kdf_seq_rest) = reader::read_sequence(kdf_after)?;
if !kdf_seq_rest.is_empty() {
return None;
}
let (salt, pbkdf2_inner) = reader::read_octet_string(pbkdf2_inner)?;
let (iter_bytes, mut pbkdf2_inner) = reader::read_integer(pbkdf2_inner)?;
if iter_bytes.len() > 4 {
return None;
}
let mut iter_buf = [0u8; 4];
iter_buf[4 - iter_bytes.len()..].copy_from_slice(iter_bytes);
let iterations = u32::from_be_bytes(iter_buf);
if iterations == 0 || iterations > PBKDF2_MAX_ITERATIONS {
return None;
}
if let Some((kl_bytes, after)) = reader::read_integer(pbkdf2_inner) {
if kl_bytes.len() > 4 {
return None;
}
let mut kl_buf = [0u8; 4];
kl_buf[4 - kl_bytes.len()..].copy_from_slice(kl_bytes);
let key_length = u32::from_be_bytes(kl_buf) as usize;
if key_length != SM4_KEY_LEN {
return None;
}
pbkdf2_inner = after;
}
if pbkdf2_inner.is_empty() {
return None;
}
let (prf_seq, prf_rest) = reader::read_sequence(pbkdf2_inner)?;
if !prf_rest.is_empty() {
return None;
}
let (prf_oid, prf_seq_rest) = reader::read_oid(prf_seq)?;
if prf_oid != ID_HMAC_WITH_SM3 {
return None;
}
if !prf_seq_rest.is_empty()
&& (reader::read_null(prf_seq_rest).is_none() || prf_seq_rest.len() != 2)
{
return None;
}
let (es_seq, params_outer_rest) = reader::read_sequence(params_rest)?;
if !params_outer_rest.is_empty() {
return None;
}
let (es_oid, es_after) = reader::read_oid(es_seq)?;
if es_oid != SM4_CBC {
return None;
}
let (iv_bytes, es_seq_rest) = reader::read_octet_string(es_after)?;
if !es_seq_rest.is_empty() || iv_bytes.len() != SM4_IV_LEN {
return None;
}
let mut iv = [0u8; SM4_IV_LEN];
iv.copy_from_slice(iv_bytes);
let (ciphertext, body_rest) = reader::read_octet_string(body)?;
if !body_rest.is_empty() {
return None;
}
Some(ParsedEncrypted {
salt,
iterations,
iv,
ciphertext,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crypto_bigint::U256;
fn sample_key() -> Sm2PrivateKey {
let d =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
Sm2PrivateKey::new(d).expect("valid d")
}
#[test]
fn round_trip_unencrypted() {
let key = sample_key();
let der = encode(&key);
let recovered = decode(&der).expect("decode");
assert!(bool::from(recovered.public_key().ct_eq(&key.public_key())));
}
#[test]
fn unencrypted_rejects_trailing_bytes() {
let key = sample_key();
let mut der = encode(&key);
der.push(0x00);
assert!(matches!(decode(&der), Err(Error::Failed)));
}
#[test]
fn unencrypted_rejects_public_key_mismatch() {
let d1 =
U256::from_be_hex("3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8");
let d2 =
U256::from_be_hex("1649AB77A00637BD5E2EFE283FBF353534AA7F7CB89463F208DDBC2920BB0DA0");
let key1 = Sm2PrivateKey::new(d1).expect("d1");
let key2 = Sm2PrivateKey::new(d2).expect("d2");
let scalar1 = key1.to_sec1_be();
let pk2 = crate::sm2::Sm2PublicKey::from_point(key2.public_key()).to_sec1_uncompressed();
let inner_bad = sec1::encode(&scalar1, Some(&pk2));
let mut alg_inner = Vec::new();
writer::write_oid(&mut alg_inner, ID_EC_PUBLIC_KEY);
writer::write_oid(&mut alg_inner, SM2P256V1);
let mut alg_seq = Vec::new();
writer::write_sequence(&mut alg_seq, &alg_inner);
let mut body = Vec::new();
writer::write_integer(&mut body, &[PKCS8_V1]);
body.extend_from_slice(&alg_seq);
writer::write_octet_string(&mut body, &inner_bad);
let mut out = Vec::new();
writer::write_sequence(&mut out, &body);
assert!(matches!(decode(&out), Err(Error::Failed)));
}
#[test]
fn round_trip_encrypted() {
let key = sample_key();
let salt = [0xAB; 16];
let iv = [0xCD; SM4_IV_LEN];
let blob =
encrypt(&key, b"correct horse battery staple", &salt, 1024, &iv).expect("encrypt");
let recovered =
decrypt(&blob, b"correct horse battery staple").expect("decrypt with right password");
assert!(bool::from(recovered.public_key().ct_eq(&key.public_key())));
}
#[test]
fn encrypted_wrong_password_fails() {
let key = sample_key();
let salt = [0xAB; 16];
let iv = [0xCD; SM4_IV_LEN];
let blob = encrypt(&key, b"right", &salt, 1024, &iv).expect("encrypt");
assert!(matches!(decrypt(&blob, b"wrong"), Err(Error::Failed)));
}
#[test]
fn encrypted_zero_iterations_rejected() {
let key = sample_key();
let salt = [0xAB; 16];
let iv = [0xCD; SM4_IV_LEN];
assert!(matches!(
encrypt(&key, b"pw", &salt, 0, &iv),
Err(Error::Failed)
));
}
#[test]
fn decrypt_rejects_truncated_blob() {
assert!(matches!(decrypt(&[], b"pw"), Err(Error::Failed)));
assert!(matches!(decrypt(&[0x30, 0x00], b"pw"), Err(Error::Failed)));
}
#[test]
fn decrypt_rejects_excessive_iterations() {
let key = sample_key();
let salt = [0xAB; 16];
let iv = [0xCD; SM4_IV_LEN];
let blob = encrypt(&key, b"pw", &salt, 1024, &iv).expect("encrypt");
let bad_iter: u32 = PBKDF2_MAX_ITERATIONS + 1;
let pbes2_params = build_pbes2_params(&salt, bad_iter, &iv);
let mut alg_inner = Vec::new();
writer::write_oid(&mut alg_inner, PBES2);
alg_inner.extend_from_slice(&pbes2_params);
let mut alg_seq = Vec::new();
writer::write_sequence(&mut alg_seq, &alg_inner);
let parsed = parse_encrypted_blob(&blob).expect("baseline parse");
let mut body = Vec::new();
body.extend_from_slice(&alg_seq);
writer::write_octet_string(&mut body, parsed.ciphertext);
let mut bad_blob = Vec::new();
writer::write_sequence(&mut bad_blob, &body);
assert!(matches!(decrypt(&bad_blob, b"pw"), Err(Error::Failed)));
}
}