use gm_sm4::Sm4Cipher;
use crate::crypter::{
CipherAlgorithmBaseTrait, CipherAlgorithmTrait, CipherAlgorithmType, IVKeyNewTrait,
};
use crate::*;
pub const SM4_IV_LENGTH: usize = 16;
pub const SM4_KEY_LENGTH: usize = 16;
pub const SM4_BLOCK_LENGTH: usize = 16;
pub struct Sm4CipherAlgorithm {
m_cipher: Sm4Cipher,
m_iv: [u8; SM4_IV_LENGTH],
m_prev_block: [u8; SM4_IV_LENGTH],
}
impl CipherAlgorithmBaseTrait for Sm4CipherAlgorithm {
const IV_LENGTH: usize = SM4_IV_LENGTH;
const KEY_LENGTH: usize = SM4_KEY_LENGTH;
const CIPHER_ALGORITHM_TYPE: CipherAlgorithmType = CipherAlgorithmType::Block(SM4_BLOCK_LENGTH);
}
impl CipherAlgorithmTrait for Sm4CipherAlgorithm {
fn crypt(&mut self, src_data: &[u8], dst_data: &mut [u8]) -> Result<()> {
if src_data.len() != dst_data.len() {
return Err(CrypterError::CryptionFailed.into());
}
if src_data.len() % SM4_BLOCK_LENGTH != 0 {
return Err(CrypterError::CryptionFailed.into());
}
let mut prev_block = self.m_prev_block;
for (src_chunk, dst_chunk) in src_data
.chunks_exact(SM4_BLOCK_LENGTH)
.zip(dst_data.chunks_exact_mut(SM4_BLOCK_LENGTH))
{
let mut block = [0u8; SM4_BLOCK_LENGTH];
for i in 0..SM4_BLOCK_LENGTH {
block[i] = src_chunk[i] ^ prev_block[i];
}
let encrypted = self
.m_cipher
.encrypt(&block)
.map_err(|_| CrypterError::CryptionFailed)?;
dst_chunk.copy_from_slice(&encrypted);
prev_block.copy_from_slice(&encrypted);
}
self.m_prev_block = prev_block;
Ok(())
}
}
impl IVKeyNewTrait for Sm4CipherAlgorithm {
fn new(iv: &[u8], key: &[u8]) -> Result<Self>
where
Self: Sized,
{
if iv.len() != Self::IV_LENGTH {
return Err(CrypterError::InvalidIVLength.into());
}
if key.len() != Self::KEY_LENGTH {
return Err(CrypterError::InvalidKeyLength.into());
}
let cipher = Sm4Cipher::new(key).map_err(|_| CrypterError::CryptionFailed)?;
let mut iv_array = [0u8; SM4_IV_LENGTH];
iv_array.copy_from_slice(iv);
Ok(Sm4CipherAlgorithm {
m_cipher: cipher,
m_iv: iv_array,
m_prev_block: iv_array,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypter::{StringCrypter, StringCrypterTrait};
#[test]
fn test_sm4() {
let string_crypter = StringCrypter::<Sm4CipherAlgorithm>::default();
let ciphertext = string_crypter
.encrypt("1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ", "qwerty")
.unwrap();
println!("ciphertext: {ciphertext}");
let plaintext = string_crypter.decrypt(&ciphertext, "qwerty").unwrap();
println!("plaintext: {plaintext}");
assert_eq!(plaintext, "1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ");
}
#[test]
fn test_sm4_long_string() {
let string_crypter = StringCrypter::<Sm4CipherAlgorithm>::default();
let pattern = "1234567890abcdefghij";
let target_length = 10_000;
let full_repeats = target_length / pattern.len();
let remainder = target_length % pattern.len();
let test_string = if remainder == 0 {
pattern.repeat(full_repeats)
} else {
format!("{}{}", pattern.repeat(full_repeats), &pattern[..remainder])
};
println!("测试字符串长度: {}", test_string.len());
let ciphertext = string_crypter.encrypt(&test_string, "qwerty").unwrap();
println!("密文长度: {}", ciphertext.len());
println!("密文: {}", &ciphertext[..]);
let plaintext = string_crypter.decrypt(&ciphertext, "qwerty").unwrap();
println!("解密后字符串长度: {}", plaintext.len());
assert_eq!(plaintext, test_string);
println!("SM4长字符串加密解密测试通过!");
}
}