#[cfg(not(feature = "std"))]
use alloc::string::ToString;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use rasn::types::ObjectIdentifier;
#[cfg(feature = "std")]
use rasn_pkix::Certificate;
use mbedtls::hash::Md;
use mbedtls::hash::Type as MdType;
#[cfg(feature = "std")]
use mbedtls::pk::ECDSA_MAX_LEN;
#[cfg(feature = "std")]
use mbedtls::rng::{CtrDrbg, OsEntropy};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use super::oids::gmssl_cms_data_oid;
#[cfg(feature = "std")]
use super::oids::{
gmssl_cms_data_oid, gmssl_cms_signed_data_oid, sm2_sign_with_sm3_oid, sm3_digest_oid,
};
use crate::error::Error;
#[cfg(feature = "std")]
use crate::key::sm2_pk_from_pkcs8_pem_with_pass;
const TAG_EXPLICIT_0: u8 = 0xa0;
const TAG_SEQUENCE: u8 = 0x30;
const TAG_OCTET_STRING: u8 = 0x04;
#[cfg(feature = "std")]
const TAG_SET: u8 = 0x31;
#[cfg(feature = "std")]
const TAG_INTEGER: u8 = 0x02;
fn push_der_len(n: usize, out: &mut Vec<u8>) {
if n < 128 {
out.push(n as u8);
} else {
let mut tmp = Vec::new();
let mut x = n;
while x > 0 {
tmp.push((x & 0xff) as u8);
x >>= 8;
}
tmp.reverse();
out.push(0x80 | tmp.len() as u8);
out.extend_from_slice(&tmp);
}
}
#[cfg(feature = "std")]
fn tlv(tag: u8, content: &[u8]) -> Vec<u8> {
let mut v = vec![tag];
push_der_len(content.len(), &mut v);
v.extend_from_slice(content);
v
}
fn der_len_encoding_size(n: usize) -> usize {
if n < 128 {
1
} else if n < 256 {
2
} else if n < 65536 {
3
} else if n < 16777216 {
4
} else {
5
}
}
fn octet_string_tlv_total_len(data_len: usize) -> usize {
1 + der_len_encoding_size(data_len) + data_len
}
fn encode_oid_tlv(oid: &ObjectIdentifier) -> Result<Vec<u8>, Error> {
rasn::der::encode(oid).map_err(|e| Error::RasnEncode(e.to_string()))
}
#[cfg(feature = "std")]
fn algorithm_identifier_oid_only(oid: &ObjectIdentifier) -> Result<Vec<u8>, Error> {
let oid_tlv = encode_oid_tlv(oid)?;
Ok(tlv(TAG_SEQUENCE, &oid_tlv))
}
fn gmssl_content_hash_prefix(data_len: usize) -> Result<Vec<u8>, Error> {
let oct_tlv_len = octet_string_tlv_total_len(data_len);
let oid_tlv = encode_oid_tlv(&gmssl_cms_data_oid())?;
let explicit_hdr_len = 1 + der_len_encoding_size(oct_tlv_len);
let inner_seq_len = oid_tlv.len() + explicit_hdr_len + oct_tlv_len;
let mut prefix = Vec::new();
prefix.push(TAG_SEQUENCE);
push_der_len(inner_seq_len, &mut prefix);
prefix.extend_from_slice(&oid_tlv);
prefix.push(TAG_EXPLICIT_0);
push_der_len(oct_tlv_len, &mut prefix);
prefix.push(TAG_OCTET_STRING);
push_der_len(data_len, &mut prefix);
Ok(prefix)
}
pub(crate) fn gmssl_cms_content_digest(data: &[u8]) -> Result<[u8; 32], Error> {
let mut msg = gmssl_content_hash_prefix(data.len())?;
msg.extend_from_slice(data);
let mut out = [0u8; 32];
Md::hash(MdType::SM3, &msg, &mut out)?;
Ok(out)
}
#[cfg(feature = "std")]
fn cms_content_info_data_to_der(data: &[u8]) -> Result<Vec<u8>, Error> {
let oct_inner = tlv(TAG_OCTET_STRING, data);
let oid_tlv = encode_oid_tlv(&gmssl_cms_data_oid())?;
let inner = [oid_tlv.as_slice(), &tlv(TAG_EXPLICIT_0, &oct_inner)].concat();
Ok(tlv(TAG_SEQUENCE, &inner))
}
#[cfg(feature = "std")]
fn issuer_and_serial_from_cert(cert_der: &[u8]) -> Result<Vec<u8>, Error> {
let cert: Certificate =
rasn::der::decode(cert_der).map_err(|e| Error::CmsSign(format!("leaf cert: {e}")))?;
let issuer = rasn::der::encode(&cert.tbs_certificate.issuer)
.map_err(|e| Error::CmsSign(format!("issuer: {e}")))?;
let serial = rasn::der::encode(&cert.tbs_certificate.serial_number)
.map_err(|e| Error::CmsSign(format!("serial: {e}")))?;
let body = [issuer.as_slice(), serial.as_slice()].concat();
Ok(tlv(TAG_SEQUENCE, &body))
}
#[cfg(feature = "std")]
fn cms_signer_info_to_der(
issuer_serial: &[u8],
digest_alg: &[u8],
sig_alg: &[u8],
enc_digest: &[u8],
) -> Vec<u8> {
let ver = tlv(TAG_INTEGER, &[1]);
let enc_oct = tlv(TAG_OCTET_STRING, enc_digest);
let body = [
ver.as_slice(),
issuer_serial,
digest_alg,
sig_alg,
enc_oct.as_slice(),
]
.concat();
tlv(TAG_SEQUENCE, &body)
}
#[cfg(feature = "std")]
fn implicit_explicit_0_raw(cert_der: &[u8]) -> Vec<u8> {
tlv(TAG_EXPLICIT_0, cert_der)
}
#[cfg(feature = "std")]
fn digest_algors_set(sm3_ai: &[u8]) -> Vec<u8> {
tlv(TAG_SET, sm3_ai)
}
#[cfg(feature = "std")]
fn signer_infos_set(signer_info: &[u8]) -> Vec<u8> {
tlv(TAG_SET, signer_info)
}
#[cfg(feature = "std")]
fn cms_content_info_signed_data_to_der(signed_data: &[u8]) -> Result<Vec<u8>, Error> {
let oid_tlv = encode_oid_tlv(&gmssl_cms_signed_data_oid())?;
let wrapped = tlv(TAG_EXPLICIT_0, signed_data);
let inner = [oid_tlv.as_slice(), wrapped.as_slice()].concat();
Ok(tlv(TAG_SEQUENCE, &inner))
}
#[cfg(feature = "std")]
pub fn sign_gmssl_cms_attached_native(
leaf_cert_der: &[u8],
leaf_key_pem: &str,
leaf_key_pass: &str,
data: &[u8],
) -> Result<Vec<u8>, Error> {
let dgst = gmssl_cms_content_digest(data)?;
let mut pk = sm2_pk_from_pkcs8_pem_with_pass(leaf_key_pem, leaf_key_pass)?;
let mut sig = vec![0u8; ECDSA_MAX_LEN];
let entropy = Arc::new(OsEntropy::new());
let mut rng = CtrDrbg::new(entropy, None)?;
let sig_len = pk.sign(MdType::SM3, &dgst, &mut sig, &mut rng)?;
sig.truncate(sig_len);
let sm3_ai = algorithm_identifier_oid_only(&sm3_digest_oid())?;
let sig_ai = algorithm_identifier_oid_only(&sm2_sign_with_sm3_oid())?;
let ias = issuer_and_serial_from_cert(leaf_cert_der)?;
let signer_info = cms_signer_info_to_der(&ias, &sm3_ai, &sig_ai, &sig);
let encap = cms_content_info_data_to_der(data)?;
let certs_wrapped = implicit_explicit_0_raw(leaf_cert_der);
let digest_set = digest_algors_set(&sm3_ai);
let signers_set = signer_infos_set(&signer_info);
let sd_ver = tlv(TAG_INTEGER, &[1]);
let sd_body = [
sd_ver.as_slice(),
digest_set.as_slice(),
encap.as_slice(),
certs_wrapped.as_slice(),
signers_set.as_slice(),
]
.concat();
let signed_data = tlv(TAG_SEQUENCE, &sd_body);
cms_content_info_signed_data_to_der(&signed_data)
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use std::path::Path;
use std::process::Command;
#[test]
fn cms_content_sm3_matches_gmssl_sm3_cli() {
let gmssl = std::env::var("GMSSL").unwrap_or_else(|_| {
format!("{}/../../GmSSL/build/bin/gmssl", env!("CARGO_MANIFEST_DIR"))
});
if !Path::new(&gmssl).exists() {
eprintln!("skip cms_content_sm3_matches_gmssl_sm3_cli: no {gmssl}");
return;
}
let data = b"hello\n";
let mut msg = gmssl_content_hash_prefix(data.len()).expect("prefix");
msg.extend_from_slice(data);
let dir = tempfile::tempdir().expect("tempdir");
let inp = dir.path().join("m");
let outp = dir.path().join("d");
std::fs::write(&inp, &msg).expect("write");
match Command::new(&gmssl)
.args(["sm3", "-bin", "-in"])
.arg(&inp)
.arg("-out")
.arg(&outp)
.status()
{
Ok(s) if s.success() => {}
_ => {
eprintln!(
"skip cms_content_sm3_matches_gmssl_sm3_cli: gmssl sm3 not runnable (LD_LIBRARY_PATH?)"
);
return;
}
}
let want = std::fs::read(&outp).expect("read digest");
let got = gmssl_cms_content_digest(data).expect("mbedtls sm3");
assert_eq!(want.as_slice(), got.as_slice(), "CMS content SM3 mismatch");
}
}