pub fn distance_filter_l2(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_avx512_supported() {
return unsafe { distance_filter_avx512(points, center, radius_sq) };
}
if is_avx2_supported() {
return unsafe { distance_filter_avx2(points, center, radius_sq) };
}
if is_sse2_supported() {
return unsafe { distance_filter_sse2(points, center, radius_sq) };
}
}
distance_filter_scalar(points, center, radius_sq)
}
pub fn distance_filter_scalar(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
let (cx, cy, cz) = center;
points
.iter()
.map(|(x, y, z)| {
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
let d2 = dx * dx + dy * dy + dz * dz;
d2 <= radius_sq
})
.collect()
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_avx512_supported() -> bool {
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => return false,
2 => return true,
_ => {}
}
let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 16) != 0;
CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
supported
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_avx2_supported() -> bool {
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => return false,
2 => return true,
_ => {}
}
let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 5) != 0;
CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
supported
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_sse2_supported() -> bool {
#[cfg(target_arch = "x86_64")]
return true;
#[cfg(target_arch = "x86")]
{
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => return false,
2 => return true,
_ => {}
}
let supported = unsafe { std::arch::x86::__cpuid(1).edx & (1 << 26) != 0 };
CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
supported
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f")]
unsafe fn distance_filter_avx512(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
use std::arch::x86_64::*;
let (cx, cy, cz) = center;
let cx_vec = _mm512_set1_ps(cx);
let cy_vec = _mm512_set1_ps(cy);
let cz_vec = _mm512_set1_ps(cz);
let radius_vec = _mm512_set1_ps(radius_sq);
let mut result = Vec::with_capacity(points.len());
let mut i = 0;
while i + 16 <= points.len() {
let mut xs = [0.0f32; 16];
let mut ys = [0.0f32; 16];
let mut zs = [0.0f32; 16];
for j in 0..16 {
xs[j] = points[i + j].0;
ys[j] = points[i + j].1;
zs[j] = points[i + j].2;
}
let x_vec = _mm512_loadu_ps(xs.as_ptr());
let y_vec = _mm512_loadu_ps(ys.as_ptr());
let z_vec = _mm512_loadu_ps(zs.as_ptr());
let dx = _mm512_sub_ps(x_vec, cx_vec);
let dy = _mm512_sub_ps(y_vec, cy_vec);
let dz = _mm512_sub_ps(z_vec, cz_vec);
let dx2 = _mm512_mul_ps(dx, dx);
let dy2 = _mm512_mul_ps(dy, dy);
let dz2 = _mm512_mul_ps(dz, dz);
let dist_sq = _mm512_add_ps(_mm512_add_ps(dx2, dy2), dz2);
let mask = _mm512_cmple_ps_mask(dist_sq, radius_vec);
for j in 0..16 {
result.push((mask >> j) & 1 != 0);
}
i += 16;
}
while i < points.len() {
let (x, y, z) = points[i];
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
i += 1;
}
result
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn distance_filter_avx2(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
use std::arch::x86_64::*;
let (cx, cy, cz) = center;
let cx_vec = _mm256_set1_ps(cx);
let cy_vec = _mm256_set1_ps(cy);
let cz_vec = _mm256_set1_ps(cz);
let radius_vec = _mm256_set1_ps(radius_sq);
let mut result = Vec::with_capacity(points.len());
let mut i = 0;
while i + 8 <= points.len() {
let mut xs = [0.0f32; 8];
let mut ys = [0.0f32; 8];
let mut zs = [0.0f32; 8];
for j in 0..8 {
xs[j] = points[i + j].0;
ys[j] = points[i + j].1;
zs[j] = points[i + j].2;
}
let x_vec = _mm256_loadu_ps(xs.as_ptr());
let y_vec = _mm256_loadu_ps(ys.as_ptr());
let z_vec = _mm256_loadu_ps(zs.as_ptr());
let dx = _mm256_sub_ps(x_vec, cx_vec);
let dy = _mm256_sub_ps(y_vec, cy_vec);
let dz = _mm256_sub_ps(z_vec, cz_vec);
let dx2 = _mm256_mul_ps(dx, dx);
let dy2 = _mm256_mul_ps(dy, dy);
let dz2 = _mm256_mul_ps(dz, dz);
let dist_sq = _mm256_add_ps(_mm256_add_ps(dx2, dy2), dz2);
let mask = _mm256_movemask_ps(_mm256_cmp_ps(dist_sq, radius_vec, _CMP_LE_OS));
for j in 0..8 {
result.push((mask >> j) & 1 != 0);
}
i += 8;
}
while i < points.len() {
let (x, y, z) = points[i];
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
i += 1;
}
result
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn distance_filter_sse2(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
use std::arch::x86_64::*;
let (cx, cy, cz) = center;
let cx_vec = _mm_set1_ps(cx);
let cy_vec = _mm_set1_ps(cy);
let cz_vec = _mm_set1_ps(cz);
let radius_vec = _mm_set1_ps(radius_sq);
let mut result = Vec::with_capacity(points.len());
let mut i = 0;
while i + 4 <= points.len() {
let mut xs = [0.0f32; 4];
let mut ys = [0.0f32; 4];
let mut zs = [0.0f32; 4];
for j in 0..4 {
xs[j] = points[i + j].0;
ys[j] = points[i + j].1;
zs[j] = points[i + j].2;
}
let x_vec = _mm_loadu_ps(xs.as_ptr());
let y_vec = _mm_loadu_ps(ys.as_ptr());
let z_vec = _mm_loadu_ps(zs.as_ptr());
let dx = _mm_sub_ps(x_vec, cx_vec);
let dy = _mm_sub_ps(y_vec, cy_vec);
let dz = _mm_sub_ps(z_vec, cz_vec);
let dx2 = _mm_mul_ps(dx, dx);
let dy2 = _mm_mul_ps(dy, dy);
let dz2 = _mm_mul_ps(dz, dz);
let dist_sq = _mm_add_ps(_mm_add_ps(dx2, dy2), dz2);
let mask = _mm_movemask_ps(_mm_cmple_ps(dist_sq, radius_vec));
for j in 0..4 {
result.push((mask >> j) & 1 != 0);
}
i += 4;
}
while i < points.len() {
let (x, y, z) = points[i];
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
i += 1;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_filter_scalar_basic() {
let points = vec![
(0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (3.0, 0.0, 0.0), ];
let result = distance_filter_scalar(&points, (0.0, 0.0, 0.0), 4.0);
assert_eq!(result.len(), 3);
assert!(result[0]); assert!(result[1]); assert!(!result[2]); }
#[test]
fn test_distance_filter_equivalence() {
let points: Vec<_> = (0..100)
.map(|i| (i as f32 * 0.1, i as f32 * 0.2, i as f32 * 0.3))
.collect();
let center = (5.0, 5.0, 5.0);
let radius_sq = 10.0;
let scalar_result = distance_filter_scalar(&points, center, radius_sq);
let auto_result = distance_filter_l2(&points, center, radius_sq);
assert_eq!(
scalar_result, auto_result,
"SIMD and scalar must produce identical results"
);
}
#[test]
fn test_distance_filter_edge_cases() {
let empty: Vec<(f32, f32, f32)> = vec![];
let result = distance_filter_l2(&empty, (0.0, 0.0, 0.0), 1.0);
assert!(result.is_empty());
let points = vec![(1.0, 0.0, 0.0)];
let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
assert!(result[0]);
let points = vec![(1.0001, 0.0, 0.0)];
let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
assert!(!result[0]);
}
}