simd_kernels/
utils.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Utility Functions** - *SIMD Processing and Memory Management Utilities*
5//!
6//! Core utilities supporting SIMD kernel implementations with efficient memory handling,
7//! bitmask operations, and performance-critical helper functions.
8
9#[cfg(feature = "str_arithmetic")]
10use std::mem::MaybeUninit;
11use std::{
12    collections::HashSet,
13    simd::{LaneCount, Mask, MaskElement, SimdElement, SupportedLaneCount},
14};
15
16use minarrow::{Bitmask, CategoricalArray, Integer, MaskedArray, StringArray, Vec64};
17#[cfg(feature = "str_arithmetic")]
18use ryu::Float;
19
20use crate::errors::KernelError;
21
22/// Extracts a core::SIMD `Mask<M, N>` for a batch of N lanes from a Minarrow `Bitmask`.
23///
24/// - `mask_bytes`: packed Arrow validity bits (LSB=index 0, bit=1 means valid)
25/// - `offset`: starting index (bit offset into the mask)
26/// - `logical_len`: number of logical bits in the mask
27/// - `M`: SIMD mask type (e.g., i64 for f64, i32 for f32, i8 for i8)
28///
29/// Returns: SIMD Mask<M, N> representing validity for these N lanes.
30/// Bits outside the logical length (i.e., mask is shorter than offset+N)
31/// are treated as valid.
32#[inline(always)]
33pub fn bitmask_to_simd_mask<const N: usize, M>(
34    mask_bytes: &[u8],
35    offset: usize,
36    logical_len: usize,
37) -> Mask<M, N>
38where
39    LaneCount<N>: SupportedLaneCount,
40    M: MaskElement + SimdElement,
41{
42    let lane_limit = (offset + N).min(logical_len);
43    let n_lanes = lane_limit - offset;
44    let mut bits: u64 = 0;
45    for j in 0..n_lanes {
46        let idx = offset + j;
47        let byte = mask_bytes[idx >> 3];
48        if ((byte >> (idx & 7)) & 1) != 0 {
49            bits |= 1u64 << j;
50        }
51    }
52    if n_lanes < N {
53        bits |= !0u64 << n_lanes;
54    }
55    Mask::<M, N>::from_bitmask(bits)
56}
57
58/// Converts a SIMD `Mask<M, N>` to a Minarrow `Bitmask` for the given logical length.
59/// Used at the end of a block operation within SIMD-accelerated kernel functions.
60#[inline(always)]
61pub fn simd_mask_to_bitmask<const N: usize, M>(mask: Mask<M, N>, len: usize) -> Bitmask
62where
63    LaneCount<N>: SupportedLaneCount,
64    M: MaskElement + SimdElement,
65{
66    let mut bits = Vec64::with_capacity((len + 7) / 8);
67    bits.resize((len + 7) / 8, 0);
68
69    let word = mask.to_bitmask();
70    let bytes = word.to_le_bytes();
71
72    let n_bytes = (len + 7) / 8;
73    bits[..n_bytes].copy_from_slice(&bytes[..n_bytes]);
74
75    if len % 8 != 0 {
76        let last = n_bytes - 1;
77        let mask_byte = (1u8 << (len % 8)) - 1;
78        bits[last] &= mask_byte;
79    }
80
81    Bitmask {
82        bits: bits.into(),
83        len,
84    }
85}
86
87/// Bulk-ORs a local bitmask block (from a SIMD mask or similar) into the global Minarrow bitmask at the correct byte offset.
88/// The block (`block_mask`) is expected to contain at least ceil(n_lanes/8) bytes,
89/// with the bit-packed validity bits starting from position 0.
90///
91/// Used to streamline repetitive boilerplate and ensure consistency across kernel null-mask handling.
92///
93/// ### Parameters
94/// - `out_mask`: mutable reference to the output/global Bitmask
95/// - `block_mask`: reference to the local Bitmask containing the block's bits
96/// - `offset`: starting bit offset in the global mask
97/// - `n_lanes`: number of bits in this block (usually SIMD lane count)
98#[inline(always)]
99pub fn write_global_bitmask_block(
100    out_mask: &mut Bitmask,
101    block_mask: &Bitmask,
102    offset: usize,
103    n_lanes: usize,
104) {
105    let n_bytes = (n_lanes + 7) / 8;
106    let base = offset / 8;
107    let block_bytes = &block_mask.bits[..n_bytes];
108    for b in 0..n_bytes {
109        if base + b < out_mask.bits.len() {
110            out_mask.bits[base + b] |= block_bytes[b];
111        }
112    }
113}
114
115/// Determines whether nulls are present given an optional null count and mask reference.
116/// Avoids computing mask cardinality to preserve performance guarantees.
117#[inline(always)]
118pub fn has_nulls(null_count: Option<usize>, mask: Option<&Bitmask>) -> bool {
119    match null_count {
120        Some(n) => n > 0,
121        None => mask.is_some(),
122    }
123}
124
125/// Creates a SIMD mask from a bitmask window for vectorised conditional operations.
126/// 
127/// Converts a contiguous section of a bitmask into a SIMD mask. 
128/// The resulting mask can be used to selectively enable/disable SIMD lanes during
129/// computation, providing efficient support for sparse or conditional operations.
130/// 
131/// # Type Parameters
132/// - `T`: Mask element type implementing `MaskElement` (typically i8, i16, i32, or i64)
133/// - `N`: Number of SIMD lanes, must match the SIMD vector width for the target operation
134/// 
135/// # Parameters
136/// - `mask`: Source bitmask containing validity information
137/// - `offset`: Starting bit offset within the bitmask
138/// - `len`: Maximum number of bits to consider (bounds checking)
139/// 
140/// # Returns
141/// A `Mask<T, N>` where each lane corresponds to the validity of the corresponding input element.
142/// Lanes beyond `len` are set to false for safety.
143/// 
144/// # Usage Example
145/// ```rust,ignore
146/// use simd_kernels::utils::simd_mask;
147/// 
148/// // Create 8-lane mask for conditional SIMD operations  
149/// let mask: Mask<i32, 8> = simd_mask(&bitmask, 0, 64);
150/// let result = simd_vector.select(mask, default_vector);
151/// ```
152#[inline(always)]
153pub fn simd_mask<T: MaskElement, const N: usize>(
154    mask: &Bitmask,
155    offset: usize,
156    len: usize,
157) -> Mask<T, N>
158where
159    LaneCount<N>: SupportedLaneCount,
160{
161    let mut bits = [false; N];
162    for l in 0..N {
163        let idx = offset + l;
164        bits[l] = idx < len && unsafe { mask.get_unchecked(idx) };
165    }
166    Mask::from_array(bits)
167}
168
169/// Merge two optional Bitmasks into a new output mask, computing per-row AND.
170/// Returns None if both inputs are None (output is dense).
171#[inline]
172pub fn merge_bitmasks_to_new(
173    lhs: Option<&Bitmask>,
174    rhs: Option<&Bitmask>,
175    len: usize,
176) -> Option<Bitmask> {
177    match (lhs, rhs) {
178        (None, None) => None,
179        (Some(l), None) | (None, Some(l)) => {
180            debug_assert!(l.len() >= len, "Bitmask too short in merge");
181            let mut out = Bitmask::new_set_all(len, true);
182            for i in 0..len {
183                out.set(i, l.get(i));
184            }
185            Some(out)
186        }
187        (Some(l), Some(r)) => {
188            debug_assert!(l.len() >= len, "Left Bitmask too short in merge");
189            debug_assert!(r.len() >= len, "Right Bitmask too short in merge");
190            let mut out = Bitmask::new_set_all(len, true);
191            for i in 0..len {
192                out.set(i, l.get(i) && r.get(i));
193            }
194            Some(out)
195        }
196    }
197}
198
199/// Checks the mask capacity is large enough
200/// Used so we can avoid bounds checks in the hot loop
201#[inline(always)]
202pub fn confirm_mask_capacity(cmp_len: usize, mask: Option<&Bitmask>) -> Result<(), KernelError> {
203    if let Some(m) = mask {
204        confirm_capacity("mask (Bitmask)", m.capacity(), cmp_len)?;
205    }
206    Ok(())
207}
208
209/// Strips '.0' from concatenated decimal values so 'Hello1.0' becomes 'Hello1'.
210#[inline]
211#[cfg(feature = "str_arithmetic")]
212pub fn format_finite<F: Float>(buf: &mut [MaybeUninit<u8>; 24], f: F) -> &str {
213    unsafe {
214        let ptr = buf.as_mut_ptr() as *mut u8;
215        let n = f.write_to_ryu_buffer(ptr);
216        debug_assert!(n <= buf.len());
217
218        let slice = core::slice::from_raw_parts(ptr, n);
219        let s = core::str::from_utf8_unchecked(slice);
220
221        // Strip trailing ".0" if present
222        if s.ends_with(".0") {
223            let trimmed_len = s.len() - 2;
224            core::str::from_utf8_unchecked(&slice[..trimmed_len])
225        } else {
226            s
227        }
228    }
229}
230
231/// Estimate cardinality ratio on a sample from a CategoricalArray.
232/// Used to quickly figure out the optimal strategy when comparing
233/// StringArray and CategoricalArrays.
234#[inline(always)]
235pub fn estimate_categorical_cardinality(cat: &CategoricalArray<u32>, sample_size: usize) -> f64 {
236    let len = cat.data.len();
237    if len == 0 {
238        return 0.0;
239    }
240    let mut seen = HashSet::with_capacity(sample_size.min(len));
241    let step = (len / sample_size.max(1)).max(1);
242    for i in (0..len).step_by(step) {
243        let s = unsafe { cat.get_str_unchecked(i) };
244        seen.insert(s);
245        if seen.len() >= sample_size {
246            break;
247        }
248    }
249    (seen.len() as f64) / (sample_size.min(len) as f64)
250}
251
252/// Estimate cardinality ratio on a sample from a StringArray.
253/// Used to quickly figure out the optimal strategy when comparing
254/// StringArray and CategoricalArrays.
255#[inline(always)]
256pub fn estimate_string_cardinality<T: Integer>(arr: &StringArray<T>, sample_size: usize) -> f64 {
257    let len = arr.len();
258    if len == 0 {
259        return 0.0;
260    }
261    let mut seen = HashSet::with_capacity(sample_size.min(len));
262    let step = (len / sample_size.max(1)).max(1);
263    for i in (0..len).step_by(step) {
264        let s = unsafe { arr.get_str_unchecked(i) };
265        seen.insert(s);
266        if seen.len() >= sample_size {
267            break;
268        }
269    }
270    (seen.len() as f64) / (sample_size.min(len) as f64)
271}
272
273/// Validates that actual capacity matches expected capacity for kernel operations.
274/// 
275/// Essential validation function used throughout the kernel library to ensure data structure
276/// capacities are correct before performing operations. Prevents buffer overruns and ensures
277/// memory safety by catching capacity mismatches early with descriptive error messages.
278/// 
279/// # Parameters
280/// - `label`: Descriptive label for the validation context (used in error messages)
281/// - `actual`: The actual capacity of the data structure being validated
282/// - `expected`: The expected capacity required for the operation
283/// 
284/// # Returns
285/// `Ok(())` if capacities match, otherwise `KernelError::InvalidArguments` with detailed message.
286/// 
287/// # Error Conditions
288/// Returns `KernelError::InvalidArguments` when `actual != expected`, providing a clear
289/// error message indicating the mismatch and context.
290#[inline(always)]
291pub fn confirm_capacity(label: &str, actual: usize, expected: usize) -> Result<(), KernelError> {
292    if actual != expected {
293        return Err(KernelError::InvalidArguments(format!(
294            "{}: capacity mismatch (expected {}, got {})",
295            label, expected, actual
296        )));
297    }
298    Ok(())
299}
300
301/// Validates that two lengths are equal for binary kernel operations.
302/// 
303/// Critical validation function ensuring input arrays have matching lengths before performing
304/// binary operations like comparisons, arithmetic, or logical operations. Prevents undefined
305/// behaviour and provides clear error diagnostics when length mismatches occur.
306/// 
307/// # Parameters
308/// - `label`: Descriptive context label for error reporting (e.g., "compare numeric")
309/// - `a`: Length of the first input array or data structure
310/// - `b`: Length of the second input array or data structure
311/// 
312/// # Returns
313/// `Ok(())` if lengths are equal, otherwise `KernelError::LengthMismatch` with diagnostic details.
314#[inline(always)]
315pub fn confirm_equal_len(label: &str, a: usize, b: usize) -> Result<(), KernelError> {
316    if a != b {
317        return Err(KernelError::LengthMismatch(format!(
318            "{}: length mismatch (lhs: {}, rhs: {})",
319            label, a, b
320        )));
321    }
322    Ok(())
323}
324
325/// SIMD Alignment check. Returns true if the slice is properly
326/// 64-byte aligned for SIMD operations, false otherwise.
327#[inline(always)]
328pub fn is_simd_aligned<T>(slice: &[T]) -> bool {
329    if slice.is_empty() {
330        true
331    } else {
332        (slice.as_ptr() as usize) % 64 == 0
333    }
334}