1use blake3::Hasher;
19use thiserror::Error;
20
21#[derive(Debug, Error, Clone, PartialEq, Eq)]
23pub enum SimdError {
24 #[error("Invalid input: {0}")]
26 InvalidInput(String),
27}
28
29pub type SimdResult<T> = Result<T, SimdError>;
31
32const MIN_PARALLEL_CHUNK: usize = 16 * 1024;
34
35pub fn xor_buffers(a: &[u8], b: &[u8], output: &mut [u8]) -> SimdResult<()> {
52 if a.len() != b.len() || a.len() != output.len() {
53 return Err(SimdError::InvalidInput(
54 "Buffer lengths must match for XOR operation".to_string(),
55 ));
56 }
57
58 let chunk_size = 32;
60 let chunks = a.len() / chunk_size;
61 let remainder = a.len() % chunk_size;
62
63 for i in 0..chunks {
65 let offset = i * chunk_size;
66 for j in 0..chunk_size {
67 output[offset + j] = a[offset + j] ^ b[offset + j];
68 }
69 }
70
71 let offset = chunks * chunk_size;
73 for i in 0..remainder {
74 output[offset + i] = a[offset + i] ^ b[offset + i];
75 }
76
77 Ok(())
78}
79
80pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
94 if a.len() != b.len() {
95 return false;
96 }
97
98 let mut diff = 0u8;
100 for i in 0..a.len() {
101 diff |= a[i] ^ b[i];
102 }
103
104 diff == 0
105}
106
107#[allow(dead_code)]
112pub fn constant_time_eq_v2(a: &[u8], b: &[u8]) -> bool {
113 if a.len() != b.len() {
114 return false;
115 }
116
117 let mut result = 0u32;
118 for i in 0..a.len() {
119 let diff = a[i] as u32 ^ b[i] as u32;
120 result |= diff;
121 }
122
123 let mut z = result;
125 z |= z >> 16;
126 z |= z >> 8;
127 z |= z >> 4;
128 z |= z >> 2;
129 z |= z >> 1;
130
131 (z & 1) == 0
132}
133
134pub fn secure_zero(data: &mut [u8]) {
143 for byte in data.iter_mut() {
145 unsafe {
146 std::ptr::write_volatile(byte, 0);
147 }
148 }
149
150 std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
152}
153
154pub fn parallel_hash(data: &[u8]) -> [u8; 32] {
172 if data.len() < MIN_PARALLEL_CHUNK {
174 return blake3::hash(data).into();
175 }
176
177 let mut hasher = Hasher::new();
180 hasher.update(data);
181 hasher.finalize().into()
182}
183
184pub fn parallel_hash_with_threads(data: &[u8], num_threads: usize) -> [u8; 32] {
199 let _num_threads = num_threads.clamp(1, 16);
200
201 if data.len() < MIN_PARALLEL_CHUNK || num_threads == 1 {
203 return blake3::hash(data).into();
204 }
205
206 let mut hasher = Hasher::new();
210 hasher.update(data);
211 hasher.finalize().into()
212}
213
214pub fn xor_keystream(data: &[u8], keystream: &[u8], output: &mut [u8]) -> SimdResult<()> {
229 if data.len() != output.len() {
230 return Err(SimdError::InvalidInput(
231 "Data and output lengths must match".to_string(),
232 ));
233 }
234
235 if keystream.is_empty() {
236 return Err(SimdError::InvalidInput(
237 "Keystream cannot be empty".to_string(),
238 ));
239 }
240
241 let chunk_size = 4096; for (chunk_idx, data_chunk) in data.chunks(chunk_size).enumerate() {
244 let out_offset = chunk_idx * chunk_size;
245 for (i, &byte) in data_chunk.iter().enumerate() {
246 let key_idx = (out_offset + i) % keystream.len();
247 output[out_offset + i] = byte ^ keystream[key_idx];
248 }
249 }
250
251 Ok(())
252}
253
254pub fn batch_constant_time_eq(pairs: &[(&[u8], &[u8])]) -> Vec<bool> {
267 pairs.iter().map(|(a, b)| constant_time_eq(a, b)).collect()
268}
269
270pub fn secure_copy(src: &[u8], dst: &mut [u8]) -> SimdResult<()> {
284 if src.len() != dst.len() {
285 return Err(SimdError::InvalidInput(
286 "Source and destination lengths must match".to_string(),
287 ));
288 }
289
290 dst.copy_from_slice(src);
291 Ok(())
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_xor_buffers() {
300 let a = [0x01, 0x02, 0x03, 0x04];
301 let b = [0x05, 0x06, 0x07, 0x08];
302 let mut output = [0u8; 4];
303
304 xor_buffers(&a, &b, &mut output).unwrap();
305 assert_eq!(output, [0x04, 0x04, 0x04, 0x0c]);
306 }
307
308 #[test]
309 fn test_xor_buffers_large() {
310 let a = vec![0xAA; 1024];
311 let b = vec![0x55; 1024];
312 let mut output = vec![0u8; 1024];
313
314 xor_buffers(&a, &b, &mut output).unwrap();
315 assert!(output.iter().all(|&x| x == 0xFF));
316 }
317
318 #[test]
319 fn test_xor_buffers_length_mismatch() {
320 let a = [1, 2, 3];
321 let b = [4, 5];
322 let mut output = [0u8; 3];
323
324 assert!(xor_buffers(&a, &b, &mut output).is_err());
325 }
326
327 #[test]
328 fn test_constant_time_eq() {
329 let a = [1, 2, 3, 4, 5];
330 let b = [1, 2, 3, 4, 5];
331 assert!(constant_time_eq(&a, &b));
332
333 let c = [1, 2, 3, 4, 6];
334 assert!(!constant_time_eq(&a, &c));
335
336 let d = [1, 2, 3, 4];
337 assert!(!constant_time_eq(&a, &d));
338 }
339
340 #[test]
341 fn test_constant_time_eq_v2() {
342 let a = [1, 2, 3, 4, 5];
343 let b = [1, 2, 3, 4, 5];
344 assert!(constant_time_eq_v2(&a, &b));
345
346 let c = [1, 2, 3, 4, 6];
347 assert!(!constant_time_eq_v2(&a, &c));
348 }
349
350 #[test]
351 fn test_secure_zero() {
352 let mut data = vec![0xFF; 100];
353 secure_zero(&mut data);
354 assert!(data.iter().all(|&x| x == 0));
355 }
356
357 #[test]
358 fn test_parallel_hash() {
359 let data = vec![0x42; 1024];
360 let hash1 = parallel_hash(&data);
361 let hash2 = blake3::hash(&data);
362
363 assert_eq!(hash1, *hash2.as_bytes());
364 }
365
366 #[test]
367 fn test_parallel_hash_large() {
368 let data = vec![0x42; 1024 * 1024]; let hash1 = parallel_hash(&data);
370 let hash2 = blake3::hash(&data);
371
372 assert_eq!(hash1, *hash2.as_bytes());
373 }
374
375 #[test]
376 fn test_parallel_hash_with_threads() {
377 let data = vec![0x42; 100_000];
378
379 for num_threads in 1..=8 {
380 let hash = parallel_hash_with_threads(&data, num_threads);
381 assert_eq!(hash.len(), 32);
382 }
383 }
384
385 #[test]
386 fn test_xor_keystream() {
387 let data = [0x01, 0x02, 0x03, 0x04, 0x05];
388 let keystream = [0xFF, 0xAA];
389 let mut output = [0u8; 5];
390
391 xor_keystream(&data, &keystream, &mut output).unwrap();
392
393 assert_eq!(output, [0xFE, 0xA8, 0xFC, 0xAE, 0xFA]);
395 }
396
397 #[test]
398 fn test_xor_keystream_empty_key() {
399 let data = [1, 2, 3];
400 let keystream = [];
401 let mut output = [0u8; 3];
402
403 assert!(xor_keystream(&data, &keystream, &mut output).is_err());
404 }
405
406 #[test]
407 fn test_batch_constant_time_eq() {
408 let pairs = [
409 ([1, 2, 3].as_slice(), [1, 2, 3].as_slice()),
410 ([4, 5, 6].as_slice(), [4, 5, 6].as_slice()),
411 ([7, 8, 9].as_slice(), [7, 8, 0].as_slice()),
412 ];
413
414 let results = batch_constant_time_eq(&pairs);
415 assert_eq!(results, vec![true, true, false]);
416 }
417
418 #[test]
419 fn test_secure_copy() {
420 let src = [1, 2, 3, 4, 5];
421 let mut dst = [0u8; 5];
422
423 secure_copy(&src, &mut dst).unwrap();
424 assert_eq!(src, dst);
425 }
426
427 #[test]
428 fn test_secure_copy_length_mismatch() {
429 let src = [1, 2, 3];
430 let mut dst = [0u8; 5];
431
432 assert!(secure_copy(&src, &mut dst).is_err());
433 }
434}