use std::convert::TryFrom;
use std::convert::TryInto;
use byteorder::{ByteOrder, LittleEndian};
use cipher::generic_array::typenum::Unsigned;
use cipher::generic_array::GenericArray;
use cipher::{BlockCipher, BlockDecrypt, BlockEncrypt, BlockSizeUser};
pub struct Xts128<C: BlockEncrypt + BlockDecrypt + BlockCipher> {
cipher_1: C,
cipher_2: C,
}
impl<C: BlockEncrypt + BlockDecrypt + BlockCipher> Xts128<C> {
pub fn new(cipher_1: C, cipher_2: C) -> Xts128<C> {
Xts128 { cipher_1, cipher_2 }
}
pub fn encrypt_sector(&self, sector: &mut [u8], mut tweak: [u8; 16]) {
assert_eq!(
<C as BlockSizeUser>::BlockSize::to_usize(),
128 / 8,
"Wrong block size"
);
assert!(
sector.len() >= 16,
"AES-XTS needs at least two blocks to perform stealing, or a single complete block"
);
let block_count = sector.len() / 16;
let need_stealing = sector.len() % 16 != 0;
self.cipher_2
.encrypt_block(GenericArray::from_mut_slice(&mut tweak));
let nosteal_block_count = if need_stealing {
block_count - 1
} else {
block_count
};
for i in (0..sector.len()).step_by(16).take(nosteal_block_count) {
let block = &mut sector[i..i + 16];
xor(block, &tweak);
self.cipher_1
.encrypt_block(GenericArray::from_mut_slice(block));
xor(block, &tweak);
tweak = galois_field_128_mul_le(tweak);
}
if need_stealing {
let next_to_last_tweak = tweak;
let last_tweak = galois_field_128_mul_le(tweak);
let remaining = sector.len() % 16;
let mut block: [u8; 16] = sector[16 * (block_count - 1)..16 * block_count]
.try_into()
.unwrap();
xor(&mut block, &next_to_last_tweak);
self.cipher_1
.encrypt_block(GenericArray::from_mut_slice(&mut block));
xor(&mut block, &next_to_last_tweak);
let mut last_block = [0u8; 16];
last_block[..remaining].copy_from_slice(§or[16 * block_count..]);
last_block[remaining..].copy_from_slice(&block[remaining..]);
xor(&mut last_block, &last_tweak);
self.cipher_1
.encrypt_block(GenericArray::from_mut_slice(&mut last_block));
xor(&mut last_block, &last_tweak);
sector[16 * (block_count - 1)..16 * block_count].copy_from_slice(&last_block);
sector[16 * block_count..].copy_from_slice(&block[..remaining]);
}
}
pub fn decrypt_sector(&self, sector: &mut [u8], mut tweak: [u8; 16]) {
assert_eq!(
<C as BlockSizeUser>::BlockSize::to_usize(),
128 / 8,
"Wrong block size"
);
assert!(
sector.len() >= 16,
"AES-XTS needs at least two blocks to perform stealing, or a single complete block"
);
let block_count = sector.len() / 16;
let need_stealing = sector.len() % 16 != 0;
self.cipher_2
.encrypt_block(GenericArray::from_mut_slice(&mut tweak));
let nosteal_block_count = if need_stealing {
block_count - 1
} else {
block_count
};
for i in (0..sector.len()).step_by(16).take(nosteal_block_count) {
let block = &mut sector[i..i + 16];
xor(block, &tweak);
self.cipher_1
.decrypt_block(GenericArray::from_mut_slice(block));
xor(block, &tweak);
tweak = galois_field_128_mul_le(tweak);
}
if need_stealing {
let next_to_last_tweak = tweak;
let last_tweak = galois_field_128_mul_le(tweak);
let remaining = sector.len() % 16;
let mut block: [u8; 16] = sector[16 * (block_count - 1)..16 * block_count]
.try_into()
.unwrap();
xor(&mut block, &last_tweak);
self.cipher_1
.decrypt_block(GenericArray::from_mut_slice(&mut block));
xor(&mut block, &last_tweak);
let mut last_block = [0u8; 16];
last_block[..remaining].copy_from_slice(§or[16 * block_count..]);
last_block[remaining..].copy_from_slice(&block[remaining..]);
xor(&mut last_block, &next_to_last_tweak);
self.cipher_1
.decrypt_block(GenericArray::from_mut_slice(&mut last_block));
xor(&mut last_block, &next_to_last_tweak);
sector[16 * (block_count - 1)..16 * block_count].copy_from_slice(&last_block);
sector[16 * block_count..].copy_from_slice(&block[..remaining]);
}
}
pub fn encrypt_area(
&self,
area: &mut [u8],
sector_size: usize,
first_sector_index: u128,
get_tweak_fn: impl Fn(u128) -> [u8; 16],
) {
let area_len = area.len();
let mut chunks = area.chunks_exact_mut(sector_size);
for (i, chunk) in (&mut chunks).enumerate() {
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.encrypt_sector(chunk, tweak);
}
let remainder = chunks.into_remainder();
if !remainder.is_empty() {
let i = area_len / sector_size;
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.encrypt_sector(remainder, tweak);
}
}
pub fn decrypt_area(
&self,
area: &mut [u8],
sector_size: usize,
first_sector_index: u128,
get_tweak_fn: impl Fn(u128) -> [u8; 16],
) {
let area_len = area.len();
let mut chunks = area.chunks_exact_mut(sector_size);
for (i, chunk) in (&mut chunks).enumerate() {
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.decrypt_sector(chunk, tweak);
}
let remainder = chunks.into_remainder();
if !remainder.is_empty() {
let i = area_len / sector_size;
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.decrypt_sector(remainder, tweak);
}
}
}
pub fn get_tweak_default(sector_index: u128) -> [u8; 16] {
sector_index.to_le_bytes()
}
#[inline(always)]
fn xor(buf: &mut [u8], key: &[u8]) {
debug_assert_eq!(buf.len(), key.len());
for (a, b) in buf.iter_mut().zip(key) {
*a ^= *b;
}
}
fn galois_field_128_mul_le(tweak_source: [u8; 16]) -> [u8; 16] {
let low_bytes = u64::from_le_bytes(tweak_source[0..8].try_into().unwrap());
let high_bytes = u64::from_le_bytes(tweak_source[8..16].try_into().unwrap());
let new_low_bytes = (low_bytes << 1) ^ if (high_bytes >> 63) != 0 { 0x87 } else { 0x00 };
let new_high_bytes = (low_bytes >> 63) | (high_bytes << 1);
let mut tweak = [0; 16];
LittleEndian::write_u64(&mut tweak[0..8], new_low_bytes);
LittleEndian::write_u64(&mut tweak[8..16], new_high_bytes);
tweak
}