Skip to main content

cryprot_codes/
ex_conv.rs

1//! Expand-convolute code of [[RRT23](https://eprint.iacr.org/2023/882)].
2use std::mem;
3
4#[cfg(target_feature = "sse4.1")]
5use bytemuck::{cast, cast_slice_mut};
6use cryprot_core::block::Block;
7use expander::ExpanderCode;
8use fast_aes_rng::FastAesRng;
9use seq_macro::seq;
10
11use crate::Coeff;
12
13mod expander;
14mod expander_modd;
15mod fast_aes_rng;
16
17#[derive(Debug, Clone, Copy)]
18pub struct ExConvCode {
19    expander: ExpanderCode,
20    conf: ExConvCodeConfig,
21    message_size: usize,
22}
23
24/// Configuration for the [`ExConvCode`].
25#[derive(Debug, Clone, Copy)]
26pub struct ExConvCodeConfig {
27    pub seed: Block,
28    pub code_size: usize,
29    pub accumulator_size: usize,
30    pub acc_twice: bool,
31    pub regular_expander: bool,
32    pub expander_weight: usize,
33}
34
35impl Default for ExConvCodeConfig {
36    fn default() -> Self {
37        Self {
38            seed: [56756745976768754, 9996754675674599].into(),
39            code_size: 0,
40            accumulator_size: 24,
41            acc_twice: true,
42            regular_expander: true,
43            expander_weight: 7,
44        }
45    }
46}
47
48const CC_BLOCK: Block = Block::new([0xcc; 16]);
49
50impl ExConvCode {
51    /// Create a new code for the given `message_size`, by default `code_size =
52    /// 2 * message_size`.
53    pub fn new(message_size: usize) -> Self {
54        Self::new_with_conf(message_size, ExConvCodeConfig::default())
55    }
56
57    /// Create a new code with the provided configuration.
58    pub fn new_with_conf(message_size: usize, mut conf: ExConvCodeConfig) -> Self {
59        if conf.code_size == 0 {
60            conf.code_size = 2 * message_size;
61        }
62        let expander = ExpanderCode::new(
63            conf.code_size - message_size,
64            conf.expander_weight,
65            conf.regular_expander,
66            conf.seed ^ CC_BLOCK,
67        );
68        Self {
69            expander,
70            message_size,
71            conf,
72        }
73    }
74
75    pub fn parity_rows(&self) -> usize {
76        self.conf.code_size - self.message_size
77    }
78
79    pub fn parity_cols(&self) -> usize {
80        self.conf.code_size
81    }
82
83    pub fn generator_rows(&self) -> usize {
84        self.message_size
85    }
86
87    pub fn generator_cols(&self) -> usize {
88        self.conf.code_size
89    }
90
91    pub fn message_size(&self) -> usize {
92        self.message_size
93    }
94
95    pub fn code_size(&self) -> usize {
96        self.conf.code_size
97    }
98
99    pub fn conf(&self) -> &ExConvCodeConfig {
100        &self.conf
101    }
102
103    /// Encode e.
104    ///
105    /// For maximum performance, the crate needs to be compiled with
106    /// `target_feature = sse4.1` enabled. Otherwise a slower scalar fallback is
107    /// used.
108    ///
109    /// # Panics
110    /// If `e.len() != self.code_size()`.
111    pub fn dual_encode<T: Coeff>(&self, e: &mut [T]) {
112        assert_eq!(self.conf.code_size, e.len(), "e must have len of code_size");
113        let (prefix, suffix) = e.split_at_mut(self.message_size);
114        self.accumulate(suffix);
115        self.expander.expand(suffix, prefix);
116    }
117
118    fn accumulate<T: Coeff>(&self, x: &mut [T]) {
119        let size = self.conf.code_size - self.message_size;
120        debug_assert_eq!(size, x.len());
121
122        self.accumulate_fixed(x, self.conf.seed);
123        if self.conf.acc_twice {
124            self.accumulate_fixed(x, !self.conf.seed);
125        }
126    }
127
128    fn accumulate_fixed<T: Coeff>(&self, x: &mut [T], seed: Block) {
129        let mut rng = FastAesRng::new(seed);
130        let mut mtx_coeffs = rng.bytes();
131
132        let main = x.len() - 1 - self.conf.accumulator_size;
133        for i in 0..x.len() {
134            if mtx_coeffs.len() < self.conf.accumulator_size.div_ceil(8) {
135                rng.refill();
136                mtx_coeffs = rng.bytes();
137            }
138
139            if i < main {
140                self.acc_one_gen::<false, _>(x, i, mtx_coeffs);
141            } else {
142                self.acc_one_gen::<true, _>(x, i, mtx_coeffs);
143            }
144            mtx_coeffs = &mtx_coeffs[1..];
145        }
146    }
147
148    fn acc_one_gen<const RANGE_CHECK: bool, T: Coeff>(
149        &self,
150        x: &mut [T],
151        i: usize,
152        matrix_coeffs: &[u8],
153    ) {
154        let mut matrix_coeffs = matrix_coeffs.iter().copied();
155        let size = x.len();
156        let xi = x[i];
157        let mut j = i + 1;
158        if RANGE_CHECK && j >= size {
159            j -= size;
160        }
161
162        let mut k = 0;
163        while k + 7 < self.conf.accumulator_size {
164            let b = matrix_coeffs.next().expect("insufficient coeffs");
165            Self::acc_one_8::<RANGE_CHECK, _>(x, xi, j, b);
166
167            j += 8;
168            if RANGE_CHECK && j >= size {
169                j -= size;
170            }
171            k += 8;
172        }
173
174        while k < self.conf.accumulator_size {
175            let mut b = matrix_coeffs.next().expect("insufficient coeffs");
176            let mut p = 0;
177            while p < 8 && k < self.conf.accumulator_size {
178                if b & 1 != 0 {
179                    x[j] ^= xi;
180                }
181                p += 1;
182                k += 1;
183                b >>= 1;
184                j += 1;
185                if RANGE_CHECK && j >= size {
186                    j -= size;
187                }
188            }
189            k += 1;
190        }
191
192        x[j] ^= xi;
193    }
194
195    #[inline(always)]
196    fn acc_one_8_offsets<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], j: usize) -> [usize; 8] {
197        let size = x.len();
198        let mut js = [j, j + 1, j + 2, j + 3, j + 4, j + 5, j + 6, j + 7];
199        if !RANGE_CHECK {
200            debug_assert!(js[7] < x.len());
201        }
202
203        if RANGE_CHECK {
204            for j in js.iter_mut() {
205                if *j >= size {
206                    *j -= size;
207                }
208            }
209        }
210        js
211    }
212
213    fn acc_one_8<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
214        if mem::size_of::<T>() == 16 && mem::align_of::<T>() == 16 {
215            #[cfg(target_feature = "sse4.1")]
216            Self::acc_one_8_sse::<RANGE_CHECK>(cast_slice_mut(x), cast(xi), j, b);
217            #[cfg(not(target_feature = "sse4.1"))]
218            Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
219        } else {
220            Self::acc_one_8_scalar::<RANGE_CHECK, _>(x, xi, j, b);
221        }
222    }
223
224    fn acc_one_8_scalar<const RANGE_CHECK: bool, T: Coeff>(x: &mut [T], xi: T, j: usize, b: u8) {
225        let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
226
227        let b_bits = [b & 1, b & 2, b & 4, b & 8, b & 16, b & 32, b & 64, b & 128];
228
229        // I've tried replacing these index operations with unchecked ones, but there is
230        // no measurable performance boost
231        seq!(N in 0..8 {
232            if b_bits[N] != 0 {
233                x[js[N]] ^= xi;
234            }
235        });
236    }
237
238    #[cfg(target_feature = "sse4.1")]
239    #[inline(always)]
240    pub fn acc_one_8_sse<const RANGE_CHECK: bool>(x: &mut [Block], xi: Block, j: usize, b: u8) {
241        #[cfg(target_arch = "x86")]
242        use std::arch::x86::*;
243        #[cfg(target_arch = "x86_64")]
244        use std::arch::x86_64::*;
245
246        let js = Self::acc_one_8_offsets::<RANGE_CHECK, _>(x, j);
247        let rnd: __m128i = Block::splat(b).into();
248        // SAFETY: sse4.1 is available per cfg
249        let bb = unsafe {
250            let bshift = [
251                _mm_slli_epi32::<7>(rnd),
252                _mm_slli_epi32::<6>(rnd),
253                _mm_slli_epi32::<5>(rnd),
254                _mm_slli_epi32::<4>(rnd),
255                _mm_slli_epi32::<3>(rnd),
256                _mm_slli_epi32::<2>(rnd),
257                _mm_slli_epi32::<1>(rnd),
258                rnd,
259            ];
260            let xii: __m128 = bytemuck::cast(xi);
261            let zero = _mm_setzero_ps();
262            let mut bb: [__m128; 8] = bytemuck::cast(bshift);
263
264            seq!(N in 0..8 {
265                bb[N] = _mm_blendv_ps(zero, xii, bb[N]);
266            });
267            bb
268        };
269
270        #[cfg(debug_assertions)]
271        for (i, bb) in bb.iter().enumerate() {
272            let exp = if ((b >> i) & 1) != 0 { xi } else { Block::ZERO };
273            debug_assert_eq!(exp, bytemuck::cast(*bb));
274        }
275
276        seq!(N in 0..8 {
277            // SAFETY: if j < x.len() - 8, js returned by acc_one_8_offsets are always < x.len()
278            // if x.len() - 8 < j < x.len(), we are called with RANGE_CHECK true and the js are wrapped around
279            *unsafe { x.get_unchecked_mut(js[N]) } ^= bytemuck::cast(bb[N]);
280        });
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    #[cfg(feature = "libote-compat")]
287    use bytemuck::cast_slice_mut;
288    use cryprot_core::block::Block;
289    #[cfg(feature = "libote-compat")]
290    use rand::{Rng, SeedableRng, rngs::StdRng};
291
292    use super::*;
293
294    // Helper function to create a Block from a u8 array
295    fn create_block(value: u8) -> Block {
296        Block::new([value; 16])
297    }
298
299    #[test]
300    fn test_config_with_explicit_code_size() {
301        let message_size = 100;
302        let code_size = 250;
303        let expander_weight = 5;
304        let accumulator_size = 24;
305        let seed = create_block(0xAA);
306        let code = ExConvCode::new_with_conf(
307            message_size,
308            ExConvCodeConfig {
309                seed,
310                code_size,
311                accumulator_size,
312                expander_weight,
313                ..Default::default()
314            },
315        );
316
317        assert_eq!(code.message_size, message_size);
318        assert_eq!(code.conf.code_size, code_size);
319        assert_eq!(code.conf.accumulator_size, accumulator_size);
320        assert_eq!(code.conf.seed, seed);
321    }
322
323    #[test]
324    fn test_config_with_default_code_size() {
325        let message_size = 100;
326        let code = ExConvCode::new(message_size);
327        assert_eq!(code.conf.code_size, 2 * message_size);
328    }
329
330    #[test]
331    fn test_generator_dimensions() {
332        let message_size = 100;
333        let code = ExConvCode::new(message_size);
334        let code_size = code.conf.code_size;
335
336        assert_eq!(code.generator_rows(), message_size);
337        assert_eq!(code.generator_cols(), code_size);
338        assert_eq!(code.parity_rows(), code_size - message_size);
339        assert_eq!(code.parity_cols(), code_size);
340    }
341
342    #[cfg(all(feature = "libote-compat", target_os = "linux"))]
343    #[test]
344    fn test_compare_to_libote() {
345        let message_size = 200;
346        let exconv = ExConvCode::new(message_size);
347        let code_size = exconv.conf.code_size;
348
349        let mut data = vec![Block::ZERO; code_size];
350        let mut rng = StdRng::seed_from_u64(2423);
351        for _ in 0..100 {
352            rng.fill_bytes(cast_slice_mut(&mut data));
353            let mut data_libote = data.clone();
354            exconv.dual_encode(&mut data);
355
356            let mut libote_exconv = libote_codes::ExConvCode::new(
357                message_size as u64,
358                code_size as u64,
359                exconv.conf.expander_weight as u64,
360                exconv.conf.accumulator_size as u64,
361            );
362            libote_exconv.dual_encode_block(cast_slice_mut(&mut data_libote));
363
364            assert_eq!(data, data_libote);
365        }
366    }
367
368    #[cfg(all(feature = "libote-compat", target_os = "linux"))]
369    #[test]
370    fn test_compare_to_libote_bytes() {
371        let message_size = 200;
372        let exconv = ExConvCode::new(message_size);
373        let code_size = exconv.conf.code_size;
374
375        let mut data = vec![u8::ZERO; code_size];
376        let mut rng = StdRng::seed_from_u64(2423);
377        for _ in 0..100 {
378            rng.fill_bytes(cast_slice_mut(&mut data));
379            let mut data_libote = data.clone();
380            exconv.dual_encode(&mut data);
381
382            let mut libote_exconv = libote_codes::ExConvCode::new(
383                message_size as u64,
384                code_size as u64,
385                exconv.conf.expander_weight as u64,
386                exconv.conf.accumulator_size as u64,
387            );
388            libote_exconv.dual_encode_byte(&mut data_libote);
389
390            assert_eq!(data, data_libote);
391        }
392    }
393}