#![allow(unsafe_op_in_unsafe_fn)]
use arrow::bitmap::Bitmap;
use bytemuck::Pod;
use polars_utils::slice::load_padded_le_u64;
unsafe fn scalar_sparse_filter64<T: Pod>(v: &[T], mut m: u64, out: *mut T) {
let mut written = 0usize;
while m > 0 {
let idx = m.trailing_zeros() as usize;
*out.add(written) = *v.get_unchecked(idx);
m &= m.wrapping_sub(1); written += 1;
let idx = (m.trailing_zeros() % 64) as usize;
*out.add(written) = *v.get_unchecked(idx);
m &= m.wrapping_sub(1); written += 1;
}
}
unsafe fn scalar_dense_filter64<T: Pod>(v: &[T], mut m: u64, out: *mut T) {
let mut written = 0usize;
let mut src = v.as_ptr();
for _ in 0..16 {
for i in 0..4 {
*out.add(written) = *src;
written += ((m >> i) & 1) as usize;
src = src.add(1);
}
m >>= 4;
}
}
pub unsafe fn scalar_filter_offset<'a, T: Pod>(
values: &'a [T],
mask: &'a Bitmap,
mut out: *mut T,
) -> (&'a [T], &'a [u8], *mut T) {
assert_eq!(values.len(), mask.len());
let (mut mask_bytes, offset, len) = mask.as_slice();
let mut value_idx = 0;
if offset > 0 {
let first_byte = mask_bytes[0];
mask_bytes = &mask_bytes[1..];
for byte_idx in offset..8 {
if value_idx < len {
unsafe {
let bit_is_set = first_byte & (1 << byte_idx) != 0;
*out = *values.get_unchecked(value_idx);
out = out.add(bit_is_set as usize);
}
value_idx += 1;
}
}
}
(&values[value_idx..], mask_bytes, out)
}
pub unsafe fn scalar_filter<T: Pod>(values: &[T], mut mask_bytes: &[u8], mut out: *mut T) {
assert!(mask_bytes.len() * 8 >= values.len());
let mut value_idx = 0;
while value_idx + 64 <= values.len() {
let (mask_chunk, value_chunk);
unsafe {
mask_chunk = mask_bytes.get_unchecked(0..8);
mask_bytes = mask_bytes.get_unchecked(8..);
value_chunk = values.get_unchecked(value_idx..value_idx + 64);
value_idx += 64;
};
let m = u64::from_le_bytes(mask_chunk.try_into().unwrap());
if m == 0 {
continue;
}
unsafe {
if m == u64::MAX {
core::ptr::copy_nonoverlapping(value_chunk.as_ptr(), out, 64);
out = out.add(64);
continue;
}
let m_popcnt = m.count_ones();
if m_popcnt <= 16 {
scalar_sparse_filter64(value_chunk, m, out)
} else {
scalar_dense_filter64(value_chunk, m, out)
};
out = out.add(m_popcnt as usize);
}
}
if value_idx < values.len() {
let rest_len = values.len() - value_idx;
assert!(rest_len < 64);
let m = load_padded_le_u64(mask_bytes) & ((1 << rest_len) - 1);
unsafe {
let value_chunk = values.get_unchecked(value_idx..);
scalar_sparse_filter64(value_chunk, m, out);
}
}
}