Skip to main content

nodedb_vector/distance/simd/
runtime.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Runtime SIMD detection and dispatch table.
4
5use super::hamming::fast_hamming;
6use super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
7
8#[cfg(target_arch = "x86_64")]
9use super::{avx2, avx512};
10
11#[cfg(target_arch = "aarch64")]
12use super::neon;
13
14/// Selected SIMD runtime — function pointers to the best available kernels.
15pub struct SimdRuntime {
16    pub l2_squared: fn(&[f32], &[f32]) -> f32,
17    pub cosine_distance: fn(&[f32], &[f32]) -> f32,
18    pub neg_inner_product: fn(&[f32], &[f32]) -> f32,
19    pub hamming: fn(&[u8], &[u8]) -> u32,
20    pub name: &'static str,
21}
22
23impl SimdRuntime {
24    /// Detect CPU features and select the best kernels.
25    pub fn detect() -> Self {
26        #[cfg(target_arch = "x86_64")]
27        {
28            let has_vpopcntdq = std::is_x86_feature_detected!("avx512vpopcntdq");
29
30            if std::is_x86_feature_detected!("avx512f") {
31                let name = if has_vpopcntdq {
32                    "avx512+vpopcntdq"
33                } else {
34                    "avx512"
35                };
36                let rt = Self {
37                    l2_squared: avx512::l2_squared,
38                    cosine_distance: avx512::cosine_distance,
39                    neg_inner_product: avx512::neg_inner_product,
40                    hamming: fast_hamming,
41                    name,
42                };
43                tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
44                debug_assert!(
45                    !has_vpopcntdq || rt.name == "avx512+vpopcntdq",
46                    "AVX-512 VPOPCNTDQ available but kernel did not select it"
47                );
48                return rt;
49            }
50            if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
51                let rt = Self {
52                    l2_squared: avx2::l2_squared,
53                    cosine_distance: avx2::cosine_distance,
54                    neg_inner_product: avx2::neg_inner_product,
55                    hamming: fast_hamming,
56                    name: "avx2+fma",
57                };
58                tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
59                return rt;
60            }
61        }
62        #[cfg(target_arch = "aarch64")]
63        {
64            let rt = Self {
65                l2_squared: neon::l2_squared,
66                cosine_distance: neon::cosine_distance,
67                neg_inner_product: neon::neg_inner_product,
68                hamming: fast_hamming,
69                name: "neon",
70            };
71            tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
72            return rt;
73        }
74        #[allow(unreachable_code)]
75        {
76            let rt = Self {
77                l2_squared: scalar_l2,
78                cosine_distance: scalar_cosine,
79                neg_inner_product: scalar_ip,
80                hamming: fast_hamming,
81                name: "scalar",
82            };
83            tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
84            rt
85        }
86    }
87}
88
89/// Global SIMD runtime — initialized once, used everywhere.
90static RUNTIME: std::sync::OnceLock<SimdRuntime> = std::sync::OnceLock::new();
91
92/// Get the global SIMD runtime (auto-detects on first call).
93pub fn runtime() -> &'static SimdRuntime {
94    RUNTIME.get_or_init(SimdRuntime::detect)
95}
96
97#[cfg(test)]
98mod tests {
99    use super::super::hamming::fast_hamming;
100    use super::super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
101    use super::*;
102
103    #[test]
104    fn runtime_detects_features() {
105        let rt = SimdRuntime::detect();
106        assert!(!rt.name.is_empty());
107        tracing::info!("SIMD runtime: {}", rt.name);
108    }
109
110    #[test]
111    fn l2_matches_scalar() {
112        let rt = runtime();
113        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
114        let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect();
115
116        let simd_result = (rt.l2_squared)(&a, &b);
117        let scalar_result = scalar_l2(&a, &b);
118        assert!(
119            (simd_result - scalar_result).abs() < 0.01,
120            "simd={simd_result}, scalar={scalar_result}"
121        );
122    }
123
124    #[test]
125    fn cosine_matches_scalar() {
126        let rt = runtime();
127        let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
128        let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
129
130        let simd_result = (rt.cosine_distance)(&a, &b);
131        let scalar_result = scalar_cosine(&a, &b);
132        assert!(
133            (simd_result - scalar_result).abs() < 0.001,
134            "simd={simd_result}, scalar={scalar_result}"
135        );
136    }
137
138    #[test]
139    fn ip_matches_scalar() {
140        let rt = runtime();
141        let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
142        let b: Vec<f32> = (0..128).map(|i| (i as f32) * 0.2).collect();
143
144        let simd_result = (rt.neg_inner_product)(&a, &b);
145        let scalar_result = scalar_ip(&a, &b);
146        assert!(
147            (simd_result - scalar_result).abs() < 0.1,
148            "simd={simd_result}, scalar={scalar_result}"
149        );
150    }
151
152    #[test]
153    fn hamming_matches() {
154        let a = vec![0b10101010u8; 16];
155        let b = vec![0b01010101u8; 16];
156        assert_eq!(fast_hamming(&a, &b), 128);
157    }
158
159    #[test]
160    fn small_vectors() {
161        let rt = runtime();
162        let a = [1.0f32, 2.0, 3.0];
163        let b = [4.0f32, 5.0, 6.0];
164        let l2 = (rt.l2_squared)(&a, &b);
165        assert!((l2 - 27.0).abs() < 0.01);
166    }
167}