1use std::mem;
3
4use bytemuck::{cast, cast_slice_mut};
5use cryprot_core::block::Block;
6use expander::ExpanderCode;
7use fast_aes_rng::FastAesRng;
8use seq_macro::seq;
9
10use crate::Coeff;
11
12mod expander;
13mod expander_modd;
14mod fast_aes_rng;
15
16#[derive(Debug, Clone, Copy)]
17pub struct ExConvCode {
18 expander: ExpanderCode,
19 conf: ExConvCodeConfig,
20 message_size: usize,
21}
22
23#[derive(Debug, Clone, Copy)]
25pub struct ExConvCodeConfig {
26 pub seed: Block,
27 pub code_size: usize,
28 pub accumulator_size: usize,
29 pub acc_twice: bool,
30 pub regular_expander: bool,
31 pub expander_weight: usize,
32}
33
34impl Default for ExConvCodeConfig {
35 fn default() -> Self {
36 Self {
37 seed: [56756745976768754, 9996754675674599].into(),
38 code_size: 0,
39 accumulator_size: 24,
40 acc_twice: true,
41 regular_expander: true,
42 expander_weight: 7,
43 }
44 }
45}
46
47const CC_BLOCK: Block = Block::new([0xcc; 16]);
48
49impl ExConvCode {
50 pub fn new(message_size: usize) -> Self {
53 Self::new_with_conf(message_size, ExConvCodeConfig::default())
54 }
55
56 pub fn new_with_conf(message_size: usize, mut conf: ExConvCodeConfig) -> Self {
58 if conf.code_size == 0 {
59 conf.code_size = 2 * message_size;
60 }
61 let expander = ExpanderCode::new(
62 conf.code_size - message_size,
63 conf.expander_weight,
64 conf.regular_expander,
65 conf.seed ^ CC_BLOCK,
66 );
67 Self {
68 expander,
69 message_size,
70 conf,
71 }
72 }
73
74 pub fn parity_rows(&self) -> usize {
75 self.conf.code_size - self.message_size
76 }
77
78 pub fn parity_cols(&self) -> usize {
79 self.conf.code_size
80 }
81
82 pub fn generator_rows(&self) -> usize {
83 self.message_size
84 }
85
86 pub fn generator_cols(&self) -> usize {
87 self.conf.code_size
88 }
89
90 pub fn message_size(&self) -> usize {
91 self.message_size
92 }
93
94 pub fn code_size(&self) -> usize {
95 self.conf.code_size
96 }
97
98 pub fn conf(&self) -> &ExConvCodeConfig {
99 &self.conf
100 }
101
102 pub fn dual_encode<T: Coeff>(&self, e: &mut [T]) {
111 assert_eq!(self.conf.code_size, e.len(), "e must have len of code_size");
112 let (prefix, suffix) = e.split_at_mut(self.message_size);
113 self.accumulate(suffix);
114 self.expander.expand(suffix, prefix);
115 }
116
117 fn accumulate<T: Coeff>(&self, x: &mut [T]) {
118 let size = self.conf.code_size - self.message_size;
119 debug_assert_eq!(size, x.len());
120
121 self.accumulate_fixed(x, self.conf.seed);
122 if self.conf.acc_twice {
123 self.accumulate_fixed(x, !self.conf.seed);
124 }
125 }
126
127 fn accumulate_fixed<T: Coeff>(&self, x: &mut [T], seed: Block) {
128 let mut rng = FastAesRng::new(seed);
129 let mut mtx_coeffs = rng.bytes();
130
131 let main = x.len() - 1 - self.conf.accumulator_size;
132 for i in 0..x.len() {
133 if mtx_coeffs.len() < self.conf.accumulator_size.div_ceil(8) {
134 rng.refill();
135 mtx_coeffs = rng.bytes();
136 }
137
138 if i < main {
139 self.acc_one_gen::<false, _>(x, i, mtx_coeffs);
140 } else {
141 self.acc_one_gen::<true, _>(x, i, mtx_coeffs);
142 }
143 mtx_coeffs = &mtx_coeffs[1..];
144 }
145 }
146
147 fn acc_one_gen<const RANGE_CHECK: bool, T: Coeff>(
148 &self,
149 x: &mut [T],
150 i: usize,
151 matrix_coeffs: &[u8],
152 ) {
153 let mut matrix_coeffs = matrix_coeffs.iter().copied();
154 let size = x.len();
155 let xi = x[i];
156 let mut j = i + 1;
157 if RANGE_CHECK && j >= size {
158 j -= size;
159 }
160
161 let mut k = 0;
162 while k + 7 < self.conf.accumulator_size {
163 let b = matrix_coeffs.next().expect("insufficient coeffs");
164 Self::acc_one_8::<RANGE_CHECK, _>(x, xi, j, b);
165
166 j += 8;
167 if RANGE_CHECK && j >= size {
168 j -= size;
169 }
170 k += 8;
171 }
172
173 while k < self.conf.accumulator_size {
174 let mut b = matrix_coeffs.next().expect("insufficient coeffs");
175 let mut p = 0;
176 while p < 8 && k < self.conf.accumulator_size {
177 if b & 1 != 0 {
178 x[j] ^= xi;
179 }
180 p += 1;
181 k += 1;
182 b >>= 1;
183 j += 1;
184 if RANGE_CHECK && j >= size {
185 j -= size;
186 }
187 }
188 k += 1;
189 }
190
191 x[j] ^= xi;
192 }
193
194 #[inline(always)]
195 fn acc_one_8_offsets<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], j: usize) -> [usize; 8] {
196 let size = x.len();
197 let mut js = [j, j + 1, j + 2, j + 3, j + 4, j + 5, j + 6, j + 7];
198 if !RANGE_CHECK {
199 debug_assert!(js[7] < x.len());
200 }
201
202 if RANGE_CHECK {
203 for j in js.iter_mut() {
204 if *j >= size {
205 *j -= size;
206 }
207 }
208 }
209 js
210 }
211
212 fn acc_one_8<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
213 if mem::size_of::<T>() == 16 && mem::align_of::<T>() == 16 {
214 #[cfg(target_feature = "sse4.1")]
215 Self::acc_one_8_sse::<RANGE_CHECK>(cast_slice_mut(x), cast(xi), j, b);
216 #[cfg(not(target_feature = "sse4.1"))]
217 Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
218 } else {
219 Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
220 }
221 }
222
223 fn acc_one_8_scalar<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
224 let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
225
226 let b_bits = [b & 1, b & 2, b & 4, b & 8, b & 16, b & 32, b & 64, b & 128];
227
228 seq!(N in 0..8 {
231 if b_bits[N] != 0 {
232 x[js[N]] ^= xi;
233 }
234 });
235 }
236
237 #[cfg(target_feature = "sse4.1")]
238 #[inline(always)]
239 pub fn acc_one_8_sse<const RANGE_CHECK: bool>(x: &mut [Block], xi: Block, j: usize, b: u8) {
240 #[cfg(target_arch = "x86")]
241 use std::arch::x86::*;
242 #[cfg(target_arch = "x86_64")]
243 use std::arch::x86_64::*;
244
245 let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
246 let rnd: __m128i = Block::splat(b).into();
247 let bb = unsafe {
249 let bshift = [
250 _mm_slli_epi32::<7>(rnd),
251 _mm_slli_epi32::<6>(rnd),
252 _mm_slli_epi32::<5>(rnd),
253 _mm_slli_epi32::<4>(rnd),
254 _mm_slli_epi32::<3>(rnd),
255 _mm_slli_epi32::<2>(rnd),
256 _mm_slli_epi32::<1>(rnd),
257 rnd,
258 ];
259 let xii: __m128 = bytemuck::cast(xi);
260 let zero = _mm_setzero_ps();
261 let mut bb: [__m128; 8] = bytemuck::cast(bshift);
262
263 seq!(N in 0..8 {
264 bb[N] = _mm_blendv_ps(zero, xii, bb[N]);
265 });
266 bb
267 };
268
269 #[cfg(debug_assertions)]
270 for (i, bb) in bb.iter().enumerate() {
271 let exp = if ((b >> i) & 1) != 0 { xi } else { Block::ZERO };
272 debug_assert_eq!(exp, bytemuck::cast(*bb));
273 }
274
275 seq!(N in 0..8 {
276 *unsafe { x.get_unchecked_mut(js[N]) } ^= bytemuck::cast(bb[N]);
279 });
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 #[cfg(feature = "libote-compat")]
286 use bytemuck::cast_slice_mut;
287 use cryprot_core::block::Block;
288 #[cfg(feature = "libote-compat")]
289 use rand::{RngCore, SeedableRng, rngs::StdRng};
290
291 use super::*;
292
293 fn create_block(value: u8) -> Block {
295 Block::new([value; 16])
296 }
297
298 #[test]
299 fn test_config_with_explicit_code_size() {
300 let message_size = 100;
301 let code_size = 250;
302 let expander_weight = 5;
303 let accumulator_size = 24;
304 let seed = create_block(0xAA);
305 let code = ExConvCode::new_with_conf(
306 message_size,
307 ExConvCodeConfig {
308 seed,
309 code_size,
310 accumulator_size,
311 expander_weight,
312 ..Default::default()
313 },
314 );
315
316 assert_eq!(code.message_size, message_size);
317 assert_eq!(code.conf.code_size, code_size);
318 assert_eq!(code.conf.accumulator_size, accumulator_size);
319 assert_eq!(code.conf.seed, seed);
320 }
321
322 #[test]
323 fn test_config_with_default_code_size() {
324 let message_size = 100;
325 let code = ExConvCode::new(message_size);
326 assert_eq!(code.conf.code_size, 2 * message_size);
327 }
328
329 #[test]
330 fn test_generator_dimensions() {
331 let message_size = 100;
332 let code = ExConvCode::new(message_size);
333 let code_size = code.conf.code_size;
334
335 assert_eq!(code.generator_rows(), message_size);
336 assert_eq!(code.generator_cols(), code_size);
337 assert_eq!(code.parity_rows(), code_size - message_size);
338 assert_eq!(code.parity_cols(), code_size);
339 }
340
341 #[cfg(all(feature = "libote-compat", target_os = "linux"))]
342 #[test]
343 fn test_compare_to_libote() {
344 let message_size = 200;
345 let exconv = ExConvCode::new(message_size);
346 let code_size = exconv.conf.code_size;
347
348 let mut data = vec![Block::ZERO; code_size];
349 let mut rng = StdRng::seed_from_u64(2423);
350 for _ in 0..100 {
351 rng.fill_bytes(cast_slice_mut(&mut data));
352 let mut data_libote = data.clone();
353 exconv.dual_encode(&mut data);
354
355 let mut libote_exconv = libote_codes::ExConvCode::new(
356 message_size as u64,
357 code_size as u64,
358 exconv.conf.expander_weight as u64,
359 exconv.conf.accumulator_size as u64,
360 );
361 libote_exconv.dual_encode_block(cast_slice_mut(&mut data_libote));
362
363 assert_eq!(data, data_libote);
364 }
365 }
366
367 #[cfg(all(feature = "libote-compat", target_os = "linux"))]
368 #[test]
369 fn test_compare_to_libote_bytes() {
370 let message_size = 200;
371 let exconv = ExConvCode::new(message_size);
372 let code_size = exconv.conf.code_size;
373
374 let mut data = vec![u8::ZERO; code_size];
375 let mut rng = StdRng::seed_from_u64(2423);
376 for _ in 0..100 {
377 rng.fill_bytes(cast_slice_mut(&mut data));
378 let mut data_libote = data.clone();
379 exconv.dual_encode(&mut data);
380
381 let mut libote_exconv = libote_codes::ExConvCode::new(
382 message_size as u64,
383 code_size as u64,
384 exconv.conf.expander_weight as u64,
385 exconv.conf.accumulator_size as u64,
386 );
387 libote_exconv.dual_encode_byte(&mut data_libote);
388
389 assert_eq!(data, data_libote);
390 }
391 }
392}