stream_vbyte/encode/
sse41.rs1use std::{
2 arch::x86_64::{
3 __m128i, _mm_loadu_si128, _mm_min_epu8, _mm_mullo_epi32, _mm_shuffle_epi8, _mm_storeu_si128,
4 },
5 simd,
6};
7
8use crate::tables;
9
10use super::Encoder;
11
12pub struct Sse41;
14
15const ONES: [u8; 16] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
16const SHIFT: u32 = 1 | 1 << 9 | 1 << 18;
18const SHIFTS: [u32; 4] = [SHIFT, SHIFT, SHIFT, SHIFT];
19#[cfg_attr(rustfmt, rustfmt_skip)]
24const LANECODES: [u8; 16] = [
25 0, 3, 2, 3,
26 1, 3, 2, 3,
27 128, 128, 128, 128,
28 128, 128, 128, 128];
29#[cfg_attr(rustfmt, rustfmt_skip)]
31const GATHER_HI: [u8; 16] = [
32 15, 11, 7, 3,
33 15, 11, 7, 3,
34 128, 128, 128, 128,
35 128, 128, 128, 128];
36const CONCAT: u32 = 1 | 1 << 10 | 1 << 20 | 1 << 30;
39const SUM: u32 = 1 | 1 << 8 | 1 << 16 | 1 << 24;
41const AGGREGATORS: [u32; 4] = [CONCAT, SUM, 0, 0];
42
43impl Encoder for Sse41 {
44 fn encode_quads(input: &[u32], control_bytes: &mut [u8], output: &mut [u8]) -> (usize, usize) {
45 let mut nums_encoded: usize = 0;
46 let mut bytes_encoded: usize = 0;
47
48 let ones = unsafe { _mm_loadu_si128(ONES.as_ptr() as *const __m128i) };
51 let shifts = unsafe { _mm_loadu_si128(SHIFTS.as_ptr() as *const __m128i) };
52 let lanecodes = unsafe { _mm_loadu_si128(LANECODES.as_ptr() as *const __m128i) };
53 let gather_hi = unsafe { _mm_loadu_si128(GATHER_HI.as_ptr() as *const __m128i) };
54 let aggregators = unsafe { _mm_loadu_si128(AGGREGATORS.as_ptr() as *const __m128i) };
55
56 let control_byte_limit = control_bytes.len().saturating_sub(3);
60
61 for control_byte in &mut control_bytes[0..control_byte_limit].iter_mut() {
62 let to_encode = unsafe {
63 _mm_loadu_si128(input[nums_encoded..(nums_encoded + 4)].as_ptr() as *const __m128i)
64 };
65
66 let mins = unsafe { _mm_min_epu8(to_encode, ones) };
68
69 let bytemaps = unsafe { _mm_mullo_epi32(mins, shifts) };
93
94 let shuffled_lanecodes = unsafe { _mm_shuffle_epi8(lanecodes, bytemaps) };
97
98 let hi_bytes = unsafe { _mm_shuffle_epi8(shuffled_lanecodes, gather_hi) };
101
102 let code_and_length = unsafe { _mm_mullo_epi32(hi_bytes, aggregators) };
105
106 let bytes = simd::u8x16::from(code_and_length);
107 let code = bytes[3];
108 let length = bytes[7] + 4;
109
110 let mask_bytes = tables::X86_ENCODE_SHUFFLE_TABLE[code as usize];
111 let encode_mask = unsafe { _mm_loadu_si128(mask_bytes.as_ptr() as *const __m128i) };
112
113 let encoded = unsafe { _mm_shuffle_epi8(to_encode, encode_mask) };
114
115 unsafe {
116 _mm_storeu_si128(
117 output[bytes_encoded..(bytes_encoded + 16)].as_ptr() as *mut __m128i,
118 encoded,
119 );
120 }
121
122 *control_byte = code;
123
124 bytes_encoded += length as usize;
125 nums_encoded += 4;
126 }
127
128 (nums_encoded, bytes_encoded)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::*;
136
137 #[test]
138 fn encodes_all_but_last_3_control_bytes() {
139 let nums: Vec<u32> = (0..32).map(|i| 1 << i).collect();
141 let mut encoded = Vec::new();
142 let mut decoded: Vec<u32> = Vec::new();
143
144 for control_bytes_len in 0..(nums.len() / 4 + 1) {
145 encoded.clear();
146 encoded.resize(nums.len() * 5, 0xFF);
147 decoded.clear();
148 decoded.resize(nums.len(), 54321);
149
150 let (nums_encoded, bytes_written) = {
151 let (control_bytes, num_bytes) = encoded.split_at_mut(control_bytes_len);
152
153 Sse41::encode_quads(&nums[0..4 * control_bytes_len], control_bytes, num_bytes)
154 };
155
156 let control_bytes_written = nums_encoded / 4;
157
158 assert_eq!(
159 cumulative_encoded_len(&encoded[0..control_bytes_written]),
160 bytes_written
161 );
162
163 let length_before_final_control_byte =
168 cumulative_encoded_len(&encoded[0..control_bytes_written.saturating_sub(1)]);
169
170 let bytes_written_for_final_control_byte =
171 bytes_written - length_before_final_control_byte;
172 let trailing_zero_len = if control_bytes_written > 0 {
173 16 - bytes_written_for_final_control_byte
174 } else {
175 0
176 };
177
178 assert!(&encoded[control_bytes_len + bytes_written
179 ..control_bytes_len + bytes_written + trailing_zero_len]
180 .iter()
181 .all(|&i| i == 0));
182 assert!(
183 &encoded[control_bytes_len + bytes_written + trailing_zero_len..]
184 .iter()
185 .all(|&i| i == 0xFF)
186 );
187 }
188 }
189}