lossless_transform_utils/histogram/
histogram32.rs

1//! Implementation of a histogram using 32-bit unsigned integers as counters.
2//!
3//! This module provides an efficient for creating a histogram which supplies you
4//! with the byte occurrences in data. It's particularly optimized for
5//! processing large amounts of data quickly.
6//!
7//! # Key Features
8//!
9//! - Fast histogram generation from byte slices
10//! - Optimized for different input sizes and hardware capabilities
11//! - Supports x86 and x86_64 specific optimizations (with nightly Rust)
12//!
13//! # Main Types
14//!
15//! - [Histogram32]: The primary struct representing a histogram with 32-bit counters.
16//!
17//! # Main Functions
18//!
19//! - [`histogram32_from_bytes`]: Efficiently creates a histogram from a byte slice.
20//!
21//! # Examples
22//!
23//! Basic usage:
24//!
25//! ```
26//! use lossless_transform_utils::histogram::histogram32_from_bytes;
27//! use lossless_transform_utils::histogram::Histogram32;
28//!
29//! let data = [1, 2, 3, 1, 2, 1];
30//! let mut histogram = Histogram32::default();
31//! histogram32_from_bytes(&data, &mut histogram);
32//!
33//! assert_eq!(histogram.inner.counter[1], 3); // Byte value 1 appears 3 times
34//! assert_eq!(histogram.inner.counter[2], 2); // Byte value 2 appears 2 times
35//! assert_eq!(histogram.inner.counter[3], 1); // Byte value 3 appears 1 time
36//! ```
37//!
38//! # Performance Considerations
39//!
40//! The implementations in this module are optimized for different input sizes:
41//!
42//! - Small inputs (< 64 bytes) use a simple, efficient implementation.
43//! - Larger inputs use batched processing with loop unrolling for better performance.
44//! - On x86_64 and x86 platforms (with nightly Rust), BMI1 instructions are utilized if available.
45//!
46//! Not optimized for non-x86 platforms, as I (Sewer) don't own any hardware.
47//!
48//! # Safety
49//!
50//! While some functions in this module use unsafe code internally for performance reasons,
51//! all public interfaces are safe to use from safe Rust code.
52
53use super::Histogram;
54use core::ops::{Deref, DerefMut};
55
56/// Implementation of a histogram using unsigned 32 bit integers as the counter.
57///
58/// Max safe array size to pass is 4,294,967,295, naturally, as a result, though in practice
59/// it can be a bit bigger.
60#[repr(C)]
61#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
62pub struct Histogram32 {
63    pub inner: Histogram<u32>,
64}
65
66impl Default for Histogram<u32> {
67    // Defaults to a zero'd array.
68    fn default() -> Self {
69        Histogram { counter: [0; 256] }
70    }
71}
72
73impl Deref for Histogram32 {
74    type Target = Histogram<u32>;
75
76    fn deref(&self) -> &Self::Target {
77        &self.inner
78    }
79}
80
81impl DerefMut for Histogram32 {
82    fn deref_mut(&mut self) -> &mut Self::Target {
83        &mut self.inner
84    }
85}
86
87impl Histogram32 {
88    /// This is a shortcut for [`histogram32_from_bytes`]
89    pub fn from_bytes(bytes: &[u8]) -> Self {
90        let mut histogram = Histogram32::default();
91        histogram32_from_bytes(bytes, &mut histogram);
92        histogram
93    }
94}
95
96/// Calculates a new histogram given a byte slice.
97///
98/// This function computes a histogram of byte occurrences in the input slice.
99/// It automatically selects the most efficient implementation based on the
100/// input size and available hardware features.
101///
102/// # Performance
103///
104/// - For small inputs (less than 64 bytes), it uses a simple reference implementation.
105/// - For larger inputs, it uses an optimized implementation with batched processing and loop unrolling.
106/// - On x86_64 and x86 (with nightly feature) platforms, it can utilize BMI1 instructions if available.
107///
108/// Not optimized for non-x86 platforms, as I (Sewer) don't own any hardware.
109///
110/// # Arguments
111///
112/// * `bytes` - A slice of bytes to process.
113///
114/// # Returns
115///
116/// Returns a [Histogram32] struct containing the computed histogram.
117/// Each element in the histogram represents the count of occurrences for a byte value (0-255).
118///
119/// # Example
120///
121/// ```
122/// use lossless_transform_utils::histogram::histogram32_from_bytes;
123/// use lossless_transform_utils::histogram::Histogram32;
124///
125/// let data = [1, 2, 3, 1, 2, 1];
126/// let mut histogram = Histogram32::default();
127/// histogram32_from_bytes(&data, &mut histogram);
128///
129/// assert_eq!(histogram.inner.counter[1], 3); // Byte value 1 appears 3 times
130/// assert_eq!(histogram.inner.counter[2], 2); // Byte value 2 appears 2 times
131/// assert_eq!(histogram.inner.counter[3], 1); // Byte value 3 appears 1 time
132/// ```
133///
134/// # Notes
135///
136/// - The function is optimized for different input sizes and hardware capabilities.
137/// - The threshold for switching between implementations (64 bytes) is based on
138///   benchmarks performed on an AMD Ryzen 9 5900X processor. This may vary on different hardware.
139///
140/// # Safety
141///
142/// While this function uses unsafe code internally for performance optimization,
143/// it is safe to call and use from safe Rust code.
144pub fn histogram32_from_bytes(bytes: &[u8], hist: &mut Histogram32) {
145    // Obtained by benching on a 5900X. May vary with different hardware.
146    if bytes.len() < 64 {
147        histogram32_reference(bytes, hist)
148    } else {
149        histogram32_generic_batched_unroll_4_u32(bytes, hist)
150    }
151}
152
153pub(crate) fn histogram32_generic_batched_unroll_4_u32(bytes: &[u8], histogram: &mut Histogram32) {
154    if bytes.is_empty() {
155        return;
156    }
157
158    unsafe {
159        let histo_ptr = histogram.inner.counter.as_mut_ptr();
160        let mut current_ptr = bytes.as_ptr() as *const u32;
161        let ptr_end = bytes.as_ptr().add(bytes.len());
162
163        // We'll read 4 u32 values at a time, so adjust alignment accordingly
164        let ptr_end_unroll = bytes
165            .as_ptr()
166            .add(bytes.len() & !(4 * size_of::<u32>() - 1))
167            as *const u32;
168
169        #[cfg(all(target_arch = "x86_64", feature = "std"))]
170        if std::is_x86_feature_detected!("bmi1") {
171            if current_ptr < ptr_end_unroll {
172                process_four_u32_bmi(histo_ptr, &mut current_ptr, ptr_end_unroll);
173            }
174        } else if current_ptr < ptr_end_unroll {
175            process_four_u32_generic(histo_ptr, &mut current_ptr, ptr_end_unroll);
176        }
177
178        #[cfg(all(target_arch = "x86", feature = "nightly", feature = "std"))]
179        if std::is_x86_feature_detected!("bmi1") {
180            if current_ptr < ptr_end_unroll {
181                process_four_u32_bmi(histo_ptr, &mut current_ptr, ptr_end_unroll);
182            }
183        } else if current_ptr < ptr_end_unroll {
184            process_four_u32_generic(histo_ptr, &mut current_ptr, ptr_end_unroll);
185        }
186
187        #[cfg(not(any(
188            all(target_arch = "x86_64", feature = "std"),
189            all(target_arch = "x86", feature = "nightly", feature = "std")
190        )))]
191        if current_ptr < ptr_end_unroll {
192            process_four_u32_generic(histo_ptr, &mut current_ptr, ptr_end_unroll);
193        }
194
195        // Handle remaining bytes that didn't fit in the unrolled loop
196        let mut current_ptr = current_ptr as *const u8;
197        while current_ptr < ptr_end {
198            let byte = *current_ptr;
199            current_ptr = current_ptr.add(1);
200            *histo_ptr.add(byte as usize) += 1;
201        }
202    }
203}
204
205#[inline(never)]
206#[cfg(target_arch = "x86_64")]
207#[target_feature(enable = "bmi1")]
208unsafe extern "sysv64" fn process_four_u32_bmi(
209    histo_ptr: *mut u32,
210    values_ptr: &mut *const u32,
211    ptr_end_unroll: *const u32,
212) {
213    core::arch::asm!(
214        // Main loop
215        "push rbp",
216        "2:",
217        "mov {eax:e}, [{cur_ptr}]",      // Load first value
218        "mov {ebx:e}, [{cur_ptr} + 4]",  // Load second value
219        "mov {ecx:e}, [{cur_ptr} + 8]",  // Load third value
220        "mov {edx:e}, [{cur_ptr} + 12]", // Load fourth value
221        "add {cur_ptr}, 16",               // Advance pointer by 16 bytes
222
223        // Process first value
224        "movzx {tmp_e:e}, {eax:l}",
225        "movzx ebp, {eax:h}",
226        "inc dword ptr [{hist_ptr} + 4*{tmp_e:r}]",
227        "bextr {tmp_e:e}, {eax:e}, {bextr_pat:e}",
228        "shr {eax:e}, 24",
229        "inc dword ptr [{hist_ptr} + 4*rbp]",
230        "inc dword ptr [{hist_ptr} + 4*{tmp_e:r}]",
231        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
232
233        // Process second value
234        "movzx {eax:e}, {ebx:l}",
235        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
236        "movzx {eax:e}, {ebx:h}",
237        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
238        "bextr {eax:e}, {ebx:e}, {bextr_pat:e}",
239        "shr {ebx:e}, 24",
240        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
241        "inc dword ptr [{hist_ptr} + 4*{ebx:r}]",
242
243        // Process third value
244        "movzx {eax:e}, {ecx:l}",
245        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
246        "movzx {eax:e}, {ecx:h}",
247        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
248        "bextr {eax:e}, {ecx:e}, {bextr_pat:e}",
249        "shr {ecx:e}, 24",
250        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
251        "inc dword ptr [{hist_ptr} + 4*{ecx:r}]",
252
253        // Process fourth value
254        "movzx {eax:e}, {edx:l}",
255        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
256        "movzx {eax:e}, {edx:h}",
257        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
258        "bextr {eax:e}, {edx:e}, {bextr_pat:e}",
259        "shr {edx:e}, 24",
260        "inc dword ptr [{hist_ptr} + 4*{eax:r}]",
261        "inc dword ptr [{hist_ptr} + 4*{edx:r}]",
262
263        // Loop condition
264        "cmp {cur_ptr}, {end_ptr}",
265        "jb 2b",
266        "pop rbp",
267
268        cur_ptr = inout(reg) *values_ptr,
269        hist_ptr = in(reg) histo_ptr,
270        end_ptr = in(reg) ptr_end_unroll,
271        bextr_pat = in(reg) 2064u32,
272        eax = out(reg_abcd) _,
273        ebx = out(reg_abcd) _,
274        ecx = out(reg_abcd) _,
275        edx = out(reg_abcd) _,
276        tmp_e = out(reg) _,
277        options(nostack)
278    );
279}
280
281#[cfg(feature = "nightly")]
282#[naked]
283#[cfg(target_arch = "x86")]
284#[target_feature(enable = "bmi1")]
285/// From a i686 linux machine with native zen3 target.
286unsafe extern "stdcall" fn process_four_u32_bmi(
287    histo_ptr: *mut u32,
288    values_ptr: &mut *const u32,
289    ptr_end_unroll: *const u32,
290) {
291    core::arch::naked_asm!(
292        // Prologue - save registers
293        "push ebp",
294        "push ebx",
295        "push edi",
296        "push esi",
297        "push eax", // Extra push for temporary storage
298        // Initial setup - load pointers
299        "mov eax, dword ptr [esp + 28]", // Load values_ptr
300        "mov esi, dword ptr [esp + 24]", // Load histo_ptr
301        "mov edx, dword ptr [eax]",      // Load current pointer value
302        // Ensure 16-byte alignment for the loop
303        ".p2align 4, 0x90",
304        // Main processing loop
305        "2:",
306        // Load four 32-bit values
307        "mov eax, dword ptr [edx]",      // Load first value
308        "mov edi, dword ptr [edx + 12]", // Load fourth value
309        "mov ecx, dword ptr [edx + 4]",  // Load second value
310        "mov ebx, dword ptr [edx + 8]",  // Load third value
311        "add edx, 16",                   // Advance pointer
312        // Process first value (in eax)
313        "movzx ebp, al",               // Extract low byte
314        "mov dword ptr [esp], edi",    // Save fourth value
315        "mov edi, 2064",               // bextr pattern
316        "inc dword ptr [esi + 4*ebp]", // Update histogram
317        "movzx ebp, ah",
318        "inc dword ptr [esi + 4*ebp]",
319        "bextr ebp, eax, edi",
320        "shr eax, 24",
321        "inc dword ptr [esi + 4*ebp]",
322        "inc dword ptr [esi + 4*eax]",
323        // Process second value (in ecx)
324        "movzx eax, cl",
325        "inc dword ptr [esi + 4*eax]",
326        "movzx eax, ch",
327        "inc dword ptr [esi + 4*eax]",
328        "bextr eax, ecx, edi",
329        "shr ecx, 24",
330        "inc dword ptr [esi + 4*eax]",
331        "inc dword ptr [esi + 4*ecx]",
332        // Process third value (in ebx)
333        "mov ecx, dword ptr [esp]", // Restore fourth value
334        "movzx eax, bl",
335        "inc dword ptr [esi + 4*eax]",
336        "movzx eax, bh",
337        "inc dword ptr [esi + 4*eax]",
338        "bextr eax, ebx, edi",
339        "shr ebx, 24",
340        "inc dword ptr [esi + 4*eax]",
341        "inc dword ptr [esi + 4*ebx]",
342        // Process fourth value (in ecx)
343        "movzx eax, cl",
344        "inc dword ptr [esi + 4*eax]",
345        "movzx eax, ch",
346        "inc dword ptr [esi + 4*eax]",
347        "bextr eax, ecx, edi",
348        "shr ecx, 24",
349        "inc dword ptr [esi + 4*eax]",
350        "inc dword ptr [esi + 4*ecx]",
351        // Loop control
352        "cmp edx, dword ptr [esp + 32]", // Compare with end pointer
353        "jb 2b",                         // Loop if not at end
354        // Store final pointer
355        "mov eax, dword ptr [esp + 28]", // Load values_ptr
356        "mov dword ptr [eax], edx",      // Store back final position
357        // Epilogue - restore registers and return
358        "add esp, 4", // Clean up temporary storage
359        "pop esi",
360        "pop edi",
361        "pop ebx",
362        "pop ebp",
363        "ret 12", // stdcall return - clean up 12 bytes (3 params * 4 bytes)
364    );
365}
366
367#[inline(never)]
368unsafe extern "cdecl" fn process_four_u32_generic(
369    histo_ptr: *mut u32,
370    values_ptr: &mut *const u32,
371    ptr_end_unroll: *const u32,
372) {
373    while {
374        // Read four 32-bit values at once
375        let value1 = **values_ptr;
376        let value2 = *values_ptr.add(1);
377        let value3 = *values_ptr.add(2);
378        let value4 = *values_ptr.add(3);
379
380        // Process first value
381        *histo_ptr.add((value1 & 0xFF) as usize) += 1;
382        *histo_ptr.add(((value1 >> 8) & 0xFF) as usize) += 1;
383        *histo_ptr.add(((value1 >> 16) & 0xFF) as usize) += 1;
384        *histo_ptr.add((value1 >> 24) as usize) += 1;
385
386        // Process second value
387        *histo_ptr.add((value2 & 0xFF) as usize) += 1;
388        *histo_ptr.add(((value2 >> 8) & 0xFF) as usize) += 1;
389        *histo_ptr.add(((value2 >> 16) & 0xFF) as usize) += 1;
390        *histo_ptr.add((value2 >> 24) as usize) += 1;
391
392        // Process third value
393        *histo_ptr.add((value3 & 0xFF) as usize) += 1;
394        *histo_ptr.add(((value3 >> 8) & 0xFF) as usize) += 1;
395        *histo_ptr.add(((value3 >> 16) & 0xFF) as usize) += 1;
396        *histo_ptr.add((value3 >> 24) as usize) += 1;
397
398        // Process fourth value
399        *histo_ptr.add((value4 & 0xFF) as usize) += 1;
400        *histo_ptr.add(((value4 >> 8) & 0xFF) as usize) += 1;
401        *histo_ptr.add(((value4 >> 16) & 0xFF) as usize) += 1;
402        *histo_ptr.add((value4 >> 24) as usize) += 1;
403
404        *values_ptr = values_ptr.add(4);
405        *values_ptr < ptr_end_unroll
406    } {}
407}
408
409/// Generic, slower version of [`Histogram32`] generation that doesn't assume anything.
410/// This is the Rust fallback, reference implementation to run other tests against.
411pub(crate) fn histogram32_reference(bytes: &[u8], histogram: &mut Histogram32) {
412    let histo_ptr = histogram.inner.counter.as_mut_ptr();
413    let mut current_ptr = bytes.as_ptr();
414    let ptr_end = unsafe { current_ptr.add(bytes.len()) };
415
416    // Unroll the loop by fetching `usize` elements at once, then doing a shift.
417    // Although there is a data dependency in the shift.
418    unsafe {
419        while current_ptr < ptr_end {
420            let byte = *current_ptr;
421            current_ptr = current_ptr.add(1);
422            *histo_ptr.add(byte as usize) += 1;
423        }
424    }
425}
426
427#[cfg(test)]
428mod reference_tests {
429    use super::*;
430    use std::vec::Vec;
431
432    // Creates bytes 0..255, to verify we reach the full range.
433    // This should be sufficient for unrolled impl.
434    #[test]
435    fn verify_full_range_in_reference_impl() {
436        let input: Vec<u8> = (0..=255).collect();
437        let mut histogram = Histogram32::default();
438        histogram32_reference(&input, &mut histogram);
439
440        // Every value should appear exactly once
441        for count in histogram.inner.counter.iter() {
442            assert_eq!(*count, 1);
443        }
444    }
445}
446
447#[cfg(test)]
448mod alternative_implementation_tests {
449    use super::*;
450    use crate::histogram::histogram32_private::*;
451    use rstest::rstest;
452    use std::vec::Vec;
453
454    // Helper function to generate test data
455    fn generate_test_data(size: usize) -> Vec<u8> {
456        (0..size).map(|i| (i % 256) as u8).collect()
457    }
458
459    #[rstest]
460    #[case::batched_u32(histogram32_generic_batched_u32)]
461    #[case::batched_u64(histogram32_generic_batched_u64)]
462    #[case::batched_unroll2_u32(histogram32_generic_batched_unroll_2_u32)]
463    #[case::batched_unroll2_u64(histogram32_generic_batched_unroll_2_u64)]
464    #[case::batched_unroll4_u32(histogram32_generic_batched_unroll_4_u32)]
465    #[case::batched_unroll4_u64(histogram32_generic_batched_unroll_4_u64)]
466    #[case::nonaliased_withruns(histogram_nonaliased_withruns_core)]
467    fn test_against_reference(#[case] implementation: fn(&[u8], &mut Histogram32)) {
468        // Test sizes from 0 to 767 bytes
469        for size in 0..=767 {
470            let test_data = generate_test_data(size);
471
472            // Get results from both implementations
473            let mut implementation_result = Histogram32::default();
474            let mut reference_result = Histogram32::default();
475            implementation(&test_data, &mut implementation_result);
476            histogram32_reference(&test_data, &mut reference_result);
477
478            assert_eq!(
479                implementation_result.inner.counter, reference_result.inner.counter,
480                "Implementation failed for size {}",
481                size
482            );
483        }
484    }
485}