Skip to main content

brainwires_rag/spectral/
mod.rs

1//! MSS-inspired spectral subset selection for diverse RAG retrieval.
2//!
3//! Standard top-k retrieval by cosine similarity produces redundant results.
4//! This module implements a greedy log-determinant maximization algorithm
5//! (inspired by DPP / Marcus-Spielman-Srivastava interlacing polynomials)
6//! that selects k items that are both relevant AND collectively diverse.
7//!
8//! # Algorithm
9//!
10//! Given n candidate embeddings with relevance scores, we build a kernel matrix:
11//! ```text
12//! L_ij = (r_i^lambda) * (r_j^lambda) * cos_sim(v_i, v_j)
13//! ```
14//! and greedily select the subset S of size k that maximizes `log det(L_S)`.
15//!
16//! The greedy algorithm achieves a (1 - 1/e) ~ 0.63 approximation ratio
17//! to the optimal solution, which is the best possible in polynomial time
18//! for submodular maximization.
19//!
20//! # Complexity
21//!
22//! O(n*k^3) -- trivial for n <= 200, k <= 20 (typical RAG retrieval sizes).
23//! With incremental Cholesky updates, the inner loop is O(k^2) per candidate,
24//! giving overall O(n*k^2).
25
26pub mod graph_ops;
27pub mod kernel;
28pub mod linalg;
29
30use brainwires_core::SearchResult;
31use kernel::{build_kernel_matrix, cross_column};
32use linalg::{cholesky_extend, log_det_incremental};
33use ndarray::Array2;
34
35/// Configuration for spectral subset selection.
36#[derive(Debug, Clone)]
37pub struct SpectralSelectConfig {
38    /// Number of items to select. If `None`, uses the query limit.
39    pub k: Option<usize>,
40    /// Relevance/diversity trade-off parameter.
41    /// - 0.0 = pure diversity (ignores relevance scores)
42    /// - 1.0 = relevance-dominated (approaches standard top-k)
43    /// - 0.5 = balanced (default)
44    pub lambda: f32,
45    /// Minimum number of candidates to trigger spectral selection.
46    /// Below this threshold, results are returned as-is.
47    pub min_candidates: usize,
48    /// Diagonal regularization epsilon for numerical stability.
49    pub regularization: f32,
50}
51
52impl Default for SpectralSelectConfig {
53    fn default() -> Self {
54        Self {
55            k: None,
56            lambda: 0.5,
57            min_candidates: 10,
58            regularization: 1e-6,
59        }
60    }
61}
62
63/// Trait for diversity-aware reranking of search results.
64pub trait DiversityReranker: Send + Sync {
65    /// Rerank candidates, returning indices into `results` in selection order.
66    ///
67    /// # Arguments
68    ///
69    /// * `results` - Original search results with scores
70    /// * `embeddings` - Embedding vectors corresponding to each result
71    /// * `k` - Number of items to select
72    ///
73    /// # Returns
74    ///
75    /// Indices into `results`, ordered by selection round (first selected = most valuable).
76    fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize>;
77}
78
79/// Spectral reranker using greedy log-determinant maximization.
80pub struct SpectralReranker {
81    config: SpectralSelectConfig,
82}
83
84impl SpectralReranker {
85    /// Create a new spectral reranker with the given configuration.
86    pub fn new(config: SpectralSelectConfig) -> Self {
87        Self { config }
88    }
89
90    /// Create a spectral reranker with default settings.
91    pub fn with_defaults() -> Self {
92        Self::new(SpectralSelectConfig::default())
93    }
94}
95
96impl DiversityReranker for SpectralReranker {
97    fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
98        let n = results.len();
99
100        // Edge cases
101        if n == 0 {
102            return Vec::new();
103        }
104        if k >= n {
105            return (0..n).collect();
106        }
107        if k == 0 {
108            return Vec::new();
109        }
110
111        // Skip spectral selection if too few candidates
112        if n < self.config.min_candidates {
113            return (0..k.min(n)).collect();
114        }
115
116        // Build kernel matrix
117        let embedding_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
118        let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
119        let kernel = build_kernel_matrix(
120            &embedding_refs,
121            &scores,
122            self.config.lambda,
123            self.config.regularization,
124        );
125
126        greedy_log_det_select(&kernel, k)
127    }
128}
129
130/// Greedy log-determinant maximization with incremental Cholesky updates.
131///
132/// Selects k indices from the n*n kernel matrix that (approximately) maximize
133/// `log det(L_S)`, achieving a (1-1/e) approximation ratio.
134fn greedy_log_det_select(kernel: &Array2<f32>, k: usize) -> Vec<usize> {
135    let n = kernel.nrows();
136    let mut selected: Vec<usize> = Vec::with_capacity(k);
137    let mut remaining: Vec<bool> = vec![true; n];
138
139    // Current Cholesky factor of the selected submatrix (starts empty)
140    let mut chol_s: Option<Array2<f32>> = None;
141    let mut current_log_det: f32 = 0.0;
142
143    for round in 0..k {
144        let mut best_idx = usize::MAX;
145        let mut best_gain = f32::NEG_INFINITY;
146
147        for c in 0..n {
148            if !remaining[c] {
149                continue;
150            }
151
152            let gain = if round == 0 {
153                // First selection: gain = log(L_{c,c})
154                let diag = kernel[[c, c]];
155                if diag > 0.0 {
156                    diag.ln()
157                } else {
158                    f32::NEG_INFINITY
159                }
160            } else {
161                // Incremental gain via Cholesky
162                let cross = cross_column(kernel, &selected, c);
163                let diag_cc = kernel[[c, c]];
164                let new_ld = log_det_incremental(
165                    chol_s.as_ref().expect(
166                        "chol_s is initialized in round 0 before any incremental round runs",
167                    ),
168                    &cross,
169                    diag_cc,
170                    current_log_det,
171                );
172                new_ld - current_log_det
173            };
174
175            if gain > best_gain {
176                best_gain = gain;
177                best_idx = c;
178            }
179        }
180
181        if best_idx == usize::MAX || best_gain == f32::NEG_INFINITY {
182            // No more valid candidates (degenerate kernel)
183            break;
184        }
185
186        // Update Cholesky factor
187        if round == 0 {
188            let diag = kernel[[best_idx, best_idx]];
189            let mut l = Array2::<f32>::zeros((1, 1));
190            l[[0, 0]] = diag.sqrt();
191            chol_s = Some(l);
192            current_log_det = diag.ln();
193        } else {
194            let cross = cross_column(kernel, &selected, best_idx);
195            let diag_cc = kernel[[best_idx, best_idx]];
196            chol_s = Some(
197                cholesky_extend(
198                    chol_s.as_ref().expect(
199                        "chol_s is initialized in round 0 before any incremental round runs",
200                    ),
201                    &cross,
202                    diag_cc,
203                )
204                .expect("Cholesky extend failed after positive gain check"),
205            );
206            current_log_det += best_gain;
207        }
208
209        selected.push(best_idx);
210        remaining[best_idx] = false;
211    }
212
213    selected
214}
215
216// ── Cross-encoder reranker ────────────────────────────────────────────────
217
218/// Configuration for the query-aware cross-encoder reranker.
219#[derive(Debug, Clone)]
220pub struct CrossEncoderConfig {
221    /// Blend weight between the original retrieval score and the query-document
222    /// cosine similarity.
223    ///
224    /// - `1.0` → use original retrieval score only (no re-ranking)
225    /// - `0.0` → use query-document cosine similarity only
226    /// - `0.5` → equal blend (default)
227    pub alpha: f32,
228    /// Pre-computed query embedding used as the "query" side of the joint score.
229    ///
230    /// If empty, the reranker falls back to the original score order (alpha = 1.0).
231    pub query_embedding: Vec<f32>,
232}
233
234impl Default for CrossEncoderConfig {
235    fn default() -> Self {
236        Self {
237            alpha: 0.5,
238            query_embedding: Vec::new(),
239        }
240    }
241}
242
243/// Query-aware reranker that blends the original retrieval score with a
244/// query-document cosine similarity for a joint re-scoring pass.
245///
246/// This is a lightweight, embedding-based approximation of a true cross-encoder
247/// that requires no additional model — it reuses the same embeddings already
248/// computed during retrieval.
249pub struct CrossEncoderReranker {
250    config: CrossEncoderConfig,
251}
252
253impl CrossEncoderReranker {
254    /// Create a new cross-encoder reranker with the given configuration.
255    pub fn new(config: CrossEncoderConfig) -> Self {
256        Self { config }
257    }
258
259    /// Convenience constructor — specify alpha and query embedding directly.
260    pub fn with_alpha(alpha: f32, query_embedding: Vec<f32>) -> Self {
261        Self::new(CrossEncoderConfig {
262            alpha,
263            query_embedding,
264        })
265    }
266}
267
268impl DiversityReranker for CrossEncoderReranker {
269    fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
270        let n = results.len();
271        if n == 0 || k == 0 {
272            return Vec::new();
273        }
274        if k >= n {
275            return (0..n).collect();
276        }
277
278        // If no query embedding, fall back to score-descending order.
279        if self.config.query_embedding.is_empty() {
280            let mut indices: Vec<usize> = (0..n).collect();
281            indices.sort_by(|&a, &b| {
282                results[b]
283                    .score
284                    .partial_cmp(&results[a].score)
285                    .unwrap_or(std::cmp::Ordering::Equal)
286            });
287            return indices.into_iter().take(k).collect();
288        }
289
290        let query_emb = &self.config.query_embedding;
291        let alpha = self.config.alpha.clamp(0.0, 1.0);
292
293        let mut scored: Vec<(usize, f32)> = (0..n)
294            .map(|i| {
295                let cos = if i < embeddings.len() {
296                    kernel::cosine_similarity(query_emb, &embeddings[i])
297                } else {
298                    0.0
299                };
300                let joint = alpha * results[i].score + (1.0 - alpha) * cos;
301                (i, joint)
302            })
303            .collect();
304
305        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
306        scored.into_iter().take(k).map(|(i, _)| i).collect()
307    }
308}
309
310/// Select which reranker(s) to apply in [`crate::rag::client::RagClient::query_diverse`].
311pub enum RerankerKind {
312    /// Greedy log-determinant spectral reranker (diversity-focused).
313    Spectral(SpectralSelectConfig),
314    /// Query-aware cross-encoder reranker (relevance-focused).
315    CrossEncoder(CrossEncoderConfig),
316    /// Apply spectral reranking first (for diversity), then cross-encoder on
317    /// the selected subset (for final relevance ordering).
318    Both {
319        /// Config for the spectral (first) pass.
320        spectral: SpectralSelectConfig,
321        /// Config for the cross-encoder (second) pass.
322        cross_encoder: CrossEncoderConfig,
323    },
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    fn make_result(score: f32) -> SearchResult {
331        SearchResult {
332            file_path: String::new(),
333            root_path: None,
334            content: String::new(),
335            score,
336            vector_score: score,
337            keyword_score: None,
338            start_line: 0,
339            end_line: 0,
340            language: String::new(),
341            project: None,
342            indexed_at: 0,
343        }
344    }
345
346    #[test]
347    fn test_empty_input() {
348        let reranker = SpectralReranker::with_defaults();
349        let result = reranker.rerank(&[], &[], 5);
350        assert!(result.is_empty());
351    }
352
353    #[test]
354    fn test_k_zero() {
355        let reranker = SpectralReranker::with_defaults();
356        let results = vec![make_result(0.9)];
357        let embeddings = vec![vec![1.0, 0.0]];
358        let result = reranker.rerank(&results, &embeddings, 0);
359        assert!(result.is_empty());
360    }
361
362    #[test]
363    fn test_k_greater_than_n() {
364        let reranker = SpectralReranker::with_defaults();
365        let results = vec![make_result(0.9), make_result(0.8)];
366        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
367        let result = reranker.rerank(&results, &embeddings, 10);
368        assert_eq!(result.len(), 2);
369    }
370
371    #[test]
372    fn test_below_min_candidates() {
373        let config = SpectralSelectConfig {
374            min_candidates: 20,
375            ..Default::default()
376        };
377        let reranker = SpectralReranker::new(config);
378        let results: Vec<SearchResult> =
379            (0..5).map(|i| make_result(0.9 - i as f32 * 0.1)).collect();
380        let embeddings: Vec<Vec<f32>> = (0..5).map(|i| vec![i as f32, 0.0]).collect();
381        let result = reranker.rerank(&results, &embeddings, 3);
382        // Should return first 3 indices unchanged
383        assert_eq!(result, vec![0, 1, 2]);
384    }
385
386    #[test]
387    fn test_spectral_prefers_diverse() {
388        // Create 10 near-duplicate vectors + 5 diverse vectors
389        // The spectral reranker should prefer the diverse ones
390        let mut results = Vec::new();
391        let mut embeddings = Vec::new();
392
393        // 10 near-duplicates (high score, very similar embeddings)
394        for i in 0..10 {
395            results.push(make_result(0.95));
396            let mut emb = vec![1.0, 0.0, 0.0, 0.0, 0.0];
397            emb[0] += i as f32 * 0.01; // tiny variation
398            embeddings.push(emb);
399        }
400
401        // 5 diverse vectors (slightly lower score, orthogonal embeddings)
402        let diverse_dirs = [
403            vec![0.0, 1.0, 0.0, 0.0, 0.0],
404            vec![0.0, 0.0, 1.0, 0.0, 0.0],
405            vec![0.0, 0.0, 0.0, 1.0, 0.0],
406            vec![0.0, 0.0, 0.0, 0.0, 1.0],
407            vec![0.5, 0.5, 0.5, 0.0, 0.0],
408        ];
409        for dir in &diverse_dirs {
410            results.push(make_result(0.85));
411            embeddings.push(dir.clone());
412        }
413
414        let reranker = SpectralReranker::new(SpectralSelectConfig {
415            min_candidates: 5,
416            lambda: 0.3, // favor diversity
417            ..Default::default()
418        });
419
420        let selected = reranker.rerank(&results, &embeddings, 5);
421        assert_eq!(selected.len(), 5);
422
423        // Count how many of the selected are from the diverse set (indices 10-14)
424        let diverse_count = selected.iter().filter(|&&idx| idx >= 10).count();
425        // With lambda=0.3 (diversity-favoring), we should pick at least 3 diverse items
426        assert!(
427            diverse_count >= 3,
428            "Expected at least 3 diverse items, got {}. Selected: {:?}",
429            diverse_count,
430            selected
431        );
432    }
433
434    #[test]
435    fn test_lambda_one_approximates_topk() {
436        // With lambda=1.0, relevance dominates -- should approximate top-k by score
437        let mut results = Vec::new();
438        let mut embeddings = Vec::new();
439
440        for i in 0..15 {
441            let score = 1.0 - i as f32 * 0.05;
442            results.push(make_result(score));
443            // Even with diverse embeddings, high lambda should prefer high scores
444            let mut emb = vec![0.0; 10];
445            emb[i % 10] = 1.0;
446            embeddings.push(emb);
447        }
448
449        let reranker = SpectralReranker::new(SpectralSelectConfig {
450            min_candidates: 5,
451            lambda: 1.0,
452            ..Default::default()
453        });
454
455        let selected = reranker.rerank(&results, &embeddings, 5);
456        assert_eq!(selected.len(), 5);
457
458        // The top 5 by score are indices 0..5
459        // With lambda=1.0 and diverse embeddings, all top-5 should be selected
460        // (since they're all diverse AND high-scoring)
461        for &idx in &selected {
462            assert!(
463                idx < 7,
464                "Expected top items, got index {}. Selected: {:?}",
465                idx,
466                selected
467            );
468        }
469    }
470
471    #[test]
472    fn test_k_equals_one() {
473        // k=1 should pick the single best item (highest diagonal = highest score * self-sim)
474        let results = vec![make_result(0.5), make_result(0.9), make_result(0.7)];
475        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
476
477        let reranker = SpectralReranker::new(SpectralSelectConfig {
478            min_candidates: 2,
479            ..Default::default()
480        });
481
482        let selected = reranker.rerank(&results, &embeddings, 1);
483        assert_eq!(selected.len(), 1);
484        // Index 1 has highest score (0.9), should be selected
485        assert_eq!(selected[0], 1);
486    }
487
488    #[test]
489    fn test_greedy_determinism() {
490        // Same input should always produce same output
491        let results: Vec<SearchResult> = (0..12)
492            .map(|i| make_result(0.9 - i as f32 * 0.05))
493            .collect();
494        let embeddings: Vec<Vec<f32>> = (0..12)
495            .map(|i| {
496                let mut e = vec![0.0; 5];
497                e[i % 5] = 1.0;
498                e
499            })
500            .collect();
501
502        let reranker = SpectralReranker::new(SpectralSelectConfig {
503            min_candidates: 5,
504            ..Default::default()
505        });
506
507        let r1 = reranker.rerank(&results, &embeddings, 4);
508        let r2 = reranker.rerank(&results, &embeddings, 4);
509        assert_eq!(r1, r2);
510    }
511
512    #[test]
513    fn test_performance_200_candidates() {
514        // 200 candidates, 384-dim (all-MiniLM-L6-v2), k=20 should complete quickly
515        let n = 200;
516        let dim = 384;
517        let k = 20;
518
519        let results: Vec<SearchResult> = (0..n)
520            .map(|i| make_result(1.0 - i as f32 / n as f32))
521            .collect();
522
523        // Create pseudo-random embeddings deterministically
524        let embeddings: Vec<Vec<f32>> = (0..n)
525            .map(|i| {
526                (0..dim)
527                    .map(|j| ((i * 7 + j * 13) % 100) as f32 / 100.0)
528                    .collect()
529            })
530            .collect();
531
532        let reranker = SpectralReranker::new(SpectralSelectConfig {
533            min_candidates: 5,
534            ..Default::default()
535        });
536
537        let start = std::time::Instant::now();
538        let selected = reranker.rerank(&results, &embeddings, k);
539        let elapsed = start.elapsed();
540
541        assert_eq!(selected.len(), k);
542        assert!(
543            elapsed.as_millis() < 500,
544            "Performance test: took {}ms, expected <500ms",
545            elapsed.as_millis()
546        );
547    }
548
549    // ── CrossEncoderReranker tests ────────────────────────────────────────
550
551    #[test]
552    fn test_cross_encoder_empty_input() {
553        let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
554        assert!(r.rerank(&[], &[], 5).is_empty());
555    }
556
557    #[test]
558    fn test_cross_encoder_k_zero() {
559        let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
560        let results = vec![make_result(0.9)];
561        let embeddings = vec![vec![1.0, 0.0]];
562        assert!(r.rerank(&results, &embeddings, 0).is_empty());
563    }
564
565    #[test]
566    fn test_cross_encoder_pure_cosine_alpha_zero() {
567        // alpha=0.0 → pure cosine similarity.
568        // query = [1, 0]; doc0 = [1, 0] (cos=1.0); doc1 = [0, 1] (cos=0.0)
569        let query_emb = vec![1.0_f32, 0.0];
570        let r = CrossEncoderReranker::with_alpha(0.0, query_emb);
571
572        let results = vec![make_result(0.5), make_result(0.9)]; // doc1 has higher original score
573        let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]]; // doc0 aligned, doc1 orthogonal
574
575        let selected = r.rerank(&results, &embeddings, 2);
576        // doc0 should rank first (cos=1.0 > cos=0.0)
577        assert_eq!(selected[0], 0);
578    }
579
580    #[test]
581    fn test_cross_encoder_pure_original_alpha_one() {
582        // alpha=1.0 → use original scores unchanged.
583        let r = CrossEncoderReranker::with_alpha(1.0, vec![1.0, 0.0]);
584        let results = vec![make_result(0.3), make_result(0.9), make_result(0.6)];
585        let embeddings = vec![vec![0.0_f32, 1.0]; 3];
586        let selected = r.rerank(&results, &embeddings, 2);
587        // Should be score-descending: indices 1, 2
588        assert_eq!(selected[0], 1); // score 0.9
589        assert_eq!(selected[1], 2); // score 0.6
590    }
591
592    #[test]
593    fn test_cross_encoder_blend_changes_ranking() {
594        // With alpha=0.5 and a query aligned to doc0, doc0 should beat doc1
595        // even though doc1 has a higher original score.
596        let query_emb = vec![1.0_f32, 0.0];
597        let r = CrossEncoderReranker::with_alpha(0.5, query_emb);
598        // doc0: score=0.3, cos=1.0  → joint = 0.5*0.3 + 0.5*1.0 = 0.65
599        // doc1: score=0.9, cos=0.0  → joint = 0.5*0.9 + 0.5*0.0 = 0.45
600        let results = vec![make_result(0.3), make_result(0.9)];
601        let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
602        let selected = r.rerank(&results, &embeddings, 2);
603        assert_eq!(selected[0], 0); // doc0 wins with blend
604    }
605
606    #[test]
607    fn test_cross_encoder_empty_query_embedding_falls_back_to_score_order() {
608        let r = CrossEncoderReranker::with_alpha(0.5, Vec::new());
609        let results = vec![make_result(0.3), make_result(0.9), make_result(0.6)];
610        let embeddings = vec![vec![1.0_f32, 0.0]; 3];
611        let selected = r.rerank(&results, &embeddings, 2);
612        assert_eq!(selected[0], 1); // highest original score
613    }
614
615    #[test]
616    fn test_cross_encoder_k_gte_n_returns_all() {
617        let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
618        let results = vec![make_result(0.8), make_result(0.5)];
619        let embeddings = vec![vec![1.0_f32, 0.0]; 2];
620        let selected = r.rerank(&results, &embeddings, 10);
621        assert_eq!(selected.len(), 2);
622    }
623}