#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use subtle::ConstantTimeEq;
use super::cipher::{encrypt_block_raw, Sm4Key};
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_ecb(key: &[u8; 16], data: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
data.chunks(16)
.flat_map(|chunk| {
let mut block = [0u8; 16];
block[..chunk.len()].copy_from_slice(chunk);
sm4.encrypt_block(&mut block);
block
})
.collect()
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_ecb(key: &[u8; 16], data: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
data.chunks(16)
.flat_map(|chunk| {
let mut block = [0u8; 16];
block[..chunk.len()].copy_from_slice(chunk);
sm4.decrypt_block(&mut block);
block
})
.collect()
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_cbc(key: &[u8; 16], iv: &[u8; 16], plaintext: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
let mut prev = *iv;
plaintext
.chunks(16)
.flat_map(|chunk| {
let mut block = [0u8; 16];
let len = chunk.len().min(16);
block[..len].copy_from_slice(&chunk[..len]);
for i in 0..16 {
block[i] ^= prev[i];
}
sm4.encrypt_block(&mut block);
prev = block;
block
})
.collect()
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_cbc(key: &[u8; 16], iv: &[u8; 16], ciphertext: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
let mut prev = *iv;
ciphertext
.chunks(16)
.flat_map(|chunk| {
let mut block = [0u8; 16];
block[..chunk.len()].copy_from_slice(chunk);
let ct = block;
sm4.decrypt_block(&mut block);
for i in 0..16 {
block[i] ^= prev[i];
}
prev = ct;
block
})
.collect()
}
#[cfg(feature = "alloc")]
pub fn sm4_crypt_ofb(key: &[u8; 16], iv: &[u8; 16], data: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
let mut feedback = *iv;
let mut out = Vec::with_capacity(data.len());
for chunk in data.chunks(16) {
sm4.encrypt_block(&mut feedback);
for (i, &b) in chunk.iter().enumerate() {
out.push(b ^ feedback[i]);
}
}
out
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_cfb(key: &[u8; 16], iv: &[u8; 16], data: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
let mut feedback = *iv;
let mut out = Vec::with_capacity(data.len());
for chunk in data.chunks(16) {
let mut ks = feedback;
sm4.encrypt_block(&mut ks);
let mut ct_block = [0u8; 16];
for (i, &b) in chunk.iter().enumerate() {
ct_block[i] = b ^ ks[i];
}
feedback = ct_block;
out.extend_from_slice(&ct_block[..chunk.len()]);
}
out
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_cfb(key: &[u8; 16], iv: &[u8; 16], data: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
let mut feedback = *iv;
let mut out = Vec::with_capacity(data.len());
for chunk in data.chunks(16) {
let mut ks = feedback;
sm4.encrypt_block(&mut ks);
let mut ct_block = [0u8; 16];
ct_block[..chunk.len()].copy_from_slice(chunk);
feedback = ct_block;
for (i, &b) in chunk.iter().enumerate() {
out.push(b ^ ks[i]);
}
}
out
}
#[inline]
fn ctr_inc(counter: &mut [u8; 16]) {
for i in (0..16).rev() {
counter[i] = counter[i].wrapping_add(1);
if counter[i] != 0 {
break;
}
}
}
#[cfg(feature = "alloc")]
pub fn sm4_crypt_ctr(key: &[u8; 16], nonce: &[u8; 16], data: &[u8]) -> Vec<u8> {
let sm4 = Sm4Key::new(key);
let mut counter = *nonce;
let mut out = Vec::with_capacity(data.len());
for chunk in data.chunks(16) {
let mut ks = counter;
sm4.encrypt_block(&mut ks);
for (i, &b) in chunk.iter().enumerate() {
out.push(b ^ ks[i]);
}
ctr_inc(&mut counter);
}
out
}
fn gf128_mul(x: &[u8; 16], y: &[u8; 16]) -> [u8; 16] {
let mut z = [0u64; 2];
let mut v = [
u64::from_be_bytes(y[0..8].try_into().unwrap()),
u64::from_be_bytes(y[8..16].try_into().unwrap()),
];
for &byte_xi in x.iter() {
for bit_idx in (0..8).rev() {
let mask = 0u64.wrapping_sub(((byte_xi >> bit_idx) & 1) as u64);
z[0] ^= v[0] & mask;
z[1] ^= v[1] & mask;
let lsb = v[1] & 1;
let carry = v[0] & 1;
v[0] >>= 1;
v[1] = (v[1] >> 1) | (carry << 63);
let reduce_mask = 0u64.wrapping_sub(lsb);
v[0] ^= 0xE100_0000_0000_0000u64 & reduce_mask;
}
}
let mut out = [0u8; 16];
out[0..8].copy_from_slice(&z[0].to_be_bytes());
out[8..16].copy_from_slice(&z[1].to_be_bytes());
out
}
fn ghash(h: &[u8; 16], aad: &[u8], ciphertext: &[u8]) -> [u8; 16] {
let mut y = [0u8; 16];
for chunk in aad.chunks(16) {
let mut block = [0u8; 16];
block[..chunk.len()].copy_from_slice(chunk);
for i in 0..16 {
y[i] ^= block[i];
}
y = gf128_mul(&y, h);
}
for chunk in ciphertext.chunks(16) {
let mut block = [0u8; 16];
block[..chunk.len()].copy_from_slice(chunk);
for i in 0..16 {
y[i] ^= block[i];
}
y = gf128_mul(&y, h);
}
let mut len_block = [0u8; 16];
len_block[0..8].copy_from_slice(&((aad.len() as u64) * 8).to_be_bytes());
len_block[8..16].copy_from_slice(&((ciphertext.len() as u64) * 8).to_be_bytes());
for i in 0..16 {
y[i] ^= len_block[i];
}
gf128_mul(&y, h)
}
#[inline]
fn gcm_ctr_inc(counter: &mut [u8; 16]) {
for i in (12..16).rev() {
counter[i] = counter[i].wrapping_add(1);
if counter[i] != 0 {
break;
}
}
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_gcm(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
plaintext: &[u8],
) -> (Vec<u8>, [u8; 16]) {
let sm4 = Sm4Key::new(key);
let rk = sm4.round_keys();
let h = encrypt_block_raw(rk, &[0u8; 16]);
let mut j0 = [0u8; 16];
j0[..12].copy_from_slice(nonce);
j0[15] = 1;
let mut ctr = j0;
gcm_ctr_inc(&mut ctr);
let ciphertext: Vec<u8> = {
let mut out = Vec::with_capacity(plaintext.len());
let mut counter = ctr;
for chunk in plaintext.chunks(16) {
let ks = encrypt_block_raw(rk, &counter);
for (i, &b) in chunk.iter().enumerate() {
out.push(b ^ ks[i]);
}
gcm_ctr_inc(&mut counter);
}
out
};
let ghash_val = ghash(&h, aad, &ciphertext);
let ej0 = encrypt_block_raw(rk, &j0);
let mut tag = [0u8; 16];
for i in 0..16 {
tag[i] = ghash_val[i] ^ ej0[i];
}
(ciphertext, tag)
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_gcm(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8; 16],
) -> Result<Vec<u8>, crate::error::Error> {
let sm4 = Sm4Key::new(key);
let rk = sm4.round_keys();
let h = encrypt_block_raw(rk, &[0u8; 16]);
let mut j0 = [0u8; 16];
j0[..12].copy_from_slice(nonce);
j0[15] = 1;
let ghash_val = ghash(&h, aad, ciphertext);
let ej0 = encrypt_block_raw(rk, &j0);
let mut expected_tag = [0u8; 16];
for i in 0..16 {
expected_tag[i] = ghash_val[i] ^ ej0[i];
}
if expected_tag.ct_eq(tag).unwrap_u8() == 0 {
return Err(crate::error::Error::AuthTagMismatch);
}
let mut ctr = j0;
gcm_ctr_inc(&mut ctr);
let mut plaintext = Vec::with_capacity(ciphertext.len());
let mut counter = ctr;
for chunk in ciphertext.chunks(16) {
let ks = encrypt_block_raw(rk, &counter);
for (i, &b) in chunk.iter().enumerate() {
plaintext.push(b ^ ks[i]);
}
gcm_ctr_inc(&mut counter);
}
Ok(plaintext)
}
fn ccm_cbc_mac(
rk: &[u32; 32],
nonce: &[u8; 12],
aad: &[u8],
message: &[u8],
tag_len: usize,
) -> Result<[u8; 16], crate::error::Error> {
let q = 3usize; let has_aad = !aad.is_empty();
let flags = ((has_aad as u8) << 6) | (((tag_len - 2) / 2) as u8) << 3 | (q as u8 - 1);
let mut b0 = [0u8; 16];
b0[0] = flags;
b0[1..13].copy_from_slice(nonce);
let msg_len = message.len() as u32;
b0[13] = (msg_len >> 16) as u8;
b0[14] = (msg_len >> 8) as u8;
b0[15] = msg_len as u8;
let mut x = encrypt_block_raw(rk, &b0);
if has_aad {
let aad_len = aad.len();
let prefix_len = 2 + aad_len;
let padded_len = prefix_len.div_ceil(16) * 16;
let mut aad_buf = [0u8; 512];
if prefix_len > aad_buf.len() {
return Err(crate::error::Error::InvalidInputLength);
}
aad_buf[0..2].copy_from_slice(&(aad_len as u16).to_be_bytes());
aad_buf[2..2 + aad_len].copy_from_slice(aad);
for chunk in aad_buf[..padded_len].chunks(16) {
let block: [u8; 16] = chunk.try_into().unwrap();
for i in 0..16 {
x[i] ^= block[i];
}
x = encrypt_block_raw(rk, &x);
}
}
for chunk in message.chunks(16) {
let mut block = [0u8; 16];
block[..chunk.len()].copy_from_slice(chunk);
for i in 0..16 {
x[i] ^= block[i];
}
x = encrypt_block_raw(rk, &x);
}
Ok(x)
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_ccm(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
plaintext: &[u8],
tag_len: usize,
) -> Result<Vec<u8>, crate::error::Error> {
assert!(
(4..=16).contains(&tag_len) && tag_len % 2 == 0,
"CCM tag_len 须为 4~16 的偶数"
);
let sm4 = Sm4Key::new(key);
let rk = sm4.round_keys();
let t = ccm_cbc_mac(rk, nonce, aad, plaintext, tag_len)?;
let mut a0 = [0u8; 16];
a0[0] = 2u8; a0[1..13].copy_from_slice(nonce);
let s0 = encrypt_block_raw(rk, &a0);
let mut enc_tag = [0u8; 16];
for i in 0..tag_len {
enc_tag[i] = t[i] ^ s0[i];
}
let mut out = Vec::with_capacity(plaintext.len() + tag_len);
for (block_idx, chunk) in plaintext.chunks(16).enumerate() {
let mut a_i = a0;
let ctr_val = (block_idx as u32) + 1;
a_i[13] = (ctr_val >> 16) as u8;
a_i[14] = (ctr_val >> 8) as u8;
a_i[15] = ctr_val as u8;
let ks = encrypt_block_raw(rk, &a_i);
for (i, &b) in chunk.iter().enumerate() {
out.push(b ^ ks[i]);
}
}
out.extend_from_slice(&enc_tag[..tag_len]);
Ok(out)
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_ccm(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
ciphertext_with_tag: &[u8],
tag_len: usize,
) -> Result<Vec<u8>, crate::error::Error> {
if ciphertext_with_tag.len() < tag_len {
return Err(crate::error::Error::InvalidInputLength);
}
let ct = &ciphertext_with_tag[..ciphertext_with_tag.len() - tag_len];
let received_tag = &ciphertext_with_tag[ciphertext_with_tag.len() - tag_len..];
let sm4 = Sm4Key::new(key);
let rk = sm4.round_keys();
let mut a0 = [0u8; 16];
a0[0] = 2u8;
a0[1..13].copy_from_slice(nonce);
let s0 = encrypt_block_raw(rk, &a0);
let mut plaintext = Vec::with_capacity(ct.len());
for (block_idx, chunk) in ct.chunks(16).enumerate() {
let mut a_i = a0;
let ctr_val = (block_idx as u32) + 1;
a_i[13] = (ctr_val >> 16) as u8;
a_i[14] = (ctr_val >> 8) as u8;
a_i[15] = ctr_val as u8;
let ks = encrypt_block_raw(rk, &a_i);
for (i, &b) in chunk.iter().enumerate() {
plaintext.push(b ^ ks[i]);
}
}
let t = ccm_cbc_mac(rk, nonce, aad, &plaintext, tag_len)?;
let mut expected_tag = [0u8; 16];
for i in 0..tag_len {
expected_tag[i] = t[i] ^ s0[i];
}
if expected_tag[..tag_len].ct_eq(received_tag).unwrap_u8() == 0 {
return Err(crate::error::Error::AuthTagMismatch);
}
Ok(plaintext)
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_gcm_combined(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
plaintext: &[u8],
) -> Vec<u8> {
let (mut ct, tag) = sm4_encrypt_gcm(key, nonce, aad, plaintext);
ct.extend_from_slice(&tag);
ct
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_gcm_combined(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
ciphertext_with_tag: &[u8],
) -> Result<Vec<u8>, crate::error::Error> {
if ciphertext_with_tag.len() < 16 {
return Err(crate::error::Error::InvalidInputLength);
}
let ct_len = ciphertext_with_tag.len() - 16;
let ct = &ciphertext_with_tag[..ct_len];
let tag: &[u8; 16] = ciphertext_with_tag[ct_len..].try_into().unwrap();
sm4_decrypt_gcm(key, nonce, aad, ct, tag)
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_ccm_combined(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
plaintext: &[u8],
) -> Result<Vec<u8>, crate::error::Error> {
sm4_encrypt_ccm(key, nonce, aad, plaintext, 16)
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_ccm_combined(
key: &[u8; 16],
nonce: &[u8; 12],
aad: &[u8],
ciphertext_with_tag: &[u8],
) -> Result<Vec<u8>, crate::error::Error> {
sm4_decrypt_ccm(key, nonce, aad, ciphertext_with_tag, 16)
}
fn xts_mul_alpha(tweak: &mut [u8; 16]) {
let carry = tweak[15] & 1;
for i in (1..16).rev() {
tweak[i] = (tweak[i] >> 1) | ((tweak[i - 1] & 1) << 7);
}
tweak[0] >>= 1;
if carry == 1 {
tweak[0] ^= 0xE1;
}
}
#[cfg(feature = "alloc")]
pub fn sm4_encrypt_xts(
key1: &[u8; 16],
key2: &[u8; 16],
tweak_sector: &[u8; 16],
data: &[u8],
) -> Result<Vec<u8>, crate::error::Error> {
if data.is_empty() || data.len() % 16 != 0 {
return Err(crate::error::Error::InvalidInputLength);
}
let sm4_1 = Sm4Key::new(key1);
let sm4_2 = Sm4Key::new(key2);
let mut tweak = *tweak_sector;
sm4_2.encrypt_block(&mut tweak);
let mut out = Vec::with_capacity(data.len());
for chunk in data.chunks(16) {
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = chunk[i] ^ tweak[i];
}
sm4_1.encrypt_block(&mut block);
for i in 0..16 {
out.push(block[i] ^ tweak[i]);
}
xts_mul_alpha(&mut tweak);
}
Ok(out)
}
#[cfg(feature = "alloc")]
pub fn sm4_decrypt_xts(
key1: &[u8; 16],
key2: &[u8; 16],
tweak_sector: &[u8; 16],
data: &[u8],
) -> Result<Vec<u8>, crate::error::Error> {
if data.is_empty() || data.len() % 16 != 0 {
return Err(crate::error::Error::InvalidInputLength);
}
let sm4_1 = Sm4Key::new(key1);
let sm4_2 = Sm4Key::new(key2);
let mut tweak = *tweak_sector;
sm4_2.encrypt_block(&mut tweak);
let mut out = Vec::with_capacity(data.len());
for chunk in data.chunks(16) {
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = chunk[i] ^ tweak[i];
}
sm4_1.decrypt_block(&mut block);
for i in 0..16 {
out.push(block[i] ^ tweak[i]);
}
xts_mul_alpha(&mut tweak);
}
Ok(out)
}
#[cfg(test)]
#[cfg(feature = "alloc")]
mod tests {
use super::*;
#[test]
fn test_cbc_vector() {
let key = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let iv = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let plain = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let ct = sm4_encrypt_cbc(&key, &iv, &plain);
let pt = sm4_decrypt_cbc(&key, &iv, &ct);
assert_eq!(pt, plain, "CBC 往返解密失败");
}
#[test]
fn test_gcm_roundtrip() {
let key = [0u8; 16];
let nonce = [1u8; 12];
let aad = b"additional data";
let plain = b"hello gcm world!";
let (ct, tag) = sm4_encrypt_gcm(&key, &nonce, aad, plain);
let pt = sm4_decrypt_gcm(&key, &nonce, aad, &ct, &tag).unwrap();
assert_eq!(pt, plain, "GCM 往返解密失败");
}
#[test]
fn test_gcm_tag_tamper() {
let key = [0u8; 16];
let nonce = [0u8; 12];
let (ct, mut tag) = sm4_encrypt_gcm(&key, &nonce, b"", b"secret");
tag[0] ^= 1;
assert!(
sm4_decrypt_gcm(&key, &nonce, b"", &ct, &tag).is_err(),
"篡改 tag 后应返回错误"
);
}
#[test]
fn test_ccm_roundtrip() {
let key = [0u8; 16];
let nonce = [2u8; 12];
let aad = b"ccm aad";
let plain = b"ccm plaintext!!!";
let ct = sm4_encrypt_ccm(&key, &nonce, aad, plain, 16).unwrap();
let pt = sm4_decrypt_ccm(&key, &nonce, aad, &ct, 16).unwrap();
assert_eq!(pt, plain, "CCM 往返解密失败");
}
#[test]
fn test_ccm_tag_tamper() {
let key = [0u8; 16];
let nonce = [0u8; 12];
let mut ct = sm4_encrypt_ccm(&key, &nonce, b"", b"secret data here", 16).unwrap();
let last = ct.len() - 1;
ct[last] ^= 1;
assert!(
sm4_decrypt_ccm(&key, &nonce, b"", &ct, 16).is_err(),
"篡改 CCM tag 后应返回错误"
);
}
#[test]
fn test_ccm_aad_too_long() {
let key = [0u8; 16];
let nonce = [0u8; 12];
let big_aad = [0u8; 511]; assert!(
sm4_encrypt_ccm(&key, &nonce, &big_aad, b"data", 16).is_err(),
"AAD 超过 510 字节时应返回 InvalidInputLength"
);
}
#[test]
fn test_xts_roundtrip() {
let key1 = [0x11u8; 16];
let key2 = [0x22u8; 16];
let tweak = [0u8; 16];
let plain = [0x42u8; 32];
let ct = sm4_encrypt_xts(&key1, &key2, &tweak, &plain).unwrap();
let pt = sm4_decrypt_xts(&key1, &key2, &tweak, &ct).unwrap();
assert_eq!(pt, plain, "XTS 往返解密失败");
}
#[test]
fn test_xts_non_aligned_rejected() {
let key1 = [0u8; 16];
let key2 = [0u8; 16];
let tweak = [0u8; 16];
assert!(
sm4_encrypt_xts(&key1, &key2, &tweak, b"").is_err(),
"空输入应返回 InvalidInputLength"
);
assert!(
sm4_encrypt_xts(&key1, &key2, &tweak, b"not-aligned-data").is_ok(),
"正好 16 字节不应返回错误"
);
assert!(
sm4_encrypt_xts(&key1, &key2, &tweak, &[0u8; 17]).is_err(),
"17 字节应返回 InvalidInputLength"
);
assert!(
sm4_decrypt_xts(&key1, &key2, &tweak, &[0u8; 15]).is_err(),
"15 字节应返回 InvalidInputLength"
);
}
#[test]
fn test_ofb_self_inverse() {
let key = [0xABu8; 16];
let iv = [0x12u8; 16];
let plain = b"ofb test message";
let ct = sm4_crypt_ofb(&key, &iv, plain);
let pt = sm4_crypt_ofb(&key, &iv, &ct);
assert_eq!(pt, plain, "OFB 应为自反模式");
}
}