stream_vbyte/encode/
sse41.rs

1use std::{
2    arch::x86_64::{
3        __m128i, _mm_loadu_si128, _mm_min_epu8, _mm_mullo_epi32, _mm_shuffle_epi8, _mm_storeu_si128,
4    },
5    simd,
6};
7
8use crate::tables;
9
10use super::Encoder;
11
12/// Encoder using SSE4.1 instructions.
13pub struct Sse41;
14
15const ONES: [u8; 16] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
16// multiplicand to achieve shifts by multiplication
17const SHIFT: u32 = 1 | 1 << 9 | 1 << 18;
18const SHIFTS: [u32; 4] = [SHIFT, SHIFT, SHIFT, SHIFT];
19// translate 3-bit bytemaps into lane codes. Last 8 will never be used.
20// 0 = 1 byte encoded num, 1 = 2 byte, etc.
21// These are concatenated into the control byte, and also used to sum to find the total length.
22// The ordering of these codes is determined by how the bytemap is calculated; see comments below.
23#[cfg_attr(rustfmt, rustfmt_skip)]
24const LANECODES: [u8; 16] = [
25    0, 3, 2, 3,
26    1, 3, 2, 3,
27    128, 128, 128, 128,
28    128, 128, 128, 128];
29// gather high bytes from each lane, 2 copies
30#[cfg_attr(rustfmt, rustfmt_skip)]
31const GATHER_HI: [u8; 16] = [
32    15, 11, 7, 3,
33    15, 11, 7, 3,
34    128, 128, 128, 128,
35    128, 128, 128, 128];
36// mul-shift magic
37// concatenate 2-bit lane codes into high byte
38const CONCAT: u32 = 1 | 1 << 10 | 1 << 20 | 1 << 30;
39// sum lane codes in high byte
40const SUM: u32 = 1 | 1 << 8 | 1 << 16 | 1 << 24;
41const AGGREGATORS: [u32; 4] = [CONCAT, SUM, 0, 0];
42
43impl Encoder for Sse41 {
44    fn encode_quads(input: &[u32], control_bytes: &mut [u8], output: &mut [u8]) -> (usize, usize) {
45        let mut nums_encoded: usize = 0;
46        let mut bytes_encoded: usize = 0;
47
48        // TODO these load unaligned once https://github.com/rust-lang/rust/issues/33626
49        // hits stable
50        let ones = unsafe { _mm_loadu_si128(ONES.as_ptr() as *const __m128i) };
51        let shifts = unsafe { _mm_loadu_si128(SHIFTS.as_ptr() as *const __m128i) };
52        let lanecodes = unsafe { _mm_loadu_si128(LANECODES.as_ptr() as *const __m128i) };
53        let gather_hi = unsafe { _mm_loadu_si128(GATHER_HI.as_ptr() as *const __m128i) };
54        let aggregators = unsafe { _mm_loadu_si128(AGGREGATORS.as_ptr() as *const __m128i) };
55
56        // Encoding writes 16 bytes at a time, but if numbers are encoded with 1 byte each, that
57        // means the last 3 quads could write past what is actually necessary. So, don't process
58        // the last few control bytes.
59        let control_byte_limit = control_bytes.len().saturating_sub(3);
60
61        for control_byte in &mut control_bytes[0..control_byte_limit].iter_mut() {
62            let to_encode = unsafe {
63                _mm_loadu_si128(input[nums_encoded..(nums_encoded + 4)].as_ptr() as *const __m128i)
64            };
65
66            // clamp each byte to 1 if nonzero
67            let mins = unsafe { _mm_min_epu8(to_encode, ones) };
68
69            // Apply shifts to clamped bytes. e.g. u32::max_value() would be (little endian):
70            // 00000001 00000001 00000001 00000001
71            // and after multiplication aka shifting:
72            // 00000001 00000011 00000111 00000111
73            // 1 << 16 | 1 would be:
74            // 00000001 00000000 00000001 00000000
75            // and shifted:
76            // 00000001 00000010 00000101 00000010
77            // At most the bottom 3 bits of each byte will be set by shifting.
78            // What we care about is the bottom 3 bits of the high byte in each num.
79            // A 1-byte number (clamped to 0x01000000) will accumulate to 0x00 in the top byte
80            // because there isn't a 3-byte shift to get that set bit into the top byte.
81            // A 2-byte number (clamped to 0x00010000) will accumulate to 0x04 in the top byte
82            // because the set bit would have been shifted 2 bytes + 2 bits higher.
83            // A 3-byte number will have the 0x02 bit set in the top byte, and possibly the 0x04
84            // bit set as well if the 2nd byte was non-zero.
85            // A 4-byte number will have the 0x01 bit set in the top byte, and possibly 0x02 and
86            // 0x04.
87            // In summary, byte lengths -> high byte:
88            // 1-byte -> 0x00
89            // 2-byte -> 0x04
90            // 3-byte -> 0x02, 0x06
91            // 4-byte -> 0x01, 0x05, 0x03, 0x07
92            let bytemaps = unsafe { _mm_mullo_epi32(mins, shifts) };
93
94            // Map high bytes to the corresponding lane codes. (Other bytes are mapped as well
95            // but are not used.)
96            let shuffled_lanecodes = unsafe { _mm_shuffle_epi8(lanecodes, bytemaps) };
97
98            // Assemble 2 copies of the high byte from each of the 4 numbers.
99            // The first copy will be used to calculate the control byte, the second the length.
100            let hi_bytes = unsafe { _mm_shuffle_epi8(shuffled_lanecodes, gather_hi) };
101
102            // use CONCAT to shift the lane code bits from bytes 0-3 into 1 byte (byte 3)
103            // use SUM to sum lane code bits from bytes 4-7 into 1 byte (byte 7)
104            let code_and_length = unsafe { _mm_mullo_epi32(hi_bytes, aggregators) };
105
106            let bytes = simd::u8x16::from(code_and_length);
107            let code = bytes[3];
108            let length = bytes[7] + 4;
109
110            let mask_bytes = tables::X86_ENCODE_SHUFFLE_TABLE[code as usize];
111            let encode_mask = unsafe { _mm_loadu_si128(mask_bytes.as_ptr() as *const __m128i) };
112
113            let encoded = unsafe { _mm_shuffle_epi8(to_encode, encode_mask) };
114
115            unsafe {
116                _mm_storeu_si128(
117                    output[bytes_encoded..(bytes_encoded + 16)].as_ptr() as *mut __m128i,
118                    encoded,
119                );
120            }
121
122            *control_byte = code;
123
124            bytes_encoded += length as usize;
125            nums_encoded += 4;
126        }
127
128        (nums_encoded, bytes_encoded)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::*;
136
137    #[test]
138    fn encodes_all_but_last_3_control_bytes() {
139        // cover the whole byte length range
140        let nums: Vec<u32> = (0..32).map(|i| 1 << i).collect();
141        let mut encoded = Vec::new();
142        let mut decoded: Vec<u32> = Vec::new();
143
144        for control_bytes_len in 0..(nums.len() / 4 + 1) {
145            encoded.clear();
146            encoded.resize(nums.len() * 5, 0xFF);
147            decoded.clear();
148            decoded.resize(nums.len(), 54321);
149
150            let (nums_encoded, bytes_written) = {
151                let (control_bytes, num_bytes) = encoded.split_at_mut(control_bytes_len);
152
153                Sse41::encode_quads(&nums[0..4 * control_bytes_len], control_bytes, num_bytes)
154            };
155
156            let control_bytes_written = nums_encoded / 4;
157
158            assert_eq!(
159                cumulative_encoded_len(&encoded[0..control_bytes_written]),
160                bytes_written
161            );
162
163            // the last control byte written may not have populated all 16 output bytes with encoded
164            // nums, depending on the length required. Any unused trailing bytes will have had 0
165            // written, but nothing beyond that 16 should be touched.
166
167            let length_before_final_control_byte =
168                cumulative_encoded_len(&encoded[0..control_bytes_written.saturating_sub(1)]);
169
170            let bytes_written_for_final_control_byte =
171                bytes_written - length_before_final_control_byte;
172            let trailing_zero_len = if control_bytes_written > 0 {
173                16 - bytes_written_for_final_control_byte
174            } else {
175                0
176            };
177
178            assert!(&encoded[control_bytes_len + bytes_written
179                ..control_bytes_len + bytes_written + trailing_zero_len]
180                .iter()
181                .all(|&i| i == 0));
182            assert!(
183                &encoded[control_bytes_len + bytes_written + trailing_zero_len..]
184                    .iter()
185                    .all(|&i| i == 0xFF)
186            );
187        }
188    }
189}