#[cfg(target_arch = "x86_64")]
pub unsafe fn xor_chunks_intrinsic_baseline<T>(data: *mut u8, key: *const u8) {
use std::arch::asm;
let size = std::mem::size_of::<T>();
let min_alignment = std::mem::align_of::<T>();
let min_alignment_bits: u32 = min_alignment.trailing_zeros();
let co_aligned_bits = data
.addr()
.trailing_zeros()
.min(key.addr().trailing_zeros());
debug_assert!(
co_aligned_bits >= min_alignment_bits,
"first safety precondition: data and key must be aligned for T"
);
let index = 0usize;
unsafe {
asm!(
"2:",
"cmp {index}, {size}",
"jae 3f",
"mov {key_byte}, byte ptr [{key} + {index}]",
"xor byte ptr [{data} + {index}], {key_byte}",
"add {index}, 1",
"jmp 2b",
"3:",
index = inout(reg) index => _,
size = in(reg) size,
data = in(reg) data,
key = in(reg) key,
key_byte = out(reg_byte) _,
options(nostack),
);
}
}
#[cfg(target_arch = "aarch64")]
pub unsafe fn xor_chunks_intrinsic_baseline<T>(data: *mut u8, key: *const u8) {
use std::arch::asm;
let size = std::mem::size_of::<T>();
let min_alignment = std::mem::align_of::<T>();
let min_alignment_bits: u32 = min_alignment.trailing_zeros();
let co_aligned_bits = data
.addr()
.trailing_zeros()
.min(key.addr().trailing_zeros());
debug_assert!(
co_aligned_bits >= min_alignment_bits,
"first safety precondition: data and key must be aligned for T"
);
unsafe {
asm!(
"cbz {size}, 2f",
"1:",
"ldrb {key_byte:w}, [{key}], 1",
"ldrb {tmp:w}, [{data}]",
"eor {tmp}, {tmp}, {key_byte}",
"strb {tmp:w}, [{data}], 1",
"subs {size}, {size}, #1",
"bne 1b",
"2:",
key_byte = out(reg) _,
tmp = out(reg) _,
size = in(reg) size,
data = in(reg) data,
key = in(reg) key,
options(nostack),
);
}
}
#[cfg(all(test, not(miri)))]
mod tests {
use super::*;
#[derive(Default)]
#[repr(C)]
struct Foo {
a: u8,
b: u16,
}
#[expect(dead_code)]
#[derive(Default)]
#[repr(align(16))]
struct Align16 {
a: u64,
b: u8,
c: u64,
}
fn test_xor_chunks_for_type<T: Default>() {
let mut data = T::default();
let mut key = T::default();
let size = std::mem::size_of::<T>();
let data_ptr = (&raw mut data).cast::<u8>();
let key_ptr = (&raw mut key).cast::<u8>();
unsafe {
std::ptr::write_bytes(data_ptr, 0xAA, size);
std::ptr::write_bytes(key_ptr, 0x55, size);
xor_chunks_intrinsic_baseline::<T>(data_ptr, key_ptr);
for i in 0..size {
assert_eq!(data_ptr.add(i).read(), 0xFF);
}
xor_chunks_intrinsic_baseline::<T>(data_ptr, key_ptr);
for i in 0..size {
assert_eq!(data_ptr.add(i).read(), 0xAA);
}
xor_chunks_intrinsic_baseline::<T>(data_ptr, data_ptr);
for i in 0..size {
assert_eq!(data_ptr.add(i).read(), 0);
}
data_ptr.cast::<T>().write(T::default());
key_ptr.cast::<T>().write(T::default());
}
}
#[test]
fn test_bytewise() {
test_xor_chunks_for_type::<()>();
test_xor_chunks_for_type::<u8>();
test_xor_chunks_for_type::<u16>();
test_xor_chunks_for_type::<u32>();
test_xor_chunks_for_type::<u64>();
test_xor_chunks_for_type::<Foo>();
test_xor_chunks_for_type::<Align16>();
test_xor_chunks_for_type::<(u8, u32, (u16, u8, u16, u64))>();
}
#[derive(Clone)]
#[repr(align(8))]
struct PinnedArray([u16; 256]);
#[test]
fn test_offsetted() {
let mut data = PinnedArray(std::array::from_fn(|i| i as u16));
let mut manual_data = data.clone();
let key = PinnedArray([
248, 230, 123, 176, 35, 3, 156, 13, 204, 19, 196, 124, 160, 184, 59, 232, 107, 98, 197,
117, 61, 97, 94, 172, 155, 68, 182, 72, 5, 108, 221, 228, 142, 114, 58, 211, 41, 21,
22, 168, 169, 189, 158, 52, 183, 136, 171, 56, 50, 223, 207, 226, 175, 144, 205, 234,
254, 40, 251, 9, 148, 213, 238, 30, 163, 16, 209, 55, 135, 244, 11, 212, 194, 216, 29,
233, 60, 153, 26, 141, 146, 152, 7, 210, 64, 36, 191, 147, 180, 208, 243, 104, 165, 89,
224, 10, 125, 24, 131, 6, 115, 38, 195, 187, 70, 231, 198, 130, 78, 80, 139, 229, 250,
214, 154, 63, 54, 113, 120, 76, 67, 242, 235, 77, 48, 88, 225, 105, 170, 166, 20, 0,
134, 82, 57, 86, 102, 109, 25, 133, 239, 37, 157, 245, 137, 85, 53, 111, 192, 174, 218,
185, 240, 203, 96, 101, 12, 51, 201, 110, 143, 116, 150, 119, 2, 140, 186, 66, 83, 39,
18, 188, 252, 237, 199, 118, 69, 215, 255, 93, 247, 132, 45, 49, 217, 99, 4, 84, 90,
100, 121, 126, 128, 75, 177, 8, 42, 246, 28, 202, 74, 32, 31, 81, 23, 167, 151, 220,
193, 178, 14, 241, 138, 219, 190, 103, 179, 122, 79, 129, 44, 112, 46, 1, 95, 222, 91,
162, 73, 127, 33, 145, 27, 71, 249, 253, 92, 34, 47, 15, 173, 161, 62, 149, 227, 181,
236, 106, 206, 200, 159, 43, 87, 164, 65, 17_u16,
]);
fn test<S>(
data: &mut [u16; 256],
manual_data: &mut [u16; 256],
key: &[u16; 256],
d: usize,
k: usize,
) {
let s = std::mem::size_of::<S>();
let mult = std::mem::align_of::<u16>();
debug_assert!(d * mult + s <= data.len() * mult);
debug_assert!(k * mult + s <= key.len() * mult);
unsafe {
let data_ptr = data.as_mut_ptr().add(d).cast::<u8>();
let key_ptr = key.as_ptr().add(k).cast::<u8>();
xor_chunks_intrinsic_baseline::<S>(data_ptr, key_ptr);
}
for i in 0..s / mult {
manual_data[d + i] ^= key[k + i];
}
assert_eq!(data, manual_data);
}
test::<[u8; 38]>(&mut data.0, &mut manual_data.0, &key.0, 0, 0);
test::<[u8; 24]>(&mut data.0, &mut manual_data.0, &key.0, 0, 0);
test::<[u8; 24]>(&mut data.0, &mut manual_data.0, &key.0, 0, 16);
test::<[u8; 24]>(&mut data.0, &mut manual_data.0, &key.0, 3, 0);
test::<[u16; 24]>(&mut data.0, &mut manual_data.0, &key.0, 4, 0);
test::<[u16; 24]>(&mut data.0, &mut manual_data.0, &key.0, 4, 40);
test::<[u64; 9]>(&mut data.0, &mut manual_data.0, &key.0, 8, 0);
test::<[u16; 215]>(&mut data.0, &mut manual_data.0, &key.0, 40, 0);
}
#[test]
fn test_structurewise() {
let mut data = [0xAAu8, 0xBB];
let key = [0xFFu8, 0xEE];
unsafe {
xor_chunks_intrinsic_baseline::<[u8; 2]>(data.as_mut_ptr(), key.as_ptr());
}
assert_eq!(data, [0xAA ^ 0xFF, 0xBB ^ 0xEE]);
#[derive(PartialEq, Eq, Debug)]
#[repr(C)]
struct Padded {
a: u8,
b: u32,
}
let mut data = Padded {
a: 0x12,
b: 0x3456789A,
};
let key = vec![0xFF, 0x00, 0x00, 0x00, 0xEE, 0xDD, 0xCC, 0xBB];
unsafe {
xor_chunks_intrinsic_baseline::<Padded>((&raw mut data).cast::<u8>(), key.as_ptr());
}
assert_eq!(data.a, 0x12 ^ 0xFF);
assert_eq!(data.b, 0x3456789A ^ 0xEEDDCCBB_u32.swap_bytes());
unsafe {
xor_chunks_intrinsic_baseline::<[u8; 8]>((&raw mut data).cast::<u8>(), key.as_ptr());
}
assert_eq!(
data,
Padded {
a: 0x12,
b: 0x3456789A
}
);
}
}