#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn masked_fill_f32(
input: *const f32,
mask: *const u8,
output: *mut f32,
len: usize,
value: f32,
) {
const LANES: usize = 8;
let chunks = len / LANES;
let fill_vec = _mm256_set1_ps(value);
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm256_loadu_ps(input.add(offset));
let mask_vec = expand_mask_u8_to_f32_avx2(mask.add(offset));
let result = _mm256_blendv_ps(input_vec, fill_vec, mask_vec);
_mm256_storeu_ps(output.add(offset), result);
}
let start = chunks * LANES;
for i in start..len {
*output.add(i) = if *mask.add(i) != 0 {
value
} else {
*input.add(i)
};
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn masked_fill_f64(
input: *const f64,
mask: *const u8,
output: *mut f64,
len: usize,
value: f64,
) {
const LANES: usize = 4;
let chunks = len / LANES;
let fill_vec = _mm256_set1_pd(value);
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm256_loadu_pd(input.add(offset));
let mask_vec = expand_mask_u8_to_f64_avx2(mask.add(offset));
let result = _mm256_blendv_pd(input_vec, fill_vec, mask_vec);
_mm256_storeu_pd(output.add(offset), result);
}
let start = chunks * LANES;
for i in start..len {
*output.add(i) = if *mask.add(i) != 0 {
value
} else {
*input.add(i)
};
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn masked_select_f32(
input: *const f32,
mask: *const u8,
output: *mut f32,
len: usize,
) -> usize {
const LANES: usize = 8;
let chunks = len / LANES;
let mut out_idx = 0;
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm256_loadu_ps(input.add(offset));
let mut mask_bits: u32 = 0;
for j in 0..LANES {
if *mask.add(offset + j) != 0 {
mask_bits |= 1 << j;
}
}
let count = mask_bits.count_ones() as usize;
if count == 0 {
continue;
}
if count == LANES {
_mm256_storeu_ps(output.add(out_idx), input_vec);
} else {
let lo = _mm256_castps256_ps128(input_vec);
let hi = _mm256_extractf128_ps(input_vec, 1);
let lo_mask = (mask_bits & 0x0F) as usize;
let hi_mask = ((mask_bits >> 4) & 0x0F) as usize;
let lo_count = compress_store_f32_128(lo, lo_mask, output.add(out_idx));
compress_store_f32_128(hi, hi_mask, output.add(out_idx + lo_count));
}
out_idx += count;
}
let start = chunks * LANES;
for i in start..len {
if *mask.add(i) != 0 {
*output.add(out_idx) = *input.add(i);
out_idx += 1;
}
}
out_idx
}
#[target_feature(enable = "avx2")]
pub unsafe fn masked_select_f64(
input: *const f64,
mask: *const u8,
output: *mut f64,
len: usize,
) -> usize {
const LANES: usize = 4;
let chunks = len / LANES;
let mut out_idx = 0;
for i in 0..chunks {
let offset = i * LANES;
let input_vec = _mm256_loadu_pd(input.add(offset));
let mut mask_bits: u32 = 0;
for j in 0..LANES {
if *mask.add(offset + j) != 0 {
mask_bits |= 1 << j;
}
}
let count = mask_bits.count_ones() as usize;
if count == 0 {
continue;
}
if count == LANES {
_mm256_storeu_pd(output.add(out_idx), input_vec);
} else {
let lo = _mm256_castpd256_pd128(input_vec);
let hi = _mm256_extractf128_pd(input_vec, 1);
let lo_mask = (mask_bits & 0x03) as usize;
let hi_mask = ((mask_bits >> 2) & 0x03) as usize;
let lo_count = compress_store_f64_128(lo, lo_mask, output.add(out_idx));
compress_store_f64_128(hi, hi_mask, output.add(out_idx + lo_count));
}
out_idx += count;
}
let start = chunks * LANES;
for i in start..len {
if *mask.add(i) != 0 {
*output.add(out_idx) = *input.add(i);
out_idx += 1;
}
}
out_idx
}
#[target_feature(enable = "avx2")]
pub unsafe fn masked_count(mask: *const u8, len: usize) -> usize {
const LANES: usize = 32; let chunks = len / LANES;
let mut count = 0usize;
let zero = _mm256_setzero_si256();
for i in 0..chunks {
let offset = i * LANES;
let mask_vec = _mm256_loadu_si256(mask.add(offset) as *const __m256i);
let cmp = _mm256_cmpeq_epi8(mask_vec, zero);
let zero_mask = _mm256_movemask_epi8(cmp) as u32;
count += (!zero_mask).count_ones() as usize;
}
let start = chunks * LANES;
for i in start..len {
if *mask.add(i) != 0 {
count += 1;
}
}
count
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn expand_mask_u8_to_f32_avx2(mask: *const u8) -> __m256 {
let mask_bytes = _mm_loadl_epi64(mask as *const __m128i);
let mask_i32 = _mm256_cvtepu8_epi32(mask_bytes);
let zero = _mm256_setzero_si256();
let cmp = _mm256_cmpeq_epi32(mask_i32, zero);
let inverted = _mm256_xor_si256(cmp, _mm256_set1_epi32(-1));
_mm256_castsi256_ps(inverted)
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn expand_mask_u8_to_f64_avx2(mask: *const u8) -> __m256d {
let m0 = if *mask.add(0) != 0 { -1i64 } else { 0i64 };
let m1 = if *mask.add(1) != 0 { -1i64 } else { 0i64 };
let m2 = if *mask.add(2) != 0 { -1i64 } else { 0i64 };
let m3 = if *mask.add(3) != 0 { -1i64 } else { 0i64 };
let mask_i64 = _mm256_set_epi64x(m3, m2, m1, m0);
_mm256_castsi256_pd(mask_i64)
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn compress_store_f32_128(vec: __m128, mask: usize, output: *mut f32) -> usize {
static SHUFFLE_LUT: [[u8; 4]; 16] = [
[0xFF, 0xFF, 0xFF, 0xFF], [0, 0xFF, 0xFF, 0xFF], [1, 0xFF, 0xFF, 0xFF], [0, 1, 0xFF, 0xFF], [2, 0xFF, 0xFF, 0xFF], [0, 2, 0xFF, 0xFF], [1, 2, 0xFF, 0xFF], [0, 1, 2, 0xFF], [3, 0xFF, 0xFF, 0xFF], [0, 3, 0xFF, 0xFF], [1, 3, 0xFF, 0xFF], [0, 1, 3, 0xFF], [2, 3, 0xFF, 0xFF], [0, 2, 3, 0xFF], [1, 2, 3, 0xFF], [0, 1, 2, 3], ];
let count = (mask as u32).count_ones() as usize;
if count == 0 {
return 0;
}
let indices = &SHUFFLE_LUT[mask & 0xF];
let arr: [f32; 4] = std::mem::transmute(vec);
for (j, &idx) in indices.iter().take(count).enumerate() {
*output.add(j) = arr[idx as usize];
}
count
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn compress_store_f64_128(vec: __m128d, mask: usize, output: *mut f64) -> usize {
let count = (mask as u32 & 0x3).count_ones() as usize;
if count == 0 {
return 0;
}
let arr: [f64; 2] = std::mem::transmute(vec);
let mut out_idx = 0;
if mask & 1 != 0 {
*output.add(out_idx) = arr[0];
out_idx += 1;
}
if mask & 2 != 0 {
*output.add(out_idx) = arr[1];
}
count
}
#[cfg(test)]
mod tests {
use super::*;
fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[test]
fn test_masked_fill_f32_avx2() {
if !has_avx2() {
return;
}
let input: Vec<f32> = (0..32).map(|i| i as f32).collect();
let mask: Vec<u8> = (0..32).map(|i| if i % 2 == 0 { 1 } else { 0 }).collect();
let mut output = vec![0.0f32; 32];
let fill_value = -1.0f32;
unsafe {
masked_fill_f32(
input.as_ptr(),
mask.as_ptr(),
output.as_mut_ptr(),
32,
fill_value,
);
}
for i in 0..32 {
let expected = if i % 2 == 0 { fill_value } else { i as f32 };
assert_eq!(output[i], expected, "mismatch at index {}", i);
}
}
#[test]
fn test_masked_select_f32_avx2() {
if !has_avx2() {
return;
}
let input: Vec<f32> = (0..32).map(|i| i as f32).collect();
let mask: Vec<u8> = (0..32).map(|i| if i % 3 == 0 { 1 } else { 0 }).collect();
let mut output = vec![0.0f32; 32];
let count =
unsafe { masked_select_f32(input.as_ptr(), mask.as_ptr(), output.as_mut_ptr(), 32) };
assert_eq!(count, 11);
let expected: Vec<f32> = (0..32).filter(|i| i % 3 == 0).map(|i| i as f32).collect();
for (j, &exp) in expected.iter().enumerate() {
assert_eq!(output[j], exp, "mismatch at output index {}", j);
}
}
#[test]
fn test_masked_count_avx2() {
if !has_avx2() {
return;
}
let mask: Vec<u8> = (0..128).map(|i| if i % 7 == 0 { 1 } else { 0 }).collect();
let count = unsafe { masked_count(mask.as_ptr(), 128) };
assert_eq!(count, 19);
}
}