1pub fn distance_filter_l2(
18 points: &[(f32, f32, f32)],
19 center: (f32, f32, f32),
20 radius_sq: f32,
21) -> Vec<bool> {
22 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
24 {
25 if is_avx512_supported() {
26 return unsafe { distance_filter_avx512(points, center, radius_sq) };
27 }
28 if is_avx2_supported() {
29 return unsafe { distance_filter_avx2(points, center, radius_sq) };
30 }
31 if is_sse2_supported() {
32 return unsafe { distance_filter_sse2(points, center, radius_sq) };
33 }
34 }
35
36 distance_filter_scalar(points, center, radius_sq)
38}
39
40pub fn distance_filter_scalar(
42 points: &[(f32, f32, f32)],
43 center: (f32, f32, f32),
44 radius_sq: f32,
45) -> Vec<bool> {
46 let (cx, cy, cz) = center;
47 points
48 .iter()
49 .map(|(x, y, z)| {
50 let dx = x - cx;
51 let dy = y - cy;
52 let dz = z - cz;
53 let d2 = dx * dx + dy * dy + dz * dz;
54 d2 <= radius_sq
55 })
56 .collect()
57}
58
59#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
60fn is_avx512_supported() -> bool {
61 use std::sync::atomic::{AtomicU8, Ordering};
62 static CACHED: AtomicU8 = AtomicU8::new(0);
63
64 match CACHED.load(Ordering::Relaxed) {
65 1 => return false,
66 2 => return true,
67 _ => {}
68 }
69
70 let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 16) != 0;
71 CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
72 supported
73}
74
75#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76fn is_avx2_supported() -> bool {
77 use std::sync::atomic::{AtomicU8, Ordering};
78 static CACHED: AtomicU8 = AtomicU8::new(0);
79
80 match CACHED.load(Ordering::Relaxed) {
81 1 => return false,
82 2 => return true,
83 _ => {}
84 }
85
86 let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 5) != 0;
87 CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
88 supported
89}
90
91#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
92fn is_sse2_supported() -> bool {
93 #[cfg(target_arch = "x86_64")]
95 return true;
96 #[cfg(target_arch = "x86")]
97 {
98 use std::sync::atomic::{AtomicU8, Ordering};
99 static CACHED: AtomicU8 = AtomicU8::new(0);
100
101 match CACHED.load(Ordering::Relaxed) {
102 1 => return false,
103 2 => return true,
104 _ => {}
105 }
106
107 let supported = unsafe { std::arch::x86::__cpuid(1).edx & (1 << 26) != 0 };
108 CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
109 supported
110 }
111}
112
113#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
114#[target_feature(enable = "avx512f")]
115unsafe fn distance_filter_avx512(
116 points: &[(f32, f32, f32)],
117 center: (f32, f32, f32),
118 radius_sq: f32,
119) -> Vec<bool> {
120 use std::arch::x86_64::*;
121
122 assert_eq!(std::mem::size_of::<(f32, f32, f32)>(), 12);
124 assert_eq!(std::mem::align_of::<(f32, f32, f32)>(), 4);
125
126 let (cx, cy, cz) = center;
127 let cx_vec = _mm512_set1_ps(cx);
128 let cy_vec = _mm512_set1_ps(cy);
129 let cz_vec = _mm512_set1_ps(cz);
130 let radius_vec = _mm512_set1_ps(radius_sq);
131
132 let mut result = Vec::with_capacity(points.len());
133 let mut i = 0;
134
135 let points_ptr = points.as_ptr() as *const f32;
136
137 let x_mask_0 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
139 let x_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 2, 5, 8, 11, 14, 0, 0, 0, 0, 0);
140 let x_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4, 7, 10, 13);
141
142 let y_mask_0 = _mm512_setr_epi32(1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
143 let y_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 3, 6, 9, 12, 15, 0, 0, 0, 0, 0);
144 let y_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 8, 11, 14);
145
146 let z_mask_0 = _mm512_setr_epi32(2, 5, 8, 11, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
147 let z_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0);
148 let z_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 9, 12, 15);
149
150 while i + 16 <= points.len() {
152 let r0 = _mm512_loadu_ps(points_ptr.add(i * 3));
154 let r1 = _mm512_loadu_ps(points_ptr.add(i * 3 + 16));
155 let r2 = _mm512_loadu_ps(points_ptr.add(i * 3 + 32));
156
157 let p0_x = _mm512_permutexvar_ps(x_mask_0, r0);
159 let p1_x = _mm512_permutexvar_ps(x_mask_1, r1);
160 let p2_x = _mm512_permutexvar_ps(x_mask_2, r2);
161
162 let p01_x = _mm512_mask_blend_ps(0b00000111_11000000, p0_x, p1_x);
163 let x_vec = _mm512_mask_blend_ps(0b11111000_00000000, p01_x, p2_x);
164
165 let p0_y = _mm512_permutexvar_ps(y_mask_0, r0);
166 let p1_y = _mm512_permutexvar_ps(y_mask_1, r1);
167 let p2_y = _mm512_permutexvar_ps(y_mask_2, r2);
168
169 let p01_y = _mm512_mask_blend_ps(0b00000111_11100000, p0_y, p1_y);
170 let y_vec = _mm512_mask_blend_ps(0b11111000_00000000, p01_y, p2_y);
171
172 let p0_z = _mm512_permutexvar_ps(z_mask_0, r0);
173 let p1_z = _mm512_permutexvar_ps(z_mask_1, r1);
174 let p2_z = _mm512_permutexvar_ps(z_mask_2, r2);
175
176 let p01_z = _mm512_mask_blend_ps(0b00000011_11100000, p0_z, p1_z);
177 let z_vec = _mm512_mask_blend_ps(0b11111100_00000000, p01_z, p2_z);
178
179 let dx = _mm512_sub_ps(x_vec, cx_vec);
180 let dy = _mm512_sub_ps(y_vec, cy_vec);
181 let dz = _mm512_sub_ps(z_vec, cz_vec);
182
183 let dx2 = _mm512_mul_ps(dx, dx);
184 let dy2 = _mm512_mul_ps(dy, dy);
185 let dz2 = _mm512_mul_ps(dz, dz);
186
187 let dist_sq = _mm512_add_ps(_mm512_add_ps(dx2, dy2), dz2);
188 let mask = _mm512_cmple_ps_mask(dist_sq, radius_vec);
189
190 for j in 0..16 {
191 result.push((mask >> j) & 1 != 0);
192 }
193
194 i += 16;
195 }
196
197 while i < points.len() {
199 let (x, y, z) = points[i];
200 let dx = x - cx;
201 let dy = y - cy;
202 let dz = z - cz;
203 result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
204 i += 1;
205 }
206
207 result
208}
209
210#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
211#[target_feature(enable = "avx2")]
212unsafe fn distance_filter_avx2(
213 points: &[(f32, f32, f32)],
214 center: (f32, f32, f32),
215 radius_sq: f32,
216) -> Vec<bool> {
217 use std::arch::x86_64::*;
218
219 assert_eq!(std::mem::size_of::<(f32, f32, f32)>(), 12);
221 assert_eq!(std::mem::align_of::<(f32, f32, f32)>(), 4);
222
223 let (cx, cy, cz) = center;
224 let cx_vec = _mm256_set1_ps(cx);
225 let cy_vec = _mm256_set1_ps(cy);
226 let cz_vec = _mm256_set1_ps(cz);
227 let radius_vec = _mm256_set1_ps(radius_sq);
228
229 let mut result = Vec::with_capacity(points.len());
230 let mut i = 0;
231
232 let points_ptr = points.as_ptr() as *const f32;
233
234 let x_mask_0 = _mm256_setr_epi32(0, 3, 6, 0, 0, 0, 0, 0);
236 let x_mask_1 = _mm256_setr_epi32(0, 0, 0, 1, 4, 7, 0, 0);
237 let x_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 2, 5);
238
239 let y_mask_0 = _mm256_setr_epi32(1, 4, 7, 0, 0, 0, 0, 0);
240 let y_mask_1 = _mm256_setr_epi32(0, 0, 0, 2, 5, 0, 0, 0);
241 let y_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 3, 6);
242
243 let z_mask_0 = _mm256_setr_epi32(2, 5, 0, 0, 0, 0, 0, 0);
244 let z_mask_1 = _mm256_setr_epi32(0, 0, 0, 3, 6, 0, 0, 0);
245 let z_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 1, 4, 7);
246
247 while i + 8 <= points.len() {
249 let r0 = _mm256_loadu_ps(points_ptr.add(i * 3));
251 let r1 = _mm256_loadu_ps(points_ptr.add(i * 3 + 8));
252 let r2 = _mm256_loadu_ps(points_ptr.add(i * 3 + 16));
253
254 let p0_x = _mm256_permutevar8x32_ps(r0, x_mask_0);
256 let p1_x = _mm256_permutevar8x32_ps(r1, x_mask_1);
257 let p2_x = _mm256_permutevar8x32_ps(r2, x_mask_2);
258
259 let p01_x = _mm256_blend_ps(p0_x, p1_x, 0b00111000);
260 let x_vec = _mm256_blend_ps(p01_x, p2_x, 0b11000000);
261
262 let p0_y = _mm256_permutevar8x32_ps(r0, y_mask_0);
263 let p1_y = _mm256_permutevar8x32_ps(r1, y_mask_1);
264 let p2_y = _mm256_permutevar8x32_ps(r2, y_mask_2);
265
266 let p01_y = _mm256_blend_ps(p0_y, p1_y, 0b00011000);
267 let y_vec = _mm256_blend_ps(p01_y, p2_y, 0b11100000);
268
269 let p0_z = _mm256_permutevar8x32_ps(r0, z_mask_0);
270 let p1_z = _mm256_permutevar8x32_ps(r1, z_mask_1);
271 let p2_z = _mm256_permutevar8x32_ps(r2, z_mask_2);
272
273 let p01_z = _mm256_blend_ps(p0_z, p1_z, 0b00011100);
274 let z_vec = _mm256_blend_ps(p01_z, p2_z, 0b11100000);
275
276 let dx = _mm256_sub_ps(x_vec, cx_vec);
277 let dy = _mm256_sub_ps(y_vec, cy_vec);
278 let dz = _mm256_sub_ps(z_vec, cz_vec);
279
280 let dx2 = _mm256_mul_ps(dx, dx);
281 let dy2 = _mm256_mul_ps(dy, dy);
282 let dz2 = _mm256_mul_ps(dz, dz);
283
284 let dist_sq = _mm256_add_ps(_mm256_add_ps(dx2, dy2), dz2);
285 let mask = _mm256_movemask_ps(_mm256_cmp_ps(dist_sq, radius_vec, _CMP_LE_OS));
286
287 for j in 0..8 {
288 result.push((mask >> j) & 1 != 0);
289 }
290
291 i += 8;
292 }
293
294 while i < points.len() {
296 let (x, y, z) = points[i];
297 let dx = x - cx;
298 let dy = y - cy;
299 let dz = z - cz;
300 result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
301 i += 1;
302 }
303
304 result
305}
306
307#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
308#[target_feature(enable = "sse2")]
309unsafe fn distance_filter_sse2(
310 points: &[(f32, f32, f32)],
311 center: (f32, f32, f32),
312 radius_sq: f32,
313) -> Vec<bool> {
314 use std::arch::x86_64::*;
315
316 let (cx, cy, cz) = center;
317 let cx_vec = _mm_set1_ps(cx);
318 let cy_vec = _mm_set1_ps(cy);
319 let cz_vec = _mm_set1_ps(cz);
320 let radius_vec = _mm_set1_ps(radius_sq);
321
322 let mut result = Vec::with_capacity(points.len());
323 let mut i = 0;
324
325 while i + 4 <= points.len() {
327 let mut xs = [0.0f32; 4];
328 let mut ys = [0.0f32; 4];
329 let mut zs = [0.0f32; 4];
330
331 for j in 0..4 {
332 xs[j] = points[i + j].0;
333 ys[j] = points[i + j].1;
334 zs[j] = points[i + j].2;
335 }
336
337 let x_vec = _mm_loadu_ps(xs.as_ptr());
338 let y_vec = _mm_loadu_ps(ys.as_ptr());
339 let z_vec = _mm_loadu_ps(zs.as_ptr());
340
341 let dx = _mm_sub_ps(x_vec, cx_vec);
342 let dy = _mm_sub_ps(y_vec, cy_vec);
343 let dz = _mm_sub_ps(z_vec, cz_vec);
344
345 let dx2 = _mm_mul_ps(dx, dx);
346 let dy2 = _mm_mul_ps(dy, dy);
347 let dz2 = _mm_mul_ps(dz, dz);
348
349 let dist_sq = _mm_add_ps(_mm_add_ps(dx2, dy2), dz2);
350 let mask = _mm_movemask_ps(_mm_cmple_ps(dist_sq, radius_vec));
351
352 for j in 0..4 {
353 result.push((mask >> j) & 1 != 0);
354 }
355
356 i += 4;
357 }
358
359 while i < points.len() {
361 let (x, y, z) = points[i];
362 let dx = x - cx;
363 let dy = y - cy;
364 let dz = z - cz;
365 result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
366 i += 1;
367 }
368
369 result
370}
371
372pub fn batch_spatial_filter_nodes(
378 nodes: &[crate::algorithms::four_d::GraphNode4D],
379 center: (f32, f32, f32),
380 radius: f32,
381) -> Vec<usize> {
382 let radius_sq = radius * radius;
383 let coords: Vec<(f32, f32, f32)> = nodes.iter().map(|n| (n.x, n.y, n.z)).collect();
384 let mask = distance_filter_l2(&coords, center, radius_sq);
385 mask.into_iter()
386 .enumerate()
387 .filter(|&(_, inside)| inside)
388 .map(|(i, _)| i)
389 .collect()
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_distance_filter_scalar_basic() {
398 let points = vec![
399 (0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (3.0, 0.0, 0.0), ];
403
404 let result = distance_filter_scalar(&points, (0.0, 0.0, 0.0), 4.0);
405 assert_eq!(result.len(), 3);
406 assert!(result[0]); assert!(result[1]); assert!(!result[2]); }
410
411 #[test]
412 fn test_distance_filter_equivalence() {
413 let points: Vec<_> = (0..100)
415 .map(|i| (i as f32 * 0.1, i as f32 * 0.2, i as f32 * 0.3))
416 .collect();
417
418 let center = (5.0, 5.0, 5.0);
419 let radius_sq = 10.0;
420
421 let scalar_result = distance_filter_scalar(&points, center, radius_sq);
422 let auto_result = distance_filter_l2(&points, center, radius_sq);
423
424 assert_eq!(
425 scalar_result, auto_result,
426 "SIMD and scalar must produce identical results"
427 );
428 }
429
430 #[test]
431 fn test_distance_filter_edge_cases() {
432 let empty: Vec<(f32, f32, f32)> = vec![];
434 let result = distance_filter_l2(&empty, (0.0, 0.0, 0.0), 1.0);
435 assert!(result.is_empty());
436
437 let points = vec![(1.0, 0.0, 0.0)];
439 let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
440 assert!(result[0]); let points = vec![(1.0001, 0.0, 0.0)];
444 let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
445 assert!(!result[0]);
446 }
447
448 #[test]
449 fn test_batch_spatial_filter_nodes_matches_scalar() {
450 use crate::algorithms::four_d::GraphNode4D;
451 use std::collections::BTreeMap;
452
453 let nodes: Vec<GraphNode4D> = (0..100)
454 .map(|i| GraphNode4D {
455 id: i as u64,
456 x: i as f32 * 0.3,
457 y: i as f32 * 0.2,
458 z: i as f32 * 0.1,
459 begin_ts: 0,
460 end_ts: 100,
461 properties: BTreeMap::new(),
462 successors: vec![],
463 })
464 .collect();
465
466 let center = (5.0_f32, 5.0_f32, 5.0_f32);
467 let radius = 4.0_f32;
468 let radius_sq = radius * radius;
469
470 let expected: Vec<usize> = nodes
472 .iter()
473 .enumerate()
474 .filter(|(_, n)| {
475 let dx = n.x - center.0;
476 let dy = n.y - center.1;
477 let dz = n.z - center.2;
478 dx * dx + dy * dy + dz * dz <= radius_sq
479 })
480 .map(|(i, _)| i)
481 .collect();
482
483 let result = batch_spatial_filter_nodes(&nodes, center, radius);
484
485 assert_eq!(result, expected, "SIMD batch must match scalar reference");
486 }
487
488 #[test]
489 fn test_batch_spatial_filter_nodes_empty() {
490 use crate::algorithms::four_d::GraphNode4D;
491 let nodes: Vec<GraphNode4D> = vec![];
492 let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 1.0);
493 assert!(result.is_empty());
494 }
495
496 #[test]
497 fn test_batch_spatial_filter_nodes_all_match() {
498 use crate::algorithms::four_d::GraphNode4D;
499 use std::collections::BTreeMap;
500
501 let nodes: Vec<GraphNode4D> = (0..10)
502 .map(|i| GraphNode4D {
503 id: i as u64,
504 x: 0.01 * i as f32,
505 y: 0.01 * i as f32,
506 z: 0.01 * i as f32,
507 begin_ts: 0,
508 end_ts: 100,
509 properties: BTreeMap::new(),
510 successors: vec![],
511 })
512 .collect();
513
514 let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 100.0);
515 assert_eq!(result.len(), 10, "All nodes should match with large radius");
516 }
517
518 #[test]
519 fn test_batch_spatial_filter_nodes_none_match() {
520 use crate::algorithms::four_d::GraphNode4D;
521 use std::collections::BTreeMap;
522
523 let nodes: Vec<GraphNode4D> = (0..10)
524 .map(|i| GraphNode4D {
525 id: i as u64,
526 x: 1000.0 + i as f32,
527 y: 1000.0,
528 z: 1000.0,
529 begin_ts: 0,
530 end_ts: 100,
531 properties: BTreeMap::new(),
532 successors: vec![],
533 })
534 .collect();
535
536 let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 1.0);
537 assert!(result.is_empty(), "No nodes should match");
538 }
539}