use blake3::Hasher;
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum SimdError {
#[error("Invalid input: {0}")]
InvalidInput(String),
}
pub type SimdResult<T> = Result<T, SimdError>;
const MIN_PARALLEL_CHUNK: usize = 16 * 1024;
pub fn xor_buffers(a: &[u8], b: &[u8], output: &mut [u8]) -> SimdResult<()> {
if a.len() != b.len() || a.len() != output.len() {
return Err(SimdError::InvalidInput(
"Buffer lengths must match for XOR operation".to_string(),
));
}
let chunk_size = 32;
let chunks = a.len() / chunk_size;
let remainder = a.len() % chunk_size;
for i in 0..chunks {
let offset = i * chunk_size;
for j in 0..chunk_size {
output[offset + j] = a[offset + j] ^ b[offset + j];
}
}
let offset = chunks * chunk_size;
for i in 0..remainder {
output[offset + i] = a[offset + i] ^ b[offset + i];
}
Ok(())
}
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for i in 0..a.len() {
diff |= a[i] ^ b[i];
}
diff == 0
}
#[allow(dead_code)]
pub fn constant_time_eq_v2(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u32;
for i in 0..a.len() {
let diff = a[i] as u32 ^ b[i] as u32;
result |= diff;
}
let mut z = result;
z |= z >> 16;
z |= z >> 8;
z |= z >> 4;
z |= z >> 2;
z |= z >> 1;
(z & 1) == 0
}
pub fn secure_zero(data: &mut [u8]) {
for byte in data.iter_mut() {
unsafe {
std::ptr::write_volatile(byte, 0);
}
}
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
pub fn parallel_hash(data: &[u8]) -> [u8; 32] {
if data.len() < MIN_PARALLEL_CHUNK {
return blake3::hash(data).into();
}
let mut hasher = Hasher::new();
hasher.update(data);
hasher.finalize().into()
}
pub fn parallel_hash_with_threads(data: &[u8], num_threads: usize) -> [u8; 32] {
let _num_threads = num_threads.clamp(1, 16);
if data.len() < MIN_PARALLEL_CHUNK || num_threads == 1 {
return blake3::hash(data).into();
}
let mut hasher = Hasher::new();
hasher.update(data);
hasher.finalize().into()
}
pub fn xor_keystream(data: &[u8], keystream: &[u8], output: &mut [u8]) -> SimdResult<()> {
if data.len() != output.len() {
return Err(SimdError::InvalidInput(
"Data and output lengths must match".to_string(),
));
}
if keystream.is_empty() {
return Err(SimdError::InvalidInput(
"Keystream cannot be empty".to_string(),
));
}
let chunk_size = 4096; for (chunk_idx, data_chunk) in data.chunks(chunk_size).enumerate() {
let out_offset = chunk_idx * chunk_size;
for (i, &byte) in data_chunk.iter().enumerate() {
let key_idx = (out_offset + i) % keystream.len();
output[out_offset + i] = byte ^ keystream[key_idx];
}
}
Ok(())
}
pub fn batch_constant_time_eq(pairs: &[(&[u8], &[u8])]) -> Vec<bool> {
pairs.iter().map(|(a, b)| constant_time_eq(a, b)).collect()
}
pub fn secure_copy(src: &[u8], dst: &mut [u8]) -> SimdResult<()> {
if src.len() != dst.len() {
return Err(SimdError::InvalidInput(
"Source and destination lengths must match".to_string(),
));
}
dst.copy_from_slice(src);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xor_buffers() {
let a = [0x01, 0x02, 0x03, 0x04];
let b = [0x05, 0x06, 0x07, 0x08];
let mut output = [0u8; 4];
xor_buffers(&a, &b, &mut output).unwrap();
assert_eq!(output, [0x04, 0x04, 0x04, 0x0c]);
}
#[test]
fn test_xor_buffers_large() {
let a = vec![0xAA; 1024];
let b = vec![0x55; 1024];
let mut output = vec![0u8; 1024];
xor_buffers(&a, &b, &mut output).unwrap();
assert!(output.iter().all(|&x| x == 0xFF));
}
#[test]
fn test_xor_buffers_length_mismatch() {
let a = [1, 2, 3];
let b = [4, 5];
let mut output = [0u8; 3];
assert!(xor_buffers(&a, &b, &mut output).is_err());
}
#[test]
fn test_constant_time_eq() {
let a = [1, 2, 3, 4, 5];
let b = [1, 2, 3, 4, 5];
assert!(constant_time_eq(&a, &b));
let c = [1, 2, 3, 4, 6];
assert!(!constant_time_eq(&a, &c));
let d = [1, 2, 3, 4];
assert!(!constant_time_eq(&a, &d));
}
#[test]
fn test_constant_time_eq_v2() {
let a = [1, 2, 3, 4, 5];
let b = [1, 2, 3, 4, 5];
assert!(constant_time_eq_v2(&a, &b));
let c = [1, 2, 3, 4, 6];
assert!(!constant_time_eq_v2(&a, &c));
}
#[test]
fn test_secure_zero() {
let mut data = vec![0xFF; 100];
secure_zero(&mut data);
assert!(data.iter().all(|&x| x == 0));
}
#[test]
fn test_parallel_hash() {
let data = vec![0x42; 1024];
let hash1 = parallel_hash(&data);
let hash2 = blake3::hash(&data);
assert_eq!(hash1, *hash2.as_bytes());
}
#[test]
fn test_parallel_hash_large() {
let data = vec![0x42; 1024 * 1024]; let hash1 = parallel_hash(&data);
let hash2 = blake3::hash(&data);
assert_eq!(hash1, *hash2.as_bytes());
}
#[test]
fn test_parallel_hash_with_threads() {
let data = vec![0x42; 100_000];
for num_threads in 1..=8 {
let hash = parallel_hash_with_threads(&data, num_threads);
assert_eq!(hash.len(), 32);
}
}
#[test]
fn test_xor_keystream() {
let data = [0x01, 0x02, 0x03, 0x04, 0x05];
let keystream = [0xFF, 0xAA];
let mut output = [0u8; 5];
xor_keystream(&data, &keystream, &mut output).unwrap();
assert_eq!(output, [0xFE, 0xA8, 0xFC, 0xAE, 0xFA]);
}
#[test]
fn test_xor_keystream_empty_key() {
let data = [1, 2, 3];
let keystream = [];
let mut output = [0u8; 3];
assert!(xor_keystream(&data, &keystream, &mut output).is_err());
}
#[test]
fn test_batch_constant_time_eq() {
let pairs = [
([1, 2, 3].as_slice(), [1, 2, 3].as_slice()),
([4, 5, 6].as_slice(), [4, 5, 6].as_slice()),
([7, 8, 9].as_slice(), [7, 8, 0].as_slice()),
];
let results = batch_constant_time_eq(&pairs);
assert_eq!(results, vec![true, true, false]);
}
#[test]
fn test_secure_copy() {
let src = [1, 2, 3, 4, 5];
let mut dst = [0u8; 5];
secure_copy(&src, &mut dst).unwrap();
assert_eq!(src, dst);
}
#[test]
fn test_secure_copy_length_mismatch() {
let src = [1, 2, 3];
let mut dst = [0u8; 5];
assert!(secure_copy(&src, &mut dst).is_err());
}
}