#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[target_feature(enable = "avx512f")]
pub unsafe fn masked_fill_f32(
input: *const f32,
mask: *const u8,
output: *mut f32,
len: usize,
value: f32,
) {
const LANES: usize = 16;
let chunks = len / LANES;
let fill_vec = _mm512_set1_ps(value);
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm512_loadu_ps(input.add(offset));
let k = build_mask16(mask.add(offset));
let result = _mm512_mask_blend_ps(k, input_vec, fill_vec);
_mm512_storeu_ps(output.add(offset), result);
}
let start = chunks * LANES;
for i in start..len {
*output.add(i) = if *mask.add(i) != 0 {
value
} else {
*input.add(i)
};
}
}
#[target_feature(enable = "avx512f")]
pub unsafe fn masked_fill_f64(
input: *const f64,
mask: *const u8,
output: *mut f64,
len: usize,
value: f64,
) {
const LANES: usize = 8;
let chunks = len / LANES;
let fill_vec = _mm512_set1_pd(value);
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm512_loadu_pd(input.add(offset));
let k = build_mask8(mask.add(offset));
let result = _mm512_mask_blend_pd(k, input_vec, fill_vec);
_mm512_storeu_pd(output.add(offset), result);
}
let start = chunks * LANES;
for i in start..len {
*output.add(i) = if *mask.add(i) != 0 {
value
} else {
*input.add(i)
};
}
}
#[target_feature(enable = "avx512f")]
pub unsafe fn masked_select_f32(
input: *const f32,
mask: *const u8,
output: *mut f32,
len: usize,
) -> usize {
const LANES: usize = 16;
let chunks = len / LANES;
let mut out_idx = 0;
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm512_loadu_ps(input.add(offset));
let k = build_mask16(mask.add(offset));
let count = (k as u32).count_ones() as usize;
if count > 0 {
_mm512_mask_compressstoreu_ps(output.add(out_idx) as *mut _, k, input_vec);
out_idx += count;
}
}
let start = chunks * LANES;
for i in start..len {
if *mask.add(i) != 0 {
*output.add(out_idx) = *input.add(i);
out_idx += 1;
}
}
out_idx
}
#[target_feature(enable = "avx512f")]
pub unsafe fn masked_select_f64(
input: *const f64,
mask: *const u8,
output: *mut f64,
len: usize,
) -> usize {
const LANES: usize = 8;
let chunks = len / LANES;
let mut out_idx = 0;
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm512_loadu_pd(input.add(offset));
let k = build_mask8(mask.add(offset));
let count = (k as u32).count_ones() as usize;
if count > 0 {
_mm512_mask_compressstoreu_pd(output.add(out_idx) as *mut _, k, input_vec);
out_idx += count;
}
}
let start = chunks * LANES;
for i in start..len {
if *mask.add(i) != 0 {
*output.add(out_idx) = *input.add(i);
out_idx += 1;
}
}
out_idx
}
#[target_feature(enable = "avx512f", enable = "avx512bw")]
pub unsafe fn masked_count(mask: *const u8, len: usize) -> usize {
const LANES: usize = 64; let chunks = len / LANES;
let mut count = 0usize;
let zero = _mm512_setzero_si512();
for i in 0..chunks {
let offset = i * LANES;
let mask_vec = _mm512_loadu_si512(mask.add(offset) as *const _);
let k = _mm512_cmpneq_epi8_mask(mask_vec, zero);
count += (k as u64).count_ones() as usize;
}
let start = chunks * LANES;
for i in start..len {
if *mask.add(i) != 0 {
count += 1;
}
}
count
}
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn build_mask16(mask: *const u8) -> __mmask16 {
let mut k: u16 = 0;
for j in 0..16 {
if *mask.add(j) != 0 {
k |= 1 << j;
}
}
k
}
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn build_mask8(mask: *const u8) -> __mmask8 {
let mut k: u8 = 0;
for j in 0..8 {
if *mask.add(j) != 0 {
k |= 1 << j;
}
}
k
}
#[cfg(test)]
mod tests {
use super::*;
fn has_avx512() -> bool {
is_x86_feature_detected!("avx512f")
}
fn has_avx512bw() -> bool {
is_x86_feature_detected!("avx512bw")
}
#[test]
fn test_masked_fill_f32_avx512() {
if !has_avx512() {
return;
}
let input: Vec<f32> = (0..64).map(|i| i as f32).collect();
let mask: Vec<u8> = (0..64).map(|i| if i % 2 == 0 { 1 } else { 0 }).collect();
let mut output = vec![0.0f32; 64];
let fill_value = -1.0f32;
unsafe {
masked_fill_f32(
input.as_ptr(),
mask.as_ptr(),
output.as_mut_ptr(),
64,
fill_value,
);
}
for i in 0..64 {
let expected = if i % 2 == 0 { fill_value } else { i as f32 };
assert_eq!(output[i], expected, "mismatch at index {}", i);
}
}
#[test]
fn test_masked_select_f32_avx512() {
if !has_avx512() {
return;
}
let input: Vec<f32> = (0..64).map(|i| i as f32).collect();
let mask: Vec<u8> = (0..64).map(|i| if i % 3 == 0 { 1 } else { 0 }).collect();
let mut output = vec![0.0f32; 64];
let count =
unsafe { masked_select_f32(input.as_ptr(), mask.as_ptr(), output.as_mut_ptr(), 64) };
let expected_count = (0..64).filter(|i| i % 3 == 0).count();
assert_eq!(count, expected_count);
let expected: Vec<f32> = (0..64).filter(|i| i % 3 == 0).map(|i| i as f32).collect();
for (j, &exp) in expected.iter().enumerate() {
assert_eq!(output[j], exp, "mismatch at output index {}", j);
}
}
#[test]
fn test_masked_count_avx512() {
if !has_avx512() || !has_avx512bw() {
return;
}
let mask: Vec<u8> = (0..256).map(|i| if i % 5 == 0 { 1 } else { 0 }).collect();
let count = unsafe { masked_count(mask.as_ptr(), 256) };
let expected = (0..256).filter(|i| i % 5 == 0).count();
assert_eq!(count, expected);
}
}