fss_rs/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (C) 2023 Yulong Ming (myl7)
3
4#[cfg(not(feature = "stable"))]
5use std::simd::{u8x16, u8x32};
6#[cfg(feature = "stable")]
7use wide::{i8x32, u8x16};
8
9pub fn xor<const BLEN: usize>(xs: &[&[u8; BLEN]]) -> [u8; BLEN] {
10    let mut res = [0; BLEN];
11    xor_inplace(&mut res, xs);
12    res
13}
14
15/// # Safety
16///
17/// Unsafe casting here is safe because:
18///
19/// - Sizes of `u8` and `i8` are the same.
20/// - Rust arrays are compact and contiguous in memory.
21/// - Array lengths match here by slicing.
22pub fn xor_inplace<const BLEN: usize>(lhs: &mut [u8; BLEN], rhss: &[&[u8; BLEN]]) {
23    rhss.iter().fold(lhs, |lhs, &rhs| {
24        assert_eq!(lhs.len(), rhs.len());
25        let mut i = 0;
26        while i < BLEN {
27            let left = BLEN - i;
28            // if left >= 64 {
29            //     let lhs_simd = u8x64::from_slice(&lhs[i..i + 64]);
30            //     let rhs_simd = u8x64::from_slice(&rhs[i..i + 64]);
31            //     lhs[i..i + 64].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
32            //     i += 64;
33            //     continue;
34            // }
35            if left >= 32 {
36                #[cfg(not(feature = "stable"))]
37                {
38                    let lhs_simd = u8x32::from_slice(&lhs[i..i + 32]);
39                    let rhs_simd = u8x32::from_slice(&rhs[i..i + 32]);
40                    lhs[i..i + 32].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
41                }
42                #[cfg(feature = "stable")]
43                {
44                    let lhs_simd =
45                        i8x32::from(unsafe { &*(&lhs[i..i + 32] as *const [u8] as *const [i8]) });
46                    let rhs_simd =
47                        i8x32::from(unsafe { &*(&rhs[i..i + 32] as *const [u8] as *const [i8]) });
48                    let res_simd = lhs_simd ^ rhs_simd;
49                    lhs[i..i + 32].copy_from_slice(unsafe {
50                        &*(res_simd.as_array_ref() as *const [i8] as *const [u8])
51                    });
52                }
53                i += 32;
54                continue;
55            }
56            if left >= 16 {
57                #[cfg(not(feature = "stable"))]
58                {
59                    let lhs_simd = u8x16::from_slice(&lhs[i..i + 16]);
60                    let rhs_simd = u8x16::from_slice(&rhs[i..i + 16]);
61                    lhs[i..i + 16].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
62                }
63                #[cfg(feature = "stable")]
64                {
65                    let lhs_simd = u8x16::from(&lhs[i..i + 16]);
66                    let rhs_simd = u8x16::from(&rhs[i..i + 16]);
67                    let res_simd = lhs_simd ^ rhs_simd;
68                    lhs[i..i + 16].copy_from_slice(res_simd.as_array_ref());
69                }
70                i += 16;
71                continue;
72            }
73            {
74                // Since a AES block is 16 bytes, and we usually use AES to construct the PRG,
75                // no need to specially handle the case where OUT_BLEN % 16 != 0.
76                // So we just xor them one by one in case wired situations make the program enter here.
77                lhs[i..]
78                    .iter_mut()
79                    .zip(&rhs[i..])
80                    .for_each(|(l, r)| *l ^= r);
81                break;
82            }
83        }
84        lhs
85    });
86}