simd_kernels/utils.rs
1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Utility Functions** - *SIMD Processing and Memory Management Utilities*
6//!
7//! Core utilities supporting SIMD kernel implementations with efficient memory handling,
8//! bitmask operations, and performance-critical helper functions.
9
10use std::simd::{Mask, MaskElement, SimdElement};
11
12use minarrow::{Bitmask, Vec64};
13
14/// Extracts a core::SIMD `Mask<M, N>` for a batch of N lanes from a Minarrow `Bitmask`.
15///
16/// - `mask_bytes`: packed Arrow validity bits (LSB=index 0, bit=1 means valid)
17/// - `offset`: starting index (bit offset into the mask)
18/// - `logical_len`: number of logical bits in the mask
19/// - `M`: SIMD mask type (e.g., i64 for f64, i32 for f32, i8 for i8)
20///
21/// Returns: SIMD Mask<M, N> representing validity for these N lanes.
22/// Bits outside the logical length (i.e., mask is shorter than offset+N)
23/// are treated as valid.
24#[inline(always)]
25pub fn bitmask_to_simd_mask<const N: usize, M>(
26 mask_bytes: &[u8],
27 offset: usize,
28 logical_len: usize,
29) -> Mask<M, N>
30where
31 M: MaskElement + SimdElement,
32{
33 let lane_limit = (offset + N).min(logical_len);
34 let n_lanes = lane_limit - offset;
35 let mut bits: u64 = 0;
36 for j in 0..n_lanes {
37 let idx = offset + j;
38 let byte = mask_bytes[idx >> 3];
39 if ((byte >> (idx & 7)) & 1) != 0 {
40 bits |= 1u64 << j;
41 }
42 }
43 if n_lanes < N {
44 bits |= !0u64 << n_lanes;
45 }
46 Mask::<M, N>::from_bitmask(bits)
47}
48
49/// Converts a SIMD `Mask<M, N>` to a Minarrow `Bitmask` for the given logical length.
50/// Used at the end of a block operation within SIMD-accelerated kernel functions.
51#[inline(always)]
52pub fn simd_mask_to_bitmask<const N: usize, M>(mask: Mask<M, N>, len: usize) -> Bitmask
53where
54 M: MaskElement + SimdElement,
55{
56 let mut bits = Vec64::with_capacity((len + 7) / 8);
57 bits.resize((len + 7) / 8, 0);
58
59 let word = mask.to_bitmask();
60 let bytes = word.to_le_bytes();
61
62 let n_bytes = (len + 7) / 8;
63 bits[..n_bytes].copy_from_slice(&bytes[..n_bytes]);
64
65 if len % 8 != 0 {
66 let last = n_bytes - 1;
67 let mask_byte = (1u8 << (len % 8)) - 1;
68 bits[last] &= mask_byte;
69 }
70
71 Bitmask {
72 bits: bits.into(),
73 len,
74 }
75}
76
77/// Bulk-ORs a local bitmask block (from a SIMD mask or similar) into the global Minarrow bitmask at the correct byte offset.
78/// The block (`block_mask`) is expected to contain at least ceil(n_lanes/8) bytes,
79/// with the bit-packed validity bits starting from position 0.
80///
81/// Used to streamline repetitive boilerplate and ensure consistency across kernel null-mask handling.
82///
83/// ### Parameters
84/// - `out_mask`: mutable reference to the output/global Bitmask
85/// - `block_mask`: reference to the local Bitmask containing the block's bits
86/// - `offset`: starting bit offset in the global mask
87/// - `n_lanes`: number of bits in this block (usually SIMD lane count)
88#[inline(always)]
89pub fn write_global_bitmask_block(
90 out_mask: &mut Bitmask,
91 block_mask: &Bitmask,
92 offset: usize,
93 n_lanes: usize,
94) {
95 let n_bytes = (n_lanes + 7) / 8;
96 let base = offset / 8;
97 let bit_off = offset % 8;
98
99 if bit_off == 0 {
100 for b in 0..n_bytes {
101 if base + b < out_mask.bits.len() {
102 out_mask.bits[base + b] |= block_mask.bits[b];
103 }
104 }
105 } else {
106 for b in 0..n_bytes {
107 let src = block_mask.bits[b];
108 if base + b < out_mask.bits.len() {
109 out_mask.bits[base + b] |= src << bit_off;
110 }
111 if base + b + 1 < out_mask.bits.len() {
112 out_mask.bits[base + b + 1] |= src >> (8 - bit_off);
113 }
114 }
115 }
116}
117
118/// Determines whether nulls are present given an optional null count and mask reference.
119/// Avoids computing mask cardinality to preserve performance guarantees.
120#[inline(always)]
121pub fn has_nulls(null_count: Option<usize>, mask: Option<&Bitmask>) -> bool {
122 match null_count {
123 Some(n) => n > 0,
124 None => mask.is_some(),
125 }
126}