1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use crate::simd::{avx2::rejection_sample::shuffle_table::SHUFFLE_TABLE, traits::FIELD_MODULUS};
use libcrux_intrinsics::avx2::*;
// Partition a stream of bytes into 24-bit values, and then clear the most
// significant bit to turn them into 23-bit ones.
#[inline(always)]
fn bytestream_to_potential_coefficients(serialized: &[u8]) -> Vec256 {
debug_assert_eq!(serialized.len(), 24);
let mut serialized_extended = [0u8; 32];
serialized_extended[..24].copy_from_slice(serialized);
const COEFFICIENT_MASK: i32 = (1 << 23) - 1;
let coefficients = mm256_loadu_si256_u8(&serialized_extended);
let coefficients =
mm256_permutevar8x32_epi32(coefficients, mm256_set_epi32(0, 5, 4, 3, 0, 2, 1, 0));
let coefficients = mm256_shuffle_epi8(
coefficients,
mm256_set_epi8(
-1, 11, 10, 9, -1, 8, 7, 6, -1, 5, 4, 3, -1, 2, 1, 0, -1, 11, 10, 9, -1, 8, 7, 6, -1,
5, 4, 3, -1, 2, 1, 0,
),
);
mm256_and_si256(coefficients, mm256_set1_epi32(COEFFICIENT_MASK))
}
#[inline(always)]
pub(crate) fn sample(input: &[u8], output: &mut [i32]) -> usize {
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
// The input bytes can be interpreted as a sequence of serialized
// 23-bit (i.e. uncompressed) coefficients. Not all coefficients may be
// less than FIELD_MODULUS though.
let potential_coefficients = bytestream_to_potential_coefficients(input);
// Suppose we view |potential_coefficients| as follows (clumping bits together
// in groups of 32):
//
// A B C D | E F G H ....
//
// and A < |FIELD_MODULUS|, D < |FIELD_MODULUS| and H < |FIELD_MODULUS|, |compare_with_field_modulus| will look like:
//
// 0xFF..FF 0 0 0xFF..FF | 0 0 0 0xFF..FF | ...
let compare_with_field_modulus = mm256_cmpgt_epi32(field_modulus, potential_coefficients);
// Since every bit in each lane is either 0 or all 1s, we only need one bit
// from each lane to tell us what coefficients to keep and what to throw-away.
// Combine all the bits (there are 8) into one byte.
let good = mm256_movemask_ps(mm256_castsi256_ps(compare_with_field_modulus));
let good_lower_half = good & 0x0F;
let good_upper_half = good >> 4;
// Each bit (and its corresponding position) represents an element we
// want to sample. We'd like all such elements to be next to each other starting
// at index 0, so that they can be read from the vector easily.
// |REJECTION_SAMPLE_SHUFFLE_TABLE| encodes the byte-level shuffling indices
// needed to make this happen.
//
// For e.g. if the lower 4 bits of good = 0b_0_0_1_0, we need to move the
// element in the 2-nd 32-bit lane to the first. To do this, we need the
// byte-level shuffle indices to be 2 3 4 5 X X ...
let lower_shuffles = SHUFFLE_TABLE[good_lower_half as usize];
// Shuffle the lower 4 32-bits accordingly ...
let lower_shuffles = mm_loadu_si128(&lower_shuffles);
let lower_coefficients = mm256_castsi256_si128(potential_coefficients);
let lower_coefficients = mm_shuffle_epi8(lower_coefficients, lower_shuffles);
// ... then write them out ...
mm_storeu_si128_i32(&mut output[0..4], lower_coefficients);
// ... and finally count the number of bits of |good_lower_half| so we know
// how many were actually sampled
let sampled_count = good_lower_half.count_ones() as usize;
// Do the same for |good_upper_half|
let upper_shuffles = SHUFFLE_TABLE[good_upper_half as usize];
let upper_shuffles = mm_loadu_si128(&upper_shuffles);
let upper_coefficients = mm256_extracti128_si256::<1>(potential_coefficients);
let upper_coefficients = mm_shuffle_epi8(upper_coefficients, upper_shuffles);
mm_storeu_si128_i32(
&mut output[sampled_count..sampled_count + 4],
upper_coefficients,
);
sampled_count + (good_upper_half.count_ones() as usize)
}