nodedb_vector/distance/simd/
runtime.rs1use super::hamming::fast_hamming;
6use super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
7
8#[cfg(target_arch = "x86_64")]
9use super::{avx2, avx512};
10
11#[cfg(target_arch = "aarch64")]
12use super::neon;
13
14pub struct SimdRuntime {
16 pub l2_squared: fn(&[f32], &[f32]) -> f32,
17 pub cosine_distance: fn(&[f32], &[f32]) -> f32,
18 pub neg_inner_product: fn(&[f32], &[f32]) -> f32,
19 pub hamming: fn(&[u8], &[u8]) -> u32,
20 pub name: &'static str,
21}
22
23impl SimdRuntime {
24 pub fn detect() -> Self {
26 #[cfg(target_arch = "x86_64")]
27 {
28 let has_vpopcntdq = std::is_x86_feature_detected!("avx512vpopcntdq");
29
30 if std::is_x86_feature_detected!("avx512f") {
31 let name = if has_vpopcntdq {
32 "avx512+vpopcntdq"
33 } else {
34 "avx512"
35 };
36 let rt = Self {
37 l2_squared: avx512::l2_squared,
38 cosine_distance: avx512::cosine_distance,
39 neg_inner_product: avx512::neg_inner_product,
40 hamming: fast_hamming,
41 name,
42 };
43 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
44 debug_assert!(
45 !has_vpopcntdq || rt.name == "avx512+vpopcntdq",
46 "AVX-512 VPOPCNTDQ available but kernel did not select it"
47 );
48 return rt;
49 }
50 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
51 let rt = Self {
52 l2_squared: avx2::l2_squared,
53 cosine_distance: avx2::cosine_distance,
54 neg_inner_product: avx2::neg_inner_product,
55 hamming: fast_hamming,
56 name: "avx2+fma",
57 };
58 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
59 return rt;
60 }
61 }
62 #[cfg(target_arch = "aarch64")]
63 {
64 let rt = Self {
65 l2_squared: neon::l2_squared,
66 cosine_distance: neon::cosine_distance,
67 neg_inner_product: neon::neg_inner_product,
68 hamming: fast_hamming,
69 name: "neon",
70 };
71 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
72 return rt;
73 }
74 #[allow(unreachable_code)]
75 {
76 let rt = Self {
77 l2_squared: scalar_l2,
78 cosine_distance: scalar_cosine,
79 neg_inner_product: scalar_ip,
80 hamming: fast_hamming,
81 name: "scalar",
82 };
83 tracing::info!(kernel = rt.name, "vector SIMD kernel selected");
84 rt
85 }
86 }
87}
88
89static RUNTIME: std::sync::OnceLock<SimdRuntime> = std::sync::OnceLock::new();
91
92pub fn runtime() -> &'static SimdRuntime {
94 RUNTIME.get_or_init(SimdRuntime::detect)
95}
96
97#[cfg(test)]
98mod tests {
99 use super::super::hamming::fast_hamming;
100 use super::super::scalar::{scalar_cosine, scalar_ip, scalar_l2};
101 use super::*;
102
103 #[test]
104 fn runtime_detects_features() {
105 let rt = SimdRuntime::detect();
106 assert!(!rt.name.is_empty());
107 tracing::info!("SIMD runtime: {}", rt.name);
108 }
109
110 #[test]
111 fn l2_matches_scalar() {
112 let rt = runtime();
113 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
114 let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01 + 0.001).collect();
115
116 let simd_result = (rt.l2_squared)(&a, &b);
117 let scalar_result = scalar_l2(&a, &b);
118 assert!(
119 (simd_result - scalar_result).abs() < 0.01,
120 "simd={simd_result}, scalar={scalar_result}"
121 );
122 }
123
124 #[test]
125 fn cosine_matches_scalar() {
126 let rt = runtime();
127 let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
128 let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
129
130 let simd_result = (rt.cosine_distance)(&a, &b);
131 let scalar_result = scalar_cosine(&a, &b);
132 assert!(
133 (simd_result - scalar_result).abs() < 0.001,
134 "simd={simd_result}, scalar={scalar_result}"
135 );
136 }
137
138 #[test]
139 fn ip_matches_scalar() {
140 let rt = runtime();
141 let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
142 let b: Vec<f32> = (0..128).map(|i| (i as f32) * 0.2).collect();
143
144 let simd_result = (rt.neg_inner_product)(&a, &b);
145 let scalar_result = scalar_ip(&a, &b);
146 assert!(
147 (simd_result - scalar_result).abs() < 0.1,
148 "simd={simd_result}, scalar={scalar_result}"
149 );
150 }
151
152 #[test]
153 fn hamming_matches() {
154 let a = vec![0b10101010u8; 16];
155 let b = vec![0b01010101u8; 16];
156 assert_eq!(fast_hamming(&a, &b), 128);
157 }
158
159 #[test]
160 fn small_vectors() {
161 let rt = runtime();
162 let a = [1.0f32, 2.0, 3.0];
163 let b = [4.0f32, 5.0, 6.0];
164 let l2 = (rt.l2_squared)(&a, &b);
165 assert!((l2 - 27.0).abs() < 0.01);
166 }
167}