1use ndarray::{ArrayView2, Axis};
20use rayon::prelude::*;
21
22mod simd {
28 #[cfg(target_arch = "x86_64")]
29 use std::arch::x86_64::*;
30
31 #[cfg(target_arch = "aarch64")]
32 use std::arch::aarch64::*;
33
34 #[inline]
36 #[allow(dead_code)] fn scalar_max(slice: &[f32]) -> f32 {
38 slice.iter().copied().fold(f32::NEG_INFINITY, f32::max)
39 }
40
41 #[inline]
43 #[allow(dead_code)] fn scalar_argmax(slice: &[f32]) -> usize {
45 slice
46 .iter()
47 .enumerate()
48 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
49 .map(|(idx, _)| idx)
50 .unwrap_or(0)
51 }
52
53 #[cfg(target_arch = "x86_64")]
56 #[inline]
57 pub fn simd_max(slice: &[f32]) -> f32 {
58 if slice.len() < 8 || !is_x86_feature_detected!("avx2") {
59 return scalar_max(slice);
60 }
61
62 unsafe {
63 let mut max_vec0 = _mm256_set1_ps(f32::NEG_INFINITY);
65 let mut max_vec1 = _mm256_set1_ps(f32::NEG_INFINITY);
66 let mut max_vec2 = _mm256_set1_ps(f32::NEG_INFINITY);
67 let mut max_vec3 = _mm256_set1_ps(f32::NEG_INFINITY);
68
69 let mut i = 0;
70
71 while i + 32 <= slice.len() {
73 _mm_prefetch(slice.as_ptr().add(i + 64) as *const i8, _MM_HINT_T0);
74
75 let data0 = _mm256_loadu_ps(slice.as_ptr().add(i));
76 let data1 = _mm256_loadu_ps(slice.as_ptr().add(i + 8));
77 let data2 = _mm256_loadu_ps(slice.as_ptr().add(i + 16));
78 let data3 = _mm256_loadu_ps(slice.as_ptr().add(i + 24));
79
80 max_vec0 = _mm256_max_ps(max_vec0, data0);
81 max_vec1 = _mm256_max_ps(max_vec1, data1);
82 max_vec2 = _mm256_max_ps(max_vec2, data2);
83 max_vec3 = _mm256_max_ps(max_vec3, data3);
84
85 i += 32;
86 }
87
88 while i + 8 <= slice.len() {
90 let data = _mm256_loadu_ps(slice.as_ptr().add(i));
91 max_vec0 = _mm256_max_ps(max_vec0, data);
92 i += 8;
93 }
94
95 max_vec0 = _mm256_max_ps(max_vec0, max_vec1);
97 max_vec2 = _mm256_max_ps(max_vec2, max_vec3);
98 max_vec0 = _mm256_max_ps(max_vec0, max_vec2);
99
100 let high = _mm256_extractf128_ps(max_vec0, 1);
102 let low = _mm256_castps256_ps128(max_vec0);
103 let max128 = _mm_max_ps(high, low);
104
105 let shuffled = _mm_shuffle_ps(max128, max128, 0b01001110);
106 let max64 = _mm_max_ps(max128, shuffled);
107 let shuffled2 = _mm_shuffle_ps(max64, max64, 0b00000001);
108 let final_max = _mm_max_ps(max64, shuffled2);
109
110 let mut result = _mm_cvtss_f32(final_max);
111
112 for &val in &slice[i..] {
114 result = result.max(val);
115 }
116
117 result
118 }
119 }
120
121 #[cfg(target_arch = "aarch64")]
123 #[inline]
124 pub fn simd_max(slice: &[f32]) -> f32 {
125 if slice.len() < 4 {
126 return slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
127 }
128
129 unsafe {
130 let mut max_vec0 = vdupq_n_f32(f32::NEG_INFINITY);
132 let mut max_vec1 = vdupq_n_f32(f32::NEG_INFINITY);
133 let mut max_vec2 = vdupq_n_f32(f32::NEG_INFINITY);
134 let mut max_vec3 = vdupq_n_f32(f32::NEG_INFINITY);
135
136 let mut i = 0;
137
138 while i + 16 <= slice.len() {
140 let data0 = vld1q_f32(slice.as_ptr().add(i));
141 let data1 = vld1q_f32(slice.as_ptr().add(i + 4));
142 let data2 = vld1q_f32(slice.as_ptr().add(i + 8));
143 let data3 = vld1q_f32(slice.as_ptr().add(i + 12));
144
145 max_vec0 = vmaxq_f32(max_vec0, data0);
146 max_vec1 = vmaxq_f32(max_vec1, data1);
147 max_vec2 = vmaxq_f32(max_vec2, data2);
148 max_vec3 = vmaxq_f32(max_vec3, data3);
149
150 i += 16;
151 }
152
153 while i + 4 <= slice.len() {
155 let data = vld1q_f32(slice.as_ptr().add(i));
156 max_vec0 = vmaxq_f32(max_vec0, data);
157 i += 4;
158 }
159
160 max_vec0 = vmaxq_f32(max_vec0, max_vec1);
162 max_vec2 = vmaxq_f32(max_vec2, max_vec3);
163 max_vec0 = vmaxq_f32(max_vec0, max_vec2);
164
165 let max_pair = vmaxq_f32(max_vec0, vextq_f32(max_vec0, max_vec0, 2));
167 let max_val = vmaxq_f32(max_pair, vextq_f32(max_pair, max_pair, 1));
168 let mut result = vgetq_lane_f32(max_val, 0);
169
170 for &val in &slice[i..] {
172 result = result.max(val);
173 }
174
175 result
176 }
177 }
178
179 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
181 #[inline]
182 pub fn simd_max(slice: &[f32]) -> f32 {
183 scalar_max(slice)
184 }
185
186 #[inline]
189 pub fn simd_argmax(slice: &[f32]) -> usize {
190 if slice.is_empty() {
191 return 0;
192 }
193
194 #[cfg(target_arch = "x86_64")]
196 if slice.len() < 8 || !is_x86_feature_detected!("avx2") {
197 return scalar_argmax(slice);
198 }
199
200 #[cfg(not(target_arch = "x86_64"))]
201 if slice.len() < 8 {
202 return scalar_argmax(slice);
203 }
204
205 let max_val = simd_max(slice);
207
208 slice.iter().position(|&x| x == max_val).unwrap_or(0)
210 }
211}
212
213#[inline]
234pub fn maxsim_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
235 let q_len = query.nrows();
236 let d_len = doc.nrows();
237
238 if q_len * d_len < 256 {
240 return maxsim_score_simple(query, doc);
241 }
242
243 let scores = query.dot(&doc.t());
246
247 let mut total = 0.0f32;
249 for q_idx in 0..q_len {
250 let row = scores.row(q_idx);
251 let max_sim = simd::simd_max(row.as_slice().unwrap());
252 if max_sim > f32::NEG_INFINITY {
253 total += max_sim;
254 }
255 }
256
257 total
258}
259
260#[inline]
262fn maxsim_score_simple(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
263 let mut total = 0.0f32;
264
265 for q_row in query.axis_iter(Axis(0)) {
266 let mut max_sim = f32::NEG_INFINITY;
267 for d_row in doc.axis_iter(Axis(0)) {
268 let sim: f32 = q_row.dot(&d_row);
269 if sim > max_sim {
270 max_sim = sim;
271 }
272 }
273 if max_sim > f32::NEG_INFINITY {
274 total += max_sim;
275 }
276 }
277
278 total
279}
280
281pub fn assign_to_centroids(
295 embeddings: &ArrayView2<f32>,
296 centroids: &ArrayView2<f32>,
297) -> Vec<usize> {
298 let n = embeddings.nrows();
299 let k = centroids.nrows();
300
301 if n == 0 || k == 0 {
302 return vec![0; n];
303 }
304
305 if n * k < 1024 {
307 return embeddings
308 .axis_iter(Axis(0))
309 .map(|emb| {
310 let mut best_idx = 0;
311 let mut best_score = f32::NEG_INFINITY;
312 for (idx, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
313 let score: f32 = emb.iter().zip(centroid.iter()).map(|(a, b)| a * b).sum();
314 if score > best_score {
315 best_score = score;
316 best_idx = idx;
317 }
318 }
319 best_idx
320 })
321 .collect();
322 }
323
324 let max_batch_by_memory = (2 * 1024 * 1024 * 1024) / (k * std::mem::size_of::<f32>());
327 let batch_size = max_batch_by_memory.clamp(1, 4096).min(n);
328
329 let mut all_codes = Vec::with_capacity(n);
330
331 for start in (0..n).step_by(batch_size) {
332 let end = (start + batch_size).min(n);
333 let batch = embeddings.slice(ndarray::s![start..end, ..]);
334
335 let scores = batch.dot(¢roids.t());
337
338 let batch_codes: Vec<usize> = scores
340 .axis_iter(Axis(0))
341 .into_par_iter()
342 .map(|row| simd::simd_argmax(row.as_slice().unwrap()))
343 .collect();
344
345 all_codes.extend(batch_codes);
346 }
347
348 all_codes
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use ndarray::Array2;
355
356 #[test]
357 fn test_maxsim_score_basic() {
358 let query =
360 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
361
362 let doc = Array2::from_shape_vec(
364 (3, 4),
365 vec![
366 0.5, 0.5, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.9, 0.1, 0.0, ],
370 )
371 .unwrap();
372
373 let score = maxsim_score(&query.view(), &doc.view());
374 assert!((score - 1.7).abs() < 1e-5);
377 }
378
379 #[test]
380 fn test_simd_max() {
381 let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
382 let max = simd::simd_max(&data);
383 assert!((max - 99.0).abs() < 1e-5);
384
385 let data2: Vec<f32> = (-50..50).map(|i| i as f32).collect();
387 let max2 = simd::simd_max(&data2);
388 assert!((max2 - 49.0).abs() < 1e-5);
389
390 let small = vec![1.0, 5.0, 3.0];
392 let max3 = simd::simd_max(&small);
393 assert!((max3 - 5.0).abs() < 1e-5);
394 }
395
396 #[test]
397 fn test_assign_to_centroids() {
398 let centroids = Array2::from_shape_vec(
400 (3, 4),
401 vec![
402 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
406 )
407 .unwrap();
408
409 let embeddings = Array2::from_shape_vec(
411 (5, 4),
412 vec![
413 0.9, 0.1, 0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.0, 0.1, 0.9, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.0, 0.8, 0.2, ],
419 )
420 .unwrap();
421
422 let assignments = assign_to_centroids(&embeddings.view(), ¢roids.view());
423
424 assert_eq!(assignments.len(), 5);
425 assert_eq!(assignments[0], 0);
426 assert_eq!(assignments[1], 1);
427 assert_eq!(assignments[2], 2);
428 assert_eq!(assignments[3], 0);
429 assert_eq!(assignments[4], 2);
430 }
431
432 #[test]
433 fn test_simd_argmax() {
434 let data: Vec<f32> = vec![1.0, 5.0, 3.0, 2.0, 4.0];
435 assert_eq!(simd::simd_argmax(&data), 1);
436
437 let data2: Vec<f32> = (0..100).map(|i| i as f32).collect();
438 assert_eq!(simd::simd_argmax(&data2), 99);
439
440 let data3: Vec<f32> = (0..100).rev().map(|i| i as f32).collect();
441 assert_eq!(simd::simd_argmax(&data3), 0);
442 }
443}