mod fnr;
use crate::error::Error;
use crate::sm4::Sm4Key;
use fnr::{clear_high_bits, fnr_decrypt, fnr_encrypt};
use zeroize::ZeroizeOnDrop;
#[derive(Clone, Copy)]
pub struct FpeTweak([u8; 15]);
#[derive(ZeroizeOnDrop)]
pub struct FpeKey {
key: Sm4Key,
num_bits: usize,
}
impl FpeKey {
pub fn new(key: &[u8; 16], num_bits: usize) -> Result<Self, Error> {
if num_bits == 0 || num_bits > 128 {
return Err(Error::InvalidInputLength);
}
Ok(FpeKey {
key: Sm4Key::new(key),
num_bits,
})
}
pub fn expand_tweak(&self, tweak: &[u8]) -> FpeTweak {
let mut state = [0u8; 16];
state[0] = (self.num_bits >> 8) as u8;
state[1] = self.num_bits as u8;
for chunk in tweak.chunks(16) {
let mut block = state;
for (i, &b) in chunk.iter().enumerate() {
block[i] ^= b;
}
self.key.encrypt_block(&mut block);
state = block;
}
self.key.encrypt_block(&mut state);
let mut out = [0u8; 15];
out.copy_from_slice(&state[..15]);
FpeTweak(out)
}
pub fn encrypt(&self, tweak: &FpeTweak, data: &mut [u8; 16]) {
let saved = save_high_bits(data, self.num_bits);
clear_high_bits(data, self.num_bits);
fnr_encrypt(&self.key, &tweak.0, data, self.num_bits);
restore_high_bits(data, &saved, self.num_bits);
}
pub fn decrypt(&self, tweak: &FpeTweak, data: &mut [u8; 16]) {
let saved = save_high_bits(data, self.num_bits);
clear_high_bits(data, self.num_bits);
fnr_decrypt(&self.key, &tweak.0, data, self.num_bits);
restore_high_bits(data, &saved, self.num_bits);
}
pub fn num_bits(&self) -> usize {
self.num_bits
}
}
fn save_high_bits(data: &[u8; 16], n: usize) -> [u8; 16] {
let mut saved = [0u8; 16];
let full_bytes = n / 8;
let rem = n % 8;
if rem != 0 && full_bytes < 16 {
let mask = 0xFF_u8 >> rem;
saved[full_bytes] = data[full_bytes] & mask;
}
let start = full_bytes + if rem > 0 { 1 } else { 0 };
saved[start..16].copy_from_slice(&data[start..16]);
saved
}
fn restore_high_bits(data: &mut [u8; 16], saved: &[u8; 16], n: usize) {
let full_bytes = n / 8;
let rem = n % 8;
if rem != 0 && full_bytes < 16 {
let mask = 0xFF_u8 >> rem; data[full_bytes] = (data[full_bytes] & !mask) | (saved[full_bytes] & mask);
}
let start = full_bytes + if rem > 0 { 1 } else { 0 };
data[start..16].copy_from_slice(&saved[start..16]);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fpe_new_valid() {
assert!(FpeKey::new(&[0u8; 16], 1).is_ok());
assert!(FpeKey::new(&[0u8; 16], 32).is_ok());
assert!(FpeKey::new(&[0u8; 16], 128).is_ok());
}
#[test]
fn test_fpe_new_invalid() {
assert!(FpeKey::new(&[0u8; 16], 0).is_err());
assert!(FpeKey::new(&[0u8; 16], 129).is_err());
}
#[test]
fn test_fpe_encrypt_decrypt_roundtrip_32bits() {
let key = [0x01u8; 16];
let fpe = FpeKey::new(&key, 32).unwrap();
let tweak = fpe.expand_tweak(b"test-tweak");
let mut data = [0u8; 16];
data[..4].copy_from_slice(&12345678u32.to_be_bytes());
let original = data;
fpe.encrypt(&tweak, &mut data);
assert_ne!(&data[..4], &original[..4], "加密后数据应变化");
fpe.decrypt(&tweak, &mut data);
assert_eq!(&data[..4], &original[..4], "解密后应恢复原始明文");
}
#[test]
fn test_fpe_encrypt_decrypt_roundtrip_8bits() {
let key = [0xABu8; 16];
let fpe = FpeKey::new(&key, 8).unwrap();
let tweak = fpe.expand_tweak(b"tweak");
for val in 0u8..=255 {
let mut data = [0u8; 16];
data[0] = val;
let original = data;
fpe.encrypt(&tweak, &mut data);
fpe.decrypt(&tweak, &mut data);
assert_eq!(data[0], original[0], "8位加解密往返应还原 val={}", val);
}
}
#[test]
fn test_fpe_encrypt_decrypt_roundtrip_1bit() {
let key = [0x99u8; 16];
let fpe = FpeKey::new(&key, 1).unwrap();
let tweak = fpe.expand_tweak(b"");
for val in [0u8, 0x80u8] {
let mut data = [0u8; 16];
data[0] = val;
let original = data;
fpe.encrypt(&tweak, &mut data);
fpe.decrypt(&tweak, &mut data);
assert_eq!(data[0] & 0x80, original[0] & 0x80, "1位加解密往返应还原");
}
}
#[test]
fn test_fpe_encrypt_decrypt_roundtrip_128bits() {
let key = [0x55u8; 16];
let fpe = FpeKey::new(&key, 128).unwrap();
let tweak = fpe.expand_tweak(b"full block");
let mut data = [0u8; 16];
for (i, d) in data.iter_mut().enumerate() {
*d = i as u8 * 17;
}
let original = data;
fpe.encrypt(&tweak, &mut data);
fpe.decrypt(&tweak, &mut data);
assert_eq!(data, original, "128位加解密往返应还原");
}
#[test]
fn test_fpe_different_tweaks_different_output() {
let key = [0x42u8; 16];
let fpe = FpeKey::new(&key, 32).unwrap();
let tweak1 = fpe.expand_tweak(b"tweak1");
let tweak2 = fpe.expand_tweak(b"tweak2");
let mut d1 = [0u8; 16];
let mut d2 = [0u8; 16];
d1[0] = 0xDE;
d1[1] = 0xAD;
d1[2] = 0xBE;
d1[3] = 0xEF;
d2[..4].copy_from_slice(&d1[..4]);
fpe.encrypt(&tweak1, &mut d1);
fpe.encrypt(&tweak2, &mut d2);
assert_ne!(&d1[..4], &d2[..4], "不同 tweak 应产生不同密文");
}
#[test]
fn test_fpe_high_bits_preserved() {
let key = [0x11u8; 16];
let fpe = FpeKey::new(&key, 4).unwrap(); let tweak = fpe.expand_tweak(b"t");
let mut data = [0u8; 16];
data[0] = 0b1010_0101;
for (i, d) in data[1..].iter_mut().enumerate() {
*d = (i + 1) as u8;
}
let saved_low = data[0] & 0x0F;
let saved_rest: [u8; 15] = data[1..].try_into().unwrap();
fpe.encrypt(&tweak, &mut data);
assert_eq!(data[0] & 0x0F, saved_low, "低4位应不变");
assert_eq!(&data[1..], &saved_rest, "字节1~15应不变");
let encrypted_high = data[0] & 0xF0;
fpe.decrypt(&tweak, &mut data);
assert_eq!(data[0] & 0xF0, 0b1010_0000, "解密后高4位应恢复");
let _ = encrypted_high;
}
}