use std::{
arch::x86_64::{
__m128i, _mm_loadu_si128, _mm_min_epu8, _mm_mullo_epi32, _mm_shuffle_epi8, _mm_storeu_si128,
},
simd,
};
use crate::tables;
use super::Encoder;
pub struct Sse41;
const ONES: [u8; 16] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
const SHIFT: u32 = 1 | 1 << 9 | 1 << 18;
const SHIFTS: [u32; 4] = [SHIFT, SHIFT, SHIFT, SHIFT];
#[cfg_attr(rustfmt, rustfmt_skip)]
const LANECODES: [u8; 16] = [
0, 3, 2, 3,
1, 3, 2, 3,
128, 128, 128, 128,
128, 128, 128, 128];
#[cfg_attr(rustfmt, rustfmt_skip)]
const GATHER_HI: [u8; 16] = [
15, 11, 7, 3,
15, 11, 7, 3,
128, 128, 128, 128,
128, 128, 128, 128];
const CONCAT: u32 = 1 | 1 << 10 | 1 << 20 | 1 << 30;
const SUM: u32 = 1 | 1 << 8 | 1 << 16 | 1 << 24;
const AGGREGATORS: [u32; 4] = [CONCAT, SUM, 0, 0];
impl Encoder for Sse41 {
fn encode_quads(input: &[u32], control_bytes: &mut [u8], output: &mut [u8]) -> (usize, usize) {
let mut nums_encoded: usize = 0;
let mut bytes_encoded: usize = 0;
let ones = unsafe { _mm_loadu_si128(ONES.as_ptr() as *const __m128i) };
let shifts = unsafe { _mm_loadu_si128(SHIFTS.as_ptr() as *const __m128i) };
let lanecodes = unsafe { _mm_loadu_si128(LANECODES.as_ptr() as *const __m128i) };
let gather_hi = unsafe { _mm_loadu_si128(GATHER_HI.as_ptr() as *const __m128i) };
let aggregators = unsafe { _mm_loadu_si128(AGGREGATORS.as_ptr() as *const __m128i) };
let control_byte_limit = control_bytes.len().saturating_sub(3);
for control_byte in &mut control_bytes[0..control_byte_limit].iter_mut() {
let to_encode = unsafe {
_mm_loadu_si128(input[nums_encoded..(nums_encoded + 4)].as_ptr() as *const __m128i)
};
let mins = unsafe { _mm_min_epu8(to_encode, ones) };
let bytemaps = unsafe { _mm_mullo_epi32(mins, shifts) };
let shuffled_lanecodes = unsafe { _mm_shuffle_epi8(lanecodes, bytemaps) };
let hi_bytes = unsafe { _mm_shuffle_epi8(shuffled_lanecodes, gather_hi) };
let code_and_length = unsafe { _mm_mullo_epi32(hi_bytes, aggregators) };
let bytes = simd::u8x16::from(code_and_length);
let code = bytes[3];
let length = bytes[7] + 4;
let mask_bytes = tables::X86_ENCODE_SHUFFLE_TABLE[code as usize];
let encode_mask = unsafe { _mm_loadu_si128(mask_bytes.as_ptr() as *const __m128i) };
let encoded = unsafe { _mm_shuffle_epi8(to_encode, encode_mask) };
unsafe {
_mm_storeu_si128(
output[bytes_encoded..(bytes_encoded + 16)].as_ptr() as *mut __m128i,
encoded,
);
}
*control_byte = code;
bytes_encoded += length as usize;
nums_encoded += 4;
}
(nums_encoded, bytes_encoded)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::*;
#[test]
fn encodes_all_but_last_3_control_bytes() {
let nums: Vec<u32> = (0..32).map(|i| 1 << i).collect();
let mut encoded = Vec::new();
let mut decoded: Vec<u32> = Vec::new();
for control_bytes_len in 0..(nums.len() / 4 + 1) {
encoded.clear();
encoded.resize(nums.len() * 5, 0xFF);
decoded.clear();
decoded.resize(nums.len(), 54321);
let (nums_encoded, bytes_written) = {
let (control_bytes, num_bytes) = encoded.split_at_mut(control_bytes_len);
Sse41::encode_quads(&nums[0..4 * control_bytes_len], control_bytes, num_bytes)
};
let control_bytes_written = nums_encoded / 4;
assert_eq!(
cumulative_encoded_len(&encoded[0..control_bytes_written]),
bytes_written
);
let length_before_final_control_byte =
cumulative_encoded_len(&encoded[0..control_bytes_written.saturating_sub(1)]);
let bytes_written_for_final_control_byte =
bytes_written - length_before_final_control_byte;
let trailing_zero_len = if control_bytes_written > 0 {
16 - bytes_written_for_final_control_byte
} else {
0
};
assert!(&encoded[control_bytes_len + bytes_written
..control_bytes_len + bytes_written + trailing_zero_len]
.iter()
.all(|&i| i == 0));
assert!(
&encoded[control_bytes_len + bytes_written + trailing_zero_len..]
.iter()
.all(|&i| i == 0xFF)
);
}
}
}