use alloc::vec::Vec;
use subtle::ConstantTimeEq;
use zeroize::Zeroize;
use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
pub const XTS_KEY_SIZE: usize = 2 * KEY_SIZE;
const MAX_LEN: usize = (1 << 20) * BLOCK_SIZE;
fn mul_alpha(t: &mut [u8; BLOCK_SIZE]) {
let mut carry = 0u8;
for b in t.iter_mut() {
let next = *b & 1;
*b = (*b >> 1) | (carry << 7);
carry = next;
}
t[0] ^= 0xE1 & carry.wrapping_neg();
}
fn xor16(a: &mut [u8; BLOCK_SIZE], b: &[u8; BLOCK_SIZE]) {
for (x, y) in a.iter_mut().zip(b.iter()) {
*x ^= *y;
}
}
fn split_keys(key: &[u8; XTS_KEY_SIZE], len: usize) -> Option<(Sm4Cipher, Sm4Cipher)> {
if !(BLOCK_SIZE..=MAX_LEN).contains(&len) {
return None;
}
if bool::from(key[..KEY_SIZE].ct_eq(&key[KEY_SIZE..])) {
return None;
}
let mut key1 = [0u8; KEY_SIZE];
let mut key2 = [0u8; KEY_SIZE];
key1.copy_from_slice(&key[..KEY_SIZE]);
key2.copy_from_slice(&key[KEY_SIZE..]);
let ciphers = (Sm4Cipher::new(&key1), Sm4Cipher::new(&key2));
key1.zeroize();
key2.zeroize();
Some(ciphers)
}
#[must_use]
pub fn encrypt(
key: &[u8; XTS_KEY_SIZE],
tweak: &[u8; BLOCK_SIZE],
data_unit: &[u8],
) -> Option<Vec<u8>> {
let (c1, c2) = split_keys(key, data_unit.len())?;
Some(xts_encrypt(&c1, &c2, tweak, data_unit))
}
#[must_use]
pub fn decrypt(
key: &[u8; XTS_KEY_SIZE],
tweak: &[u8; BLOCK_SIZE],
data_unit: &[u8],
) -> Option<Vec<u8>> {
let (c1, c2) = split_keys(key, data_unit.len())?;
Some(xts_decrypt(&c1, &c2, tweak, data_unit))
}
#[must_use]
pub fn encrypt_sectors(
key: &[u8; XTS_KEY_SIZE],
sector_size: usize,
start_sector: u128,
buf: &mut [u8],
) -> Option<()> {
process_sectors(key, sector_size, start_sector, buf, true)
}
#[must_use]
pub fn decrypt_sectors(
key: &[u8; XTS_KEY_SIZE],
sector_size: usize,
start_sector: u128,
buf: &mut [u8],
) -> Option<()> {
process_sectors(key, sector_size, start_sector, buf, false)
}
fn process_sectors(
key: &[u8; XTS_KEY_SIZE],
sector_size: usize,
start_sector: u128,
buf: &mut [u8],
encrypt: bool,
) -> Option<()> {
if !(BLOCK_SIZE..=MAX_LEN).contains(§or_size) || sector_size % BLOCK_SIZE != 0 {
return None;
}
if buf.len() % sector_size != 0 {
return None;
}
let sector_count = buf.len() / sector_size;
if let Some(last) = sector_count.checked_sub(1) {
start_sector.checked_add(last as u128)?;
}
let (c1, c2) = split_keys(key, sector_size)?;
let nblocks = sector_size / BLOCK_SIZE;
let mut blocks: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(nblocks);
let mut tweaks: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(nblocks);
for (i, sector) in buf.chunks_mut(sector_size).enumerate() {
let sector_num = start_sector.wrapping_add(i as u128);
let tweak = sector_num.to_le_bytes();
xts_sector_in_place(&c1, &c2, &tweak, sector, &mut blocks, &mut tweaks, encrypt);
}
Some(())
}
fn xts_sector_in_place(
c1: &Sm4Cipher,
c2: &Sm4Cipher,
tweak: &[u8; BLOCK_SIZE],
sector: &mut [u8],
blocks: &mut Vec<[u8; BLOCK_SIZE]>,
tweaks: &mut Vec<[u8; BLOCK_SIZE]>,
encrypt: bool,
) {
let mut t = *tweak;
c2.encrypt_block(&mut t);
blocks.clear();
tweaks.clear();
for chunk in sector.chunks_exact(BLOCK_SIZE) {
let mut blk = [0u8; BLOCK_SIZE];
blk.copy_from_slice(chunk);
xor16(&mut blk, &t);
blocks.push(blk);
tweaks.push(t);
mul_alpha(&mut t);
}
if encrypt {
c1.encrypt_blocks(blocks);
} else {
c1.decrypt_blocks(blocks);
}
for (blk, tw) in blocks.iter_mut().zip(tweaks.iter()) {
xor16(blk, tw);
}
for (chunk, blk) in sector.chunks_exact_mut(BLOCK_SIZE).zip(blocks.iter()) {
chunk.copy_from_slice(blk);
}
t.zeroize();
for tw in tweaks.iter_mut() {
tw.zeroize();
}
}
fn xts_encrypt(c1: &Sm4Cipher, c2: &Sm4Cipher, tweak: &[u8; BLOCK_SIZE], data: &[u8]) -> Vec<u8> {
let len = data.len();
let full = len / BLOCK_SIZE;
let rem = len % BLOCK_SIZE;
let normal = if rem == 0 { full } else { full - 1 };
let mut t = *tweak;
c2.encrypt_block(&mut t);
let mut tweaks: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(normal);
let mut blocks: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(normal);
for j in 0..normal {
let mut blk = [0u8; BLOCK_SIZE];
blk.copy_from_slice(&data[j * BLOCK_SIZE..j * BLOCK_SIZE + BLOCK_SIZE]);
xor16(&mut blk, &t);
blocks.push(blk);
tweaks.push(t);
mul_alpha(&mut t);
}
c1.encrypt_blocks(&mut blocks);
for (blk, tw) in blocks.iter_mut().zip(tweaks.iter()) {
xor16(blk, tw);
}
let mut out = Vec::with_capacity(len);
for blk in &blocks {
out.extend_from_slice(blk);
}
if rem != 0 {
let mut t_last = t;
let mut t_steal = t;
mul_alpha(&mut t_steal);
let mut cc = [0u8; BLOCK_SIZE];
cc.copy_from_slice(&data[normal * BLOCK_SIZE..normal * BLOCK_SIZE + BLOCK_SIZE]);
xor16(&mut cc, &t_last);
c1.encrypt_block(&mut cc);
xor16(&mut cc, &t_last);
let mut pp = [0u8; BLOCK_SIZE];
let partial = &data[normal * BLOCK_SIZE + BLOCK_SIZE..];
pp[..rem].copy_from_slice(partial);
pp[rem..].copy_from_slice(&cc[rem..]);
xor16(&mut pp, &t_steal);
c1.encrypt_block(&mut pp);
xor16(&mut pp, &t_steal);
out.extend_from_slice(&pp);
out.extend_from_slice(&cc[..rem]);
t_last.zeroize();
t_steal.zeroize();
}
t.zeroize();
for tw in &mut tweaks {
tw.zeroize();
}
out
}
fn xts_decrypt(c1: &Sm4Cipher, c2: &Sm4Cipher, tweak: &[u8; BLOCK_SIZE], data: &[u8]) -> Vec<u8> {
let len = data.len();
let full = len / BLOCK_SIZE;
let rem = len % BLOCK_SIZE;
let normal = if rem == 0 { full } else { full - 1 };
let mut t = *tweak;
c2.encrypt_block(&mut t);
let mut tweaks: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(normal);
let mut blocks: Vec<[u8; BLOCK_SIZE]> = Vec::with_capacity(normal);
for j in 0..normal {
let mut blk = [0u8; BLOCK_SIZE];
blk.copy_from_slice(&data[j * BLOCK_SIZE..j * BLOCK_SIZE + BLOCK_SIZE]);
xor16(&mut blk, &t);
blocks.push(blk);
tweaks.push(t);
mul_alpha(&mut t);
}
c1.decrypt_blocks(&mut blocks);
for (blk, tw) in blocks.iter_mut().zip(tweaks.iter()) {
xor16(blk, tw);
}
let mut out = Vec::with_capacity(len);
for blk in &blocks {
out.extend_from_slice(blk);
}
if rem != 0 {
let mut t_last = t;
let mut t_steal = t;
mul_alpha(&mut t_steal);
let mut pp = [0u8; BLOCK_SIZE];
pp.copy_from_slice(&data[normal * BLOCK_SIZE..normal * BLOCK_SIZE + BLOCK_SIZE]);
xor16(&mut pp, &t_steal);
c1.decrypt_block(&mut pp);
xor16(&mut pp, &t_steal);
let mut cc = [0u8; BLOCK_SIZE];
let partial = &data[normal * BLOCK_SIZE + BLOCK_SIZE..];
cc[..rem].copy_from_slice(partial);
cc[rem..].copy_from_slice(&pp[rem..]);
xor16(&mut cc, &t_last);
c1.decrypt_block(&mut cc);
xor16(&mut cc, &t_last);
out.extend_from_slice(&cc);
out.extend_from_slice(&pp[..rem]);
t_last.zeroize();
t_steal.zeroize();
}
t.zeroize();
for tw in &mut tweaks {
tw.zeroize();
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mul_alpha_no_carry_is_plain_right_shift() {
let mut t = [0u8; 16];
t[0] = 0x02;
mul_alpha(&mut t);
let mut expected = [0u8; 16];
expected[0] = 0x01;
assert_eq!(t, expected);
}
#[test]
fn mul_alpha_carry_xors_0xe1() {
let mut t = [0u8; 16];
t[15] = 0x01;
mul_alpha(&mut t);
let mut expected = [0u8; 16];
expected[0] = 0xE1;
assert_eq!(t, expected);
}
const KEY: [u8; XTS_KEY_SIZE] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
0x0e, 0x0f,
];
const TWEAK: [u8; 16] = [0x11; 16];
#[test]
#[allow(clippy::cast_possible_truncation)]
fn xts_round_trip_all_lengths() {
for len in 16..=80usize {
let pt: Vec<u8> = (0..len).map(|i| (i as u8) ^ 0xA5).collect();
let ct = encrypt(&KEY, &TWEAK, &pt).expect("valid");
assert_eq!(ct.len(), pt.len(), "len-preserving at {len}");
let rt = decrypt(&KEY, &TWEAK, &ct).expect("valid");
assert_eq!(rt, pt, "round-trip at length {len}");
}
}
#[test]
fn xts_rejects_short_long_and_weak_key() {
assert!(encrypt(&KEY, &TWEAK, &[0u8; 15]).is_none(), "len 15 < 16");
assert!(encrypt(&KEY, &TWEAK, &[]).is_none(), "len 0");
assert!(
encrypt(&KEY, &TWEAK, &alloc::vec![0u8; MAX_LEN + 1]).is_none(),
"len > 16 MiB"
);
let mut weak = KEY;
weak.copy_within(0..16, 16); assert!(encrypt(&weak, &TWEAK, &[0u8; 16]).is_none(), "Key1 == Key2");
assert!(decrypt(&weak, &TWEAK, &[0u8; 16]).is_none(), "Key1 == Key2");
}
}