pub fn distance_filter_l2(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_avx512_supported() {
return unsafe { distance_filter_avx512(points, center, radius_sq) };
}
if is_avx2_supported() {
return unsafe { distance_filter_avx2(points, center, radius_sq) };
}
if is_sse2_supported() {
return unsafe { distance_filter_sse2(points, center, radius_sq) };
}
}
distance_filter_scalar(points, center, radius_sq)
}
pub fn distance_filter_scalar(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
let (cx, cy, cz) = center;
points
.iter()
.map(|(x, y, z)| {
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
let d2 = dx * dx + dy * dy + dz * dz;
d2 <= radius_sq
})
.collect()
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_avx512_supported() -> bool {
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => return false,
2 => return true,
_ => {}
}
let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 16) != 0;
CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
supported
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_avx2_supported() -> bool {
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => return false,
2 => return true,
_ => {}
}
let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 5) != 0;
CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
supported
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn is_sse2_supported() -> bool {
#[cfg(target_arch = "x86_64")]
return true;
#[cfg(target_arch = "x86")]
{
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
match CACHED.load(Ordering::Relaxed) {
1 => return false,
2 => return true,
_ => {}
}
let supported = unsafe { std::arch::x86::__cpuid(1).edx & (1 << 26) != 0 };
CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
supported
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f")]
unsafe fn distance_filter_avx512(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
use std::arch::x86_64::*;
assert_eq!(std::mem::size_of::<(f32, f32, f32)>(), 12);
assert_eq!(std::mem::align_of::<(f32, f32, f32)>(), 4);
let (cx, cy, cz) = center;
let cx_vec = _mm512_set1_ps(cx);
let cy_vec = _mm512_set1_ps(cy);
let cz_vec = _mm512_set1_ps(cz);
let radius_vec = _mm512_set1_ps(radius_sq);
let mut result = Vec::with_capacity(points.len());
let mut i = 0;
let points_ptr = points.as_ptr() as *const f32;
let x_mask_0 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
let x_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 2, 5, 8, 11, 14, 0, 0, 0, 0, 0);
let x_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4, 7, 10, 13);
let y_mask_0 = _mm512_setr_epi32(1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
let y_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 3, 6, 9, 12, 15, 0, 0, 0, 0, 0);
let y_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 8, 11, 14);
let z_mask_0 = _mm512_setr_epi32(2, 5, 8, 11, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
let z_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0);
let z_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 9, 12, 15);
while i + 16 <= points.len() {
let r0 = _mm512_loadu_ps(points_ptr.add(i * 3));
let r1 = _mm512_loadu_ps(points_ptr.add(i * 3 + 16));
let r2 = _mm512_loadu_ps(points_ptr.add(i * 3 + 32));
let p0_x = _mm512_permutexvar_ps(x_mask_0, r0);
let p1_x = _mm512_permutexvar_ps(x_mask_1, r1);
let p2_x = _mm512_permutexvar_ps(x_mask_2, r2);
let p01_x = _mm512_mask_blend_ps(0b00000111_11000000, p0_x, p1_x);
let x_vec = _mm512_mask_blend_ps(0b11111000_00000000, p01_x, p2_x);
let p0_y = _mm512_permutexvar_ps(y_mask_0, r0);
let p1_y = _mm512_permutexvar_ps(y_mask_1, r1);
let p2_y = _mm512_permutexvar_ps(y_mask_2, r2);
let p01_y = _mm512_mask_blend_ps(0b00000111_11100000, p0_y, p1_y);
let y_vec = _mm512_mask_blend_ps(0b11111000_00000000, p01_y, p2_y);
let p0_z = _mm512_permutexvar_ps(z_mask_0, r0);
let p1_z = _mm512_permutexvar_ps(z_mask_1, r1);
let p2_z = _mm512_permutexvar_ps(z_mask_2, r2);
let p01_z = _mm512_mask_blend_ps(0b00000011_11100000, p0_z, p1_z);
let z_vec = _mm512_mask_blend_ps(0b11111100_00000000, p01_z, p2_z);
let dx = _mm512_sub_ps(x_vec, cx_vec);
let dy = _mm512_sub_ps(y_vec, cy_vec);
let dz = _mm512_sub_ps(z_vec, cz_vec);
let dx2 = _mm512_mul_ps(dx, dx);
let dy2 = _mm512_mul_ps(dy, dy);
let dz2 = _mm512_mul_ps(dz, dz);
let dist_sq = _mm512_add_ps(_mm512_add_ps(dx2, dy2), dz2);
let mask = _mm512_cmple_ps_mask(dist_sq, radius_vec);
for j in 0..16 {
result.push((mask >> j) & 1 != 0);
}
i += 16;
}
while i < points.len() {
let (x, y, z) = points[i];
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
i += 1;
}
result
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn distance_filter_avx2(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
use std::arch::x86_64::*;
assert_eq!(std::mem::size_of::<(f32, f32, f32)>(), 12);
assert_eq!(std::mem::align_of::<(f32, f32, f32)>(), 4);
let (cx, cy, cz) = center;
let cx_vec = _mm256_set1_ps(cx);
let cy_vec = _mm256_set1_ps(cy);
let cz_vec = _mm256_set1_ps(cz);
let radius_vec = _mm256_set1_ps(radius_sq);
let mut result = Vec::with_capacity(points.len());
let mut i = 0;
let points_ptr = points.as_ptr() as *const f32;
let x_mask_0 = _mm256_setr_epi32(0, 3, 6, 0, 0, 0, 0, 0);
let x_mask_1 = _mm256_setr_epi32(0, 0, 0, 1, 4, 7, 0, 0);
let x_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 2, 5);
let y_mask_0 = _mm256_setr_epi32(1, 4, 7, 0, 0, 0, 0, 0);
let y_mask_1 = _mm256_setr_epi32(0, 0, 0, 2, 5, 0, 0, 0);
let y_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 3, 6);
let z_mask_0 = _mm256_setr_epi32(2, 5, 0, 0, 0, 0, 0, 0);
let z_mask_1 = _mm256_setr_epi32(0, 0, 0, 3, 6, 0, 0, 0);
let z_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 1, 4, 7);
while i + 8 <= points.len() {
let r0 = _mm256_loadu_ps(points_ptr.add(i * 3));
let r1 = _mm256_loadu_ps(points_ptr.add(i * 3 + 8));
let r2 = _mm256_loadu_ps(points_ptr.add(i * 3 + 16));
let p0_x = _mm256_permutevar8x32_ps(r0, x_mask_0);
let p1_x = _mm256_permutevar8x32_ps(r1, x_mask_1);
let p2_x = _mm256_permutevar8x32_ps(r2, x_mask_2);
let p01_x = _mm256_blend_ps(p0_x, p1_x, 0b00111000);
let x_vec = _mm256_blend_ps(p01_x, p2_x, 0b11000000);
let p0_y = _mm256_permutevar8x32_ps(r0, y_mask_0);
let p1_y = _mm256_permutevar8x32_ps(r1, y_mask_1);
let p2_y = _mm256_permutevar8x32_ps(r2, y_mask_2);
let p01_y = _mm256_blend_ps(p0_y, p1_y, 0b00011000);
let y_vec = _mm256_blend_ps(p01_y, p2_y, 0b11100000);
let p0_z = _mm256_permutevar8x32_ps(r0, z_mask_0);
let p1_z = _mm256_permutevar8x32_ps(r1, z_mask_1);
let p2_z = _mm256_permutevar8x32_ps(r2, z_mask_2);
let p01_z = _mm256_blend_ps(p0_z, p1_z, 0b00011100);
let z_vec = _mm256_blend_ps(p01_z, p2_z, 0b11100000);
let dx = _mm256_sub_ps(x_vec, cx_vec);
let dy = _mm256_sub_ps(y_vec, cy_vec);
let dz = _mm256_sub_ps(z_vec, cz_vec);
let dx2 = _mm256_mul_ps(dx, dx);
let dy2 = _mm256_mul_ps(dy, dy);
let dz2 = _mm256_mul_ps(dz, dz);
let dist_sq = _mm256_add_ps(_mm256_add_ps(dx2, dy2), dz2);
let mask = _mm256_movemask_ps(_mm256_cmp_ps(dist_sq, radius_vec, _CMP_LE_OS));
for j in 0..8 {
result.push((mask >> j) & 1 != 0);
}
i += 8;
}
while i < points.len() {
let (x, y, z) = points[i];
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
i += 1;
}
result
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn distance_filter_sse2(
points: &[(f32, f32, f32)],
center: (f32, f32, f32),
radius_sq: f32,
) -> Vec<bool> {
use std::arch::x86_64::*;
let (cx, cy, cz) = center;
let cx_vec = _mm_set1_ps(cx);
let cy_vec = _mm_set1_ps(cy);
let cz_vec = _mm_set1_ps(cz);
let radius_vec = _mm_set1_ps(radius_sq);
let mut result = Vec::with_capacity(points.len());
let mut i = 0;
while i + 4 <= points.len() {
let mut xs = [0.0f32; 4];
let mut ys = [0.0f32; 4];
let mut zs = [0.0f32; 4];
for j in 0..4 {
xs[j] = points[i + j].0;
ys[j] = points[i + j].1;
zs[j] = points[i + j].2;
}
let x_vec = _mm_loadu_ps(xs.as_ptr());
let y_vec = _mm_loadu_ps(ys.as_ptr());
let z_vec = _mm_loadu_ps(zs.as_ptr());
let dx = _mm_sub_ps(x_vec, cx_vec);
let dy = _mm_sub_ps(y_vec, cy_vec);
let dz = _mm_sub_ps(z_vec, cz_vec);
let dx2 = _mm_mul_ps(dx, dx);
let dy2 = _mm_mul_ps(dy, dy);
let dz2 = _mm_mul_ps(dz, dz);
let dist_sq = _mm_add_ps(_mm_add_ps(dx2, dy2), dz2);
let mask = _mm_movemask_ps(_mm_cmple_ps(dist_sq, radius_vec));
for j in 0..4 {
result.push((mask >> j) & 1 != 0);
}
i += 4;
}
while i < points.len() {
let (x, y, z) = points[i];
let dx = x - cx;
let dy = y - cy;
let dz = z - cz;
result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
i += 1;
}
result
}
pub fn batch_spatial_filter_nodes(
nodes: &[crate::algorithms::four_d::GraphNode4D],
center: (f32, f32, f32),
radius: f32,
) -> Vec<usize> {
let radius_sq = radius * radius;
let coords: Vec<(f32, f32, f32)> = nodes.iter().map(|n| (n.x, n.y, n.z)).collect();
let mask = distance_filter_l2(&coords, center, radius_sq);
mask.into_iter()
.enumerate()
.filter(|&(_, inside)| inside)
.map(|(i, _)| i)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_filter_scalar_basic() {
let points = vec![
(0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (3.0, 0.0, 0.0), ];
let result = distance_filter_scalar(&points, (0.0, 0.0, 0.0), 4.0);
assert_eq!(result.len(), 3);
assert!(result[0]); assert!(result[1]); assert!(!result[2]); }
#[test]
fn test_distance_filter_equivalence() {
let points: Vec<_> = (0..100)
.map(|i| (i as f32 * 0.1, i as f32 * 0.2, i as f32 * 0.3))
.collect();
let center = (5.0, 5.0, 5.0);
let radius_sq = 10.0;
let scalar_result = distance_filter_scalar(&points, center, radius_sq);
let auto_result = distance_filter_l2(&points, center, radius_sq);
assert_eq!(
scalar_result, auto_result,
"SIMD and scalar must produce identical results"
);
}
#[test]
fn test_distance_filter_edge_cases() {
let empty: Vec<(f32, f32, f32)> = vec![];
let result = distance_filter_l2(&empty, (0.0, 0.0, 0.0), 1.0);
assert!(result.is_empty());
let points = vec![(1.0, 0.0, 0.0)];
let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
assert!(result[0]);
let points = vec![(1.0001, 0.0, 0.0)];
let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
assert!(!result[0]);
}
#[test]
fn test_batch_spatial_filter_nodes_matches_scalar() {
use crate::algorithms::four_d::GraphNode4D;
use std::collections::BTreeMap;
let nodes: Vec<GraphNode4D> = (0..100)
.map(|i| GraphNode4D {
id: i as u64,
x: i as f32 * 0.3,
y: i as f32 * 0.2,
z: i as f32 * 0.1,
begin_ts: 0,
end_ts: 100,
properties: BTreeMap::new(),
successors: vec![],
})
.collect();
let center = (5.0_f32, 5.0_f32, 5.0_f32);
let radius = 4.0_f32;
let radius_sq = radius * radius;
let expected: Vec<usize> = nodes
.iter()
.enumerate()
.filter(|(_, n)| {
let dx = n.x - center.0;
let dy = n.y - center.1;
let dz = n.z - center.2;
dx * dx + dy * dy + dz * dz <= radius_sq
})
.map(|(i, _)| i)
.collect();
let result = batch_spatial_filter_nodes(&nodes, center, radius);
assert_eq!(result, expected, "SIMD batch must match scalar reference");
}
#[test]
fn test_batch_spatial_filter_nodes_empty() {
use crate::algorithms::four_d::GraphNode4D;
let nodes: Vec<GraphNode4D> = vec![];
let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 1.0);
assert!(result.is_empty());
}
#[test]
fn test_batch_spatial_filter_nodes_all_match() {
use crate::algorithms::four_d::GraphNode4D;
use std::collections::BTreeMap;
let nodes: Vec<GraphNode4D> = (0..10)
.map(|i| GraphNode4D {
id: i as u64,
x: 0.01 * i as f32,
y: 0.01 * i as f32,
z: 0.01 * i as f32,
begin_ts: 0,
end_ts: 100,
properties: BTreeMap::new(),
successors: vec![],
})
.collect();
let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 100.0);
assert_eq!(result.len(), 10, "All nodes should match with large radius");
}
#[test]
fn test_batch_spatial_filter_nodes_none_match() {
use crate::algorithms::four_d::GraphNode4D;
use std::collections::BTreeMap;
let nodes: Vec<GraphNode4D> = (0..10)
.map(|i| GraphNode4D {
id: i as u64,
x: 1000.0 + i as f32,
y: 1000.0,
z: 1000.0,
begin_ts: 0,
end_ts: 100,
properties: BTreeMap::new(),
successors: vec![],
})
.collect();
let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 1.0);
assert!(result.is_empty(), "No nodes should match");
}
}