cryprot_codes/
ex_conv.rs

1//! Expand-convolute code of [[RRT23](https://eprint.iacr.org/2023/882)].
2use std::mem;
3
4use bytemuck::{cast, cast_slice_mut};
5use cryprot_core::block::Block;
6use expander::ExpanderCode;
7use fast_aes_rng::FastAesRng;
8use seq_macro::seq;
9
10use crate::Coeff;
11
12mod expander;
13mod expander_modd;
14mod fast_aes_rng;
15
16#[derive(Debug, Clone, Copy)]
17pub struct ExConvCode {
18    expander: ExpanderCode,
19    conf: ExConvCodeConfig,
20    message_size: usize,
21}
22
23/// Configugarion for the [`ExConvCode`].
24#[derive(Debug, Clone, Copy)]
25pub struct ExConvCodeConfig {
26    pub seed: Block,
27    pub code_size: usize,
28    pub accumulator_size: usize,
29    pub acc_twice: bool,
30    pub regular_expander: bool,
31    pub expander_weight: usize,
32}
33
34impl Default for ExConvCodeConfig {
35    fn default() -> Self {
36        Self {
37            seed: [56756745976768754, 9996754675674599].into(),
38            code_size: 0,
39            accumulator_size: 24,
40            acc_twice: true,
41            regular_expander: true,
42            expander_weight: 7,
43        }
44    }
45}
46
47const CC_BLOCK: Block = Block::new([0xcc; 16]);
48
49impl ExConvCode {
50    /// Create a new code for the given `message_size`, by default `code_size =
51    /// 2 * message_size`.
52    pub fn new(message_size: usize) -> Self {
53        Self::new_with_conf(message_size, ExConvCodeConfig::default())
54    }
55
56    /// Create a new code with the provided configuration.
57    pub fn new_with_conf(message_size: usize, mut conf: ExConvCodeConfig) -> Self {
58        if conf.code_size == 0 {
59            conf.code_size = 2 * message_size;
60        }
61        let expander = ExpanderCode::new(
62            conf.code_size - message_size,
63            conf.expander_weight,
64            conf.regular_expander,
65            conf.seed ^ CC_BLOCK,
66        );
67        Self {
68            expander,
69            message_size,
70            conf,
71        }
72    }
73
74    pub fn parity_rows(&self) -> usize {
75        self.conf.code_size - self.message_size
76    }
77
78    pub fn parity_cols(&self) -> usize {
79        self.conf.code_size
80    }
81
82    pub fn generator_rows(&self) -> usize {
83        self.message_size
84    }
85
86    pub fn generator_cols(&self) -> usize {
87        self.conf.code_size
88    }
89
90    pub fn message_size(&self) -> usize {
91        self.message_size
92    }
93
94    pub fn code_size(&self) -> usize {
95        self.conf.code_size
96    }
97
98    pub fn conf(&self) -> &ExConvCodeConfig {
99        &self.conf
100    }
101
102    /// Encode e.
103    ///
104    /// For maximum performance, the crate needs to be compiled with
105    /// `target_feature = sse4.1` enabled. Otherwise a slower scalar fallback is
106    /// used.
107    ///
108    /// # Panics
109    /// If `e.len() != self.code_size()`.
110    pub fn dual_encode<T: Coeff>(&self, e: &mut [T]) {
111        assert_eq!(self.conf.code_size, e.len(), "e must have len of code_size");
112        let (prefix, suffix) = e.split_at_mut(self.message_size);
113        self.accumulate(suffix);
114        self.expander.expand(suffix, prefix);
115    }
116
117    fn accumulate<T: Coeff>(&self, x: &mut [T]) {
118        let size = self.conf.code_size - self.message_size;
119        debug_assert_eq!(size, x.len());
120
121        self.accumulate_fixed(x, self.conf.seed);
122        if self.conf.acc_twice {
123            self.accumulate_fixed(x, !self.conf.seed);
124        }
125    }
126
127    fn accumulate_fixed<T: Coeff>(&self, x: &mut [T], seed: Block) {
128        let mut rng = FastAesRng::new(seed);
129        let mut mtx_coeffs = rng.bytes();
130
131        let main = x.len() - 1 - self.conf.accumulator_size;
132        for i in 0..x.len() {
133            if mtx_coeffs.len() < self.conf.accumulator_size.div_ceil(8) {
134                rng.refill();
135                mtx_coeffs = rng.bytes();
136            }
137
138            if i < main {
139                self.acc_one_gen::<false, _>(x, i, mtx_coeffs);
140            } else {
141                self.acc_one_gen::<true, _>(x, i, mtx_coeffs);
142            }
143            mtx_coeffs = &mtx_coeffs[1..];
144        }
145    }
146
147    fn acc_one_gen<const RANGE_CHECK: bool, T: Coeff>(
148        &self,
149        x: &mut [T],
150        i: usize,
151        matrix_coeffs: &[u8],
152    ) {
153        let mut matrix_coeffs = matrix_coeffs.iter().copied();
154        let size = x.len();
155        let xi = x[i];
156        let mut j = i + 1;
157        if RANGE_CHECK && j >= size {
158            j -= size;
159        }
160
161        let mut k = 0;
162        while k + 7 < self.conf.accumulator_size {
163            let b = matrix_coeffs.next().expect("insufficient coeffs");
164            Self::acc_one_8::<RANGE_CHECK, _>(x, xi, j, b);
165
166            j += 8;
167            if RANGE_CHECK && j >= size {
168                j -= size;
169            }
170            k += 8;
171        }
172
173        while k < self.conf.accumulator_size {
174            let mut b = matrix_coeffs.next().expect("insufficient coeffs");
175            let mut p = 0;
176            while p < 8 && k < self.conf.accumulator_size {
177                if b & 1 != 0 {
178                    x[j] ^= xi;
179                }
180                p += 1;
181                k += 1;
182                b >>= 1;
183                j += 1;
184                if RANGE_CHECK && j >= size {
185                    j -= size;
186                }
187            }
188            k += 1;
189        }
190
191        x[j] ^= xi;
192    }
193
194    #[inline(always)]
195    fn acc_one_8_offsets<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], j: usize) -> [usize; 8] {
196        let size = x.len();
197        let mut js = [j, j + 1, j + 2, j + 3, j + 4, j + 5, j + 6, j + 7];
198        if !RANGE_CHECK {
199            debug_assert!(js[7] < x.len());
200        }
201
202        if RANGE_CHECK {
203            for j in js.iter_mut() {
204                if *j >= size {
205                    *j -= size;
206                }
207            }
208        }
209        js
210    }
211
212    fn acc_one_8<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
213        if mem::size_of::<T>() == 16 && mem::align_of::<T>() == 16 {
214            #[cfg(target_feature = "sse4.1")]
215            Self::acc_one_8_sse::<RANGE_CHECK>(cast_slice_mut(x), cast(xi), j, b);
216            #[cfg(not(target_feature = "sse4.1"))]
217            Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
218        } else {
219            Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
220        }
221    }
222
223    fn acc_one_8_scalar<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
224        let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
225
226        let b_bits = [b & 1, b & 2, b & 4, b & 8, b & 16, b & 32, b & 64, b & 128];
227
228        // I've tried replacing these index operations with unchecked ones, but there is
229        // no measurable performance boost
230        seq!(N in 0..8 {
231            if b_bits[N] != 0 {
232                x[js[N]] ^= xi;
233            }
234        });
235    }
236
237    #[cfg(target_feature = "sse4.1")]
238    #[inline(always)]
239    pub fn acc_one_8_sse<const RANGE_CHECK: bool>(x: &mut [Block], xi: Block, j: usize, b: u8) {
240        #[cfg(target_arch = "x86")]
241        use std::arch::x86::*;
242        #[cfg(target_arch = "x86_64")]
243        use std::arch::x86_64::*;
244
245        let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
246        let rnd: __m128i = Block::splat(b).into();
247        // SAFETY: sse4.1 is available per cfg
248        let bb = unsafe {
249            let bshift = [
250                _mm_slli_epi32::<7>(rnd),
251                _mm_slli_epi32::<6>(rnd),
252                _mm_slli_epi32::<5>(rnd),
253                _mm_slli_epi32::<4>(rnd),
254                _mm_slli_epi32::<3>(rnd),
255                _mm_slli_epi32::<2>(rnd),
256                _mm_slli_epi32::<1>(rnd),
257                rnd,
258            ];
259            let xii: __m128 = bytemuck::cast(xi);
260            let zero = _mm_setzero_ps();
261            let mut bb: [__m128; 8] = bytemuck::cast(bshift);
262
263            seq!(N in 0..8 {
264                bb[N] = _mm_blendv_ps(zero, xii, bb[N]);
265            });
266            bb
267        };
268
269        #[cfg(debug_assertions)]
270        for (i, bb) in bb.iter().enumerate() {
271            let exp = if ((b >> i) & 1) != 0 { xi } else { Block::ZERO };
272            debug_assert_eq!(exp, bytemuck::cast(*bb));
273        }
274
275        seq!(N in 0..8 {
276            // SAFETY: if j < x.len() - 8, js returned by acc_one_8_offsets are always < x.len()
277            // if x.len() - 8 < j < x.len(), we are called with RANGE_CHECK true and the js are wrapped around
278            *unsafe { x.get_unchecked_mut(js[N]) } ^= bytemuck::cast(bb[N]);
279        });
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    #[cfg(feature = "libote-compat")]
286    use bytemuck::cast_slice_mut;
287    use cryprot_core::block::Block;
288    #[cfg(feature = "libote-compat")]
289    use rand::{RngCore, SeedableRng, rngs::StdRng};
290
291    use super::*;
292
293    // Helper function to create a Block from a u8 array
294    fn create_block(value: u8) -> Block {
295        Block::new([value; 16])
296    }
297
298    #[test]
299    fn test_config_with_explicit_code_size() {
300        let message_size = 100;
301        let code_size = 250;
302        let expander_weight = 5;
303        let accumulator_size = 24;
304        let seed = create_block(0xAA);
305        let code = ExConvCode::new_with_conf(
306            message_size,
307            ExConvCodeConfig {
308                seed,
309                code_size,
310                accumulator_size,
311                expander_weight,
312                ..Default::default()
313            },
314        );
315
316        assert_eq!(code.message_size, message_size);
317        assert_eq!(code.conf.code_size, code_size);
318        assert_eq!(code.conf.accumulator_size, accumulator_size);
319        assert_eq!(code.conf.seed, seed);
320    }
321
322    #[test]
323    fn test_config_with_default_code_size() {
324        let message_size = 100;
325        let code = ExConvCode::new(message_size);
326        assert_eq!(code.conf.code_size, 2 * message_size);
327    }
328
329    #[test]
330    fn test_generator_dimensions() {
331        let message_size = 100;
332        let code = ExConvCode::new(message_size);
333        let code_size = code.conf.code_size;
334
335        assert_eq!(code.generator_rows(), message_size);
336        assert_eq!(code.generator_cols(), code_size);
337        assert_eq!(code.parity_rows(), code_size - message_size);
338        assert_eq!(code.parity_cols(), code_size);
339    }
340
341    #[cfg(all(feature = "libote-compat", target_os = "linux"))]
342    #[test]
343    fn test_compare_to_libote() {
344        let message_size = 200;
345        let exconv = ExConvCode::new(message_size);
346        let code_size = exconv.conf.code_size;
347
348        let mut data = vec![Block::ZERO; code_size];
349        let mut rng = StdRng::seed_from_u64(2423);
350        for _ in 0..100 {
351            rng.fill_bytes(cast_slice_mut(&mut data));
352            let mut data_libote = data.clone();
353            exconv.dual_encode(&mut data);
354
355            let mut libote_exconv = libote_codes::ExConvCode::new(
356                message_size as u64,
357                code_size as u64,
358                exconv.conf.expander_weight as u64,
359                exconv.conf.accumulator_size as u64,
360            );
361            libote_exconv.dual_encode_block(cast_slice_mut(&mut data_libote));
362
363            assert_eq!(data, data_libote);
364        }
365    }
366
367    #[cfg(all(feature = "libote-compat", target_os = "linux"))]
368    #[test]
369    fn test_compare_to_libote_bytes() {
370        let message_size = 200;
371        let exconv = ExConvCode::new(message_size);
372        let code_size = exconv.conf.code_size;
373
374        let mut data = vec![u8::ZERO; code_size];
375        let mut rng = StdRng::seed_from_u64(2423);
376        for _ in 0..100 {
377            rng.fill_bytes(cast_slice_mut(&mut data));
378            let mut data_libote = data.clone();
379            exconv.dual_encode(&mut data);
380
381            let mut libote_exconv = libote_codes::ExConvCode::new(
382                message_size as u64,
383                code_size as u64,
384                exconv.conf.expander_weight as u64,
385                exconv.conf.accumulator_size as u64,
386            );
387            libote_exconv.dual_encode_byte(&mut data_libote);
388
389            assert_eq!(data, data_libote);
390        }
391    }
392}