polars-compute 0.53.0

Private compute kernels for the Polars DataFrame library
Documentation
use arrow::bitmap::Bitmap;
use bytemuck::{Pod, cast_slice, cast_vec};
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
use polars_utils::cpuid::is_avx512_enabled;

#[cfg(all(target_arch = "x86_64", feature = "simd"))]
use super::avx512;
use super::boolean::filter_boolean_kernel;
use super::scalar::{scalar_filter, scalar_filter_offset};

type FilterFn<T> = for<'a> unsafe fn(&'a [T], &'a [u8], *mut T) -> (&'a [T], &'a [u8], *mut T);

fn nop_filter<'a, T: Pod>(
    values: &'a [T],
    mask: &'a [u8],
    out: *mut T,
) -> (&'a [T], &'a [u8], *mut T) {
    (values, mask, out)
}

pub fn filter_values<T: Pod>(values: &[T], mask: &Bitmap) -> Vec<T> {
    match (size_of::<T>(), align_of::<T>()) {
        (1, 1) => cast_vec(filter_values_u8(cast_slice(values), mask)),
        (2, 2) => cast_vec(filter_values_u16(cast_slice(values), mask)),
        (4, 4) => cast_vec(filter_values_u32(cast_slice(values), mask)),
        (8, 8) => cast_vec(filter_values_u64(cast_slice(values), mask)),
        _ => filter_values_generic(values, mask, 1, nop_filter),
    }
}

fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec<u8> {
    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
    if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") {
        return filter_values_generic(values, mask, 64, avx512::filter_u8_avx512vbmi2);
    }

    filter_values_generic(values, mask, 1, nop_filter)
}

fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec<u16> {
    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
    if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") {
        return filter_values_generic(values, mask, 32, avx512::filter_u16_avx512vbmi2);
    }

    filter_values_generic(values, mask, 1, nop_filter)
}

fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec<u32> {
    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
    if is_avx512_enabled() {
        return filter_values_generic(values, mask, 16, avx512::filter_u32_avx512f);
    }

    filter_values_generic(values, mask, 1, nop_filter)
}

fn filter_values_u64(values: &[u64], mask: &Bitmap) -> Vec<u64> {
    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
    if is_avx512_enabled() {
        return filter_values_generic(values, mask, 8, avx512::filter_u64_avx512f);
    }

    filter_values_generic(values, mask, 1, nop_filter)
}

fn filter_values_generic<T: Pod>(
    values: &[T],
    mask: &Bitmap,
    pad: usize,
    bulk_filter: FilterFn<T>,
) -> Vec<T> {
    assert_eq!(values.len(), mask.len());
    let mask_bits_set = mask.set_bits();
    let mut out = Vec::with_capacity(mask_bits_set + pad);
    unsafe {
        let (values, mask_bytes, out_ptr) = scalar_filter_offset(values, mask, out.as_mut_ptr());
        let (values, mask_bytes, out_ptr) = bulk_filter(values, mask_bytes, out_ptr);
        scalar_filter(values, mask_bytes, out_ptr);
        out.set_len(mask_bits_set);
    }
    out
}

pub fn filter_values_and_validity<T: Pod>(
    values: &[T],
    validity: Option<&Bitmap>,
    mask: &Bitmap,
) -> (Vec<T>, Option<Bitmap>) {
    (
        filter_values(values, mask),
        validity.map(|v| filter_boolean_kernel(v, mask)),
    )
}