geographdb-core 0.3.1

Geometric graph database core - 3D spatial indexing for code analysis
Documentation
//! SIMD-accelerated spatial distance filtering
//!
//! Provides portable SIMD implementations for 3D distance calculations
//! with runtime CPU feature detection and scalar fallback.
//!
//! Ported from geographdb_prototype/acceleration/simd_backend.rs

/// Filter points by L2 distance using best available SIMD implementation
///
/// # Arguments
/// * `points` - Slice of (x, y, z) tuples
/// * `center` - Query center point (cx, cy, cz)
/// * `radius_sq` - Squared radius for inclusion
///
/// # Returns
/// Vec<bool> where true means point is within radius
pub fn distance_filter_l2(
    points: &[(f32, f32, f32)],
    center: (f32, f32, f32),
    radius_sq: f32,
) -> Vec<bool> {
    // Runtime CPU feature dispatch
    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
    {
        if is_avx512_supported() {
            return unsafe { distance_filter_avx512(points, center, radius_sq) };
        }
        if is_avx2_supported() {
            return unsafe { distance_filter_avx2(points, center, radius_sq) };
        }
        if is_sse2_supported() {
            return unsafe { distance_filter_sse2(points, center, radius_sq) };
        }
    }

    // Scalar fallback (portable, always available)
    distance_filter_scalar(points, center, radius_sq)
}

