group_math/
byte.rs

1//! Group of bytes
2//!
3//! - Associative operation: Xor
4//! - Identity element: All zero
5//! - Inverse element: `x` itself
6
7use std::ops::{Add, AddAssign};
8
9use crate::Group;
10use utils::xor_inplace;
11
12/// See [`self`]
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct ByteGroup<const LAMBDA: usize>(pub [u8; LAMBDA]);
15
16impl<const LAMBDA: usize> Add for ByteGroup<LAMBDA> {
17    type Output = Self;
18
19    fn add(mut self, rhs: Self) -> Self::Output {
20        xor_inplace(&mut self.0, &[&rhs.0]);
21        self
22    }
23}
24
25impl<const LAMBDA: usize> AddAssign for ByteGroup<LAMBDA> {
26    fn add_assign(&mut self, rhs: Self) {
27        xor_inplace(&mut self.0, &[&rhs.0])
28    }
29}
30
31impl<const LAMBDA: usize> Group<LAMBDA> for ByteGroup<LAMBDA> {
32    fn zero() -> Self {
33        ByteGroup([0; LAMBDA])
34    }
35
36    fn add_inverse(self) -> Self {
37        self
38    }
39}
40
41impl<const LAMBDA: usize> From<[u8; LAMBDA]> for ByteGroup<LAMBDA> {
42    fn from(value: [u8; LAMBDA]) -> Self {
43        Self(value)
44    }
45}
46
47impl<const LAMBDA: usize> From<ByteGroup<LAMBDA>> for [u8; LAMBDA] {
48    fn from(value: ByteGroup<LAMBDA>) -> Self {
49        value.0
50    }
51}
52
53pub mod utils {
54    use std::simd::{u8x16, u8x32, u8x64};
55
56    pub fn xor<const LAMBDA: usize>(xs: &[&[u8; LAMBDA]]) -> [u8; LAMBDA] {
57        let mut res = [0; LAMBDA];
58        xor_inplace(&mut res, xs);
59        res
60    }
61
62    pub fn xor_inplace<const LAMBDA: usize>(lhs: &mut [u8; LAMBDA], rhss: &[&[u8; LAMBDA]]) {
63        rhss.iter().fold(lhs, |lhs, &rhs| {
64            let mut i = 0;
65            while i < LAMBDA {
66                let left = LAMBDA - i;
67                if left >= 64 {
68                    let lhs_simd = u8x64::from_slice(&lhs[i..i + 64]);
69                    let rhs_simd = u8x64::from_slice(&rhs[i..i + 64]);
70                    lhs[i..i + 64].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
71                    i += 64;
72                } else if left >= 32 {
73                    let lhs_simd = u8x32::from_slice(&lhs[i..i + 32]);
74                    let rhs_simd = u8x32::from_slice(&rhs[i..i + 32]);
75                    lhs[i..i + 32].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
76                    i += 32;
77                } else if left >= 16 {
78                    let lhs_simd = u8x16::from_slice(&lhs[i..i + 16]);
79                    let rhs_simd = u8x16::from_slice(&rhs[i..i + 16]);
80                    lhs[i..i + 16].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
81                    i += 16;
82                } else {
83                    // Since a AES block is 16 bytes, and we usually use AES to construct the PRG,
84                    // no need to specially handle the case where LAMBDA % 16 != 0.
85                    // So we just xor them one by one in case wired situations make the program enter here.
86                    lhs[i..]
87                        .iter_mut()
88                        .zip(&rhs[i..])
89                        .for_each(|(l, r)| *l ^= r);
90                    break;
91                }
92            }
93            lhs
94        });
95    }
96}