Skip to main content

geographdb_core/spatial/
simd.rs

1//! SIMD-accelerated spatial distance filtering
2//!
3//! Provides portable SIMD implementations for 3D distance calculations
4//! with runtime CPU feature detection and scalar fallback.
5//!
6//! Ported from geographdb_prototype/acceleration/simd_backend.rs
7
8/// Filter points by L2 distance using best available SIMD implementation
9///
10/// # Arguments
11/// * `points` - Slice of (x, y, z) tuples
12/// * `center` - Query center point (cx, cy, cz)
13/// * `radius_sq` - Squared radius for inclusion
14///
15/// # Returns
16/// Vec<bool> where true means point is within radius
17pub fn distance_filter_l2(
18    points: &[(f32, f32, f32)],
19    center: (f32, f32, f32),
20    radius_sq: f32,
21) -> Vec<bool> {
22    // Runtime CPU feature dispatch
23    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
24    {
25        if is_avx512_supported() {
26            return unsafe { distance_filter_avx512(points, center, radius_sq) };
27        }
28        if is_avx2_supported() {
29            return unsafe { distance_filter_avx2(points, center, radius_sq) };
30        }
31        if is_sse2_supported() {
32            return unsafe { distance_filter_sse2(points, center, radius_sq) };
33        }
34    }
35
36    // Scalar fallback (portable, always available)
37    distance_filter_scalar(points, center, radius_sq)
38}
39
40/// Scalar implementation - guaranteed to work on all platforms
41pub fn distance_filter_scalar(
42    points: &[(f32, f32, f32)],
43    center: (f32, f32, f32),
44    radius_sq: f32,
45) -> Vec<bool> {
46    let (cx, cy, cz) = center;
47    points
48        .iter()
49        .map(|(x, y, z)| {
50            let dx = x - cx;
51            let dy = y - cy;
52            let dz = z - cz;
53            let d2 = dx * dx + dy * dy + dz * dz;
54            d2 <= radius_sq
55        })
56        .collect()
57}
58
59#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
60fn is_avx512_supported() -> bool {
61    use std::sync::atomic::{AtomicU8, Ordering};
62    static CACHED: AtomicU8 = AtomicU8::new(0);
63
64    match CACHED.load(Ordering::Relaxed) {
65        1 => return false,
66        2 => return true,
67        _ => {}
68    }
69
70    let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 16) != 0;
71    CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
72    supported
73}
74
75#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76fn is_avx2_supported() -> bool {
77    use std::sync::atomic::{AtomicU8, Ordering};
78    static CACHED: AtomicU8 = AtomicU8::new(0);
79
80    match CACHED.load(Ordering::Relaxed) {
81        1 => return false,
82        2 => return true,
83        _ => {}
84    }
85
86    let supported = std::arch::x86_64::__cpuid(7).ebx & (1 << 5) != 0;
87    CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
88    supported
89}
90
91#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
92fn is_sse2_supported() -> bool {
93    // SSE2 is baseline for x86_64, only check for x86
94    #[cfg(target_arch = "x86_64")]
95    return true;
96    #[cfg(target_arch = "x86")]
97    {
98        use std::sync::atomic::{AtomicU8, Ordering};
99        static CACHED: AtomicU8 = AtomicU8::new(0);
100
101        match CACHED.load(Ordering::Relaxed) {
102            1 => return false,
103            2 => return true,
104            _ => {}
105        }
106
107        let supported = unsafe { std::arch::x86::__cpuid(1).edx & (1 << 26) != 0 };
108        CACHED.store(if supported { 2 } else { 1 }, Ordering::Relaxed);
109        supported
110    }
111}
112
113#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
114#[target_feature(enable = "avx512f")]
115unsafe fn distance_filter_avx512(
116    points: &[(f32, f32, f32)],
117    center: (f32, f32, f32),
118    radius_sq: f32,
119) -> Vec<bool> {
120    use std::arch::x86_64::*;
121
122    // Verify tuple layout assumptions at runtime (highly optimized out by compiler)
123    assert_eq!(std::mem::size_of::<(f32, f32, f32)>(), 12);
124    assert_eq!(std::mem::align_of::<(f32, f32, f32)>(), 4);
125
126    let (cx, cy, cz) = center;
127    let cx_vec = _mm512_set1_ps(cx);
128    let cy_vec = _mm512_set1_ps(cy);
129    let cz_vec = _mm512_set1_ps(cz);
130    let radius_vec = _mm512_set1_ps(radius_sq);
131
132    let mut result = Vec::with_capacity(points.len());
133    let mut i = 0;
134
135    let points_ptr = points.as_ptr() as *const f32;
136
137    // Shuffle masks for AVX-512 de-interleaving AoS to SoA
138    let x_mask_0 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
139    let x_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 2, 5, 8, 11, 14, 0, 0, 0, 0, 0);
140    let x_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4, 7, 10, 13);
141
142    let y_mask_0 = _mm512_setr_epi32(1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
143    let y_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 3, 6, 9, 12, 15, 0, 0, 0, 0, 0);
144    let y_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 8, 11, 14);
145
146    let z_mask_0 = _mm512_setr_epi32(2, 5, 8, 11, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
147    let z_mask_1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0);
148    let z_mask_2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 9, 12, 15);
149
150    // Process 16 points at a time
151    while i + 16 <= points.len() {
152        // Load 48 floats representing 16 points into three 512-bit registers
153        let r0 = _mm512_loadu_ps(points_ptr.add(i * 3));
154        let r1 = _mm512_loadu_ps(points_ptr.add(i * 3 + 16));
155        let r2 = _mm512_loadu_ps(points_ptr.add(i * 3 + 32));
156
157        // Permute to collect components
158        let p0_x = _mm512_permutexvar_ps(x_mask_0, r0);
159        let p1_x = _mm512_permutexvar_ps(x_mask_1, r1);
160        let p2_x = _mm512_permutexvar_ps(x_mask_2, r2);
161
162        let p01_x = _mm512_mask_blend_ps(0b00000111_11000000, p0_x, p1_x);
163        let x_vec = _mm512_mask_blend_ps(0b11111000_00000000, p01_x, p2_x);
164
165        let p0_y = _mm512_permutexvar_ps(y_mask_0, r0);
166        let p1_y = _mm512_permutexvar_ps(y_mask_1, r1);
167        let p2_y = _mm512_permutexvar_ps(y_mask_2, r2);
168
169        let p01_y = _mm512_mask_blend_ps(0b00000111_11100000, p0_y, p1_y);
170        let y_vec = _mm512_mask_blend_ps(0b11111000_00000000, p01_y, p2_y);
171
172        let p0_z = _mm512_permutexvar_ps(z_mask_0, r0);
173        let p1_z = _mm512_permutexvar_ps(z_mask_1, r1);
174        let p2_z = _mm512_permutexvar_ps(z_mask_2, r2);
175
176        let p01_z = _mm512_mask_blend_ps(0b00000011_11100000, p0_z, p1_z);
177        let z_vec = _mm512_mask_blend_ps(0b11111100_00000000, p01_z, p2_z);
178
179        let dx = _mm512_sub_ps(x_vec, cx_vec);
180        let dy = _mm512_sub_ps(y_vec, cy_vec);
181        let dz = _mm512_sub_ps(z_vec, cz_vec);
182
183        let dx2 = _mm512_mul_ps(dx, dx);
184        let dy2 = _mm512_mul_ps(dy, dy);
185        let dz2 = _mm512_mul_ps(dz, dz);
186
187        let dist_sq = _mm512_add_ps(_mm512_add_ps(dx2, dy2), dz2);
188        let mask = _mm512_cmple_ps_mask(dist_sq, radius_vec);
189
190        for j in 0..16 {
191            result.push((mask >> j) & 1 != 0);
192        }
193
194        i += 16;
195    }
196
197    // Handle remaining points with scalar
198    while i < points.len() {
199        let (x, y, z) = points[i];
200        let dx = x - cx;
201        let dy = y - cy;
202        let dz = z - cz;
203        result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
204        i += 1;
205    }
206
207    result
208}
209
210#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
211#[target_feature(enable = "avx2")]
212unsafe fn distance_filter_avx2(
213    points: &[(f32, f32, f32)],
214    center: (f32, f32, f32),
215    radius_sq: f32,
216) -> Vec<bool> {
217    use std::arch::x86_64::*;
218
219    // Verify tuple layout assumptions at runtime (highly optimized out by compiler)
220    assert_eq!(std::mem::size_of::<(f32, f32, f32)>(), 12);
221    assert_eq!(std::mem::align_of::<(f32, f32, f32)>(), 4);
222
223    let (cx, cy, cz) = center;
224    let cx_vec = _mm256_set1_ps(cx);
225    let cy_vec = _mm256_set1_ps(cy);
226    let cz_vec = _mm256_set1_ps(cz);
227    let radius_vec = _mm256_set1_ps(radius_sq);
228
229    let mut result = Vec::with_capacity(points.len());
230    let mut i = 0;
231
232    let points_ptr = points.as_ptr() as *const f32;
233
234    // Shuffle masks for de-interleaving AoS to SoA
235    let x_mask_0 = _mm256_setr_epi32(0, 3, 6, 0, 0, 0, 0, 0);
236    let x_mask_1 = _mm256_setr_epi32(0, 0, 0, 1, 4, 7, 0, 0);
237    let x_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 2, 5);
238
239    let y_mask_0 = _mm256_setr_epi32(1, 4, 7, 0, 0, 0, 0, 0);
240    let y_mask_1 = _mm256_setr_epi32(0, 0, 0, 2, 5, 0, 0, 0);
241    let y_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 3, 6);
242
243    let z_mask_0 = _mm256_setr_epi32(2, 5, 0, 0, 0, 0, 0, 0);
244    let z_mask_1 = _mm256_setr_epi32(0, 0, 0, 3, 6, 0, 0, 0);
245    let z_mask_2 = _mm256_setr_epi32(0, 0, 0, 0, 0, 1, 4, 7);
246
247    // Process 8 points at a time
248    while i + 8 <= points.len() {
249        // Load 24 floats representing 8 points into three 256-bit registers
250        let r0 = _mm256_loadu_ps(points_ptr.add(i * 3));
251        let r1 = _mm256_loadu_ps(points_ptr.add(i * 3 + 8));
252        let r2 = _mm256_loadu_ps(points_ptr.add(i * 3 + 16));
253
254        // Permute to collect components
255        let p0_x = _mm256_permutevar8x32_ps(r0, x_mask_0);
256        let p1_x = _mm256_permutevar8x32_ps(r1, x_mask_1);
257        let p2_x = _mm256_permutevar8x32_ps(r2, x_mask_2);
258
259        let p01_x = _mm256_blend_ps(p0_x, p1_x, 0b00111000);
260        let x_vec = _mm256_blend_ps(p01_x, p2_x, 0b11000000);
261
262        let p0_y = _mm256_permutevar8x32_ps(r0, y_mask_0);
263        let p1_y = _mm256_permutevar8x32_ps(r1, y_mask_1);
264        let p2_y = _mm256_permutevar8x32_ps(r2, y_mask_2);
265
266        let p01_y = _mm256_blend_ps(p0_y, p1_y, 0b00011000);
267        let y_vec = _mm256_blend_ps(p01_y, p2_y, 0b11100000);
268
269        let p0_z = _mm256_permutevar8x32_ps(r0, z_mask_0);
270        let p1_z = _mm256_permutevar8x32_ps(r1, z_mask_1);
271        let p2_z = _mm256_permutevar8x32_ps(r2, z_mask_2);
272
273        let p01_z = _mm256_blend_ps(p0_z, p1_z, 0b00011100);
274        let z_vec = _mm256_blend_ps(p01_z, p2_z, 0b11100000);
275
276        let dx = _mm256_sub_ps(x_vec, cx_vec);
277        let dy = _mm256_sub_ps(y_vec, cy_vec);
278        let dz = _mm256_sub_ps(z_vec, cz_vec);
279
280        let dx2 = _mm256_mul_ps(dx, dx);
281        let dy2 = _mm256_mul_ps(dy, dy);
282        let dz2 = _mm256_mul_ps(dz, dz);
283
284        let dist_sq = _mm256_add_ps(_mm256_add_ps(dx2, dy2), dz2);
285        let mask = _mm256_movemask_ps(_mm256_cmp_ps(dist_sq, radius_vec, _CMP_LE_OS));
286
287        for j in 0..8 {
288            result.push((mask >> j) & 1 != 0);
289        }
290
291        i += 8;
292    }
293
294    // Handle remaining points with scalar
295    while i < points.len() {
296        let (x, y, z) = points[i];
297        let dx = x - cx;
298        let dy = y - cy;
299        let dz = z - cz;
300        result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
301        i += 1;
302    }
303
304    result
305}
306
307#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
308#[target_feature(enable = "sse2")]
309unsafe fn distance_filter_sse2(
310    points: &[(f32, f32, f32)],
311    center: (f32, f32, f32),
312    radius_sq: f32,
313) -> Vec<bool> {
314    use std::arch::x86_64::*;
315
316    let (cx, cy, cz) = center;
317    let cx_vec = _mm_set1_ps(cx);
318    let cy_vec = _mm_set1_ps(cy);
319    let cz_vec = _mm_set1_ps(cz);
320    let radius_vec = _mm_set1_ps(radius_sq);
321
322    let mut result = Vec::with_capacity(points.len());
323    let mut i = 0;
324
325    // Process 4 points at a time
326    while i + 4 <= points.len() {
327        let mut xs = [0.0f32; 4];
328        let mut ys = [0.0f32; 4];
329        let mut zs = [0.0f32; 4];
330
331        for j in 0..4 {
332            xs[j] = points[i + j].0;
333            ys[j] = points[i + j].1;
334            zs[j] = points[i + j].2;
335        }
336
337        let x_vec = _mm_loadu_ps(xs.as_ptr());
338        let y_vec = _mm_loadu_ps(ys.as_ptr());
339        let z_vec = _mm_loadu_ps(zs.as_ptr());
340
341        let dx = _mm_sub_ps(x_vec, cx_vec);
342        let dy = _mm_sub_ps(y_vec, cy_vec);
343        let dz = _mm_sub_ps(z_vec, cz_vec);
344
345        let dx2 = _mm_mul_ps(dx, dx);
346        let dy2 = _mm_mul_ps(dy, dy);
347        let dz2 = _mm_mul_ps(dz, dz);
348
349        let dist_sq = _mm_add_ps(_mm_add_ps(dx2, dy2), dz2);
350        let mask = _mm_movemask_ps(_mm_cmple_ps(dist_sq, radius_vec));
351
352        for j in 0..4 {
353            result.push((mask >> j) & 1 != 0);
354        }
355
356        i += 4;
357    }
358
359    // Handle remaining points with scalar
360    while i < points.len() {
361        let (x, y, z) = points[i];
362        let dx = x - cx;
363        let dy = y - cy;
364        let dz = z - cz;
365        result.push(dx * dx + dy * dy + dz * dz <= radius_sq);
366        i += 1;
367    }
368
369    result
370}
371
372/// Batch-filter `GraphNode4D` nodes by L2 distance using SIMD.
373///
374/// Returns indices of nodes whose `(x, y, z)` position falls within
375/// `radius` of `center`. Temporal filtering is NOT included — this
376/// is a pure spatial filter.
377pub fn batch_spatial_filter_nodes(
378    nodes: &[crate::algorithms::four_d::GraphNode4D],
379    center: (f32, f32, f32),
380    radius: f32,
381) -> Vec<usize> {
382    let radius_sq = radius * radius;
383    let coords: Vec<(f32, f32, f32)> = nodes.iter().map(|n| (n.x, n.y, n.z)).collect();
384    let mask = distance_filter_l2(&coords, center, radius_sq);
385    mask.into_iter()
386        .enumerate()
387        .filter(|&(_, inside)| inside)
388        .map(|(i, _)| i)
389        .collect()
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_distance_filter_scalar_basic() {
398        let points = vec![
399            (0.0, 0.0, 0.0), // Center - should match
400            (1.0, 0.0, 0.0), // Distance 1 - should match (radius 2)
401            (3.0, 0.0, 0.0), // Distance 3 - should NOT match
402        ];
403
404        let result = distance_filter_scalar(&points, (0.0, 0.0, 0.0), 4.0);
405        assert_eq!(result.len(), 3);
406        assert!(result[0]); // Center point
407        assert!(result[1]); // Distance 1 <= 2
408        assert!(!result[2]); // Distance 3 > 2
409    }
410
411    #[test]
412    fn test_distance_filter_equivalence() {
413        // Generate test points
414        let points: Vec<_> = (0..100)
415            .map(|i| (i as f32 * 0.1, i as f32 * 0.2, i as f32 * 0.3))
416            .collect();
417
418        let center = (5.0, 5.0, 5.0);
419        let radius_sq = 10.0;
420
421        let scalar_result = distance_filter_scalar(&points, center, radius_sq);
422        let auto_result = distance_filter_l2(&points, center, radius_sq);
423
424        assert_eq!(
425            scalar_result, auto_result,
426            "SIMD and scalar must produce identical results"
427        );
428    }
429
430    #[test]
431    fn test_distance_filter_edge_cases() {
432        // Empty input
433        let empty: Vec<(f32, f32, f32)> = vec![];
434        let result = distance_filter_l2(&empty, (0.0, 0.0, 0.0), 1.0);
435        assert!(result.is_empty());
436
437        // Single point exactly at radius boundary
438        let points = vec![(1.0, 0.0, 0.0)];
439        let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
440        assert!(result[0]); // Distance squared = 1.0, radius_sq = 1.0, should be <=
441
442        // Point just outside
443        let points = vec![(1.0001, 0.0, 0.0)];
444        let result = distance_filter_l2(&points, (0.0, 0.0, 0.0), 1.0);
445        assert!(!result[0]);
446    }
447
448    #[test]
449    fn test_batch_spatial_filter_nodes_matches_scalar() {
450        use crate::algorithms::four_d::GraphNode4D;
451        use std::collections::BTreeMap;
452
453        let nodes: Vec<GraphNode4D> = (0..100)
454            .map(|i| GraphNode4D {
455                id: i as u64,
456                x: i as f32 * 0.3,
457                y: i as f32 * 0.2,
458                z: i as f32 * 0.1,
459                begin_ts: 0,
460                end_ts: 100,
461                properties: BTreeMap::new(),
462                successors: vec![],
463            })
464            .collect();
465
466        let center = (5.0_f32, 5.0_f32, 5.0_f32);
467        let radius = 4.0_f32;
468        let radius_sq = radius * radius;
469
470        // Scalar reference: check each node individually
471        let expected: Vec<usize> = nodes
472            .iter()
473            .enumerate()
474            .filter(|(_, n)| {
475                let dx = n.x - center.0;
476                let dy = n.y - center.1;
477                let dz = n.z - center.2;
478                dx * dx + dy * dy + dz * dz <= radius_sq
479            })
480            .map(|(i, _)| i)
481            .collect();
482
483        let result = batch_spatial_filter_nodes(&nodes, center, radius);
484
485        assert_eq!(result, expected, "SIMD batch must match scalar reference");
486    }
487
488    #[test]
489    fn test_batch_spatial_filter_nodes_empty() {
490        use crate::algorithms::four_d::GraphNode4D;
491        let nodes: Vec<GraphNode4D> = vec![];
492        let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 1.0);
493        assert!(result.is_empty());
494    }
495
496    #[test]
497    fn test_batch_spatial_filter_nodes_all_match() {
498        use crate::algorithms::four_d::GraphNode4D;
499        use std::collections::BTreeMap;
500
501        let nodes: Vec<GraphNode4D> = (0..10)
502            .map(|i| GraphNode4D {
503                id: i as u64,
504                x: 0.01 * i as f32,
505                y: 0.01 * i as f32,
506                z: 0.01 * i as f32,
507                begin_ts: 0,
508                end_ts: 100,
509                properties: BTreeMap::new(),
510                successors: vec![],
511            })
512            .collect();
513
514        let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 100.0);
515        assert_eq!(result.len(), 10, "All nodes should match with large radius");
516    }
517
518    #[test]
519    fn test_batch_spatial_filter_nodes_none_match() {
520        use crate::algorithms::four_d::GraphNode4D;
521        use std::collections::BTreeMap;
522
523        let nodes: Vec<GraphNode4D> = (0..10)
524            .map(|i| GraphNode4D {
525                id: i as u64,
526                x: 1000.0 + i as f32,
527                y: 1000.0,
528                z: 1000.0,
529                begin_ts: 0,
530                end_ts: 100,
531                properties: BTreeMap::new(),
532                successors: vec![],
533            })
534            .collect();
535
536        let result = batch_spatial_filter_nodes(&nodes, (0.0, 0.0, 0.0), 1.0);
537        assert!(result.is_empty(), "No nodes should match");
538    }
539}