use crate::error::{Result, ZiporaError};
use crate::entropy::huffman::{HuffmanTree, HuffmanEncoder};
use crate::memory::simd_ops::SimdMemOps;
use crate::succinct::rank_select::bmi2_acceleration::Bmi2Capabilities;
use crate::system::cpu_features::{CpuFeatures, get_cpu_features};
const SMALL_BATCH_THRESHOLD: usize = 64;
const MEDIUM_BATCH_THRESHOLD: usize = 1024;
const LARGE_BATCH_THRESHOLD: usize = 8192;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HuffmanSimdTier {
Avx2Bmi2,
Avx2,
Sse42Bmi2,
Sse42,
Bmi2,
Scalar,
}
#[derive(Debug, Clone)]
pub struct SimdHuffmanConfig {
pub preferred_tier: HuffmanSimdTier,
pub enable_batch_processing: bool,
pub batch_size: usize,
pub enable_prefetching: bool,
pub cache_aligned_buffers: bool,
}
impl Default for SimdHuffmanConfig {
fn default() -> Self {
Self {
preferred_tier: HuffmanSimdTier::Avx2Bmi2,
enable_batch_processing: true,
batch_size: 256,
enable_prefetching: true,
cache_aligned_buffers: true,
}
}
}
pub struct SimdHuffmanEncoder {
base_encoder: HuffmanEncoder,
tier: HuffmanSimdTier,
bmi2_caps: &'static Bmi2Capabilities,
cpu_features: &'static CpuFeatures,
config: SimdHuffmanConfig,
simd_ops: SimdMemOps,
}
impl SimdHuffmanEncoder {
pub fn new(data: &[u8]) -> Result<Self> {
Self::with_config(data, SimdHuffmanConfig::default())
}
pub fn with_config(data: &[u8], config: SimdHuffmanConfig) -> Result<Self> {
let base_encoder = HuffmanEncoder::new(data)?;
let cpu_features = get_cpu_features();
let bmi2_caps = Bmi2Capabilities::get();
let tier = Self::select_optimal_tier(&config, cpu_features, bmi2_caps);
let simd_ops = SimdMemOps::new();
Ok(Self {
base_encoder,
tier,
bmi2_caps,
cpu_features,
config,
simd_ops,
})
}
fn select_optimal_tier(
config: &SimdHuffmanConfig,
cpu_features: &CpuFeatures,
bmi2_caps: &Bmi2Capabilities,
) -> HuffmanSimdTier {
match config.preferred_tier {
HuffmanSimdTier::Avx2Bmi2 if cpu_features.has_avx2 && bmi2_caps.has_bmi2 => {
HuffmanSimdTier::Avx2Bmi2
}
HuffmanSimdTier::Avx2Bmi2 | HuffmanSimdTier::Avx2 if cpu_features.has_avx2 => {
HuffmanSimdTier::Avx2
}
HuffmanSimdTier::Sse42Bmi2 if cpu_features.has_sse42 && bmi2_caps.has_bmi2 => {
HuffmanSimdTier::Sse42Bmi2
}
HuffmanSimdTier::Sse42Bmi2 | HuffmanSimdTier::Sse42 if cpu_features.has_sse42 => {
HuffmanSimdTier::Sse42
}
HuffmanSimdTier::Bmi2 if bmi2_caps.has_bmi2 => HuffmanSimdTier::Bmi2,
_ => HuffmanSimdTier::Scalar,
}
}
pub fn encode(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.is_empty() {
return Ok(Vec::new());
}
match (data.len(), self.tier) {
(len, HuffmanSimdTier::Avx2Bmi2) if len >= LARGE_BATCH_THRESHOLD => {
self.encode_avx2_bmi2_large(data)
}
(len, HuffmanSimdTier::Avx2Bmi2) if len >= MEDIUM_BATCH_THRESHOLD => {
self.encode_avx2_bmi2_medium(data)
}
(_, HuffmanSimdTier::Avx2Bmi2) => self.encode_avx2_bmi2_small(data),
(len, HuffmanSimdTier::Avx2) if len >= LARGE_BATCH_THRESHOLD => {
self.encode_avx2_large(data)
}
(len, HuffmanSimdTier::Avx2) if len >= MEDIUM_BATCH_THRESHOLD => {
self.encode_avx2_medium(data)
}
(_, HuffmanSimdTier::Avx2) => self.encode_avx2_small(data),
(_, HuffmanSimdTier::Sse42Bmi2) => self.encode_sse42_bmi2(data),
(_, HuffmanSimdTier::Sse42) => self.encode_sse42(data),
(_, HuffmanSimdTier::Bmi2) => self.encode_bmi2(data),
(_, HuffmanSimdTier::Scalar) => self.base_encoder.encode(data),
}
}
fn encode_avx2_bmi2_large(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("bmi2") {
return unsafe { self.encode_avx2_bmi2_impl(data) };
}
}
self.encode_avx2_large(data)
}
fn encode_avx2_bmi2_medium(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("bmi2") {
return unsafe { self.encode_avx2_bmi2_impl(data) };
}
}
self.encode_avx2_medium(data)
}
fn encode_avx2_bmi2_small(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("bmi2") {
return unsafe { self.encode_avx2_bmi2_impl(data) };
}
}
self.encode_avx2_small(data)
}
fn encode_avx2_large(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { self.encode_avx2_impl(data) };
}
}
self.encode_sse42(data)
}
fn encode_avx2_medium(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { self.encode_avx2_impl(data) };
}
}
self.encode_sse42(data)
}
fn encode_avx2_small(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { self.encode_avx2_impl(data) };
}
}
self.encode_sse42(data)
}
fn encode_sse42_bmi2(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("sse4.2") && is_x86_feature_detected!("bmi2") {
return unsafe { self.encode_sse42_bmi2_impl(data) };
}
}
self.encode_sse42(data)
}
fn encode_sse42(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("sse4.2") {
return unsafe { self.encode_sse42_impl(data) };
}
}
self.encode_bmi2(data)
}
fn encode_bmi2(&self, data: &[u8]) -> Result<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("bmi2") {
return unsafe { self.encode_bmi2_impl(data) };
}
}
self.base_encoder.encode(data)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,bmi2")]
unsafe fn encode_avx2_bmi2_impl(&self, data: &[u8]) -> Result<Vec<u8>> {
use std::arch::x86_64::*;
let estimated_bits = data.len() * 8; let mut bit_buffer = BitBuffer::with_capacity(estimated_bits);
let mut symbol_codes = [0u32; 256];
let mut symbol_lengths = [0u8; 256];
for symbol in 0u8..=255 {
if let Some(code) = self.base_encoder.tree().get_code(symbol) {
if code.len() <= 32 {
let mut packed_code = 0u32;
for (i, &bit) in code.iter().enumerate() {
if bit {
packed_code |= 1u32 << i;
}
}
symbol_codes[symbol as usize] = packed_code;
symbol_lengths[symbol as usize] = code.len() as u8;
}
}
}
let chunks = data.chunks_exact(32);
let remainder = chunks.remainder();
for chunk in chunks {
if self.config.enable_prefetching {
unsafe {
_mm_prefetch(chunk.as_ptr().add(32) as *const i8, _MM_HINT_T0);
}
}
let _symbols = unsafe { _mm256_loadu_si256(chunk.as_ptr() as *const __m256i) };
for i in (0..32).step_by(8) {
let symbols_8 = [
chunk[i], chunk[i+1], chunk[i+2], chunk[i+3],
chunk[i+4], chunk[i+5], chunk[i+6], chunk[i+7]
];
for &symbol in &symbols_8 {
let code = symbol_codes[symbol as usize];
let length = symbol_lengths[symbol as usize];
if length > 0 {
let masked_code = unsafe { _bzhi_u32(code, length as u32) };
bit_buffer.append_bits(masked_code as u64, length)?;
} else {
return Err(ZiporaError::invalid_data(format!(
"Symbol {} not in Huffman tree",
symbol
)));
}
}
}
}
for &symbol in remainder {
let code = symbol_codes[symbol as usize];
let length = symbol_lengths[symbol as usize];
if length > 0 {
let masked_code = unsafe { _bzhi_u32(code, length as u32) };
bit_buffer.append_bits(masked_code as u64, length)?;
} else {
return Err(ZiporaError::invalid_data(format!(
"Symbol {} not in Huffman tree",
symbol
)));
}
}
Ok(bit_buffer.into_bytes())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn encode_avx2_impl(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut symbol_codes = Vec::with_capacity(256);
let mut symbol_lengths = Vec::with_capacity(256);
for symbol in 0u8..=255 {
if let Some(code) = self.base_encoder.tree().get_code(symbol) {
if code.len() <= 32 {
let mut packed_code = 0u32;
for (i, &bit) in code.iter().enumerate() {
if bit {
packed_code |= 1u32 << i;
}
}
symbol_codes.push(packed_code);
symbol_lengths.push(code.len() as u8);
} else {
symbol_codes.push(0);
symbol_lengths.push(0);
}
} else {
symbol_codes.push(0);
symbol_lengths.push(0);
}
}
let estimated_bits = data.len() * 8;
let mut bit_buffer = BitBuffer::with_capacity(estimated_bits);
for chunk in data.chunks(self.config.batch_size) {
for &symbol in chunk {
let code = symbol_codes[symbol as usize];
let length = symbol_lengths[symbol as usize];
if length > 0 {
bit_buffer.append_bits(code as u64, length)?;
} else {
return Err(ZiporaError::invalid_data(format!(
"Symbol {} not in Huffman tree",
symbol
)));
}
}
}
Ok(bit_buffer.into_bytes())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2,bmi2")]
unsafe fn encode_sse42_bmi2_impl(&self, data: &[u8]) -> Result<Vec<u8>> {
unsafe { self.encode_bmi2_impl(data) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn encode_sse42_impl(&self, data: &[u8]) -> Result<Vec<u8>> {
self.base_encoder.encode(data)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
unsafe fn encode_bmi2_impl(&self, data: &[u8]) -> Result<Vec<u8>> {
use std::arch::x86_64::*;
let estimated_bits = data.len() * 8;
let mut bit_buffer = BitBuffer::with_capacity(estimated_bits);
for &symbol in data {
if let Some(code) = self.base_encoder.tree().get_code(symbol) {
if code.len() <= 64 {
let mut packed_code = 0u64;
for (i, &bit) in code.iter().enumerate() {
if bit {
packed_code |= 1u64 << i;
}
}
let length = code.len() as u8;
let final_code = unsafe { _bzhi_u64(packed_code, length as u32) };
bit_buffer.append_bits(final_code, length)?;
} else {
return Err(ZiporaError::invalid_data(
"Code too long for BMI2 optimization"
));
}
} else {
return Err(ZiporaError::invalid_data(format!(
"Symbol {} not in Huffman tree",
symbol
)));
}
}
Ok(bit_buffer.into_bytes())
}
pub fn tier(&self) -> HuffmanSimdTier {
self.tier
}
pub fn tree(&self) -> &HuffmanTree {
self.base_encoder.tree()
}
pub fn estimate_compression_ratio(&self, data: &[u8]) -> f64 {
self.base_encoder.estimate_compression_ratio(data)
}
}
struct BitBuffer {
bytes: Vec<u8>,
current_byte: u8,
bit_count: usize,
}
impl BitBuffer {
fn with_capacity(estimated_bits: usize) -> Self {
let estimated_bytes = (estimated_bits + 7) / 8;
Self {
bytes: Vec::with_capacity(estimated_bytes),
current_byte: 0,
bit_count: 0,
}
}
fn append_bits(&mut self, bits: u64, length: u8) -> Result<()> {
if length > 64 {
return Err(ZiporaError::invalid_data("Too many bits"));
}
let mut remaining_bits = length as usize;
let mut current_bits = bits;
while remaining_bits > 0 {
let bits_to_add = std::cmp::min(remaining_bits, 8 - self.bit_count);
let mask = (1u64 << bits_to_add) - 1;
let bits_chunk = (current_bits & mask) as u8;
self.current_byte |= bits_chunk << self.bit_count;
self.bit_count += bits_to_add;
remaining_bits -= bits_to_add;
current_bits >>= bits_to_add;
if self.bit_count == 8 {
self.bytes.push(self.current_byte);
self.current_byte = 0;
self.bit_count = 0;
}
}
Ok(())
}
fn into_bytes(mut self) -> Vec<u8> {
if self.bit_count > 0 {
self.bytes.push(self.current_byte);
}
self.bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_huffman_basic() -> Result<()> {
let data = b"aaaaaabbbbbbccccccddddddeeeeeeffffffgggggg";
let encoder = SimdHuffmanEncoder::new(data)?;
let encoded = encoder.encode(data)?;
assert!(!encoded.is_empty());
println!("Encoded {} bytes to {} bytes (tier: {:?})",
data.len(), encoded.len(), encoder.tier());
Ok(())
}
#[test]
fn test_simd_huffman_tiers() -> Result<()> {
let data = b"test data for simd tier testing";
let configs = [
SimdHuffmanConfig {
preferred_tier: HuffmanSimdTier::Avx2Bmi2,
..Default::default()
},
SimdHuffmanConfig {
preferred_tier: HuffmanSimdTier::Avx2,
..Default::default()
},
SimdHuffmanConfig {
preferred_tier: HuffmanSimdTier::Sse42,
..Default::default()
},
SimdHuffmanConfig {
preferred_tier: HuffmanSimdTier::Scalar,
..Default::default()
},
];
for config in &configs {
let encoder = SimdHuffmanEncoder::with_config(data, config.clone())?;
let encoded = encoder.encode(data)?;
assert!(!encoded.is_empty());
println!("Tier {:?}: {} bytes -> {} bytes",
encoder.tier(), data.len(), encoded.len());
}
Ok(())
}
#[test]
fn test_bit_buffer() -> Result<()> {
let mut buffer = BitBuffer::with_capacity(100);
buffer.append_bits(0b1010, 4)?; buffer.append_bits(0b11, 2)?; buffer.append_bits(0b0, 1)?; buffer.append_bits(0b1, 1)?;
let bytes = buffer.into_bytes();
assert_eq!(bytes.len(), 1);
println!("Actual byte: {} = 0b{:08b}", bytes[0], bytes[0]);
assert_eq!(bytes[0], 186);
Ok(())
}
#[test]
fn test_large_data_encoding() -> Result<()> {
let large_data = "This is a test message for large data encoding with SIMD Huffman compression. It has sufficient data volume to trigger large data processing paths in the encoder implementation.".repeat(100);
let large_data = large_data.as_bytes();
let encoder = SimdHuffmanEncoder::new(&large_data)?;
let encoded = encoder.encode(&large_data)?;
assert!(!encoded.is_empty());
println!("Large data: {} bytes -> {} bytes (tier: {:?})",
large_data.len(), encoded.len(), encoder.tier());
Ok(())
}
}