1use std::mem;
3
4#[cfg(target_feature = "sse4.1")]
5use bytemuck::{cast, cast_slice_mut};
6use cryprot_core::block::Block;
7use expander::ExpanderCode;
8use fast_aes_rng::FastAesRng;
9use seq_macro::seq;
10
11use crate::Coeff;
12
13mod expander;
14mod expander_modd;
15mod fast_aes_rng;
16
17#[derive(Debug, Clone, Copy)]
18pub struct ExConvCode {
19 expander: ExpanderCode,
20 conf: ExConvCodeConfig,
21 message_size: usize,
22}
23
24#[derive(Debug, Clone, Copy)]
26pub struct ExConvCodeConfig {
27 pub seed: Block,
28 pub code_size: usize,
29 pub accumulator_size: usize,
30 pub acc_twice: bool,
31 pub regular_expander: bool,
32 pub expander_weight: usize,
33}
34
35impl Default for ExConvCodeConfig {
36 fn default() -> Self {
37 Self {
38 seed: [56756745976768754, 9996754675674599].into(),
39 code_size: 0,
40 accumulator_size: 24,
41 acc_twice: true,
42 regular_expander: true,
43 expander_weight: 7,
44 }
45 }
46}
47
48const CC_BLOCK: Block = Block::new([0xcc; 16]);
49
50impl ExConvCode {
51 pub fn new(message_size: usize) -> Self {
54 Self::new_with_conf(message_size, ExConvCodeConfig::default())
55 }
56
57 pub fn new_with_conf(message_size: usize, mut conf: ExConvCodeConfig) -> Self {
59 if conf.code_size == 0 {
60 conf.code_size = 2 * message_size;
61 }
62 let expander = ExpanderCode::new(
63 conf.code_size - message_size,
64 conf.expander_weight,
65 conf.regular_expander,
66 conf.seed ^ CC_BLOCK,
67 );
68 Self {
69 expander,
70 message_size,
71 conf,
72 }
73 }
74
75 pub fn parity_rows(&self) -> usize {
76 self.conf.code_size - self.message_size
77 }
78
79 pub fn parity_cols(&self) -> usize {
80 self.conf.code_size
81 }
82
83 pub fn generator_rows(&self) -> usize {
84 self.message_size
85 }
86
87 pub fn generator_cols(&self) -> usize {
88 self.conf.code_size
89 }
90
91 pub fn message_size(&self) -> usize {
92 self.message_size
93 }
94
95 pub fn code_size(&self) -> usize {
96 self.conf.code_size
97 }
98
99 pub fn conf(&self) -> &ExConvCodeConfig {
100 &self.conf
101 }
102
103 pub fn dual_encode<T: Coeff>(&self, e: &mut [T]) {
112 assert_eq!(self.conf.code_size, e.len(), "e must have len of code_size");
113 let (prefix, suffix) = e.split_at_mut(self.message_size);
114 self.accumulate(suffix);
115 self.expander.expand(suffix, prefix);
116 }
117
118 fn accumulate<T: Coeff>(&self, x: &mut [T]) {
119 let size = self.conf.code_size - self.message_size;
120 debug_assert_eq!(size, x.len());
121
122 self.accumulate_fixed(x, self.conf.seed);
123 if self.conf.acc_twice {
124 self.accumulate_fixed(x, !self.conf.seed);
125 }
126 }
127
128 fn accumulate_fixed<T: Coeff>(&self, x: &mut [T], seed: Block) {
129 let mut rng = FastAesRng::new(seed);
130 let mut mtx_coeffs = rng.bytes();
131
132 let main = x.len() - 1 - self.conf.accumulator_size;
133 for i in 0..x.len() {
134 if mtx_coeffs.len() < self.conf.accumulator_size.div_ceil(8) {
135 rng.refill();
136 mtx_coeffs = rng.bytes();
137 }
138
139 if i < main {
140 self.acc_one_gen::<false, _>(x, i, mtx_coeffs);
141 } else {
142 self.acc_one_gen::<true, _>(x, i, mtx_coeffs);
143 }
144 mtx_coeffs = &mtx_coeffs[1..];
145 }
146 }
147
148 fn acc_one_gen<const RANGE_CHECK: bool, T: Coeff>(
149 &self,
150 x: &mut [T],
151 i: usize,
152 matrix_coeffs: &[u8],
153 ) {
154 let mut matrix_coeffs = matrix_coeffs.iter().copied();
155 let size = x.len();
156 let xi = x[i];
157 let mut j = i + 1;
158 if RANGE_CHECK && j >= size {
159 j -= size;
160 }
161
162 let mut k = 0;
163 while k + 7 < self.conf.accumulator_size {
164 let b = matrix_coeffs.next().expect("insufficient coeffs");
165 Self::acc_one_8::<RANGE_CHECK, _>(x, xi, j, b);
166
167 j += 8;
168 if RANGE_CHECK && j >= size {
169 j -= size;
170 }
171 k += 8;
172 }
173
174 while k < self.conf.accumulator_size {
175 let mut b = matrix_coeffs.next().expect("insufficient coeffs");
176 let mut p = 0;
177 while p < 8 && k < self.conf.accumulator_size {
178 if b & 1 != 0 {
179 x[j] ^= xi;
180 }
181 p += 1;
182 k += 1;
183 b >>= 1;
184 j += 1;
185 if RANGE_CHECK && j >= size {
186 j -= size;
187 }
188 }
189 k += 1;
190 }
191
192 x[j] ^= xi;
193 }
194
195 #[inline(always)]
196 fn acc_one_8_offsets<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], j: usize) -> [usize; 8] {
197 let size = x.len();
198 let mut js = [j, j + 1, j + 2, j + 3, j + 4, j + 5, j + 6, j + 7];
199 if !RANGE_CHECK {
200 debug_assert!(js[7] < x.len());
201 }
202
203 if RANGE_CHECK {
204 for j in js.iter_mut() {
205 if *j >= size {
206 *j -= size;
207 }
208 }
209 }
210 js
211 }
212
213 fn acc_one_8<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
214 if mem::size_of::<T>() == 16 && mem::align_of::<T>() == 16 {
215 #[cfg(target_feature = "sse4.1")]
216 Self::acc_one_8_sse::<RANGE_CHECK>(cast_slice_mut(x), cast(xi), j, b);
217 #[cfg(not(target_feature = "sse4.1"))]
218 Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
219 } else {
220 Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
221 }
222 }
223
224 fn acc_one_8_scalar<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
225 let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
226
227 let b_bits = [b & 1, b & 2, b & 4, b & 8, b & 16, b & 32, b & 64, b & 128];
228
229 seq!(N in 0..8 {
232 if b_bits[N] != 0 {
233 x[js[N]] ^= xi;
234 }
235 });
236 }
237
238 #[cfg(target_feature = "sse4.1")]
239 #[inline(always)]
240 pub fn acc_one_8_sse<const RANGE_CHECK: bool>(x: &mut [Block], xi: Block, j: usize, b: u8) {
241 #[cfg(target_arch = "x86")]
242 use std::arch::x86::*;
243 #[cfg(target_arch = "x86_64")]
244 use std::arch::x86_64::*;
245
246 let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
247 let rnd: __m128i = Block::splat(b).into();
248 let bb = unsafe {
250 let bshift = [
251 _mm_slli_epi32::<7>(rnd),
252 _mm_slli_epi32::<6>(rnd),
253 _mm_slli_epi32::<5>(rnd),
254 _mm_slli_epi32::<4>(rnd),
255 _mm_slli_epi32::<3>(rnd),
256 _mm_slli_epi32::<2>(rnd),
257 _mm_slli_epi32::<1>(rnd),
258 rnd,
259 ];
260 let xii: __m128 = bytemuck::cast(xi);
261 let zero = _mm_setzero_ps();
262 let mut bb: [__m128; 8] = bytemuck::cast(bshift);
263
264 seq!(N in 0..8 {
265 bb[N] = _mm_blendv_ps(zero, xii, bb[N]);
266 });
267 bb
268 };
269
270 #[cfg(debug_assertions)]
271 for (i, bb) in bb.iter().enumerate() {
272 let exp = if ((b >> i) & 1) != 0 { xi } else { Block::ZERO };
273 debug_assert_eq!(exp, bytemuck::cast(*bb));
274 }
275
276 seq!(N in 0..8 {
277 *unsafe { x.get_unchecked_mut(js[N]) } ^= bytemuck::cast(bb[N]);
280 });
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 #[cfg(feature = "libote-compat")]
287 use bytemuck::cast_slice_mut;
288 use cryprot_core::block::Block;
289 #[cfg(feature = "libote-compat")]
290 use rand::{Rng, SeedableRng, rngs::StdRng};
291
292 use super::*;
293
294 fn create_block(value: u8) -> Block {
296 Block::new([value; 16])
297 }
298
299 #[test]
300 fn test_config_with_explicit_code_size() {
301 let message_size = 100;
302 let code_size = 250;
303 let expander_weight = 5;
304 let accumulator_size = 24;
305 let seed = create_block(0xAA);
306 let code = ExConvCode::new_with_conf(
307 message_size,
308 ExConvCodeConfig {
309 seed,
310 code_size,
311 accumulator_size,
312 expander_weight,
313 ..Default::default()
314 },
315 );
316
317 assert_eq!(code.message_size, message_size);
318 assert_eq!(code.conf.code_size, code_size);
319 assert_eq!(code.conf.accumulator_size, accumulator_size);
320 assert_eq!(code.conf.seed, seed);
321 }
322
323 #[test]
324 fn test_config_with_default_code_size() {
325 let message_size = 100;
326 let code = ExConvCode::new(message_size);
327 assert_eq!(code.conf.code_size, 2 * message_size);
328 }
329
330 #[test]
331 fn test_generator_dimensions() {
332 let message_size = 100;
333 let code = ExConvCode::new(message_size);
334 let code_size = code.conf.code_size;
335
336 assert_eq!(code.generator_rows(), message_size);
337 assert_eq!(code.generator_cols(), code_size);
338 assert_eq!(code.parity_rows(), code_size - message_size);
339 assert_eq!(code.parity_cols(), code_size);
340 }
341
342 #[cfg(all(feature = "libote-compat", target_os = "linux"))]
343 #[test]
344 fn test_compare_to_libote() {
345 let message_size = 200;
346 let exconv = ExConvCode::new(message_size);
347 let code_size = exconv.conf.code_size;
348
349 let mut data = vec![Block::ZERO; code_size];
350 let mut rng = StdRng::seed_from_u64(2423);
351 for _ in 0..100 {
352 rng.fill_bytes(cast_slice_mut(&mut data));
353 let mut data_libote = data.clone();
354 exconv.dual_encode(&mut data);
355
356 let mut libote_exconv = libote_codes::ExConvCode::new(
357 message_size as u64,
358 code_size as u64,
359 exconv.conf.expander_weight as u64,
360 exconv.conf.accumulator_size as u64,
361 );
362 libote_exconv.dual_encode_block(cast_slice_mut(&mut data_libote));
363
364 assert_eq!(data, data_libote);
365 }
366 }
367
368 #[cfg(all(feature = "libote-compat", target_os = "linux"))]
369 #[test]
370 fn test_compare_to_libote_bytes() {
371 let message_size = 200;
372 let exconv = ExConvCode::new(message_size);
373 let code_size = exconv.conf.code_size;
374
375 let mut data = vec![u8::ZERO; code_size];
376 let mut rng = StdRng::seed_from_u64(2423);
377 for _ in 0..100 {
378 rng.fill_bytes(cast_slice_mut(&mut data));
379 let mut data_libote = data.clone();
380 exconv.dual_encode(&mut data);
381
382 let mut libote_exconv = libote_codes::ExConvCode::new(
383 message_size as u64,
384 code_size as u64,
385 exconv.conf.expander_weight as u64,
386 exconv.conf.accumulator_size as u64,
387 );
388 libote_exconv.dual_encode_byte(&mut data_libote);
389
390 assert_eq!(data, data_libote);
391 }
392 }
393}