use core::{
cell::UnsafeCell,
marker::PhantomData,
ops::Deref,
sync::atomic::{AtomicU8, Ordering},
};
use crate::{
Algorithm, ByteArray, Encrypted, STATE_DECRYPTED, STATE_DECRYPTING, STATE_UNENCRYPTED,
StringLiteral,
drop_strategy::{DropStrategy, Zeroize},
};
pub struct ReEncrypt<const KEY_LEN: usize>;
impl<const KEY_LEN: usize> DropStrategy for ReEncrypt<KEY_LEN> {
type Extra = [u8; KEY_LEN];
fn drop(data: &mut [u8], key: &[u8; KEY_LEN]) {
let mut s = [0u8; 256];
let mut j: u8 = 0;
let mut i = 0usize;
while i < 256 {
s[i] = i as u8;
i += 1;
}
let mut i = 0usize;
while i < 256 {
j = j.wrapping_add(s[i]).wrapping_add(key[i % KEY_LEN]);
s.swap(i, j as usize);
i += 1;
}
let mut i: u8 = 0;
j = 0;
let mut idx = 0usize;
let n = data.len();
while idx < n {
i = i.wrapping_add(1);
j = j.wrapping_add(s[i as usize]);
s.swap(i as usize, j as usize);
let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
data[idx] ^= k;
idx += 1;
}
}
}
pub struct Rc4<const KEY_LEN: usize, D: DropStrategy = Zeroize>(PhantomData<D>);
impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>> Algorithm for Rc4<KEY_LEN, D> {
type Drop = D;
type Extra = [u8; KEY_LEN];
}
impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>, M, const N: usize>
Encrypted<Rc4<KEY_LEN, D>, M, N>
{
pub const fn new(mut buffer: [u8; N], key: [u8; KEY_LEN]) -> Self {
let mut s = [0u8; 256];
let mut j: u8 = 0;
let mut i = 0usize;
while i < 256 {
s[i] = i as u8;
i += 1;
}
let mut i = 0usize;
while i < 256 {
let key_byte = key[i % KEY_LEN];
j = j.wrapping_add(s[i]).wrapping_add(key_byte);
let temp = s[i];
s[i] = s[j as usize];
s[j as usize] = temp;
i += 1;
}
let mut i: u8 = 0;
j = 0;
let mut idx = 0usize;
while idx < N {
i = i.wrapping_add(1);
j = j.wrapping_add(s[i as usize]);
let temp = s[i as usize];
s[i as usize] = s[j as usize];
s[j as usize] = temp;
let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
buffer[idx] ^= k;
idx += 1;
}
Encrypted {
buffer: UnsafeCell::new(buffer),
decryption_state: AtomicU8::new(STATE_UNENCRYPTED),
extra: key,
_phantom: PhantomData,
}
}
}
impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>, const N: usize> Deref
for Encrypted<Rc4<KEY_LEN, D>, ByteArray, N>
{
type Target = [u8; N];
fn deref(&self) -> &Self::Target {
if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
return unsafe { &*self.buffer.get() };
}
match self.decryption_state.compare_exchange(
STATE_UNENCRYPTED,
STATE_DECRYPTING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let data = unsafe { &mut *self.buffer.get() };
let key = &self.extra;
let mut s = [0u8; 256];
let mut j: u8 = 0;
let mut i = 0usize;
while i < 256 {
s[i] = i as u8;
i += 1;
}
let mut i = 0usize;
while i < 256 {
j = j.wrapping_add(s[i]).wrapping_add(key[i % KEY_LEN]);
s.swap(i, j as usize);
i += 1;
}
let mut i: u8 = 0;
j = 0;
let mut idx = 0usize;
while idx < N {
i = i.wrapping_add(1);
j = j.wrapping_add(s[i as usize]);
s.swap(i as usize, j as usize);
let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
data[idx] ^= k;
idx += 1;
}
self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
}
Err(_) => {
while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
core::hint::spin_loop();
}
}
}
unsafe { &*self.buffer.get() }
}
}
impl<const KEY_LEN: usize, D: DropStrategy<Extra = [u8; KEY_LEN]>, const N: usize> Deref
for Encrypted<Rc4<KEY_LEN, D>, StringLiteral, N>
{
type Target = str;
fn deref(&self) -> &Self::Target {
if self.decryption_state.load(Ordering::Acquire) == STATE_DECRYPTED {
let bytes = unsafe { &*self.buffer.get() };
return unsafe { core::str::from_utf8_unchecked(bytes) };
}
match self.decryption_state.compare_exchange(
STATE_UNENCRYPTED,
STATE_DECRYPTING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let data = unsafe { &mut *self.buffer.get() };
let key = &self.extra;
let mut s = [0u8; 256];
let mut j: u8 = 0;
let mut i = 0usize;
while i < 256 {
s[i] = i as u8;
i += 1;
}
let mut i = 0usize;
while i < 256 {
j = j.wrapping_add(s[i]).wrapping_add(key[i % KEY_LEN]);
s.swap(i, j as usize);
i += 1;
}
let mut i: u8 = 0;
j = 0;
let mut idx = 0usize;
while idx < N {
i = i.wrapping_add(1);
j = j.wrapping_add(s[i as usize]);
s.swap(i as usize, j as usize);
let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
data[idx] ^= k;
idx += 1;
}
self.decryption_state.store(STATE_DECRYPTED, Ordering::Release);
}
Err(_) => {
while self.decryption_state.load(Ordering::Acquire) != STATE_DECRYPTED {
core::hint::spin_loop();
}
}
}
let bytes = unsafe { &*self.buffer.get() };
unsafe { core::str::from_utf8_unchecked(bytes) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
ByteArray, StringLiteral,
drop_strategy::{NoOp, Zeroize},
rc4::Rc4,
};
use alloc::vec;
use alloc::vec::Vec;
use core::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::thread;
const RC4_KEY: [u8; 5] = *b"mykey";
const RC4_KEY2: [u8; 16] = *b"sixteen-byte-key";
const CONST_ENCRYPTED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 5> =
Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 5>::new(*b"hello", RC4_KEY);
const CONST_ENCRYPTED_STR: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5> =
Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5>::new(*b"hello", RC4_KEY);
const CONST_ENCRYPTED_16: Encrypted<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 8> =
Encrypted::<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 8>::new(*b"longdata", RC4_KEY2);
#[test]
fn test_rc4_buffer_is_encrypted_before_deref() {
let encrypted = CONST_ENCRYPTED;
let raw = unsafe { &*encrypted.buffer.get() };
assert_ne!(raw, b"hello", "buffer must NOT be plaintext before deref");
assert_eq!(encrypted.extra, RC4_KEY, "key should be stored in extra");
}
#[test]
fn test_rc4_bytearray_deref_decrypts() {
let encrypted = CONST_ENCRYPTED;
let plain: &[u8; 5] = &*encrypted;
assert_eq!(plain, b"hello");
}
#[test]
fn test_rc4_string_deref_decrypts() {
let encrypted = CONST_ENCRYPTED_STR;
let plain: &str = &*encrypted;
assert_eq!(plain, "hello");
}
#[test]
fn test_rc4_multiple_derefs_are_idempotent() {
let encrypted = CONST_ENCRYPTED;
let first: &[u8; 5] = &*encrypted;
let second: &[u8; 5] = &*encrypted;
assert_eq!(first, b"hello");
assert_eq!(second, b"hello");
}
#[test]
fn test_rc4_different_key_length() {
let encrypted = CONST_ENCRYPTED_16;
let plain: &[u8; 8] = &*encrypted;
assert_eq!(plain, b"longdata");
}
#[test]
fn test_rc4_encrypted_is_sync() {
const fn assert_sync<T: Sync>() {}
const fn check() {
assert_sync::<Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 8>>();
assert_sync::<Encrypted<Rc4<16, Zeroize<[u8; 16]>>, StringLiteral, 10>>();
assert_sync::<Encrypted<Rc4<32, NoOp<[u8; 32]>>, ByteArray, 16>>();
}
check();
}
#[test]
fn test_rc4_concurrent_deref_same_value() {
const SHARED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5> =
Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 5>::new(*b"hello", RC4_KEY);
let shared = Arc::new(SHARED);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..10 {
let shared_clone = Arc::clone(&shared);
let handle = thread::spawn(move || {
let decrypted: &str = &*shared_clone;
assert_eq!(decrypted, "hello");
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_rc4_concurrent_deref_bytearray() {
const SHARED: Encrypted<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 4> =
Encrypted::<Rc4<16, Zeroize<[u8; 16]>>, ByteArray, 4>::new([1, 2, 3, 4], RC4_KEY2);
let shared = Arc::new(SHARED);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..20 {
let shared_clone = Arc::clone(&shared);
let handle = thread::spawn(move || {
let decrypted: &[u8; 4] = &*shared_clone;
assert_eq!(decrypted, &[1, 2, 3, 4]);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_rc4_concurrent_deref_race_condition() {
const SHARED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 8> =
Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, StringLiteral, 8>::new(*b"racetest", RC4_KEY);
let shared = Arc::new(SHARED);
let results = Arc::new(AtomicUsize::new(0));
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..50 {
let shared_clone = Arc::clone(&shared);
let results_clone = Arc::clone(&results);
let handle = thread::spawn(move || {
let decrypted: &str = &*shared_clone;
if decrypted == "racetest" {
results_clone.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let success_count = results.load(core::sync::atomic::Ordering::Relaxed);
assert_eq!(success_count, 50, "all threads should see correct plaintext");
}
#[test]
fn test_rc4_single_byte() {
const ENCRYPTED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 1> =
Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 1>::new([42], RC4_KEY);
let plain: &[u8; 1] = &*ENCRYPTED;
assert_eq!(plain, &[42]);
}
#[test]
fn test_rc4_all_zeros() {
const ENCRYPTED: Encrypted<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 4> =
Encrypted::<Rc4<5, Zeroize<[u8; 5]>>, ByteArray, 4>::new([0, 0, 0, 0], RC4_KEY);
let plain: &[u8; 4] = &*ENCRYPTED;
assert_eq!(plain, &[0, 0, 0, 0]);
}
#[test]
fn test_rc4_reencrypt_drop() {
use crate::rc4::ReEncrypt;
const SHARED: Encrypted<Rc4<5, ReEncrypt<5>>, StringLiteral, 5> =
Encrypted::<Rc4<5, ReEncrypt<5>>, StringLiteral, 5>::new(*b"hello", RC4_KEY);
let shared = Arc::new(SHARED);
let mut handles: Vec<thread::JoinHandle<()>> = vec![];
for _ in 0..10 {
let shared_clone = Arc::clone(&shared);
let handle = thread::spawn(move || {
let decrypted: &str = &*shared_clone;
assert_eq!(decrypted, "hello");
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
}