use core::{
cell::UnsafeCell,
marker::PhantomData,
ops::Deref,
sync::atomic::{AtomicU8, Ordering},
};
use crate::{
Algorithm, ByteArray, Encrypted, STATE_DECRYPTED, STATE_DECRYPTING, STATE_UNENCRYPTED,
StringLiteral,
drop_strategy::{DropStrategy, Zeroize},
};
pub struct ReEncrypt<const KEY: u8>;
impl<const KEY: u8> DropStrategy for ReEncrypt<KEY> {
type Extra = ();
fn drop(data: &mut [u8], _extra: &()) {
for byte in data {
*byte ^= KEY;
}
}
}
pub struct Xor<const KEY: u8, D: DropStrategy = Zeroize>(PhantomData<D>);
impl<const KEY: u8, D: DropStrategy<Extra = ()>> Algorithm for Xor<KEY, D> {
type Drop = D;
type Extra = ();
}
impl<const KEY: u8, D: DropStrategy<Extra = ()>, M, const N: usize> Encrypted<Xor<KEY, D>, M, N> {
pub const fn new(mut buffer: [u8; N]) -> Self {
let mut i = 0;
while i < N {
buffer[i] ^= KEY;
i += 1;
}
Encrypted {
buffer: UnsafeCell::new(buffer),
decryption_state: AtomicU8::new(STATE_UNENCRYPTED),
extra: (),
_phantom: PhantomData,
}
}
}
impl<const KEY: u8, D: DropStrategy<Extra = ()>, const N: usize> Deref
for Encrypted<Xor<KEY, D>, ByteArray, N>
{
type Target = [u8; N];
fn deref(&self) -> &Self::Target {
if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
return unsafe { &*self.buffer.get() };
}
match self.decryption_state.compare_exchange(
STATE_UNENCRYPTED,
STATE_DECRYPTING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let data = unsafe { &mut *self.buffer.get() };
for byte in data.iter_mut() {
*byte ^= KEY;
}
self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
}
Err(_) => {
while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
core::hint::spin_loop();
}
}
}
unsafe { &*self.buffer.get() }
}
}
impl<const KEY: u8, D: DropStrategy<Extra = ()>, const N: usize> Deref
for Encrypted<Xor<KEY, D>, StringLiteral, N>
{
type Target = str;
fn deref(&self) -> &Self::Target {
if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
let bytes = unsafe { &*self.buffer.get() };
return unsafe { core::str::from_utf8_unchecked(bytes) };
}
match self.decryption_state.compare_exchange(
STATE_UNENCRYPTED,
STATE_DECRYPTING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let data = unsafe { &mut *self.buffer.get() };
for byte in data.iter_mut() {
*byte ^= KEY;
}
self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
}
Err(_) => {
while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
core::hint::spin_loop();
}
}
}
let bytes = unsafe { &*self.buffer.get() };
unsafe { core::str::from_utf8_unchecked(bytes) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
ByteArray, StringLiteral,
align::{Aligned8, Aligned16},
drop_strategy::{NoOp, Zeroize},
xor::Xor,
};
use alloc::vec;
use alloc::vec::Vec;
use core::{mem::size_of, sync::atomic::AtomicUsize};
use std::sync::Arc;
use std::thread;
#[test]
fn test_size() {
assert_eq!(17, size_of::<Encrypted<Xor<0xAA, Zeroize>, ByteArray, 16>>());
assert_eq!(17, size_of::<Encrypted<Xor<0xAA, NoOp>, ByteArray, 16>>());
assert_eq!(17, size_of::<Encrypted<Xor<0xAA, ReEncrypt<0xAA>>, ByteArray, 16>>());
assert_eq!(24, size_of::<Aligned8<Encrypted<Xor<0xAA, ReEncrypt<0xAA>>, ByteArray, 16>>>());
assert_eq!(
32,
size_of::<Aligned16<Encrypted<Xor<0xAA, ReEncrypt<0xAA>>, ByteArray, 16>>>()
);
}
const CONST_ENCRYPTED: Encrypted<Xor<0xAA, Zeroize>, ByteArray, 5> =
Encrypted::<Xor<0xAA, Zeroize>, ByteArray, 5>::new(*b"hello");
const CONST_ENCRYPTED_STR: Encrypted<Xor<0xFF, Zeroize>, StringLiteral, 3> =
Encrypted::<Xor<0xFF, Zeroize>, StringLiteral, 3>::new(*b"abc");
const CONST_ENCRYPTED_SINGLE: Encrypted<Xor<0xFF, Zeroize>, ByteArray, 1> =
Encrypted::<Xor<0xFF, Zeroize>, ByteArray, 1>::new([42]);
const CONST_ENCRYPTED_ZEROS: Encrypted<Xor<0xAA, Zeroize>, ByteArray, 4> =
Encrypted::<Xor<0xAA, Zeroize>, ByteArray, 4>::new([0, 0, 0, 0]);
const CONST_ENCRYPTED_NOOP_KEY: Encrypted<Xor<0x00, Zeroize>, ByteArray, 3> =
Encrypted::<Xor<0x00, Zeroize>, ByteArray, 3>::new(*b"abc");
#[test]
fn test_new_in_const_context() {
let plain: &[u8; 5] = &*CONST_ENCRYPTED;
assert_eq!(plain, b"hello");
}
#[test]
fn test_buffer_is_encrypted_before_deref() {
let encrypted = CONST_ENCRYPTED;
let raw = unsafe { &*encrypted.buffer.get() };
let expected = [b'h' ^ 0xAA, b'e' ^ 0xAA, b'l' ^ 0xAA, b'l' ^ 0xAA, b'o' ^ 0xAA];
assert_eq!(raw, &expected, "buffer should be XOR-encrypted before deref");
assert_ne!(raw, b"hello", "buffer must NOT be plaintext before deref");
}
#[test]
fn test_string_buffer_is_encrypted_before_deref() {
let encrypted = CONST_ENCRYPTED_STR;
let raw = unsafe { &*encrypted.buffer.get() };
let expected = [b'a' ^ 0xFF, b'b' ^ 0xFF, b'c' ^ 0xFF];
assert_eq!(raw, &expected, "string buffer should be XOR-encrypted before deref");
assert_ne!(raw, b"abc");
}
#[test]
fn test_bytearray_deref_decrypts() {
let encrypted = CONST_ENCRYPTED;
let plain: &[u8; 5] = &*encrypted;
assert_eq!(plain, b"hello");
}
#[test]
fn test_bytearray_deref_single_byte() {
let pre_deref = CONST_ENCRYPTED_SINGLE;
let raw = unsafe { &*pre_deref.buffer.get() };
assert_eq!(raw, &[42 ^ 0xFF]);
let encrypted = CONST_ENCRYPTED_SINGLE;
let plain: &[u8; 1] = &*encrypted;
assert_eq!(plain, &[42]);
}
#[test]
fn test_bytearray_deref_all_zeros() {
let pre_deref = CONST_ENCRYPTED_ZEROS;
let raw = unsafe { &*pre_deref.buffer.get() };
assert_eq!(raw, &[0xAA, 0xAA, 0xAA, 0xAA]);
let encrypted = CONST_ENCRYPTED_ZEROS;
let plain: &[u8; 4] = &*encrypted;
assert_eq!(plain, &[0, 0, 0, 0]);
}
#[test]
fn test_bytearray_deref_key_zero_is_identity() {
let pre_deref = CONST_ENCRYPTED_NOOP_KEY;
let raw = unsafe { &*pre_deref.buffer.get() };
assert_eq!(raw, b"abc", "key 0x00 should leave buffer unchanged");
let encrypted = CONST_ENCRYPTED_NOOP_KEY;
let plain: &[u8; 3] = &*encrypted;
assert_eq!(plain, b"abc");
}
#[test]
fn test_bytearray_multiple_derefs_are_idempotent() {
let encrypted = CONST_ENCRYPTED;
let first: &[u8; 5] = &*encrypted;
let second: &[u8; 5] = &*encrypted;
assert_eq!(first, b"hello");
assert_eq!(second, b"hello");
}
#[test]
fn test_encrypted_is_sync() {
const fn assert_sync<T: Sync>() {}
const fn check() {
assert_sync::<Encrypted<Xor<0xAA, Zeroize>, ByteArray, 5>>();
assert_sync::<Encrypted<Xor<0xBB, ReEncrypt<0xBB>>, StringLiteral, 5>>();
assert_sync::<Encrypted<Xor<0xCC, NoOp>, ByteArray, 8>>();
}
check();
}
#[test]
fn test_concurrent_deref_same_value() {
const SHARED: Encrypted<Xor<0xAA, Zeroize>, StringLiteral, 5> =
Encrypted::<Xor<0xAA, Zeroize>, StringLiteral, 5>::new(*b"hello");
let shared = Arc::new(SHARED);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..10 {
let shared_clone = Arc::clone(&shared);
let handle = thread::spawn(move || {
let decrypted: &str = &*shared_clone;
assert_eq!(decrypted, "hello");
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_deref_bytearray() {
const SHARED: Encrypted<Xor<0xFF, Zeroize>, ByteArray, 4> =
Encrypted::<Xor<0xFF, Zeroize>, ByteArray, 4>::new([1, 2, 3, 4]);
let shared = Arc::new(SHARED);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..20 {
let shared_clone = Arc::clone(&shared);
let handle = thread::spawn(move || {
let decrypted: &[u8; 4] = &*shared_clone;
assert_eq!(decrypted, &[1, 2, 3, 4]);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_deref_reencrypt() {
const SHARED: Encrypted<Xor<0xBB, ReEncrypt<0xBB>>, StringLiteral, 6> =
Encrypted::<Xor<0xBB, ReEncrypt<0xBB>>, StringLiteral, 6>::new(*b"secret");
let shared = Arc::new(SHARED);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..15 {
let shared_clone = Arc::clone(&shared);
let handle = thread::spawn(move || {
let decrypted: &str = &*shared_clone;
assert_eq!(decrypted, "secret");
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_deref_race_condition() {
const SHARED: Encrypted<Xor<0x42, Zeroize>, StringLiteral, 8> =
Encrypted::<Xor<0x42, Zeroize>, StringLiteral, 8>::new(*b"racetest");
let shared = Arc::new(SHARED);
let results = Arc::new(AtomicUsize::new(0));
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..50 {
let shared_clone = Arc::clone(&shared);
let results_clone = Arc::clone(&results);
let handle = thread::spawn(move || {
let decrypted: &str = &*shared_clone;
if decrypted == "racetest" {
results_clone.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let success_count = results.load(core::sync::atomic::Ordering::Relaxed);
assert_eq!(success_count, 50, "all threads should see correct plaintext");
}
#[test]
fn test_concurrent_multiple_values() {
const SECRET1: Encrypted<Xor<0xAA, Zeroize>, StringLiteral, 5> =
Encrypted::<Xor<0xAA, Zeroize>, StringLiteral, 5>::new(*b"hello");
const SECRET2: Encrypted<Xor<0xFF, Zeroize>, ByteArray, 4> =
Encrypted::<Xor<0xFF, Zeroize>, ByteArray, 4>::new([1, 2, 3, 4]);
let secret1 = Arc::new(SECRET1);
let secret2 = Arc::new(SECRET2);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for i in 0..20 {
if i % 2 == 0 {
let secret_clone = Arc::clone(&secret1);
let handle = thread::spawn(move || {
let decrypted: &str = &*secret_clone;
assert_eq!(decrypted, "hello");
});
handles.push(handle);
} else {
let secret_clone = Arc::clone(&secret2);
let handle = thread::spawn(move || {
let decrypted: &[u8; 4] = &*secret_clone;
assert_eq!(decrypted, &[1, 2, 3, 4]);
});
handles.push(handle);
}
}
for handle in handles {
handle.join().unwrap();
}
}
}