Skip to main content

nodedb_vector/distance/simd/
runtime.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Runtime SIMD detection and dispatch table.
4
5use super::hamming::fast_hamming;
6use super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
7use crate::distance::typed_scalar;
8
9#[cfg(target_arch = "x86_64")]
10use super::{avx2, avx512};
11
12#[cfg(target_arch = "aarch64")]
13use super::neon;
14
15/// Function pointer type for half-precision byte-level distance kernels.
16type HalfFn = fn(&[u8], &[u8], usize) -> f32;
17
18/// Selected SIMD runtime — function pointers to the best available kernels.
19pub struct SimdRuntime {
20    pub l2_squared: fn(&[f32], &[f32]) -> f32,
21    pub cosine_distance: fn(&[f32], &[f32]) -> f32,
22    pub neg_inner_product: fn(&[f32], &[f32]) -> f32,
23    pub hamming: fn(&[u8], &[u8]) -> u32,
24    pub name: &'static str,
25    /// F16 fused decode-and-compute kernels (no intermediate Vec<f32>).
26    pub l2_squared_f16: HalfFn,
27    pub cosine_distance_f16: HalfFn,
28    pub neg_inner_product_f16: HalfFn,
29    /// BF16 fused decode-and-compute kernels.
30    pub l2_squared_bf16: HalfFn,
31    pub cosine_distance_bf16: HalfFn,
32    pub neg_inner_product_bf16: HalfFn,
33}
34
35impl SimdRuntime {
36    /// Detect CPU features and select the best kernels.
37    pub fn detect() -> Self {
38        #[cfg(target_arch = "x86_64")]
39        {
40            let has_vpopcntdq = std::is_x86_feature_detected!("avx512vpopcntdq");
41
42            if std::is_x86_feature_detected!("avx512f") {
43                let name = if has_vpopcntdq {
44                    "avx512+vpopcntdq"
45                } else {
46                    "avx512"
47                };
48                let rt = Self {
49                    l2_squared: avx512::l2_squared,
50                    cosine_distance: avx512::cosine_distance,
51                    neg_inner_product: avx512::neg_inner_product,
52                    hamming: fast_hamming,
53                    name,
54                    l2_squared_f16: typed_scalar::l2_squared_f16,
55                    cosine_distance_f16: typed_scalar::cosine_f16,
56                    neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
57                    l2_squared_bf16: typed_scalar::l2_squared_bf16,
58                    cosine_distance_bf16: typed_scalar::cosine_bf16,
59                    neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
60                };
61                tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
62                debug_assert!(
63                    !has_vpopcntdq || rt.name == "avx512+vpopcntdq",
64                    "AVX-512 VPOPCNTDQ available but kernel did not select it"
65                );
66                return rt;
67            }
68            if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
69                let rt = Self {
70                    l2_squared: avx2::l2_squared,
71                    cosine_distance: avx2::cosine_distance,
72                    neg_inner_product: avx2::neg_inner_product,
73                    hamming: fast_hamming,
74                    name: "avx2+fma",
75                    l2_squared_f16: typed_scalar::l2_squared_f16,
76                    cosine_distance_f16: typed_scalar::cosine_f16,
77                    neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
78                    l2_squared_bf16: typed_scalar::l2_squared_bf16,
79                    cosine_distance_bf16: typed_scalar::cosine_bf16,
80                    neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
81                };
82                tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
83                return rt;
84            }
85        }
86        #[cfg(target_arch = "aarch64")]
87        {
88            let rt = Self {
89                l2_squared: neon::l2_squared,
90                cosine_distance: neon::cosine_distance,
91                neg_inner_product: neon::neg_inner_product,
92                hamming: fast_hamming,
93                name: "neon",
94                l2_squared_f16: typed_scalar::l2_squared_f16,
95                cosine_distance_f16: typed_scalar::cosine_f16,
96                neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
97                l2_squared_bf16: typed_scalar::l2_squared_bf16,
98                cosine_distance_bf16: typed_scalar::cosine_bf16,
99                neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
100            };
101            tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
102            return rt;
103        }
104        #[allow(unreachable_code)]
105        {
106            let rt = Self {
107                l2_squared: scalar_l2,
108                cosine_distance: scalar_cosine,
109                neg_inner_product: scalar_ip,
110                hamming: fast_hamming,
111                name: "scalar",
112                l2_squared_f16: typed_scalar::l2_squared_f16,
113                cosine_distance_f16: typed_scalar::cosine_f16,
114                neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
115                l2_squared_bf16: typed_scalar::l2_squared_bf16,
116                cosine_distance_bf16: typed_scalar::cosine_bf16,
117                neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
118            };
119            tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
120            rt
121        }
122    }
123}
124
125/// Global SIMD runtime — initialized once, used everywhere.
126static RUNTIME: std::sync::OnceLock<SimdRuntime> = std::sync::OnceLock::new();
127
128/// Get the global SIMD runtime (auto-detects on first call).
129pub fn runtime() -> &'static SimdRuntime {
130    RUNTIME.get_or_init(SimdRuntime::detect)
131}
132
133#[cfg(test)]
134mod tests {
135    use super::super::hamming::fast_hamming;
136    use super::super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
137    use super::*;
138
139    #[test]
140    fn runtime_detects_features() {
141        let rt = SimdRuntime::detect();
142        assert!(!rt.name.is_empty());
143        tracing::info!("SIMD runtime: {}", rt.name);
144    }
145
146    #[test]
147    fn l2_matches_scalar() {
148        let rt = runtime();
149        let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
150        let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect();
151
152        let simd_result = (rt.l2_squared)(&a, &b);
153        let scalar_result = scalar_l2(&a, &b);
154        assert!(
155            (simd_result - scalar_result).abs() < 0.01,
156            "simd={simd_result}, scalar={scalar_result}"
157        );
158    }
159
160    #[test]
161    fn cosine_matches_scalar() {
162        let rt = runtime();
163        let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
164        let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
165
166        let simd_result = (rt.cosine_distance)(&a, &b);
167        let scalar_result = scalar_cosine(&a, &b);
168        assert!(
169            (simd_result - scalar_result).abs() < 0.001,
170            "simd={simd_result}, scalar={scalar_result}"
171        );
172    }
173
174    #[test]
175    fn ip_matches_scalar() {
176        let rt = runtime();
177        let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
178        let b: Vec<f32> = (0..128).map(|i| (i as f32) * 0.2).collect();
179
180        let simd_result = (rt.neg_inner_product)(&a, &b);
181        let scalar_result = scalar_ip(&a, &b);
182        assert!(
183            (simd_result - scalar_result).abs() < 0.1,
184            "simd={simd_result}, scalar={scalar_result}"
185        );
186    }
187
188    #[test]
189    fn hamming_matches() {
190        let a = vec![0b10101010u8; 16];
191        let b = vec![0b01010101u8; 16];
192        assert_eq!(fast_hamming(&a, &b), 128);
193    }
194
195    #[test]
196    fn small_vectors() {
197        let rt = runtime();
198        let a = [1.0f32, 2.0, 3.0];
199        let b = [4.0f32, 5.0, 6.0];
200        let l2 = (rt.l2_squared)(&a, &b);
201        assert!((l2 - 27.0).abs() < 0.01);
202    }
203}