use crate::sm4::Sm4Key;
const N_ROUND: usize = 7;
#[inline]
fn get_bit(data: &[u8; 16], i: usize) -> u8 {
let byte = i / 8;
let bit = 7 - (i % 8);
(data[byte] >> bit) & 1
}
#[inline]
fn set_bit(data: &mut [u8; 16], i: usize, val: u8) {
let byte = i / 8;
let bit = 7 - (i % 8);
data[byte] = (data[byte] & !(1 << bit)) | (val << bit);
}
pub(super) fn clear_high_bits(data: &mut [u8; 16], n: usize) {
for i in n..128 {
set_bit(data, i, 0);
}
}
fn xor_bits(a: &[u8; 16], b: &[u8; 16], n: usize) -> [u8; 16] {
let mut out = [0u8; 16];
let full = n / 8;
for i in 0..full {
out[i] = a[i] ^ b[i];
}
if n % 8 != 0 {
let mask = 0xFF_u8 << (8 - n % 8);
out[full] = (a[full] ^ b[full]) & mask;
}
out
}
fn round_fn(
key: &Sm4Key,
tweak: &[u8; 15],
half: &[u8; 16],
half_bits: usize,
out_bits: usize,
round: usize,
) -> [u8; 16] {
let mut block = [0u8; 16];
block[..15].copy_from_slice(tweak);
block[15] = round as u8;
let half_bytes = half_bits.div_ceil(8);
for i in 0..half_bytes.min(16) {
block[i] ^= half[i];
}
key.encrypt_block(&mut block);
clear_high_bits(&mut block, out_bits);
block
}
pub fn fnr_encrypt(key: &Sm4Key, tweak: &[u8; 15], data: &mut [u8; 16], num_bits: usize) {
if num_bits == 1 {
fnr_1bit(key, tweak, data, true);
return;
}
let left_bits = num_bits / 2;
let right_bits = num_bits - left_bits;
let mut l = [0u8; 16];
let mut r = [0u8; 16];
for i in 0..left_bits {
set_bit(&mut l, i, get_bit(data, i));
}
for i in 0..right_bits {
set_bit(&mut r, i, get_bit(data, left_bits + i));
}
for round in 0..N_ROUND {
let f = round_fn(key, tweak, &r, right_bits, left_bits, round);
let new_r = xor_bits(&l, &f, left_bits);
l = r;
clear_high_bits(&mut l, left_bits); r = new_r;
}
for i in 0..left_bits {
set_bit(data, i, get_bit(&l, i));
}
for i in 0..right_bits {
set_bit(data, left_bits + i, get_bit(&r, i));
}
}
pub fn fnr_decrypt(key: &Sm4Key, tweak: &[u8; 15], data: &mut [u8; 16], num_bits: usize) {
if num_bits == 1 {
fnr_1bit(key, tweak, data, false);
return;
}
let left_bits = num_bits / 2;
let right_bits = num_bits - left_bits;
let mut l = [0u8; 16];
let mut r = [0u8; 16];
for i in 0..left_bits {
set_bit(&mut l, i, get_bit(data, i));
}
for i in 0..right_bits {
set_bit(&mut r, i, get_bit(data, left_bits + i));
}
for round in (0..N_ROUND).rev() {
let f = round_fn(key, tweak, &l, left_bits, right_bits, round);
let new_l = xor_bits(&r, &f, right_bits);
r = l;
clear_high_bits(&mut r, right_bits);
l = new_l;
}
for i in 0..left_bits {
set_bit(data, i, get_bit(&l, i));
}
for i in 0..right_bits {
set_bit(data, left_bits + i, get_bit(&r, i));
}
}
fn fnr_1bit(key: &Sm4Key, tweak: &[u8; 15], data: &mut [u8; 16], _encrypt: bool) {
let mut block = [0u8; 16];
block[..15].copy_from_slice(tweak);
block[15] = 0xFF; key.encrypt_block(&mut block);
let perm_bit = (block[0] >> 7) & 1;
let orig = get_bit(data, 0);
set_bit(data, 0, orig ^ perm_bit);
}