use super::Histogram;
use core::ops::{Deref, DerefMut};
#[repr(C)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct Histogram32 {
pub inner: Histogram<u32>,
}
impl Default for Histogram<u32> {
fn default() -> Self {
Histogram { counter: [0; 256] }
}
}
impl Deref for Histogram32 {
type Target = Histogram<u32>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for Histogram32 {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl Histogram32 {
pub fn from_bytes(bytes: &[u8]) -> Self {
let mut histogram = Histogram32::default();
histogram32_from_bytes(bytes, &mut histogram);
histogram
}
}
pub fn histogram32_from_bytes(bytes: &[u8], hist: &mut Histogram32) {
if bytes.len() < 64 {
histogram32_reference(bytes, hist)
} else {
histogram32_generic_batched_unroll_4_u32(bytes, hist)
}
}
pub(crate) fn histogram32_generic_batched_unroll_4_u32(bytes: &[u8], histogram: &mut Histogram32) {
if bytes.is_empty() {
return;
}
unsafe {
let histo_ptr = histogram.inner.counter.as_mut_ptr();
let mut current_ptr = bytes.as_ptr() as *const u32;
let ptr_end = bytes.as_ptr().add(bytes.len());
let ptr_end_unroll = bytes
.as_ptr()
.add(bytes.len() & !(4 * size_of::<u32>() - 1))
as *const u32;
#[cfg(all(target_arch = "x86_64", feature = "std"))]
if std::is_x86_feature_detected!("bmi1") {
if current_ptr < ptr_end_unroll {
process_four_u32_bmi(histo_ptr, &mut current_ptr, ptr_end_unroll);
}
} else if current_ptr < ptr_end_unroll {
process_four_u32_generic(histo_ptr, &mut current_ptr, ptr_end_unroll);
}
#[cfg(all(target_arch = "x86", feature = "nightly", feature = "std"))]
if std::is_x86_feature_detected!("bmi1") {
if current_ptr < ptr_end_unroll {
process_four_u32_bmi(histo_ptr, &mut current_ptr, ptr_end_unroll);
}
} else if current_ptr < ptr_end_unroll {
process_four_u32_generic(histo_ptr, &mut current_ptr, ptr_end_unroll);
}
#[cfg(not(any(
all(target_arch = "x86_64", feature = "std"),
all(target_arch = "x86", feature = "nightly", feature = "std")
)))]
if current_ptr < ptr_end_unroll {
process_four_u32_generic(histo_ptr, &mut current_ptr, ptr_end_unroll);
}
let mut current_ptr = current_ptr as *const u8;
while current_ptr < ptr_end {
let byte = *current_ptr;
current_ptr = current_ptr.add(1);
*histo_ptr.add(byte as usize) += 1;
}
}
}
#[inline(never)]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi1")]
unsafe extern "sysv64" fn process_four_u32_bmi(
histo_ptr: *mut u32,
values_ptr: &mut *const u32,
ptr_end_unroll: *const u32,
) {
core::arch::asm!(
"push rbp",
"2:",
"mov {eax:e}, [{cur_ptr}]", "mov {ebx:e}, [{cur_ptr} + 4]", "mov {ecx:e}, [{cur_ptr} + 8]", "mov {edx:e}, [{cur_ptr} + 12]", "add {cur_ptr}, 16",
"movzx {tmp_e:e}, {eax:l}",
"movzx ebp, {eax:h}",
"inc dword ptr [{hist_ptr} + 4*{tmp_e:r}]",
"bextr {tmp_e:e}, {eax:e}, {bextr_pat:e}",
"shr {eax:e}, 24",
"inc dword ptr [{hist_ptr} + 4*rbp]",
"inc dword ptr [{hist_ptr} + 4*{tmp_e:r}]",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"movzx {eax:e}, {ebx:l}",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"movzx {eax:e}, {ebx:h}",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"bextr {eax:e}, {ebx:e}, {bextr_pat:e}",
"shr {ebx:e}, 24",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"inc dword ptr [{hist_ptr} + 4*{ebx:r}]",
"movzx {eax:e}, {ecx:l}",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"movzx {eax:e}, {ecx:h}",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"bextr {eax:e}, {ecx:e}, {bextr_pat:e}",
"shr {ecx:e}, 24",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"inc dword ptr [{hist_ptr} + 4*{ecx:r}]",
"movzx {eax:e}, {edx:l}",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"movzx {eax:e}, {edx:h}",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"bextr {eax:e}, {edx:e}, {bextr_pat:e}",
"shr {edx:e}, 24",
"inc dword ptr [{hist_ptr} + 4*{eax:r}]",
"inc dword ptr [{hist_ptr} + 4*{edx:r}]",
"cmp {cur_ptr}, {end_ptr}",
"jb 2b",
"pop rbp",
cur_ptr = inout(reg) *values_ptr,
hist_ptr = in(reg) histo_ptr,
end_ptr = in(reg) ptr_end_unroll,
bextr_pat = in(reg) 2064u32,
eax = out(reg_abcd) _,
ebx = out(reg_abcd) _,
ecx = out(reg_abcd) _,
edx = out(reg_abcd) _,
tmp_e = out(reg) _,
options(nostack)
);
}
#[cfg(feature = "nightly")]
#[unsafe(naked)]
#[cfg(target_arch = "x86")]
#[target_feature(enable = "bmi1")]
unsafe extern "stdcall" fn process_four_u32_bmi(
histo_ptr: *mut u32,
values_ptr: &mut *const u32,
ptr_end_unroll: *const u32,
) {
core::arch::naked_asm!(
"push ebp",
"push ebx",
"push edi",
"push esi",
"push eax", "mov eax, dword ptr [esp + 28]", "mov esi, dword ptr [esp + 24]", "mov edx, dword ptr [eax]", ".p2align 4, 0x90",
"2:",
"mov eax, dword ptr [edx]", "mov edi, dword ptr [edx + 12]", "mov ecx, dword ptr [edx + 4]", "mov ebx, dword ptr [edx + 8]", "add edx, 16", "movzx ebp, al", "mov dword ptr [esp], edi", "mov edi, 2064", "inc dword ptr [esi + 4*ebp]", "movzx ebp, ah",
"inc dword ptr [esi + 4*ebp]",
"bextr ebp, eax, edi",
"shr eax, 24",
"inc dword ptr [esi + 4*ebp]",
"inc dword ptr [esi + 4*eax]",
"movzx eax, cl",
"inc dword ptr [esi + 4*eax]",
"movzx eax, ch",
"inc dword ptr [esi + 4*eax]",
"bextr eax, ecx, edi",
"shr ecx, 24",
"inc dword ptr [esi + 4*eax]",
"inc dword ptr [esi + 4*ecx]",
"mov ecx, dword ptr [esp]", "movzx eax, bl",
"inc dword ptr [esi + 4*eax]",
"movzx eax, bh",
"inc dword ptr [esi + 4*eax]",
"bextr eax, ebx, edi",
"shr ebx, 24",
"inc dword ptr [esi + 4*eax]",
"inc dword ptr [esi + 4*ebx]",
"movzx eax, cl",
"inc dword ptr [esi + 4*eax]",
"movzx eax, ch",
"inc dword ptr [esi + 4*eax]",
"bextr eax, ecx, edi",
"shr ecx, 24",
"inc dword ptr [esi + 4*eax]",
"inc dword ptr [esi + 4*ecx]",
"cmp edx, dword ptr [esp + 32]", "jb 2b", "mov eax, dword ptr [esp + 28]", "mov dword ptr [eax], edx", "add esp, 4", "pop esi",
"pop edi",
"pop ebx",
"pop ebp",
"ret 12", );
}
#[inline(never)] unsafe extern "C" fn process_four_u32_generic(
histo_ptr: *mut u32,
values_ptr: &mut *const u32,
ptr_end_unroll: *const u32,
) {
while {
let value1 = values_ptr.read_unaligned();
let value2 = values_ptr.add(1).read_unaligned();
let value3 = values_ptr.add(2).read_unaligned();
let value4 = values_ptr.add(3).read_unaligned();
*histo_ptr.add((value1 & 0xFF) as usize) += 1;
*histo_ptr.add(((value1 >> 8) & 0xFF) as usize) += 1;
*histo_ptr.add(((value1 >> 16) & 0xFF) as usize) += 1;
*histo_ptr.add((value1 >> 24) as usize) += 1;
*histo_ptr.add((value2 & 0xFF) as usize) += 1;
*histo_ptr.add(((value2 >> 8) & 0xFF) as usize) += 1;
*histo_ptr.add(((value2 >> 16) & 0xFF) as usize) += 1;
*histo_ptr.add((value2 >> 24) as usize) += 1;
*histo_ptr.add((value3 & 0xFF) as usize) += 1;
*histo_ptr.add(((value3 >> 8) & 0xFF) as usize) += 1;
*histo_ptr.add(((value3 >> 16) & 0xFF) as usize) += 1;
*histo_ptr.add((value3 >> 24) as usize) += 1;
*histo_ptr.add((value4 & 0xFF) as usize) += 1;
*histo_ptr.add(((value4 >> 8) & 0xFF) as usize) += 1;
*histo_ptr.add(((value4 >> 16) & 0xFF) as usize) += 1;
*histo_ptr.add((value4 >> 24) as usize) += 1;
*values_ptr = values_ptr.add(4);
*values_ptr < ptr_end_unroll
} {}
}
pub(crate) fn histogram32_reference(bytes: &[u8], histogram: &mut Histogram32) {
let histo_ptr = histogram.inner.counter.as_mut_ptr();
let mut current_ptr = bytes.as_ptr();
let ptr_end = unsafe { current_ptr.add(bytes.len()) };
unsafe {
while current_ptr < ptr_end {
let byte = *current_ptr;
current_ptr = current_ptr.add(1);
*histo_ptr.add(byte as usize) += 1;
}
}
}
#[cfg(test)]
mod reference_tests {
use super::*;
use std::vec::Vec;
#[test]
fn verify_full_range_in_reference_impl() {
let input: Vec<u8> = (0..=255).collect();
let mut histogram = Histogram32::default();
histogram32_reference(&input, &mut histogram);
for count in histogram.inner.counter.iter() {
assert_eq!(*count, 1);
}
}
}
#[cfg(test)]
mod alternative_implementation_tests {
use super::*;
use crate::histogram::histogram32_private::*;
use rstest::rstest;
use std::vec::Vec;
fn generate_test_data(size: usize) -> Vec<u8> {
(0..size).map(|i| (i % 256) as u8).collect()
}
#[rstest]
#[case::batched_u32(histogram32_generic_batched_u32)]
#[case::batched_u64(histogram32_generic_batched_u64)]
#[case::batched_unroll2_u32(histogram32_generic_batched_unroll_2_u32)]
#[case::batched_unroll2_u64(histogram32_generic_batched_unroll_2_u64)]
#[case::batched_unroll4_u32(histogram32_generic_batched_unroll_4_u32)]
#[case::batched_unroll4_u64(histogram32_generic_batched_unroll_4_u64)]
#[case::nonaliased_withruns(histogram_nonaliased_withruns_core)]
fn test_against_reference(#[case] implementation: fn(&[u8], &mut Histogram32)) {
for size in 0..=767 {
let test_data = generate_test_data(size);
let mut implementation_result = Histogram32::default();
let mut reference_result = Histogram32::default();
implementation(&test_data, &mut implementation_result);
histogram32_reference(&test_data, &mut reference_result);
assert_eq!(
implementation_result.inner.counter, reference_result.inner.counter,
"Implementation failed for size {size}"
);
}
}
}