batuta/oracle/rag/quantization/
simd.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SimdBackend {
18 Avx2,
20 Avx512,
22 Neon,
24 Scalar,
26}
27
28impl SimdBackend {
29 pub fn detect() -> Self {
31 #[cfg(target_arch = "x86_64")]
32 {
33 let has_avx512 =
34 is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw");
35 let has_avx2 = is_x86_feature_detected!("avx2");
36 return Self::from_x86_features(has_avx512, has_avx2);
37 }
38 #[cfg(target_arch = "aarch64")]
39 {
40 return Self::Neon;
42 }
43 #[allow(unreachable_code)]
44 Self::Scalar
45 }
46
47 #[cfg(target_arch = "x86_64")]
49 pub fn from_x86_features(has_avx512: bool, has_avx2: bool) -> Self {
50 if has_avx512 {
51 Self::Avx512
52 } else if has_avx2 {
53 Self::Avx2
54 } else {
55 Self::Scalar
56 }
57 }
58
59 #[allow(unsafe_code)]
67 pub fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
68 debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
69
70 match self {
71 #[cfg(target_arch = "x86_64")]
72 Self::Avx2 => {
73 if is_x86_feature_detected!("avx2") {
74 return unsafe { dot_i8_avx2(a, b) };
76 }
77 dot_i8_scalar(a, b)
78 }
79 #[cfg(target_arch = "x86_64")]
80 Self::Avx512 => {
81 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
82 return unsafe { dot_i8_avx512(a, b) };
84 }
85 dot_i8_scalar(a, b)
86 }
87 #[cfg(target_arch = "aarch64")]
88 Self::Neon => {
89 unsafe { dot_i8_neon(a, b) }
93 }
94 _ => dot_i8_scalar(a, b),
95 }
96 }
97
98 pub fn dot_f32_i8(&self, query: &[f32], doc: &[i8], scale: f32) -> f32 {
102 debug_assert_eq!(query.len(), doc.len(), "Vectors must have same length");
103
104 let mut sum: f32 = 0.0;
106 for (&q, &d) in query.iter().zip(doc.iter()) {
107 sum += q * (d as f32 * scale);
108 }
109 sum
110 }
111}
112
113pub fn dot_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
115 a.iter().zip(b.iter()).map(|(&x, &y)| (x as i32) * (y as i32)).sum()
116}
117
118#[inline]
123fn dot_i8_scalar_tail(a: &[i8], b: &[i8], start: usize) -> i32 {
124 a[start..].iter().zip(b[start..].iter()).map(|(&x, &y)| (x as i32) * (y as i32)).sum()
125}
126
127#[cfg(target_arch = "x86_64")]
130#[target_feature(enable = "avx2")]
131#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
132unsafe fn dot_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
133 use std::arch::x86_64::*;
134
135 let n = a.len();
136 let mut sum = _mm256_setzero_si256();
137
138 let mut i = 0;
140 while i + 32 <= n {
141 let va = _mm256_loadu_si256(a[i..].as_ptr().cast::<__m256i>());
142 let vb = _mm256_loadu_si256(b[i..].as_ptr().cast::<__m256i>());
143
144 let lo_a = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 0));
146 let lo_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 0));
147 let hi_a = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
148 let hi_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
149
150 let prod_lo = _mm256_madd_epi16(lo_a, lo_b);
152 let prod_hi = _mm256_madd_epi16(hi_a, hi_b);
153
154 sum = _mm256_add_epi32(sum, prod_lo);
155 sum = _mm256_add_epi32(sum, prod_hi);
156
157 i += 32;
158 }
159
160 let sum128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1));
162 let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
163 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
164 let result = _mm_cvtsi128_si32(sum32);
165
166 result + dot_i8_scalar_tail(a, b, i)
168}
169
170#[cfg(target_arch = "x86_64")]
174#[target_feature(enable = "avx512f", enable = "avx512bw")]
175#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
176unsafe fn dot_i8_avx512(a: &[i8], b: &[i8]) -> i32 {
177 use std::arch::x86_64::*;
178
179 let n = a.len();
180 let mut sum = _mm512_setzero_si512();
181
182 let mut i = 0;
184 while i + 64 <= n {
185 let va = _mm512_loadu_si512(a[i..].as_ptr().cast::<__m512i>());
186 let vb = _mm512_loadu_si512(b[i..].as_ptr().cast::<__m512i>());
187
188 let lo_a = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 0));
190 let lo_b = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 0));
191 let hi_a = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 1));
192 let hi_b = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 1));
193
194 let prod_lo = _mm512_madd_epi16(lo_a, lo_b);
195 let prod_hi = _mm512_madd_epi16(hi_a, hi_b);
196
197 sum = _mm512_add_epi32(sum, prod_lo);
198 sum = _mm512_add_epi32(sum, prod_hi);
199
200 i += 64;
201 }
202
203 let result = _mm512_reduce_add_epi32(sum);
205
206 result + dot_i8_scalar_tail(a, b, i)
208}
209
210#[cfg(target_arch = "aarch64")]
213#[target_feature(enable = "neon")]
214#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
215unsafe fn dot_i8_neon(a: &[i8], b: &[i8]) -> i32 {
216 use std::arch::aarch64::*;
217
218 let n = a.len();
219 let mut sum = vdupq_n_s32(0);
220
221 let mut i = 0;
223 while i + 16 <= n {
224 let va = vld1q_s8(a[i..].as_ptr());
225 let vb = vld1q_s8(b[i..].as_ptr());
226
227 let lo_a = vmovl_s8(vget_low_s8(va));
229 let lo_b = vmovl_s8(vget_low_s8(vb));
230 let hi_a = vmovl_s8(vget_high_s8(va));
231 let hi_b = vmovl_s8(vget_high_s8(vb));
232
233 let prod_lo = vmull_s16(vget_low_s16(lo_a), vget_low_s16(lo_b));
234 let prod_lo2 = vmull_s16(vget_high_s16(lo_a), vget_high_s16(lo_b));
235 let prod_hi = vmull_s16(vget_low_s16(hi_a), vget_low_s16(hi_b));
236 let prod_hi2 = vmull_s16(vget_high_s16(hi_a), vget_high_s16(hi_b));
237
238 sum = vaddq_s32(sum, prod_lo);
239 sum = vaddq_s32(sum, prod_lo2);
240 sum = vaddq_s32(sum, prod_hi);
241 sum = vaddq_s32(sum, prod_hi2);
242
243 i += 16;
244 }
245
246 let result = vaddvq_s32(sum);
248
249 result + dot_i8_scalar_tail(a, b, i)
251}