fss_rs/utils.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright (C) 2023 Yulong Ming (myl7)
3
4#[cfg(not(feature = "stable"))]
5use std::simd::{u8x16, u8x32};
6#[cfg(feature = "stable")]
7use wide::{i8x32, u8x16};
8
9pub fn xor<const BLEN: usize>(xs: &[&[u8; BLEN]]) -> [u8; BLEN] {
10 let mut res = [0; BLEN];
11 xor_inplace(&mut res, xs);
12 res
13}
14
15/// # Safety
16///
17/// Unsafe casting here is safe because:
18///
19/// - Sizes of `u8` and `i8` are the same.
20/// - Rust arrays are compact and contiguous in memory.
21/// - Array lengths match here by slicing.
22pub fn xor_inplace<const BLEN: usize>(lhs: &mut [u8; BLEN], rhss: &[&[u8; BLEN]]) {
23 rhss.iter().fold(lhs, |lhs, &rhs| {
24 assert_eq!(lhs.len(), rhs.len());
25 let mut i = 0;
26 while i < BLEN {
27 let left = BLEN - i;
28 // if left >= 64 {
29 // let lhs_simd = u8x64::from_slice(&lhs[i..i + 64]);
30 // let rhs_simd = u8x64::from_slice(&rhs[i..i + 64]);
31 // lhs[i..i + 64].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
32 // i += 64;
33 // continue;
34 // }
35 if left >= 32 {
36 #[cfg(not(feature = "stable"))]
37 {
38 let lhs_simd = u8x32::from_slice(&lhs[i..i + 32]);
39 let rhs_simd = u8x32::from_slice(&rhs[i..i + 32]);
40 lhs[i..i + 32].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
41 }
42 #[cfg(feature = "stable")]
43 {
44 let lhs_simd =
45 i8x32::from(unsafe { &*(&lhs[i..i + 32] as *const [u8] as *const [i8]) });
46 let rhs_simd =
47 i8x32::from(unsafe { &*(&rhs[i..i + 32] as *const [u8] as *const [i8]) });
48 let res_simd = lhs_simd ^ rhs_simd;
49 lhs[i..i + 32].copy_from_slice(unsafe {
50 &*(res_simd.as_array_ref() as *const [i8] as *const [u8])
51 });
52 }
53 i += 32;
54 continue;
55 }
56 if left >= 16 {
57 #[cfg(not(feature = "stable"))]
58 {
59 let lhs_simd = u8x16::from_slice(&lhs[i..i + 16]);
60 let rhs_simd = u8x16::from_slice(&rhs[i..i + 16]);
61 lhs[i..i + 16].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
62 }
63 #[cfg(feature = "stable")]
64 {
65 let lhs_simd = u8x16::from(&lhs[i..i + 16]);
66 let rhs_simd = u8x16::from(&rhs[i..i + 16]);
67 let res_simd = lhs_simd ^ rhs_simd;
68 lhs[i..i + 16].copy_from_slice(res_simd.as_array_ref());
69 }
70 i += 16;
71 continue;
72 }
73 {
74 // Since a AES block is 16 bytes, and we usually use AES to construct the PRG,
75 // no need to specially handle the case where OUT_BLEN % 16 != 0.
76 // So we just xor them one by one in case wired situations make the program enter here.
77 lhs[i..]
78 .iter_mut()
79 .zip(&rhs[i..])
80 .for_each(|(l, r)| *l ^= r);
81 break;
82 }
83 }
84 lhs
85 });
86}