1use std::ops::{Add, AddAssign};
8
9use crate::Group;
10use utils::xor_inplace;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct ByteGroup<const LAMBDA: usize>(pub [u8; LAMBDA]);
15
16impl<const LAMBDA: usize> Add for ByteGroup<LAMBDA> {
17 type Output = Self;
18
19 fn add(mut self, rhs: Self) -> Self::Output {
20 xor_inplace(&mut self.0, &[&rhs.0]);
21 self
22 }
23}
24
25impl<const LAMBDA: usize> AddAssign for ByteGroup<LAMBDA> {
26 fn add_assign(&mut self, rhs: Self) {
27 xor_inplace(&mut self.0, &[&rhs.0])
28 }
29}
30
31impl<const LAMBDA: usize> Group<LAMBDA> for ByteGroup<LAMBDA> {
32 fn zero() -> Self {
33 ByteGroup([0; LAMBDA])
34 }
35
36 fn add_inverse(self) -> Self {
37 self
38 }
39}
40
41impl<const LAMBDA: usize> From<[u8; LAMBDA]> for ByteGroup<LAMBDA> {
42 fn from(value: [u8; LAMBDA]) -> Self {
43 Self(value)
44 }
45}
46
47impl<const LAMBDA: usize> From<ByteGroup<LAMBDA>> for [u8; LAMBDA] {
48 fn from(value: ByteGroup<LAMBDA>) -> Self {
49 value.0
50 }
51}
52
53pub mod utils {
54 use std::simd::{u8x16, u8x32, u8x64};
55
56 pub fn xor<const LAMBDA: usize>(xs: &[&[u8; LAMBDA]]) -> [u8; LAMBDA] {
57 let mut res = [0; LAMBDA];
58 xor_inplace(&mut res, xs);
59 res
60 }
61
62 pub fn xor_inplace<const LAMBDA: usize>(lhs: &mut [u8; LAMBDA], rhss: &[&[u8; LAMBDA]]) {
63 rhss.iter().fold(lhs, |lhs, &rhs| {
64 let mut i = 0;
65 while i < LAMBDA {
66 let left = LAMBDA - i;
67 if left >= 64 {
68 let lhs_simd = u8x64::from_slice(&lhs[i..i + 64]);
69 let rhs_simd = u8x64::from_slice(&rhs[i..i + 64]);
70 lhs[i..i + 64].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
71 i += 64;
72 } else if left >= 32 {
73 let lhs_simd = u8x32::from_slice(&lhs[i..i + 32]);
74 let rhs_simd = u8x32::from_slice(&rhs[i..i + 32]);
75 lhs[i..i + 32].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
76 i += 32;
77 } else if left >= 16 {
78 let lhs_simd = u8x16::from_slice(&lhs[i..i + 16]);
79 let rhs_simd = u8x16::from_slice(&rhs[i..i + 16]);
80 lhs[i..i + 16].copy_from_slice((lhs_simd ^ rhs_simd).as_array());
81 i += 16;
82 } else {
83 lhs[i..]
87 .iter_mut()
88 .zip(&rhs[i..])
89 .for_each(|(l, r)| *l ^= r);
90 break;
91 }
92 }
93 lhs
94 });
95 }
96}