nodedb_vector/distance/simd/
runtime.rs1use super::hamming::fast_hamming;
6use super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
7use crate::distance::typed_scalar;
8
9#[cfg(target_arch = "x86_64")]
10use super::{avx2, avx512};
11
12#[cfg(target_arch = "aarch64")]
13use super::neon;
14
15type HalfFn = fn(&[u8], &[u8], usize) -> f32;
17
18pub struct SimdRuntime {
20 pub l2_squared: fn(&[f32], &[f32]) -> f32,
21 pub cosine_distance: fn(&[f32], &[f32]) -> f32,
22 pub neg_inner_product: fn(&[f32], &[f32]) -> f32,
23 pub hamming: fn(&[u8], &[u8]) -> u32,
24 pub name: &'static str,
25 pub l2_squared_f16: HalfFn,
27 pub cosine_distance_f16: HalfFn,
28 pub neg_inner_product_f16: HalfFn,
29 pub l2_squared_bf16: HalfFn,
31 pub cosine_distance_bf16: HalfFn,
32 pub neg_inner_product_bf16: HalfFn,
33}
34
35impl SimdRuntime {
36 pub fn detect() -> Self {
38 #[cfg(target_arch = "x86_64")]
39 {
40 let has_vpopcntdq = std::is_x86_feature_detected!("avx512vpopcntdq");
41
42 if std::is_x86_feature_detected!("avx512f") {
43 let name = if has_vpopcntdq {
44 "avx512+vpopcntdq"
45 } else {
46 "avx512"
47 };
48 let rt = Self {
49 l2_squared: avx512::l2_squared,
50 cosine_distance: avx512::cosine_distance,
51 neg_inner_product: avx512::neg_inner_product,
52 hamming: fast_hamming,
53 name,
54 l2_squared_f16: typed_scalar::l2_squared_f16,
55 cosine_distance_f16: typed_scalar::cosine_f16,
56 neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
57 l2_squared_bf16: typed_scalar::l2_squared_bf16,
58 cosine_distance_bf16: typed_scalar::cosine_bf16,
59 neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
60 };
61 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
62 debug_assert!(
63 !has_vpopcntdq || rt.name == "avx512+vpopcntdq",
64 "AVX-512 VPOPCNTDQ available but kernel did not select it"
65 );
66 return rt;
67 }
68 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
69 let rt = Self {
70 l2_squared: avx2::l2_squared,
71 cosine_distance: avx2::cosine_distance,
72 neg_inner_product: avx2::neg_inner_product,
73 hamming: fast_hamming,
74 name: "avx2+fma",
75 l2_squared_f16: typed_scalar::l2_squared_f16,
76 cosine_distance_f16: typed_scalar::cosine_f16,
77 neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
78 l2_squared_bf16: typed_scalar::l2_squared_bf16,
79 cosine_distance_bf16: typed_scalar::cosine_bf16,
80 neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
81 };
82 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
83 return rt;
84 }
85 }
86 #[cfg(target_arch = "aarch64")]
87 {
88 let rt = Self {
89 l2_squared: neon::l2_squared,
90 cosine_distance: neon::cosine_distance,
91 neg_inner_product: neon::neg_inner_product,
92 hamming: fast_hamming,
93 name: "neon",
94 l2_squared_f16: typed_scalar::l2_squared_f16,
95 cosine_distance_f16: typed_scalar::cosine_f16,
96 neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
97 l2_squared_bf16: typed_scalar::l2_squared_bf16,
98 cosine_distance_bf16: typed_scalar::cosine_bf16,
99 neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
100 };
101 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
102 return rt;
103 }
104 #[allow(unreachable_code)]
105 {
106 let rt = Self {
107 l2_squared: scalar_l2,
108 cosine_distance: scalar_cosine,
109 neg_inner_product: scalar_ip,
110 hamming: fast_hamming,
111 name: "scalar",
112 l2_squared_f16: typed_scalar::l2_squared_f16,
113 cosine_distance_f16: typed_scalar::cosine_f16,
114 neg_inner_product_f16: typed_scalar::neg_inner_product_f16,
115 l2_squared_bf16: typed_scalar::l2_squared_bf16,
116 cosine_distance_bf16: typed_scalar::cosine_bf16,
117 neg_inner_product_bf16: typed_scalar::neg_inner_product_bf16,
118 };
119 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
120 rt
121 }
122 }
123}
124
125static RUNTIME: std::sync::OnceLock<SimdRuntime> = std::sync::OnceLock::new();
127
128pub fn runtime() -> &'static SimdRuntime {
130 RUNTIME.get_or_init(SimdRuntime::detect)
131}
132
133#[cfg(test)]
134mod tests {
135 use super::super::hamming::fast_hamming;
136 use super::super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
137 use super::*;
138
139 #[test]
140 fn runtime_detects_features() {
141 let rt = SimdRuntime::detect();
142 assert!(!rt.name.is_empty());
143 tracing::info!("SIMD runtime: {}", rt.name);
144 }
145
146 #[test]
147 fn l2_matches_scalar() {
148 let rt = runtime();
149 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
150 let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect();
151
152 let simd_result = (rt.l2_squared)(&a, &b);
153 let scalar_result = scalar_l2(&a, &b);
154 assert!(
155 (simd_result - scalar_result).abs() < 0.01,
156 "simd={simd_result}, scalar={scalar_result}"
157 );
158 }
159
160 #[test]
161 fn cosine_matches_scalar() {
162 let rt = runtime();
163 let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
164 let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
165
166 let simd_result = (rt.cosine_distance)(&a, &b);
167 let scalar_result = scalar_cosine(&a, &b);
168 assert!(
169 (simd_result - scalar_result).abs() < 0.001,
170 "simd={simd_result}, scalar={scalar_result}"
171 );
172 }
173
174 #[test]
175 fn ip_matches_scalar() {
176 let rt = runtime();
177 let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
178 let b: Vec<f32> = (0..128).map(|i| (i as f32) * 0.2).collect();
179
180 let simd_result = (rt.neg_inner_product)(&a, &b);
181 let scalar_result = scalar_ip(&a, &b);
182 assert!(
183 (simd_result - scalar_result).abs() < 0.1,
184 "simd={simd_result}, scalar={scalar_result}"
185 );
186 }
187
188 #[test]
189 fn hamming_matches() {
190 let a = vec![0b10101010u8; 16];
191 let b = vec![0b01010101u8; 16];
192 assert_eq!(fast_hamming(&a, &b), 128);
193 }
194
195 #[test]
196 fn small_vectors() {
197 let rt = runtime();
198 let a = [1.0f32, 2.0, 3.0];
199 let b = [4.0f32, 5.0, 6.0];
200 let l2 = (rt.l2_squared)(&a, &b);
201 assert!((l2 - 27.0).abs() < 0.01);
202 }
203}