#![no_std]
use core::convert::TryFrom;
pub use cipher::array::{self, Array};
use cipher::consts::U16;
use cipher::{BlockCipherDecrypt, BlockCipherEncrypt, BlockSizeUser};
pub struct Xts128<C> {
cipher_1: C,
cipher_2: C,
}
impl<C> Xts128<C>
where
C: BlockSizeUser<BlockSize = U16>,
{
pub fn new(cipher_1: C, cipher_2: C) -> Xts128<C> {
Xts128 { cipher_1, cipher_2 }
}
}
impl<C> Xts128<C>
where
C: BlockSizeUser<BlockSize = U16> + BlockCipherEncrypt,
{
pub fn encrypt_sector(&self, sector: &mut [u8], mut tweak: Array<u8, U16>) {
assert!(
sector.len() >= 16,
"AES-XTS needs at least one complete block in each sector"
);
self.cipher_2.encrypt_block(&mut tweak);
let block_count = sector.len() / 16;
for i in (0..sector.len()).step_by(16).take(block_count - 1) {
let block = Array::slice_as_mut_array(&mut sector[i..i + 16]).unwrap();
xor(block, &tweak);
self.cipher_1.encrypt_block(block);
xor(block, &tweak);
tweak = galois_field_128_mul_le(tweak);
}
{
let block =
Array::slice_as_mut_array(&mut sector[16 * (block_count - 1)..16 * block_count])
.unwrap();
xor(block, &tweak);
self.cipher_1.encrypt_block(block);
xor(block, &tweak);
}
let remainder = sector.len() % 16;
if remainder != 0 {
let last_tweak = galois_field_128_mul_le(tweak);
let (full_block, partial_block) = sector[16 * (block_count - 1)..].split_at_mut(16);
let (last_ciphertext, cipher_plaintext) = full_block.split_at(remainder);
let mut last_block = Array([0u8; 16]);
last_block[..remainder].copy_from_slice(partial_block);
last_block[remainder..].copy_from_slice(cipher_plaintext);
xor(&mut last_block, &last_tweak);
self.cipher_1.encrypt_block(&mut last_block);
xor(&mut last_block, &last_tweak);
partial_block.copy_from_slice(last_ciphertext);
full_block.copy_from_slice(&last_block);
}
}
pub fn encrypt_area(
&self,
area: &mut [u8],
sector_size: usize,
first_sector_index: u128,
get_tweak_fn: impl Fn(u128) -> Array<u8, U16>,
) {
let area_len = area.len();
let mut chunks = area.chunks_exact_mut(sector_size);
for (i, chunk) in (&mut chunks).enumerate() {
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.encrypt_sector(chunk, tweak);
}
let remainder = chunks.into_remainder();
if !remainder.is_empty() {
let i = area_len / sector_size;
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.encrypt_sector(remainder, tweak);
}
}
}
impl<C> Xts128<C>
where
C: BlockSizeUser<BlockSize = U16> + BlockCipherEncrypt + BlockCipherDecrypt,
{
pub fn decrypt_sector(&self, sector: &mut [u8], mut tweak: Array<u8, U16>) {
assert!(
sector.len() >= 16,
"AES-XTS needs at least one complete block in each sector"
);
self.cipher_2.encrypt_block(&mut tweak);
let block_count = sector.len() / 16;
for i in (0..sector.len()).step_by(16).take(block_count - 1) {
let block = Array::slice_as_mut_array(&mut sector[i..i + 16]).unwrap();
xor(block, &tweak);
self.cipher_1.decrypt_block(block);
xor(block, &tweak);
tweak = galois_field_128_mul_le(tweak);
}
let remainder = sector.len() % 16;
if remainder != 0 {
let next_to_last_tweak = tweak;
let last_tweak = galois_field_128_mul_le(tweak);
let (full_block, partial_block) = sector[16 * (block_count - 1)..].split_at_mut(16);
let full_block = Array::slice_as_mut_array(full_block).unwrap();
xor(full_block, &last_tweak);
self.cipher_1.decrypt_block(full_block);
xor(full_block, &last_tweak);
let (last_plaintext, cipher_plaintext) = full_block.split_at(remainder);
let mut last_block = Array([0u8; 16]);
last_block[..remainder].copy_from_slice(partial_block);
last_block[remainder..].copy_from_slice(cipher_plaintext);
xor(&mut last_block, &next_to_last_tweak);
self.cipher_1.decrypt_block(&mut last_block);
xor(&mut last_block, &next_to_last_tweak);
partial_block.copy_from_slice(last_plaintext);
full_block.copy_from_slice(&last_block);
} else {
let block = Array::slice_as_mut_array(&mut sector[16 * (block_count - 1)..]).unwrap();
xor(block, &tweak);
self.cipher_1.decrypt_block(block);
xor(block, &tweak);
}
}
pub fn decrypt_area(
&self,
area: &mut [u8],
sector_size: usize,
first_sector_index: u128,
get_tweak_fn: impl Fn(u128) -> Array<u8, U16>,
) {
let area_len = area.len();
let mut chunks = area.chunks_exact_mut(sector_size);
for (i, chunk) in (&mut chunks).enumerate() {
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.decrypt_sector(chunk, tweak);
}
let remainder = chunks.into_remainder();
if !remainder.is_empty() {
let i = area_len / sector_size;
let tweak = get_tweak_fn(
u128::try_from(i).expect("usize cannot be bigger than u128") + first_sector_index,
);
self.decrypt_sector(remainder, tweak);
}
}
}
pub fn get_tweak_default(sector_index: u128) -> Array<u8, U16> {
Array(sector_index.to_le_bytes())
}
#[inline(always)]
fn xor(buf: &mut [u8], key: &[u8]) {
debug_assert_eq!(buf.len(), key.len());
for (a, b) in buf.iter_mut().zip(key) {
*a ^= *b;
}
}
fn galois_field_128_mul_le(tweak_source: Array<u8, U16>) -> Array<u8, U16> {
let tweak_source = u128::from_le_bytes(tweak_source.0);
let special = ((tweak_source as i128 >> 127) & 0x87) as u128;
let tweak = tweak_source << 1 ^ special;
Array(tweak.to_le_bytes())
}