use alloc::boxed::Box;
use aes::cipher::{BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
use aes::{Aes256, Block};
use crate::{
types::{VolumeCipher, VolumeCipherSupplier},
VckError, VckResult,
};
const BATCH: usize = 8;
pub struct XtsVolumeCipher {
cipher_1: Aes256,
cipher_2: Aes256,
}
impl VolumeCipher for XtsVolumeCipher {
fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
XtsVolumeCipher::encrypt_sector(self, rel_sector, sector)
}
fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
XtsVolumeCipher::decrypt_sector(self, rel_sector, sector)
}
fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
XtsVolumeCipher::encrypt_area(self, buf, sector_size, first_rel_sector)
}
fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
XtsVolumeCipher::decrypt_area(self, buf, sector_size, first_rel_sector)
}
}
pub struct StaticCipherSupplier {
key1: [u8; 32],
key2: [u8; 32],
}
impl StaticCipherSupplier {
pub fn new(key1: [u8; 32], key2: [u8; 32]) -> Self {
Self { key1, key2 }
}
}
impl VolumeCipherSupplier for StaticCipherSupplier {
fn get_cipher(&self) -> Option<Box<dyn VolumeCipher>> {
XtsVolumeCipher::new(&self.key1, &self.key2)
.ok()
.map(|c| Box::new(c) as Box<dyn VolumeCipher>)
}
}
#[inline(always)]
fn gf128_mul(t: Block) -> Block {
let lo = u64::from_le_bytes(t[..8].try_into().unwrap());
let hi = u64::from_le_bytes(t[8..].try_into().unwrap());
let carry = if hi >> 63 != 0 { 0x87u64 } else { 0u64 };
let mut out = Block::default();
out[..8].copy_from_slice(&((lo << 1) ^ carry).to_le_bytes());
out[8..].copy_from_slice(&((hi << 1) | (lo >> 63)).to_le_bytes());
out
}
impl XtsVolumeCipher {
pub fn new(key1: &[u8; 32], key2: &[u8; 32]) -> VckResult<Self> {
let cipher_1 =
Aes256::new_from_slice(key1).map_err(|_| VckError::CryptoFailed("invalid XTS key1"))?;
let cipher_2 =
Aes256::new_from_slice(key2).map_err(|_| VckError::CryptoFailed("invalid XTS key2"))?;
Ok(Self { cipher_1, cipher_2 })
}
pub fn encrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
self.encrypt_sector_inner(rel_sector, sector);
}
pub fn decrypt_sector(&self, rel_sector: u64, sector: &mut [u8]) {
self.decrypt_sector_inner(rel_sector, sector);
}
pub fn encrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
self.encrypt_sector_inner(first_rel_sector + si as u64, sector);
}
}
pub fn decrypt_area(&self, buf: &mut [u8], sector_size: usize, first_rel_sector: u64) {
for (si, sector) in buf.chunks_mut(sector_size).enumerate() {
self.decrypt_sector_inner(first_rel_sector + si as u64, sector);
}
}
#[inline(never)]
fn encrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
self.cipher_2.encrypt_block(&mut tw);
let n = sector.len() / 16;
let mut off = 0;
while off + BATCH <= n {
let mut ts = [Block::default(); BATCH];
ts[0] = tw;
for i in 1..BATCH {
ts[i] = gf128_mul(ts[i - 1]);
}
tw = gf128_mul(ts[BATCH - 1]);
let mut batch = [Block::default(); BATCH];
for i in 0..BATCH {
let src = §or[(off + i) * 16..(off + i + 1) * 16];
for j in 0..16 {
batch[i][j] = src[j] ^ ts[i][j];
}
}
self.cipher_1.encrypt_blocks(&mut batch);
for i in 0..BATCH {
let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
for j in 0..16 {
dst[j] = batch[i][j] ^ ts[i][j];
}
}
off += BATCH;
}
while off < n {
let block = &mut sector[off * 16..(off + 1) * 16];
for j in 0..16 {
block[j] ^= tw[j];
}
let mut ga: Block = Block::try_from(&block[..]).unwrap();
self.cipher_1.encrypt_block(&mut ga);
block.copy_from_slice(&ga);
for j in 0..16 {
block[j] ^= tw[j];
}
tw = gf128_mul(tw);
off += 1;
}
}
#[inline(never)]
fn decrypt_sector_inner(&self, rel_sector: u64, sector: &mut [u8]) {
let mut tw: Block = (rel_sector as u128).to_le_bytes().into();
self.cipher_2.encrypt_block(&mut tw);
let n = sector.len() / 16;
let mut off = 0;
while off + BATCH <= n {
let mut ts = [Block::default(); BATCH];
ts[0] = tw;
for i in 1..BATCH {
ts[i] = gf128_mul(ts[i - 1]);
}
tw = gf128_mul(ts[BATCH - 1]);
let mut batch = [Block::default(); BATCH];
for i in 0..BATCH {
let src = §or[(off + i) * 16..(off + i + 1) * 16];
for j in 0..16 {
batch[i][j] = src[j] ^ ts[i][j];
}
}
self.cipher_1.decrypt_blocks(&mut batch);
for i in 0..BATCH {
let dst = &mut sector[(off + i) * 16..(off + i + 1) * 16];
for j in 0..16 {
dst[j] = batch[i][j] ^ ts[i][j];
}
}
off += BATCH;
}
while off < n {
let block = &mut sector[off * 16..(off + 1) * 16];
for j in 0..16 {
block[j] ^= tw[j];
}
let mut ga: Block = Block::try_from(&block[..]).unwrap();
self.cipher_1.decrypt_block(&mut ga);
block.copy_from_slice(&ga);
for j in 0..16 {
block[j] ^= tw[j];
}
tw = gf128_mul(tw);
off += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
use xts_mode::{get_tweak_default, Xts128};
const KEY1: [u8; 32] = [0x11; 32];
const KEY2: [u8; 32] = [0x22; 32];
fn reference() -> Xts128<Aes256> {
let c1 = Aes256::new_from_slice(&KEY1).unwrap();
let c2 = Aes256::new_from_slice(&KEY2).unwrap();
Xts128::new(c1, c2)
}
#[test]
fn sector_roundtrip() {
let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
let plain: Vec<u8> = (0..512).map(|i| i as u8).collect();
let mut buf = plain.clone();
c.encrypt_sector(42, &mut buf);
assert_ne!(buf, plain, "ciphertext must differ from plaintext");
c.decrypt_sector(42, &mut buf);
assert_eq!(buf, plain);
}
#[test]
fn tweak_depends_on_sector() {
let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
let plain = [0xABu8; 512];
let mut a = plain;
let mut b = plain;
c.encrypt_sector(0, &mut a);
c.encrypt_sector(1, &mut b);
assert_ne!(a, b, "same plaintext at different sectors must differ");
}
#[test]
fn matches_xts_mode_reference() {
let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
let xts = reference();
let sector_size = 512usize;
let first = 7u64;
let plain: Vec<u8> = (0..sector_size * 3).map(|i| (i * 7) as u8).collect();
let mut ours = plain.clone();
c.encrypt_area(&mut ours, sector_size, first);
let mut refer = plain.clone();
for s in 0..3u64 {
let start = s as usize * sector_size;
xts.encrypt_sector(
&mut refer[start..start + sector_size],
get_tweak_default((first + s) as u128),
);
}
assert_eq!(ours, refer, "parallel XTS must match xts-mode reference");
c.decrypt_area(&mut ours, sector_size, first);
assert_eq!(ours, plain);
}
#[test]
fn small_sector_roundtrip() {
let c = XtsVolumeCipher::new(&KEY1, &KEY2).unwrap();
let sector_size = 64usize;
let plain: Vec<u8> = (0..sector_size * 5).map(|i| i as u8).collect();
let mut buf = plain.clone();
c.encrypt_area(&mut buf, sector_size, 0);
assert_ne!(buf, plain);
c.decrypt_area(&mut buf, sector_size, 0);
assert_eq!(buf, plain);
}
}