use std::sync::atomic::{fence, Ordering};
use std::mem::{MaybeUninit, size_of};
use std::marker::PhantomData;
use std::ptr::NonNull;
use bytemuck::NoUninit;
unsafe fn xor_chunks<T>(data: *mut u8, key: *const u8) {
for i in 0..size_of::<T>() {
let data_byte = unsafe { *data.wrapping_add(i) };
let key_byte = unsafe { *key.wrapping_add(i) };
unsafe {
data.wrapping_add(i).write_volatile(data_byte ^ key_byte);
}
}
fence(Ordering::SeqCst);
}
pub struct MangledBox<T: NoUninit> {
data: Box<MaybeUninit<T>>,
key: MaybeUninit<T>,
}
impl<T: NoUninit> MangledBox<T> {
pub fn new() -> Self {
let data = Box::new_zeroed();
let mut key = MaybeUninit::uninit();
getrandom::fill_uninit(key.as_bytes_mut()).expect("no keygen");
Self { data, key }
}
pub fn rekey(&mut self) {
let mut diff_key = MaybeUninit::<T>::uninit();
getrandom::fill_uninit(diff_key.as_bytes_mut()).expect("no keygen");
unsafe {
xor_chunks::<T>(
Box::as_mut_ptr(&mut self.data).cast::<u8>(),
diff_key.as_ptr().cast::<u8>(),
);
xor_chunks::<T>(
self.key.as_mut_ptr().cast::<u8>(),
diff_key.as_ptr().cast::<u8>(),
);
}
}
pub fn with_unmangled<F, R>(&mut self, f: F) -> R
where
F: FnOnce(NonNull<T>) -> R,
{
let data_ptr = Box::as_mut_ptr(&mut self.data).cast::<u8>();
let key_ptr = self.key.as_ptr().cast::<u8>();
let data_nn: NonNull<u8> = NonNull::new(data_ptr).unwrap();
unsafe {
xor_chunks::<T>(data_ptr, key_ptr);
}
struct RemangleGuard<T> {
data: *mut u8,
key: *const u8,
token: PhantomData<T>,
}
impl<T> Drop for RemangleGuard<T> {
fn drop(&mut self) {
unsafe { xor_chunks::<T>(self.data, self.key) }
}
}
let _guard = RemangleGuard::<T> {
data: data_ptr,
key: key_ptr,
token: PhantomData,
};
f(data_nn.cast())
}
}
impl<T: NoUninit> Default for MangledBox<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: NoUninit> Drop for MangledBox<T> {
fn drop(&mut self) {
let data_ptr = Box::as_mut_ptr(&mut self.data).cast::<u8>();
let key_ptr = self.key.as_mut_ptr().cast::<u8>();
unsafe {
xor_chunks::<T>(data_ptr, data_ptr);
xor_chunks::<T>(key_ptr, key_ptr);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ensure_send<T: Send>(_v: &T) {}
fn ensure_sync<T: Sync>(_v: &T) {}
#[test]
fn zst() {
let mut empty_box = MangledBox::<()>::new();
ensure_send(&empty_box);
ensure_sync(&empty_box);
empty_box.with_unmangled(|_| {});
}
#[derive(bytemuck::NoUninit, Clone, Copy)]
#[repr(C, align(64))]
struct Align64;
#[test]
fn overaligned_zst() {
let mut align64_box = MangledBox::<Align64>::new();
ensure_send(&align64_box);
ensure_sync(&align64_box);
align64_box.with_unmangled(|p| {
assert_eq!(
p.as_ptr().align_offset(64),
0,
"alignment not preserved on overaligned ZST type"
);
});
}
#[test]
fn data_u8_preserved() {
let mut box_ = MangledBox::<u8>::new();
box_.with_unmangled(|p| unsafe { p.write(42) });
box_.with_unmangled(|p| {
assert_eq!(unsafe { p.read() }, 42);
});
box_.rekey();
box_.with_unmangled(|p| {
assert_eq!(unsafe { p.read() }, 42);
});
box_.with_unmangled(|p| {
assert_eq!(unsafe { p.read() }, 42);
});
}
#[test]
fn data_u64_preserved() {
let mut box_ = MangledBox::<u64>::new();
let pattern: u64 = 0x123456789abcdef;
box_.with_unmangled(|p| unsafe { p.write(pattern) });
box_.with_unmangled(|p| {
assert_eq!(unsafe { p.read() }, pattern);
});
}
}