Skip to main content

next_plaid/
maxsim.rs

1//! High-performance MaxSim scoring for late-interaction (ColBERT) workflows.
2//!
3//! This module provides optimized CPU implementations of MaxSim scoring using:
4//! - SIMD instructions (AVX2 on x86_64, NEON on ARM) for fast max reduction
5//! - BLAS-accelerated matrix multiplication via ndarray (when `accelerate` or `openblas` features enabled)
6//!
7//! # Credits
8//!
9//! The SIMD optimization techniques in this module are adapted from the
10//! [maxsim-cpu](https://github.com/mixedbread-ai/maxsim-cpu/tree/main)
11//! which provides high-performance MaxSim computation for ColBERT-style late interaction models.
12//!
13//! # Platform Support
14//!
15//! - **macOS ARM**: Uses NEON SIMD + Apple Accelerate (with `accelerate` feature)
16//! - **Linux x86_64**: Uses AVX2 SIMD + OpenBLAS (with `openblas` feature)
17//! - **Other platforms**: Falls back to scalar operations
18
19use ndarray::{ArrayView2, Axis};
20use rayon::prelude::*;
21
22// ============================================================================
23// SIMD Module - Platform-specific fast max/argmax
24// Adapted from https://github.com/lightonai/maxsim-cpu
25// ============================================================================
26
27mod simd {
28    #[cfg(target_arch = "x86_64")]
29    use std::arch::x86_64::*;
30
31    #[cfg(target_arch = "aarch64")]
32    use std::arch::aarch64::*;
33
34    /// Scalar fallback for max - used when SIMD is unavailable or slice is small.
35    #[inline]
36    #[allow(dead_code)] // Used conditionally on x86_64 without AVX2
37    fn scalar_max(slice: &[f32]) -> f32 {
38        slice.iter().copied().fold(f32::NEG_INFINITY, f32::max)
39    }
40
41    /// Scalar fallback for argmax - used when SIMD is unavailable or slice is small.
42    #[inline]
43    #[allow(dead_code)] // Used conditionally on x86_64 without AVX2
44    fn scalar_argmax(slice: &[f32]) -> usize {
45        slice
46            .iter()
47            .enumerate()
48            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
49            .map(|(idx, _)| idx)
50            .unwrap_or(0)
51    }
52
53    /// Find max value in slice using AVX2 SIMD with prefetching.
54    /// Falls back to scalar if AVX2 is not available (e.g., emulation).
55    #[cfg(target_arch = "x86_64")]
56    #[inline]
57    pub fn simd_max(slice: &[f32]) -> f32 {
58        if slice.len() < 8 || !is_x86_feature_detected!("avx2") {
59            return scalar_max(slice);
60        }
61
62        unsafe {
63            // Use 4 vectors for better ILP (Instruction Level Parallelism)
64            let mut max_vec0 = _mm256_set1_ps(f32::NEG_INFINITY);
65            let mut max_vec1 = _mm256_set1_ps(f32::NEG_INFINITY);
66            let mut max_vec2 = _mm256_set1_ps(f32::NEG_INFINITY);
67            let mut max_vec3 = _mm256_set1_ps(f32::NEG_INFINITY);
68
69            let mut i = 0;
70
71            // Process 32 elements at a time (4x8) for better ILP
72            while i + 32 <= slice.len() {
73                _mm_prefetch(slice.as_ptr().add(i + 64) as *const i8, _MM_HINT_T0);
74
75                let data0 = _mm256_loadu_ps(slice.as_ptr().add(i));
76                let data1 = _mm256_loadu_ps(slice.as_ptr().add(i + 8));
77                let data2 = _mm256_loadu_ps(slice.as_ptr().add(i + 16));
78                let data3 = _mm256_loadu_ps(slice.as_ptr().add(i + 24));
79
80                max_vec0 = _mm256_max_ps(max_vec0, data0);
81                max_vec1 = _mm256_max_ps(max_vec1, data1);
82                max_vec2 = _mm256_max_ps(max_vec2, data2);
83                max_vec3 = _mm256_max_ps(max_vec3, data3);
84
85                i += 32;
86            }
87
88            // Process remaining groups of 8
89            while i + 8 <= slice.len() {
90                let data = _mm256_loadu_ps(slice.as_ptr().add(i));
91                max_vec0 = _mm256_max_ps(max_vec0, data);
92                i += 8;
93            }
94
95            // Combine the 4 vectors
96            max_vec0 = _mm256_max_ps(max_vec0, max_vec1);
97            max_vec2 = _mm256_max_ps(max_vec2, max_vec3);
98            max_vec0 = _mm256_max_ps(max_vec0, max_vec2);
99
100            // Horizontal max within the final vector
101            let high = _mm256_extractf128_ps(max_vec0, 1);
102            let low = _mm256_castps256_ps128(max_vec0);
103            let max128 = _mm_max_ps(high, low);
104
105            let shuffled = _mm_shuffle_ps(max128, max128, 0b01001110);
106            let max64 = _mm_max_ps(max128, shuffled);
107            let shuffled2 = _mm_shuffle_ps(max64, max64, 0b00000001);
108            let final_max = _mm_max_ps(max64, shuffled2);
109
110            let mut result = _mm_cvtss_f32(final_max);
111
112            // Handle remaining elements
113            for &val in &slice[i..] {
114                result = result.max(val);
115            }
116
117            result
118        }
119    }
120
121    /// Find max value in slice using ARM NEON SIMD.
122    #[cfg(target_arch = "aarch64")]
123    #[inline]
124    pub fn simd_max(slice: &[f32]) -> f32 {
125        if slice.len() < 4 {
126            return slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
127        }
128
129        unsafe {
130            // Initialize 4 vectors for better ILP
131            let mut max_vec0 = vdupq_n_f32(f32::NEG_INFINITY);
132            let mut max_vec1 = vdupq_n_f32(f32::NEG_INFINITY);
133            let mut max_vec2 = vdupq_n_f32(f32::NEG_INFINITY);
134            let mut max_vec3 = vdupq_n_f32(f32::NEG_INFINITY);
135
136            let mut i = 0;
137
138            // Process 16 elements at a time (4x4)
139            while i + 16 <= slice.len() {
140                let data0 = vld1q_f32(slice.as_ptr().add(i));
141                let data1 = vld1q_f32(slice.as_ptr().add(i + 4));
142                let data2 = vld1q_f32(slice.as_ptr().add(i + 8));
143                let data3 = vld1q_f32(slice.as_ptr().add(i + 12));
144
145                max_vec0 = vmaxq_f32(max_vec0, data0);
146                max_vec1 = vmaxq_f32(max_vec1, data1);
147                max_vec2 = vmaxq_f32(max_vec2, data2);
148                max_vec3 = vmaxq_f32(max_vec3, data3);
149
150                i += 16;
151            }
152
153            // Process remaining groups of 4
154            while i + 4 <= slice.len() {
155                let data = vld1q_f32(slice.as_ptr().add(i));
156                max_vec0 = vmaxq_f32(max_vec0, data);
157                i += 4;
158            }
159
160            // Combine the 4 vectors
161            max_vec0 = vmaxq_f32(max_vec0, max_vec1);
162            max_vec2 = vmaxq_f32(max_vec2, max_vec3);
163            max_vec0 = vmaxq_f32(max_vec0, max_vec2);
164
165            // Horizontal max within the final vector
166            let max_pair = vmaxq_f32(max_vec0, vextq_f32(max_vec0, max_vec0, 2));
167            let max_val = vmaxq_f32(max_pair, vextq_f32(max_pair, max_pair, 1));
168            let mut result = vgetq_lane_f32(max_val, 0);
169
170            // Handle remaining elements
171            for &val in &slice[i..] {
172                result = result.max(val);
173            }
174
175            result
176        }
177    }
178
179    /// Fallback for unsupported architectures.
180    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
181    #[inline]
182    pub fn simd_max(slice: &[f32]) -> f32 {
183        scalar_max(slice)
184    }
185
186    /// Find argmax (index of maximum value) in slice.
187    /// Uses SIMD to find the max value, then scans for its index.
188    #[inline]
189    pub fn simd_argmax(slice: &[f32]) -> usize {
190        if slice.is_empty() {
191            return 0;
192        }
193
194        // Check for SIMD availability at runtime (x86_64 only)
195        #[cfg(target_arch = "x86_64")]
196        if slice.len() < 8 || !is_x86_feature_detected!("avx2") {
197            return scalar_argmax(slice);
198        }
199
200        #[cfg(not(target_arch = "x86_64"))]
201        if slice.len() < 8 {
202            return scalar_argmax(slice);
203        }
204
205        // Find the max value using SIMD
206        let max_val = simd_max(slice);
207
208        // Scan for the index (first occurrence)
209        slice.iter().position(|&x| x == max_val).unwrap_or(0)
210    }
211}
212
213// ============================================================================
214// Public API
215// ============================================================================
216
217/// Compute MaxSim score for a single query-document pair.
218///
219/// For each query token, finds the maximum similarity with any document token,
220/// then sums across all query tokens.
221///
222/// Uses BLAS-accelerated matrix multiplication (when available) and SIMD for
223/// the max reduction.
224///
225/// # Arguments
226///
227/// * `query` - Query embeddings of shape `[num_query_tokens, dim]`
228/// * `doc` - Document embeddings of shape `[num_doc_tokens, dim]`
229///
230/// # Returns
231///
232/// The MaxSim score (sum of per-query-token max similarities)
233#[inline]
234pub fn maxsim_score(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
235    let q_len = query.nrows();
236    let d_len = doc.nrows();
237
238    // For small matrices, use simple approach to avoid GEMM overhead
239    if q_len * d_len < 256 {
240        return maxsim_score_simple(query, doc);
241    }
242
243    // Compute similarity matrix using BLAS-accelerated dot product
244    // scores[i, j] = query[i] ยท doc[j]
245    let scores = query.dot(&doc.t());
246
247    // Find max per query token and sum using SIMD
248    let mut total = 0.0f32;
249    for q_idx in 0..q_len {
250        let row = scores.row(q_idx);
251        let max_sim = simd::simd_max(row.as_slice().unwrap());
252        if max_sim > f32::NEG_INFINITY {
253            total += max_sim;
254        }
255    }
256
257    total
258}
259
260/// Simple MaxSim implementation for small matrices.
261#[inline]
262fn maxsim_score_simple(query: &ArrayView2<f32>, doc: &ArrayView2<f32>) -> f32 {
263    let mut total = 0.0f32;
264
265    for q_row in query.axis_iter(Axis(0)) {
266        let mut max_sim = f32::NEG_INFINITY;
267        for d_row in doc.axis_iter(Axis(0)) {
268            let sim: f32 = q_row.dot(&d_row);
269            if sim > max_sim {
270                max_sim = sim;
271            }
272        }
273        if max_sim > f32::NEG_INFINITY {
274            total += max_sim;
275        }
276    }
277
278    total
279}
280
281/// Assign embeddings to their nearest centroids using batched GEMM.
282///
283/// Uses batched matrix multiplication for computing all similarities at once,
284/// with SIMD-accelerated argmax for finding the best centroid per embedding.
285///
286/// # Arguments
287///
288/// * `embeddings` - Embeddings to assign, shape `[N, dim]`
289/// * `centroids` - Centroids, shape `[K, dim]`
290///
291/// # Returns
292///
293/// Vector of centroid indices, one per embedding
294pub fn assign_to_centroids(
295    embeddings: &ArrayView2<f32>,
296    centroids: &ArrayView2<f32>,
297) -> Vec<usize> {
298    let n = embeddings.nrows();
299    let k = centroids.nrows();
300
301    if n == 0 || k == 0 {
302        return vec![0; n];
303    }
304
305    // For small inputs, use simple approach
306    if n * k < 1024 {
307        return embeddings
308            .axis_iter(Axis(0))
309            .map(|emb| {
310                let mut best_idx = 0;
311                let mut best_score = f32::NEG_INFINITY;
312                for (idx, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
313                    let score: f32 = emb.iter().zip(centroid.iter()).map(|(a, b)| a * b).sum();
314                    if score > best_score {
315                        best_score = score;
316                        best_idx = idx;
317                    }
318                }
319                best_idx
320            })
321            .collect();
322    }
323
324    // Batched approach: compute [N, K] score matrix, then argmax per row
325    // Use batching to limit memory: max 2GB for scores matrix
326    let max_batch_by_memory = (2 * 1024 * 1024 * 1024) / (k * std::mem::size_of::<f32>());
327    let batch_size = max_batch_by_memory.clamp(1, 4096).min(n);
328
329    let mut all_codes = Vec::with_capacity(n);
330
331    for start in (0..n).step_by(batch_size) {
332        let end = (start + batch_size).min(n);
333        let batch = embeddings.slice(ndarray::s![start..end, ..]);
334
335        // Batch matrix multiplication: [batch, dim] @ [dim, K] -> [batch, K]
336        let scores = batch.dot(&centroids.t());
337
338        // Parallel argmax over each row using SIMD
339        let batch_codes: Vec<usize> = scores
340            .axis_iter(Axis(0))
341            .into_par_iter()
342            .map(|row| simd::simd_argmax(row.as_slice().unwrap()))
343            .collect();
344
345        all_codes.extend(batch_codes);
346    }
347
348    all_codes
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use ndarray::Array2;
355
356    #[test]
357    fn test_maxsim_score_basic() {
358        // Query with 2 tokens, dim 4
359        let query =
360            Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
361
362        // Document with 3 tokens
363        let doc = Array2::from_shape_vec(
364            (3, 4),
365            vec![
366                0.5, 0.5, 0.0, 0.0, // sim with q0: 0.5, sim with q1: 0.5
367                0.8, 0.2, 0.0, 0.0, // sim with q0: 0.8, sim with q1: 0.2
368                0.0, 0.9, 0.1, 0.0, // sim with q0: 0.0, sim with q1: 0.9
369            ],
370        )
371        .unwrap();
372
373        let score = maxsim_score(&query.view(), &doc.view());
374        // q0 max: 0.8 (from token 1), q1 max: 0.9 (from token 2)
375        // Total: 0.8 + 0.9 = 1.7
376        assert!((score - 1.7).abs() < 1e-5);
377    }
378
379    #[test]
380    fn test_simd_max() {
381        let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
382        let max = simd::simd_max(&data);
383        assert!((max - 99.0).abs() < 1e-5);
384
385        // Test with negative values
386        let data2: Vec<f32> = (-50..50).map(|i| i as f32).collect();
387        let max2 = simd::simd_max(&data2);
388        assert!((max2 - 49.0).abs() < 1e-5);
389
390        // Test small slice
391        let small = vec![1.0, 5.0, 3.0];
392        let max3 = simd::simd_max(&small);
393        assert!((max3 - 5.0).abs() < 1e-5);
394    }
395
396    #[test]
397    fn test_assign_to_centroids() {
398        // 3 centroids in 4D space
399        let centroids = Array2::from_shape_vec(
400            (3, 4),
401            vec![
402                1.0, 0.0, 0.0, 0.0, // centroid 0: points in +x direction
403                0.0, 1.0, 0.0, 0.0, // centroid 1: points in +y direction
404                0.0, 0.0, 1.0, 0.0, // centroid 2: points in +z direction
405            ],
406        )
407        .unwrap();
408
409        // 5 embeddings
410        let embeddings = Array2::from_shape_vec(
411            (5, 4),
412            vec![
413                0.9, 0.1, 0.0, 0.0, // should match centroid 0
414                0.1, 0.9, 0.0, 0.0, // should match centroid 1
415                0.0, 0.1, 0.9, 0.0, // should match centroid 2
416                0.8, 0.2, 0.0, 0.0, // should match centroid 0
417                0.0, 0.0, 0.8, 0.2, // should match centroid 2
418            ],
419        )
420        .unwrap();
421
422        let assignments = assign_to_centroids(&embeddings.view(), &centroids.view());
423
424        assert_eq!(assignments.len(), 5);
425        assert_eq!(assignments[0], 0);
426        assert_eq!(assignments[1], 1);
427        assert_eq!(assignments[2], 2);
428        assert_eq!(assignments[3], 0);
429        assert_eq!(assignments[4], 2);
430    }
431
432    #[test]
433    fn test_simd_argmax() {
434        let data: Vec<f32> = vec![1.0, 5.0, 3.0, 2.0, 4.0];
435        assert_eq!(simd::simd_argmax(&data), 1);
436
437        let data2: Vec<f32> = (0..100).map(|i| i as f32).collect();
438        assert_eq!(simd::simd_argmax(&data2), 99);
439
440        let data3: Vec<f32> = (0..100).rev().map(|i| i as f32).collect();
441        assert_eq!(simd::simd_argmax(&data3), 0);
442    }
443}