#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_fill_f32(
input: *const f32,
mask: *const u8,
output: *mut f32,
len: usize,
value: f32,
) {
let lanes = 4;
let chunks = len / lanes;
let v_value = vdupq_n_f32(value);
for i in 0..chunks {
let idx = i * lanes;
let v_in = vld1q_f32(input.add(idx));
let m0 = if *mask.add(idx) != 0 {
0xFFFFFFFFu32
} else {
0
};
let m1 = if *mask.add(idx + 1) != 0 {
0xFFFFFFFFu32
} else {
0
};
let m2 = if *mask.add(idx + 2) != 0 {
0xFFFFFFFFu32
} else {
0
};
let m3 = if *mask.add(idx + 3) != 0 {
0xFFFFFFFFu32
} else {
0
};
let mask_arr = [m0, m1, m2, m3];
let v_mask = vld1q_u32(mask_arr.as_ptr());
let result = vbslq_f32(v_mask, v_value, v_in);
vst1q_f32(output.add(idx), result);
}
for i in (chunks * lanes)..len {
*output.add(i) = if *mask.add(i) != 0 {
value
} else {
*input.add(i)
};
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_fill_f64(
input: *const f64,
mask: *const u8,
output: *mut f64,
len: usize,
value: f64,
) {
let lanes = 2;
let chunks = len / lanes;
let v_value = vdupq_n_f64(value);
for i in 0..chunks {
let idx = i * lanes;
let v_in = vld1q_f64(input.add(idx));
let m0 = if *mask.add(idx) != 0 {
0xFFFFFFFFFFFFFFFFu64
} else {
0
};
let m1 = if *mask.add(idx + 1) != 0 {
0xFFFFFFFFFFFFFFFFu64
} else {
0
};
let mask_arr = [m0, m1];
let v_mask = vld1q_u64(mask_arr.as_ptr());
let result = vbslq_f64(v_mask, v_value, v_in);
vst1q_f64(output.add(idx), result);
}
for i in (chunks * lanes)..len {
*output.add(i) = if *mask.add(i) != 0 {
value
} else {
*input.add(i)
};
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_select_f32(
input: *const f32,
mask: *const u8,
output: *mut f32,
len: usize,
) -> usize {
let mut out_idx = 0;
for i in 0..len {
if *mask.add(i) != 0 {
*output.add(out_idx) = *input.add(i);
out_idx += 1;
}
}
out_idx
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_select_f64(
input: *const f64,
mask: *const u8,
output: *mut f64,
len: usize,
) -> usize {
let mut out_idx = 0;
for i in 0..len {
if *mask.add(i) != 0 {
*output.add(out_idx) = *input.add(i);
out_idx += 1;
}
}
out_idx
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_count(mask: *const u8, len: usize) -> usize {
let lanes = 16; let chunks = len / lanes;
let mut total_acc = vdupq_n_u8(0);
for i in 0..chunks {
let idx = i * lanes;
let v = vld1q_u8(mask.add(idx));
let zero = vdupq_n_u8(0);
let one = vdupq_n_u8(1);
let cmp = vcgtq_u8(v, zero); let ones = vandq_u8(cmp, one);
total_acc = vaddq_u8(total_acc, ones);
if (i + 1) % 255 == 0 {
let sum16 = vpaddlq_u8(total_acc);
let sum32 = vpaddlq_u16(sum16);
let _sum64 = vpaddlq_u32(sum32);
}
}
let sum16 = vpaddlq_u8(total_acc); let sum32 = vpaddlq_u16(sum16); let sum64 = vpaddlq_u32(sum32);
let mut count = vgetq_lane_u64(sum64, 0) + vgetq_lane_u64(sum64, 1);
for i in (chunks * lanes)..len {
if *mask.add(i) != 0 {
count += 1;
}
}
count as usize
}