use super::aes::Aes;
use crate::BlockCipher;
fn gf128_mul_alpha(t: &mut [u8; 16]) {
let mut carry: u8 = 0;
for byte in t.iter_mut() {
let new_carry = *byte >> 7;
*byte = (*byte << 1) | carry;
carry = new_carry;
}
if carry != 0 {
t[0] ^= 0x87;
}
}
pub struct AesXts {
k1: Aes,
k2: Aes,
}
impl AesXts {
pub fn new(key: &[u8]) -> Option<Self> {
let half = match key.len() {
32 => 16,
64 => 32,
_ => return None,
};
let (k1_bytes, k2_bytes) = key.split_at(half);
if k1_bytes == k2_bytes {
return None;
}
Some(Self {
k1: <Aes as BlockCipher>::new(k1_bytes),
k2: <Aes as BlockCipher>::new(k2_bytes),
})
}
pub fn encrypt_sector(&self, tweak: &[u8; 16], data: &mut [u8]) {
if data.len() < 16 {
return;
}
let mut t = *tweak;
self.k2.encrypt_block(&mut t);
let n = data.len();
let full_blocks = n / 16;
let tail = n % 16;
let blocks_in_main = if tail == 0 { full_blocks } else { full_blocks - 1 };
for j in 0..blocks_in_main {
let off = j * 16;
let mut block = [0u8; 16];
block.copy_from_slice(&data[off..off + 16]);
for i in 0..16 {
block[i] ^= t[i];
}
self.k1.encrypt_block(&mut block);
for i in 0..16 {
block[i] ^= t[i];
}
data[off..off + 16].copy_from_slice(&block);
gf128_mul_alpha(&mut t);
}
if tail > 0 {
let last_full_off = (full_blocks - 1) * 16;
let mut block = [0u8; 16];
block.copy_from_slice(&data[last_full_off..last_full_off + 16]);
for i in 0..16 {
block[i] ^= t[i];
}
self.k1.encrypt_block(&mut block);
for i in 0..16 {
block[i] ^= t[i];
}
let tail_off = full_blocks * 16;
let mut cc = [0u8; 16];
cc[tail..].copy_from_slice(&block[tail..]);
cc[..tail].copy_from_slice(&data[tail_off..tail_off + tail]);
gf128_mul_alpha(&mut t);
for i in 0..16 {
cc[i] ^= t[i];
}
self.k1.encrypt_block(&mut cc);
for i in 0..16 {
cc[i] ^= t[i];
}
data[last_full_off..last_full_off + 16].copy_from_slice(&cc);
data[tail_off..tail_off + tail].copy_from_slice(&block[..tail]);
}
}
pub fn decrypt_sector(&self, tweak: &[u8; 16], data: &mut [u8]) {
if data.len() < 16 {
return;
}
let mut t = *tweak;
self.k2.encrypt_block(&mut t);
let n = data.len();
let full_blocks = n / 16;
let tail = n % 16;
let blocks_in_main = if tail == 0 { full_blocks } else { full_blocks - 1 };
for j in 0..blocks_in_main {
let off = j * 16;
let mut block = [0u8; 16];
block.copy_from_slice(&data[off..off + 16]);
for i in 0..16 {
block[i] ^= t[i];
}
self.k1.decrypt_block(&mut block);
for i in 0..16 {
block[i] ^= t[i];
}
data[off..off + 16].copy_from_slice(&block);
gf128_mul_alpha(&mut t);
}
if tail > 0 {
let last_full_off = (full_blocks - 1) * 16;
let tail_off = full_blocks * 16;
let mut t_advanced = t;
gf128_mul_alpha(&mut t_advanced);
let mut block = [0u8; 16];
block.copy_from_slice(&data[last_full_off..last_full_off + 16]);
for i in 0..16 {
block[i] ^= t_advanced[i];
}
self.k1.decrypt_block(&mut block);
for i in 0..16 {
block[i] ^= t_advanced[i];
}
let mut cc = [0u8; 16];
cc[..tail].copy_from_slice(&data[tail_off..tail_off + tail]);
cc[tail..].copy_from_slice(&block[tail..]);
for i in 0..16 {
cc[i] ^= t[i];
}
self.k1.decrypt_block(&mut cc);
for i in 0..16 {
cc[i] ^= t[i];
}
data[last_full_off..last_full_off + 16].copy_from_slice(&cc);
data[tail_off..tail_off + tail].copy_from_slice(&block[..tail]);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn hex(s: &str) -> Vec<u8> {
let s: String = s.chars().filter(|c| !c.is_whitespace()).collect();
assert!(s.len() % 2 == 0);
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
fn hex_arr<const N: usize>(s: &str) -> [u8; N] {
let v = hex(s);
assert_eq!(v.len(), N);
let mut out = [0u8; N];
out.copy_from_slice(&v);
out
}
#[test]
fn ieee1619_vector_1_all_zero() {
let zero = [0u8; 16];
let xts = AesXts {
k1: <Aes as BlockCipher>::new(&zero),
k2: <Aes as BlockCipher>::new(&zero),
};
let tweak: [u8; 16] = [0u8; 16];
let mut data = [0u8; 32];
xts.encrypt_sector(&tweak, &mut data);
let expected = hex("917cf69ebd68b2ec9b9fe9a3eadda692\
cd43d2f59598ed858c02c2652fbf922e");
assert_eq!(data.to_vec(), expected);
xts.decrypt_sector(&tweak, &mut data);
assert_eq!(data, [0u8; 32]);
}
#[test]
fn ieee1619_vector_2() {
let key = hex("11111111111111111111111111111111\
22222222222222222222222222222222");
let tweak: [u8; 16] = hex_arr("33333333330000000000000000000000");
let mut data = hex("44444444444444444444444444444444\
44444444444444444444444444444444");
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
let expected = hex("c454185e6a16936e39334038acef838b\
fb186fff7480adc4289382ecd6d394f0");
assert_eq!(data, expected);
xts.decrypt_sector(&tweak, &mut data);
let original = hex("44444444444444444444444444444444\
44444444444444444444444444444444");
assert_eq!(data, original);
}
#[test]
fn ieee1619_vector_3() {
let key = hex("fffefdfcfbfaf9f8f7f6f5f4f3f2f1f0\
22222222222222222222222222222222");
let tweak: [u8; 16] = hex_arr("33333333330000000000000000000000");
let mut data = hex("44444444444444444444444444444444\
44444444444444444444444444444444");
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
let expected = hex("af85336b597afc1a900b2eb21ec949d2\
92df4c047e0b21532186a5971a227a89");
assert_eq!(data, expected);
}
fn distinct_key_32() -> [u8; 32] {
let mut k = [0u8; 32];
for i in 0..16 {
k[i] = i as u8 ^ 0x42;
}
for i in 16..32 {
k[i] = i as u8 ^ 0xa5;
}
k
}
#[test]
fn xts_multi_block_roundtrip() {
let key = distinct_key_32();
let tweak = [0xa5u8; 16];
let original: Vec<u8> = (0..64).map(|i| i as u8).collect();
let mut data = original.clone();
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
assert_ne!(data, original);
xts.decrypt_sector(&tweak, &mut data);
assert_eq!(data, original);
}
#[test]
fn xts_ciphertext_stealing_17_bytes() {
let key = distinct_key_32();
let tweak = [0xa5u8; 16];
let original: Vec<u8> = (0..17).map(|i| (i * 11) as u8).collect();
let mut data = original.clone();
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
assert_eq!(data.len(), 17, "XTS must be length-preserving");
assert_ne!(data, original);
xts.decrypt_sector(&tweak, &mut data);
assert_eq!(data, original);
}
#[test]
fn xts_ciphertext_stealing_31_bytes() {
let key = distinct_key_32();
let tweak = [0xa5u8; 16];
let original: Vec<u8> = (0..31).map(|i| (i * 7) as u8).collect();
let mut data = original.clone();
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
assert_eq!(data.len(), 31);
xts.decrypt_sector(&tweak, &mut data);
assert_eq!(data, original);
}
#[test]
fn xts_ciphertext_stealing_100_bytes() {
let key = distinct_key_32();
let tweak = [0xa5u8; 16];
let original: Vec<u8> = (0..100).map(|i| (i ^ 0x5a) as u8).collect();
let mut data = original.clone();
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
assert_eq!(data.len(), 100);
xts.decrypt_sector(&tweak, &mut data);
assert_eq!(data, original);
}
#[test]
fn xts_different_tweaks_differ() {
let key = distinct_key_32();
let tweak1 = [0u8; 16];
let mut tweak2 = [0u8; 16];
tweak2[0] = 1;
let original: Vec<u8> = (0..32).map(|i| i as u8).collect();
let xts = AesXts::new(&key).unwrap();
let mut data1 = original.clone();
xts.encrypt_sector(&tweak1, &mut data1);
let mut data2 = original.clone();
xts.encrypt_sector(&tweak2, &mut data2);
assert_ne!(data1, data2);
}
#[test]
fn xts_aes256_roundtrip() {
let mut key = [0u8; 64];
for i in 0..32 {
key[i] = i as u8;
}
for i in 32..64 {
key[i] = (i + 0x80) as u8;
}
let tweak = [0x11u8; 16];
let original: Vec<u8> = (0..48).map(|i| i as u8).collect();
let mut data = original.clone();
let xts = AesXts::new(&key).unwrap();
xts.encrypt_sector(&tweak, &mut data);
xts.decrypt_sector(&tweak, &mut data);
assert_eq!(data, original);
}
#[test]
fn xts_rejects_invalid_key_lengths() {
for bad_len in [0, 1, 15, 16, 17, 24, 31, 33, 48, 63, 65, 128] {
let key = vec![0u8; bad_len];
assert!(AesXts::new(&key).is_none(), "key length {} should be rejected", bad_len);
}
let mut k32 = [0u8; 32];
k32[16] = 1;
assert!(AesXts::new(&k32).is_some());
let mut k64 = [0u8; 64];
k64[32] = 1;
assert!(AesXts::new(&k64).is_some());
}
#[test]
fn xts_rejects_k1_eq_k2() {
let k32 = [0x11u8; 32];
assert!(AesXts::new(&k32).is_none());
let k64 = [0x11u8; 64];
assert!(AesXts::new(&k64).is_none());
}
#[test]
fn gf128_mul_alpha_basic() {
let mut t = [0u8; 16];
t[0] = 0x01;
gf128_mul_alpha(&mut t);
let mut expected = [0u8; 16];
expected[0] = 0x02;
assert_eq!(t, expected);
let mut t = [0u8; 16];
t[15] = 0x80;
gf128_mul_alpha(&mut t);
let mut expected = [0u8; 16];
expected[0] = 0x87;
assert_eq!(t, expected);
}
}