Skip to main content

simd_kernels/
utils.rs

1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Utility Functions** - *SIMD Processing and Memory Management Utilities*
6//!
7//! Core utilities supporting SIMD kernel implementations with efficient memory handling,
8//! bitmask operations, and performance-critical helper functions.
9
10use std::simd::{Mask, MaskElement, SimdElement};
11
12use minarrow::{Bitmask, Vec64};
13
14/// Extracts a core::SIMD `Mask<M, N>` for a batch of N lanes from a Minarrow `Bitmask`.
15///
16/// - `mask_bytes`: packed Arrow validity bits (LSB=index 0, bit=1 means valid)
17/// - `offset`: starting index (bit offset into the mask)
18/// - `logical_len`: number of logical bits in the mask
19/// - `M`: SIMD mask type (e.g., i64 for f64, i32 for f32, i8 for i8)
20///
21/// Returns: SIMD Mask<M, N> representing validity for these N lanes.
22/// Bits outside the logical length (i.e., mask is shorter than offset+N)
23/// are treated as valid.
24#[inline(always)]
25pub fn bitmask_to_simd_mask<const N: usize, M>(
26    mask_bytes: &[u8],
27    offset: usize,
28    logical_len: usize,
29) -> Mask<M, N>
30where
31    M: MaskElement + SimdElement,
32{
33    let lane_limit = (offset + N).min(logical_len);
34    let n_lanes = lane_limit - offset;
35    let mut bits: u64 = 0;
36    for j in 0..n_lanes {
37        let idx = offset + j;
38        let byte = mask_bytes[idx >> 3];
39        if ((byte >> (idx & 7)) & 1) != 0 {
40            bits |= 1u64 << j;
41        }
42    }
43    if n_lanes < N {
44        bits |= !0u64 << n_lanes;
45    }
46    Mask::<M, N>::from_bitmask(bits)
47}
48
49/// Converts a SIMD `Mask<M, N>` to a Minarrow `Bitmask` for the given logical length.
50/// Used at the end of a block operation within SIMD-accelerated kernel functions.
51#[inline(always)]
52pub fn simd_mask_to_bitmask<const N: usize, M>(mask: Mask<M, N>, len: usize) -> Bitmask
53where
54    M: MaskElement + SimdElement,
55{
56    let mut bits = Vec64::with_capacity((len + 7) / 8);
57    bits.resize((len + 7) / 8, 0);
58
59    let word = mask.to_bitmask();
60    let bytes = word.to_le_bytes();
61
62    let n_bytes = (len + 7) / 8;
63    bits[..n_bytes].copy_from_slice(&bytes[..n_bytes]);
64
65    if len % 8 != 0 {
66        let last = n_bytes - 1;
67        let mask_byte = (1u8 << (len % 8)) - 1;
68        bits[last] &= mask_byte;
69    }
70
71    Bitmask {
72        bits: bits.into(),
73        len,
74    }
75}
76
77/// Bulk-ORs a local bitmask block (from a SIMD mask or similar) into the global Minarrow bitmask at the correct byte offset.
78/// The block (`block_mask`) is expected to contain at least ceil(n_lanes/8) bytes,
79/// with the bit-packed validity bits starting from position 0.
80///
81/// Used to streamline repetitive boilerplate and ensure consistency across kernel null-mask handling.
82///
83/// ### Parameters
84/// - `out_mask`: mutable reference to the output/global Bitmask
85/// - `block_mask`: reference to the local Bitmask containing the block's bits
86/// - `offset`: starting bit offset in the global mask
87/// - `n_lanes`: number of bits in this block (usually SIMD lane count)
88#[inline(always)]
89pub fn write_global_bitmask_block(
90    out_mask: &mut Bitmask,
91    block_mask: &Bitmask,
92    offset: usize,
93    n_lanes: usize,
94) {
95    let n_bytes = (n_lanes + 7) / 8;
96    let base = offset / 8;
97    let bit_off = offset % 8;
98
99    if bit_off == 0 {
100        for b in 0..n_bytes {
101            if base + b < out_mask.bits.len() {
102                out_mask.bits[base + b] |= block_mask.bits[b];
103            }
104        }
105    } else {
106        for b in 0..n_bytes {
107            let src = block_mask.bits[b];
108            if base + b < out_mask.bits.len() {
109                out_mask.bits[base + b] |= src << bit_off;
110            }
111            if base + b + 1 < out_mask.bits.len() {
112                out_mask.bits[base + b + 1] |= src >> (8 - bit_off);
113            }
114        }
115    }
116}
117
118/// Determines whether nulls are present given an optional null count and mask reference.
119/// Avoids computing mask cardinality to preserve performance guarantees.
120#[inline(always)]
121pub fn has_nulls(null_count: Option<usize>, mask: Option<&Bitmask>) -> bool {
122    match null_count {
123        Some(n) => n > 0,
124        None => mask.is_some(),
125    }
126}