/// Scalar implementation - guaranteed to work on all platforms
pub fn distance_filter_scalar(
    points: &[(f32, f32, f32)],
    center: (f32, f32, f32),
    radius_sq: f32,
) -> Vec<bool> {
    let (cx, cy, cz) = center;
    points
        .iter()
        .map(|(x, y, z)| {
            let dx = x - cx;
            let dy = y - cy;
            let dz = z - cz;
            let d2 = dx * dx + dy * dy + dz * dz;
            d2 <= radius_sq
        })
        .collect()
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_avx512_supported() -> bool {
    use std::sync::atomic::{AtomicU8, Ordering};
    static CACHED: AtomicU8 = AtomicU8::new(0);

    match CACHED.load(Ordering::Relaxed) {
        1 => return false,
        2 => return true,
        _ => {}
    }

    let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 16) != 0;
    CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
    supported
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_avx2_supported() -> bool {
    use std::sync::atomic::{AtomicU8, Ordering};
    static CACHED: AtomicU8 = AtomicU8::new(0);

    match CACHED.load(Ordering::Relaxed) {
        1 => return false,
        2 => return true,
        _ => {}
    }

    let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 5) != 0;
    CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
    supported
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_sse2_supported() -> bool {
    // SSE2 is baseline for x86_64, only check for x86
    #[cfg(target_arch = "x86_64")]
    return true;
    #[cfg(target_arch = "x86")]
    {
        use std::sync::atomic::{AtomicU8, Ordering};
        static CACHED: AtomicU8 = AtomicU8::new(0);

        match CACHED.load(Ordering::Relaxed) {
            1 => return false,
            2 => return true,
            _ => {}
        }

        let supported = unsafe { std::arch::x86::__cpuid(1).edx & (1 << 26) != 0 };
        CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
        supported
    }
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f")]
unsafe fn distance_filter_avx512(
    points: &[(f32, f32, f32)],
    center: (f32, f32, f32),
    radius_sq: f32,
) -> Vec<bool> {
    use std::arch::x86_64::*;

    let (cx, cy, cz) = center;
    let cx_vec = _mm512_set1_ps(cx);
    let cy_vec = _mm512_set1_ps(cy);
    let cz_vec = _mm512_set1_ps(cz);
    let radius_vec = _mm512_set1_ps(radius_sq);

    let mut result = Vec::with_capacity(points.len());
    let mut i = 0;

    // Process 16 points at a time
    while i + 16 <= points.len() {
        let mut xs = [0.0f32; 16];
        let mut ys = [0.0f32; 16];
        let mut zs = [0.0f32; 16];

        for j in 0..16 {
            xs[j] = points[i + j].0;
            ys[j] = points[i + j].1;
            zs[j] = points[i + j].2;
        }

        let x_vec = _mm512_loadu_ps(xs.as_ptr());
        let y_vec = _mm512_loadu_ps(ys.as_ptr());
        let z_vec = _mm512_loadu_ps(zs.as_ptr());

        let dx = _mm512_sub_ps(x_vec, cx_vec);
        let dy = _mm512_sub_ps(y_vec, cy_vec);
        let dz = _mm512_sub_ps(z_vec, cz_vec);

        let dx2 = _mm512_mul_ps(dx, dx);
        let dy2 = _mm512_mul_ps(dy, dy);
        let dz2 = _mm512_mul_ps(dz, dz);

        let dist_sq = _mm512_add_ps(_mm512_add_ps(dx2, dy2), dz2);
        let mask = _mm512_cmple_ps_mask(dist_sq, radius_vec);

        for j in 0..16 {
            result.push((mask >> j) & 1 != 0);
        }

        i += 16;
    }

    // Handle remaining points with scalar
    while i < points.len() {
        let (x, y, z) = points[i];
        let dx = x - cx;
        let dy = y - cy;
        let dz = z - cz;
        result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
        i += 1;
    }

    result
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn distance_filter_avx2(
    points: &[(f32, f32, f32)],
    center: (f32, f32, f32),
    radius_sq: f32,
) -> Vec<bool> {
    use std::arch::x86_64::*;

    let (cx, cy, cz) = center;
    let cx_vec = _mm256_set1_ps(cx);
    let cy_vec = _mm256_set1_ps(cy);
    let cz_vec = _mm256_set1_ps(cz);
    let radius_vec = _mm256_set1_ps(radius_sq);

    let mut result = Vec::with_capacity(points.len());
    let mut i = 0;

    // Process 8 points at a time
    while i + 8 <= points.len() {
        let mut xs = [0.0f32; 8];
        let mut ys = [0.0f32; 8];
        let mut zs = [0.0f32; 8];

        for j in 0..8 {
            xs[j] = points[i + j].0;
            ys[j] = points[i + j].1;
            zs[j] = points[i + j].2;
        }

        let x_vec = _mm256_loadu_ps(xs.as_ptr());
        let y_vec = _mm256_loadu_ps(ys.as_ptr());
        let z_vec = _mm256_loadu_ps(zs.as_ptr());

        let dx = _mm256_sub_ps(x_vec, cx_vec);
        let dy = _mm256_sub_ps(y_vec, cy_vec);
        let dz = _mm256_sub_ps(z_vec, cz_vec);

        let dx2 = _mm256_mul_ps(dx, dx);
        let dy2 = _mm256_mul_ps(dy, dy);
        let dz2 = _mm256_mul_ps(dz, dz);

        let dist_sq = _mm256_add_ps(_mm256_add_ps(dx2, dy2), dz2);
        let mask = _mm256_movemask_ps(_mm256_cmp_ps(dist_sq, radius_vec, _CMP_LE_OS));

        for j in 0..8 {
            result.push((mask >> j) & 1 != 0);
        }

        i += 8;
    }

    // Handle remaining points with scalar
    while i < points.len() {
        let (x, y, z) = points[i];
        let dx = x - cx;
        let dy = y - cy;
        let dz = z - cz;
        result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
        i += 1;
    }

    result
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn distance_filter_sse2(
    points: &[(f32, f32, f32)],
    center: (f32, f32, f32),
    radius_sq: f32,
) -> Vec<bool> {
    use std::arch::x86_64::*;

    let (cx, cy, cz) = center;
    let cx_vec = _mm_set1_ps(cx);
    let cy_vec = _mm_set1_ps(cy);
    let cz_vec = _mm_set1_ps(cz);
    let radius_vec = _mm_set1_ps(radius_sq);

    let mut result = Vec::with_capacity(points.len());
    let mut i = 0;

    // Process 4 points at a time
    while i + 4 <= points.len() {
        let mut xs = [0.0f32; 4];
        let mut ys = [0.0f32; 4];
        let mut zs = [0.0f32; 4];

        for j in 0..4 {
            xs[j] = points[i + j].0;
            ys[j] = points[i + j].1;
            zs[j] = points[i + j].2;
        }

        let x_vec = _mm_loadu_ps(xs.as_ptr());
        let y_vec = _mm_loadu_ps(ys.as_ptr());
        let z_vec = _mm_loadu_ps(zs.as_ptr());

        let dx = _mm_sub_ps(x_vec, cx_vec);
        let dy = _mm_sub_ps(y_vec, cy_vec);
        let dz = _mm_sub_ps(z_vec, cz_vec);

        let dx2 = _mm_mul_ps(dx, dx);
        let dy2 = _mm_mul_ps(dy, dy);
        let dz2 = _mm_mul_ps(dz, dz);

        let dist_sq = _mm_add_ps(_mm_add_ps(dx2, dy2), dz2);
        let mask = _mm_movemask_ps(_mm_cmple_ps(dist_sq, radius_vec));

        for j in 0..4 {
            result.push((mask >> j) & 1 != 0);
        }

        i += 4;
    }

    // Handle remaining points with scalar
    while i < points.len() {
        let (x, y, z) = points[i];
        let dx = x - cx;
        let dy = y - cy;
        let dz = z - cz;
        result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
        i += 1;
    }

    result
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_distance_filter_scalar_basic() {
        let points = vec![
            (0.0, 0.0, 0.0), // Center - should match
            (1.0, 0.0, 0.0), // Distance 1 - should match (radius 2)
            (3.0, 0.0, 0.0), // Distance 3 - should NOT match
        ];

        let result = distance_filter_scalar(&points, (0.0, 0.0, 0.0), 4.0);
        assert_eq!(result.len(), 3);
        assert!(result[0]); // Center point
        assert!(result[1]); // Distance 1 <= 2
        assert!(!result[2]); // Distance 3 > 2
    }

    #[test]
    fn test_distance_filter_equivalence() {
        // Generate test points
        let points: Vec<_> = (0..100)
            .map(|i| (i as f32 * 0.1, i as f32 * 0.2, i as f32 * 0.3))
            .collect();

        let center = (5.0, 5.0, 5.0);
        let radius_sq = 10.0;

        let scalar_result = distance_filter_scalar(&points, center, radius_sq);
        let auto_result = distance_filter_l2(&points, center, radius_sq);

        assert_eq!(
            scalar_result, auto_result,
            "SIMD and scalar must produce identical results"
        );
    }

    #[test]
    fn test_distance_filter_edge_cases() {
        // Empty input
        let empty: Vec<(f32, f32, f32)> = vec![];
        let result = distance_filter_l2(&empty, (0.0, 0.0, 0.0), 1.0);
        assert!(result.is_empty());

        // Single point exactly at radius boundary
        let points = vec![(1.0, 0.0, 0.0)];
        let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
        assert!(result[0]); // Distance squared = 1.0, radius_sq = 1.0, should be <=

        // Point just outside
        let points = vec![(1.0001, 0.0, 0.0)];
        let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
        assert!(!result[0]);
    }
}