use crate::encrypt;
use crate::error::Sm2Error;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use byteorder::{BigEndian, WriteBytesExt};
use libsm::sm2::ecc::{EccCtx, Point};
use libsm::sm2::signature::Signature;
use libsm::sm3::hash::Sm3Hash;
use num_bigint::BigUint;
use num_traits::{One, Zero};
use sha1::Digest;
use smcrypto::sm3;
pub fn sign_body(data: &str, secret: &str) -> Result<String, Sm2Error> {
let data = sha1_base64(data);
let data = sm3::sm3_hash(data.as_bytes());
encrypt(&data, secret)
}
pub fn sign(msg: &[u8], pk: &str, sk: &str) -> Result<String, Sm2Error> {
let curve = &EccCtx::new();
let pk = &hex::decode(pk)?;
let pk = load_pubkey(curve, pk)?;
let sk = &hex::decode(sk)?;
let sk = BigUint::from_bytes_be(sk);
let digest = sign_hash(curve, &pk, msg)?;
let signature = sign_raw(curve, &digest[..], &sk)?;
let base64 = STANDARD.encode(signature);
Ok(base64)
}
pub fn verify(msg: &[u8], pk: &str, sign: &str) -> Result<bool, Sm2Error> {
let sig = Signature::der_decode(&hex::decode(STANDARD.decode(sign)?)?).map_err(|e| Sm2Error::LibSmError(format!("{}", e)))?;
let pk = efficient_sm2::PublicKey::from_slice(&hex::decode(&pk[2..])?);
let signature = efficient_sm2::Signature::new(&sig.get_r().to_bytes_be(), &sig.get_s().to_bytes_be()).map_err(|e| Sm2Error::LibSmError(format!("{:?}", e)))?;
if signature.verify(&pk, msg).is_ok() {
return Ok(true);
}
Ok(false)
}
pub fn sha1_base64(data: &str) -> String {
let mut hasher = sha1::Sha1::new();
hasher.update(data);
let hash_bytes = hasher.finalize();
STANDARD.encode(hash_bytes)
}
fn load_pubkey(curve: &EccCtx, buf: &[u8]) -> Result<Point, Sm2Error> {
curve.bytes_to_point(buf).map_err(|e| Sm2Error::LibSmError(e.to_string()))
}
fn sign_hash(curve: &EccCtx, pk: &Point, msg: &[u8]) -> Result<[u8; 32], Sm2Error> {
let id = "1234567812345678";
let mut prepend: Vec<u8> = Vec::new();
prepend.write_u16::<BigEndian>((id.len() * 8) as u16).map_err(|e| Sm2Error::LibSmError(e.to_string()))?;
for c in id.bytes() {
prepend.push(c);
}
let mut a = curve.get_a().to_bytes();
let mut b = curve.get_b().to_bytes();
prepend.append(&mut a);
prepend.append(&mut b);
let (x_g, y_g) = curve.to_affine(
&curve.generator().map_err(|e| Sm2Error::LibSmError(e.to_string()))?
).map_err(|e| Sm2Error::LibSmError(e.to_string()))?;
let (mut x_g, mut y_g) = (x_g.to_bytes(), y_g.to_bytes());
prepend.append(&mut x_g);
prepend.append(&mut y_g);
let (x_a, y_a) = curve.to_affine(pk).map_err(|e| Sm2Error::LibSmError(e.to_string()))?;
let (mut x_a, mut y_a) = (x_a.to_bytes(), y_a.to_bytes());
prepend.append(&mut x_a);
prepend.append(&mut y_a);
let mut hasher = Sm3Hash::new(&prepend[..]);
let z_a = hasher.get_hash();
let mut prepended_msg: Vec<u8> = Vec::new();
prepended_msg.extend_from_slice(&z_a[..]);
prepended_msg.extend_from_slice(msg);
let mut hasher = Sm3Hash::new(&prepended_msg[..]);
Ok(hasher.get_hash())
}
fn sign_raw(curve: &EccCtx, digest: &[u8], sk: &BigUint) -> Result<String, Sm2Error> {
let e = BigUint::from_bytes_be(digest);
loop {
let k = BigUint::from_bytes_be(&hex::decode("dc388f4220a38e285ebb3da2bfd48ff2d52358030e13672c76ceceebcb03991f")?);
let p_1 = curve.bytes_to_point(&hex::decode("047121548af8b601e2ebec3c7b679429f0d3387b5f3ea7883925c68fdd1161a73fa593a5dd16c9248b93ef9d8d85053715754e050cc199e94ac332cfdde8bdff52")?)
.map_err(|e| Sm2Error::LibSmError(e.to_string()))?;
let (x_1, _) = curve.to_affine(&p_1).map_err(|e| Sm2Error::LibSmError(e.to_string()))?;
let x_1 = x_1.to_biguint();
let r = (&e + x_1) % curve.get_n();
if r == BigUint::zero() || &r + &k == *curve.get_n() {
continue;
}
let s1 = curve.inv_n(&(sk + BigUint::one())).map_err(|e| Sm2Error::LibSmError(e.to_string()))?;
let mut s2_1 = &r * sk;
if s2_1 < k {
s2_1 += curve.get_n();
}
let mut s2 = s2_1 - k;
s2 %= curve.get_n();
let s2 = curve.get_n() - s2;
let s = (s1 * s2) % curve.get_n();
if s != BigUint::zero() {
let signature = yasna::construct_der(|writer| {
writer.write_sequence(|writer| {
writer.next().write_biguint(&r);
writer.next().write_biguint(&s);
})
});
let signature = hex::encode(signature);
return Ok(signature);
}
return Err(Sm2Error::LibSmError("Signature failed".to_string()));
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_sign() {
let msg = r#"hello world"#;
let pk = "04ff055e4349345eba0fc69362f483f4f408d876dda2520e8e424e81978129da56b19587538253a2406d035a8d9981efeeac60ec72b3308b9a07a5398b61d3d189";
let sk = "6d7964184b735645ef49b3c1ee5a2c2efdbd15d6c9d851c57eef341ed0e1eb1b";
let signature = super::sign(msg.as_bytes(), pk, sk).unwrap();
assert_eq!(signature, "MzA0NjAyMjEwMDgwMGM3NDg0NTlkZDQ2MTdlMzMzNWM3OGRjNDJlOGFjZWU0OTg5YmYwYjk2NzFmYWYzZjkxN2ZkNmU0NGFhOTkwMjIxMDBjOGM2YjhiZjI5NzRmNzljYWE3Mjc4MzZjZjgwMTc2MzI0YmI1YjkxZDFkYWQzNjIzMWQyODA2MDVhZTNhNDYy");
let verify = super::verify(msg.as_bytes(), pk, &signature).unwrap();
assert_eq!(verify, true);
let secret = "ZTPkMS00ZNTP5NzPwNjAu";
let signature = super::sign_body(msg, secret).unwrap();
assert_eq!(signature, "ZTVlOGU0YzZmYjk0N2JiZDQxNDdmZjgyNTgwYTVhMzgxMjVmN2U5M2Q1MzA0NTg2YmJkNjljMmJiYWZlNWMyZWZlY2JiOWI0YmQzNWQ1YWE3OTZlYTkzY2Q0M2RmNmM2ZGEyMzA1NGJiOTEzMTJmMDE5YzI2YzVjOTZhYWVmYmNmMzkwYjMzZTNlY2Q3MzQzMjMwNWM1YzYzNTQ3ZmI0OQ==");
}
#[test]
fn test_verify() {
let signature = "MzA0NTAyMjEwMGVkOWQ3ZjY3YzhkNmU3NGVmODJjZDJjNDI2N2IwMDQ4ZDliZjc0NDcwNThmNGY4Mzc4NmUyZjI1OGVhNjRhYjQwMjIwMTE1NjM3ZTRmNGY5YWE4Yzg5NmQ3MTE0NjNkN2E3OGEzZGE1NjQ0NDQyOWU2NTlmNTk2NWMwMjJkNmVhNzMxYw==";
let body = "M2YyZWFlOTU4MzBkZTUxMGQyOTNjNmUzYzA1ODg2NjM=";
let secret = "ZTPkMS00ZNTP5NzPwNjAu";
let spdb_pk = "04049bc3c83c5709b1b9d7fce408095809f20ee9cd16fde7944f95fa21392f109bd3c7caed077e41682126f383e547bd48899f9c279cff5f2f06ca0e41013abf11";
let msg = crate::decrypt(&body, secret).unwrap();
assert_eq!(msg, "hello world");
let verify = super::verify(crate::sha1_base64(&msg).as_bytes(), spdb_pk, &signature).unwrap();
assert_eq!(verify, true);
}
}