numr 0.5.1

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
//! NEON index operation kernels for ARM64
//!
//! Provides vectorized masked_fill, masked_select, and masked_count
//! operations using 128-bit NEON registers.
//!
//! # SIMD Strategy
//!
//! - masked_fill: Process 4 f32 (or 2 f64) at a time with vbslq blending
//! - masked_select: Stream compaction using scalar loop (gather not available)
//! - masked_count: Vectorized popcount using horizontal add

#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;

// ============================================================================
// Masked Fill
// ============================================================================

/// NEON masked fill for f32
///
/// Fills output with `value` where mask is true, otherwise copies from input.
///
/// # Safety
/// - All pointers must be valid for `len` elements
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_fill_f32(
    input: *const f32,
    mask: *const u8,
    output: *mut f32,
    len: usize,
    value: f32,
) {
    let lanes = 4;
    let chunks = len / lanes;

    let v_value = vdupq_n_f32(value);

    for i in 0..chunks {
        let idx = i * lanes;

        // Load input values
        let v_in = vld1q_f32(input.add(idx));

        // Load and expand mask bytes to u32 mask
        // Each byte becomes 0x00000000 (false) or 0xFFFFFFFF (true)
        let m0 = if *mask.add(idx) != 0 {
            0xFFFFFFFFu32
        } else {
            0
        };
        let m1 = if *mask.add(idx + 1) != 0 {
            0xFFFFFFFFu32
        } else {
            0
        };
        let m2 = if *mask.add(idx + 2) != 0 {
            0xFFFFFFFFu32
        } else {
            0
        };
        let m3 = if *mask.add(idx + 3) != 0 {
            0xFFFFFFFFu32
        } else {
            0
        };

        let mask_arr = [m0, m1, m2, m3];
        let v_mask = vld1q_u32(mask_arr.as_ptr());

        // Blend: select value where mask is true, input where false
        let result = vbslq_f32(v_mask, v_value, v_in);
        vst1q_f32(output.add(idx), result);
    }

    // Scalar tail
    for i in (chunks * lanes)..len {
        *output.add(i) = if *mask.add(i) != 0 {
            value
        } else {
            *input.add(i)
        };
    }
}

/// NEON masked fill for f64
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_fill_f64(
    input: *const f64,
    mask: *const u8,
    output: *mut f64,
    len: usize,
    value: f64,
) {
    let lanes = 2;
    let chunks = len / lanes;

    let v_value = vdupq_n_f64(value);

    for i in 0..chunks {
        let idx = i * lanes;

        let v_in = vld1q_f64(input.add(idx));

        // Expand mask bytes to u64 mask
        let m0 = if *mask.add(idx) != 0 {
            0xFFFFFFFFFFFFFFFFu64
        } else {
            0
        };
        let m1 = if *mask.add(idx + 1) != 0 {
            0xFFFFFFFFFFFFFFFFu64
        } else {
            0
        };

        let mask_arr = [m0, m1];
        let v_mask = vld1q_u64(mask_arr.as_ptr());

        let result = vbslq_f64(v_mask, v_value, v_in);
        vst1q_f64(output.add(idx), result);
    }

    // Scalar tail
    for i in (chunks * lanes)..len {
        *output.add(i) = if *mask.add(i) != 0 {
            value
        } else {
            *input.add(i)
        };
    }
}

// ============================================================================
// Masked Select
// ============================================================================

/// NEON masked select for f32
///
/// Selects elements where mask is true into a contiguous output.
/// Returns the number of selected elements.
///
/// # Note
/// NEON doesn't have efficient gather/scatter, so we use scalar loop
/// but still benefit from prefetching patterns.
///
/// # Safety
/// - All pointers must be valid
/// - `output` must have space for at least `len` elements
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_select_f32(
    input: *const f32,
    mask: *const u8,
    output: *mut f32,
    len: usize,
) -> usize {
    // NEON doesn't have efficient compress/gather, use scalar
    let mut out_idx = 0;
    for i in 0..len {
        if *mask.add(i) != 0 {
            *output.add(out_idx) = *input.add(i);
            out_idx += 1;
        }
    }
    out_idx
}

/// NEON masked select for f64
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_select_f64(
    input: *const f64,
    mask: *const u8,
    output: *mut f64,
    len: usize,
) -> usize {
    let mut out_idx = 0;
    for i in 0..len {
        if *mask.add(i) != 0 {
            *output.add(out_idx) = *input.add(i);
            out_idx += 1;
        }
    }
    out_idx
}

// ============================================================================
// Masked Count
// ============================================================================

/// NEON mask count (popcount)
///
/// Counts the number of true elements in the mask using vectorized addition.
///
/// # Safety
/// - `mask` must be valid for `len` elements
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn masked_count(mask: *const u8, len: usize) -> usize {
    let lanes = 16; // Process 16 bytes at a time
    let chunks = len / lanes;

    let mut total_acc = vdupq_n_u8(0);

    for i in 0..chunks {
        let idx = i * lanes;
        let v = vld1q_u8(mask.add(idx));

        // Each byte is 0 or non-zero, convert to 0 or 1
        let zero = vdupq_n_u8(0);
        let one = vdupq_n_u8(1);
        let cmp = vcgtq_u8(v, zero); // 0xFF where v > 0, else 0
        let ones = vandq_u8(cmp, one); // 1 where mask is true

        // Accumulate (saturating add to handle overflow within chunks)
        total_acc = vaddq_u8(total_acc, ones);

        // Every 255 iterations, reduce to avoid overflow
        if (i + 1) % 255 == 0 {
            // Horizontal sum
            let sum16 = vpaddlq_u8(total_acc);
            let sum32 = vpaddlq_u16(sum16);
            let _sum64 = vpaddlq_u32(sum32);
            // Will handle at final reduction
        }
    }

    // Horizontal sum of total_acc
    let sum16 = vpaddlq_u8(total_acc); // 8 x u16
    let sum32 = vpaddlq_u16(sum16); // 4 x u32
    let sum64 = vpaddlq_u32(sum32); // 2 x u64

    let mut count = vgetq_lane_u64(sum64, 0) + vgetq_lane_u64(sum64, 1);

    // Scalar tail
    for i in (chunks * lanes)..len {
        if *mask.add(i) != 0 {
            count += 1;
        }
    }

    count as usize
}