1pub mod graph_ops;
27pub mod kernel;
28pub mod linalg;
29
30use brainwires_core::SearchResult;
31use kernel::{build_kernel_matrix, cross_column};
32use linalg::{cholesky_extend, log_det_incremental};
33use ndarray::Array2;
34
35#[derive(Debug, Clone)]
37pub struct SpectralSelectConfig {
38 pub k: Option<usize>,
40 pub lambda: f32,
45 pub min_candidates: usize,
48 pub regularization: f32,
50}
51
52impl Default for SpectralSelectConfig {
53 fn default() -> Self {
54 Self {
55 k: None,
56 lambda: 0.5,
57 min_candidates: 10,
58 regularization: 1e-6,
59 }
60 }
61}
62
63pub trait DiversityReranker: Send + Sync {
65 fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize>;
77}
78
79pub struct SpectralReranker {
81 config: SpectralSelectConfig,
82}
83
84impl SpectralReranker {
85 pub fn new(config: SpectralSelectConfig) -> Self {
87 Self { config }
88 }
89
90 pub fn with_defaults() -> Self {
92 Self::new(SpectralSelectConfig::default())
93 }
94}
95
96impl DiversityReranker for SpectralReranker {
97 fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
98 let n = results.len();
99
100 if n == 0 {
102 return Vec::new();
103 }
104 if k >= n {
105 return (0..n).collect();
106 }
107 if k == 0 {
108 return Vec::new();
109 }
110
111 if n < self.config.min_candidates {
113 return (0..k.min(n)).collect();
114 }
115
116 let embedding_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
118 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
119 let kernel = build_kernel_matrix(
120 &embedding_refs,
121 &scores,
122 self.config.lambda,
123 self.config.regularization,
124 );
125
126 greedy_log_det_select(&kernel, k)
127 }
128}
129
130fn greedy_log_det_select(kernel: &Array2<f32>, k: usize) -> Vec<usize> {
135 let n = kernel.nrows();
136 let mut selected: Vec<usize> = Vec::with_capacity(k);
137 let mut remaining: Vec<bool> = vec![true; n];
138
139 let mut chol_s: Option<Array2<f32>> = None;
141 let mut current_log_det: f32 = 0.0;
142
143 for round in 0..k {
144 let mut best_idx = usize::MAX;
145 let mut best_gain = f32::NEG_INFINITY;
146
147 for c in 0..n {
148 if !remaining[c] {
149 continue;
150 }
151
152 let gain = if round == 0 {
153 let diag = kernel[[c, c]];
155 if diag > 0.0 {
156 diag.ln()
157 } else {
158 f32::NEG_INFINITY
159 }
160 } else {
161 let cross = cross_column(kernel, &selected, c);
163 let diag_cc = kernel[[c, c]];
164 let new_ld = log_det_incremental(
165 chol_s.as_ref().expect(
166 "chol_s is initialized in round 0 before any incremental round runs",
167 ),
168 &cross,
169 diag_cc,
170 current_log_det,
171 );
172 new_ld - current_log_det
173 };
174
175 if gain > best_gain {
176 best_gain = gain;
177 best_idx = c;
178 }
179 }
180
181 if best_idx == usize::MAX || best_gain == f32::NEG_INFINITY {
182 break;
184 }
185
186 if round == 0 {
188 let diag = kernel[[best_idx, best_idx]];
189 let mut l = Array2::<f32>::zeros((1, 1));
190 l[[0, 0]] = diag.sqrt();
191 chol_s = Some(l);
192 current_log_det = diag.ln();
193 } else {
194 let cross = cross_column(kernel, &selected, best_idx);
195 let diag_cc = kernel[[best_idx, best_idx]];
196 chol_s = Some(
197 cholesky_extend(
198 chol_s.as_ref().expect(
199 "chol_s is initialized in round 0 before any incremental round runs",
200 ),
201 &cross,
202 diag_cc,
203 )
204 .expect("Cholesky extend failed after positive gain check"),
205 );
206 current_log_det += best_gain;
207 }
208
209 selected.push(best_idx);
210 remaining[best_idx] = false;
211 }
212
213 selected
214}
215
216#[derive(Debug, Clone)]
220pub struct CrossEncoderConfig {
221 pub alpha: f32,
228 pub query_embedding: Vec<f32>,
232}
233
234impl Default for CrossEncoderConfig {
235 fn default() -> Self {
236 Self {
237 alpha: 0.5,
238 query_embedding: Vec::new(),
239 }
240 }
241}
242
243pub struct CrossEncoderReranker {
250 config: CrossEncoderConfig,
251}
252
253impl CrossEncoderReranker {
254 pub fn new(config: CrossEncoderConfig) -> Self {
256 Self { config }
257 }
258
259 pub fn with_alpha(alpha: f32, query_embedding: Vec<f32>) -> Self {
261 Self::new(CrossEncoderConfig {
262 alpha,
263 query_embedding,
264 })
265 }
266}
267
268impl DiversityReranker for CrossEncoderReranker {
269 fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
270 let n = results.len();
271 if n == 0 || k == 0 {
272 return Vec::new();
273 }
274 if k >= n {
275 return (0..n).collect();
276 }
277
278 if self.config.query_embedding.is_empty() {
280 let mut indices: Vec<usize> = (0..n).collect();
281 indices.sort_by(|&a, &b| {
282 results[b]
283 .score
284 .partial_cmp(&results[a].score)
285 .unwrap_or(std::cmp::Ordering::Equal)
286 });
287 return indices.into_iter().take(k).collect();
288 }
289
290 let query_emb = &self.config.query_embedding;
291 let alpha = self.config.alpha.clamp(0.0, 1.0);
292
293 let mut scored: Vec<(usize, f32)> = (0..n)
294 .map(|i| {
295 let cos = if i < embeddings.len() {
296 kernel::cosine_similarity(query_emb, &embeddings[i])
297 } else {
298 0.0
299 };
300 let joint = alpha * results[i].score + (1.0 - alpha) * cos;
301 (i, joint)
302 })
303 .collect();
304
305 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
306 scored.into_iter().take(k).map(|(i, _)| i).collect()
307 }
308}
309
310pub enum RerankerKind {
312 Spectral(SpectralSelectConfig),
314 CrossEncoder(CrossEncoderConfig),
316 Both {
319 spectral: SpectralSelectConfig,
321 cross_encoder: CrossEncoderConfig,
323 },
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 fn make_result(score: f32) -> SearchResult {
331 SearchResult {
332 file_path: String::new(),
333 root_path: None,
334 content: String::new(),
335 score,
336 vector_score: score,
337 keyword_score: None,
338 start_line: 0,
339 end_line: 0,
340 language: String::new(),
341 project: None,
342 indexed_at: 0,
343 }
344 }
345
346 #[test]
347 fn test_empty_input() {
348 let reranker = SpectralReranker::with_defaults();
349 let result = reranker.rerank(&[], &[], 5);
350 assert!(result.is_empty());
351 }
352
353 #[test]
354 fn test_k_zero() {
355 let reranker = SpectralReranker::with_defaults();
356 let results = vec![make_result(0.9)];
357 let embeddings = vec![vec![1.0, 0.0]];
358 let result = reranker.rerank(&results, &embeddings, 0);
359 assert!(result.is_empty());
360 }
361
362 #[test]
363 fn test_k_greater_than_n() {
364 let reranker = SpectralReranker::with_defaults();
365 let results = vec![make_result(0.9), make_result(0.8)];
366 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
367 let result = reranker.rerank(&results, &embeddings, 10);
368 assert_eq!(result.len(), 2);
369 }
370
371 #[test]
372 fn test_below_min_candidates() {
373 let config = SpectralSelectConfig {
374 min_candidates: 20,
375 ..Default::default()
376 };
377 let reranker = SpectralReranker::new(config);
378 let results: Vec<SearchResult> =
379 (0..5).map(|i| make_result(0.9 - i as f32 * 0.1)).collect();
380 let embeddings: Vec<Vec<f32>> = (0..5).map(|i| vec![i as f32, 0.0]).collect();
381 let result = reranker.rerank(&results, &embeddings, 3);
382 assert_eq!(result, vec![0, 1, 2]);
384 }
385
386 #[test]
387 fn test_spectral_prefers_diverse() {
388 let mut results = Vec::new();
391 let mut embeddings = Vec::new();
392
393 for i in 0..10 {
395 results.push(make_result(0.95));
396 let mut emb = vec![1.0, 0.0, 0.0, 0.0, 0.0];
397 emb[0] += i as f32 * 0.01; embeddings.push(emb);
399 }
400
401 let diverse_dirs = [
403 vec![0.0, 1.0, 0.0, 0.0, 0.0],
404 vec![0.0, 0.0, 1.0, 0.0, 0.0],
405 vec![0.0, 0.0, 0.0, 1.0, 0.0],
406 vec![0.0, 0.0, 0.0, 0.0, 1.0],
407 vec![0.5, 0.5, 0.5, 0.0, 0.0],
408 ];
409 for dir in &diverse_dirs {
410 results.push(make_result(0.85));
411 embeddings.push(dir.clone());
412 }
413
414 let reranker = SpectralReranker::new(SpectralSelectConfig {
415 min_candidates: 5,
416 lambda: 0.3, ..Default::default()
418 });
419
420 let selected = reranker.rerank(&results, &embeddings, 5);
421 assert_eq!(selected.len(), 5);
422
423 let diverse_count = selected.iter().filter(|&&idx| idx >= 10).count();
425 assert!(
427 diverse_count >= 3,
428 "Expected at least 3 diverse items, got {}. Selected: {:?}",
429 diverse_count,
430 selected
431 );
432 }
433
434 #[test]
435 fn test_lambda_one_approximates_topk() {
436 let mut results = Vec::new();
438 let mut embeddings = Vec::new();
439
440 for i in 0..15 {
441 let score = 1.0 - i as f32 * 0.05;
442 results.push(make_result(score));
443 let mut emb = vec![0.0; 10];
445 emb[i % 10] = 1.0;
446 embeddings.push(emb);
447 }
448
449 let reranker = SpectralReranker::new(SpectralSelectConfig {
450 min_candidates: 5,
451 lambda: 1.0,
452 ..Default::default()
453 });
454
455 let selected = reranker.rerank(&results, &embeddings, 5);
456 assert_eq!(selected.len(), 5);
457
458 for &idx in &selected {
462 assert!(
463 idx < 7,
464 "Expected top items, got index {}. Selected: {:?}",
465 idx,
466 selected
467 );
468 }
469 }
470
471 #[test]
472 fn test_k_equals_one() {
473 let results = vec![make_result(0.5), make_result(0.9), make_result(0.7)];
475 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
476
477 let reranker = SpectralReranker::new(SpectralSelectConfig {
478 min_candidates: 2,
479 ..Default::default()
480 });
481
482 let selected = reranker.rerank(&results, &embeddings, 1);
483 assert_eq!(selected.len(), 1);
484 assert_eq!(selected[0], 1);
486 }
487
488 #[test]
489 fn test_greedy_determinism() {
490 let results: Vec<SearchResult> = (0..12)
492 .map(|i| make_result(0.9 - i as f32 * 0.05))
493 .collect();
494 let embeddings: Vec<Vec<f32>> = (0..12)
495 .map(|i| {
496 let mut e = vec![0.0; 5];
497 e[i % 5] = 1.0;
498 e
499 })
500 .collect();
501
502 let reranker = SpectralReranker::new(SpectralSelectConfig {
503 min_candidates: 5,
504 ..Default::default()
505 });
506
507 let r1 = reranker.rerank(&results, &embeddings, 4);
508 let r2 = reranker.rerank(&results, &embeddings, 4);
509 assert_eq!(r1, r2);
510 }
511
512 #[test]
513 fn test_performance_200_candidates() {
514 let n = 200;
516 let dim = 384;
517 let k = 20;
518
519 let results: Vec<SearchResult> = (0..n)
520 .map(|i| make_result(1.0 - i as f32 / n as f32))
521 .collect();
522
523 let embeddings: Vec<Vec<f32>> = (0..n)
525 .map(|i| {
526 (0..dim)
527 .map(|j| ((i * 7 + j * 13) % 100) as f32 / 100.0)
528 .collect()
529 })
530 .collect();
531
532 let reranker = SpectralReranker::new(SpectralSelectConfig {
533 min_candidates: 5,
534 ..Default::default()
535 });
536
537 let start = std::time::Instant::now();
538 let selected = reranker.rerank(&results, &embeddings, k);
539 let elapsed = start.elapsed();
540
541 assert_eq!(selected.len(), k);
542 assert!(
543 elapsed.as_millis() < 500,
544 "Performance test: took {}ms, expected <500ms",
545 elapsed.as_millis()
546 );
547 }
548
549 #[test]
552 fn test_cross_encoder_empty_input() {
553 let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
554 assert!(r.rerank(&[], &[], 5).is_empty());
555 }
556
557 #[test]
558 fn test_cross_encoder_k_zero() {
559 let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
560 let results = vec![make_result(0.9)];
561 let embeddings = vec![vec![1.0, 0.0]];
562 assert!(r.rerank(&results, &embeddings, 0).is_empty());
563 }
564
565 #[test]
566 fn test_cross_encoder_pure_cosine_alpha_zero() {
567 let query_emb = vec![1.0_f32, 0.0];
570 let r = CrossEncoderReranker::with_alpha(0.0, query_emb);
571
572 let results = vec![make_result(0.5), make_result(0.9)]; let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]]; let selected = r.rerank(&results, &embeddings, 2);
576 assert_eq!(selected[0], 0);
578 }
579
580 #[test]
581 fn test_cross_encoder_pure_original_alpha_one() {
582 let r = CrossEncoderReranker::with_alpha(1.0, vec![1.0, 0.0]);
584 let results = vec![make_result(0.3), make_result(0.9), make_result(0.6)];
585 let embeddings = vec![vec![0.0_f32, 1.0]; 3];
586 let selected = r.rerank(&results, &embeddings, 2);
587 assert_eq!(selected[0], 1); assert_eq!(selected[1], 2); }
591
592 #[test]
593 fn test_cross_encoder_blend_changes_ranking() {
594 let query_emb = vec![1.0_f32, 0.0];
597 let r = CrossEncoderReranker::with_alpha(0.5, query_emb);
598 let results = vec![make_result(0.3), make_result(0.9)];
601 let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
602 let selected = r.rerank(&results, &embeddings, 2);
603 assert_eq!(selected[0], 0); }
605
606 #[test]
607 fn test_cross_encoder_empty_query_embedding_falls_back_to_score_order() {
608 let r = CrossEncoderReranker::with_alpha(0.5, Vec::new());
609 let results = vec![make_result(0.3), make_result(0.9), make_result(0.6)];
610 let embeddings = vec![vec![1.0_f32, 0.0]; 3];
611 let selected = r.rerank(&results, &embeddings, 2);
612 assert_eq!(selected[0], 1); }
614
615 #[test]
616 fn test_cross_encoder_k_gte_n_returns_all() {
617 let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
618 let results = vec![make_result(0.8), make_result(0.5)];
619 let embeddings = vec![vec![1.0_f32, 0.0]; 2];
620 let selected = r.rerank(&results, &embeddings, 10);
621 assert_eq!(selected.len(), 2);
622 }
623}