pub mod graph_ops;
pub mod kernel;
pub mod linalg;
use brainwires_core::SearchResult;
use kernel::{build_kernel_matrix, cross_column};
use linalg::{cholesky_extend, log_det_incremental};
use ndarray::Array2;
#[derive(Debug, Clone)]
pub struct SpectralSelectConfig {
pub k: Option<usize>,
pub lambda: f32,
pub min_candidates: usize,
pub regularization: f32,
}
impl Default for SpectralSelectConfig {
fn default() -> Self {
Self {
k: None,
lambda: 0.5,
min_candidates: 10,
regularization: 1e-6,
}
}
}
pub trait DiversityReranker: Send + Sync {
fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize>;
}
pub struct SpectralReranker {
config: SpectralSelectConfig,
}
impl SpectralReranker {
pub fn new(config: SpectralSelectConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(SpectralSelectConfig::default())
}
}
impl DiversityReranker for SpectralReranker {
fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
let n = results.len();
if n == 0 {
return Vec::new();
}
if k >= n {
return (0..n).collect();
}
if k == 0 {
return Vec::new();
}
if n < self.config.min_candidates {
return (0..k.min(n)).collect();
}
let embedding_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let kernel = build_kernel_matrix(
&embedding_refs,
&scores,
self.config.lambda,
self.config.regularization,
);
greedy_log_det_select(&kernel, k)
}
}
fn greedy_log_det_select(kernel: &Array2<f32>, k: usize) -> Vec<usize> {
let n = kernel.nrows();
let mut selected: Vec<usize> = Vec::with_capacity(k);
let mut remaining: Vec<bool> = vec![true; n];
let mut chol_s: Option<Array2<f32>> = None;
let mut current_log_det: f32 = 0.0;
for round in 0..k {
let mut best_idx = usize::MAX;
let mut best_gain = f32::NEG_INFINITY;
for c in 0..n {
if !remaining[c] {
continue;
}
let gain = if round == 0 {
let diag = kernel[[c, c]];
if diag > 0.0 {
diag.ln()
} else {
f32::NEG_INFINITY
}
} else {
let cross = cross_column(kernel, &selected, c);
let diag_cc = kernel[[c, c]];
let new_ld =
log_det_incremental(chol_s.as_ref().unwrap(), &cross, diag_cc, current_log_det);
new_ld - current_log_det
};
if gain > best_gain {
best_gain = gain;
best_idx = c;
}
}
if best_idx == usize::MAX || best_gain == f32::NEG_INFINITY {
break;
}
if round == 0 {
let diag = kernel[[best_idx, best_idx]];
let mut l = Array2::<f32>::zeros((1, 1));
l[[0, 0]] = diag.sqrt();
chol_s = Some(l);
current_log_det = diag.ln();
} else {
let cross = cross_column(kernel, &selected, best_idx);
let diag_cc = kernel[[best_idx, best_idx]];
chol_s = Some(
cholesky_extend(chol_s.as_ref().unwrap(), &cross, diag_cc)
.expect("Cholesky extend failed after positive gain check"),
);
current_log_det += best_gain;
}
selected.push(best_idx);
remaining[best_idx] = false;
}
selected
}
#[derive(Debug, Clone)]
pub struct CrossEncoderConfig {
pub alpha: f32,
pub query_embedding: Vec<f32>,
}
impl Default for CrossEncoderConfig {
fn default() -> Self {
Self {
alpha: 0.5,
query_embedding: Vec::new(),
}
}
}
pub struct CrossEncoderReranker {
config: CrossEncoderConfig,
}
impl CrossEncoderReranker {
pub fn new(config: CrossEncoderConfig) -> Self {
Self { config }
}
pub fn with_alpha(alpha: f32, query_embedding: Vec<f32>) -> Self {
Self::new(CrossEncoderConfig {
alpha,
query_embedding,
})
}
}
impl DiversityReranker for CrossEncoderReranker {
fn rerank(&self, results: &[SearchResult], embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
let n = results.len();
if n == 0 || k == 0 {
return Vec::new();
}
if k >= n {
return (0..n).collect();
}
if self.config.query_embedding.is_empty() {
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| {
results[b]
.score
.partial_cmp(&results[a].score)
.unwrap_or(std::cmp::Ordering::Equal)
});
return indices.into_iter().take(k).collect();
}
let query_emb = &self.config.query_embedding;
let alpha = self.config.alpha.clamp(0.0, 1.0);
let mut scored: Vec<(usize, f32)> = (0..n)
.map(|i| {
let cos = if i < embeddings.len() {
kernel::cosine_similarity(query_emb, &embeddings[i])
} else {
0.0
};
let joint = alpha * results[i].score + (1.0 - alpha) * cos;
(i, joint)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(k).map(|(i, _)| i).collect()
}
}
pub enum RerankerKind {
Spectral(SpectralSelectConfig),
CrossEncoder(CrossEncoderConfig),
Both {
spectral: SpectralSelectConfig,
cross_encoder: CrossEncoderConfig,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn make_result(score: f32) -> SearchResult {
SearchResult {
file_path: String::new(),
root_path: None,
content: String::new(),
score,
vector_score: score,
keyword_score: None,
start_line: 0,
end_line: 0,
language: String::new(),
project: None,
indexed_at: 0,
}
}
#[test]
fn test_empty_input() {
let reranker = SpectralReranker::with_defaults();
let result = reranker.rerank(&[], &[], 5);
assert!(result.is_empty());
}
#[test]
fn test_k_zero() {
let reranker = SpectralReranker::with_defaults();
let results = vec![make_result(0.9)];
let embeddings = vec![vec![1.0, 0.0]];
let result = reranker.rerank(&results, &embeddings, 0);
assert!(result.is_empty());
}
#[test]
fn test_k_greater_than_n() {
let reranker = SpectralReranker::with_defaults();
let results = vec![make_result(0.9), make_result(0.8)];
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let result = reranker.rerank(&results, &embeddings, 10);
assert_eq!(result.len(), 2);
}
#[test]
fn test_below_min_candidates() {
let config = SpectralSelectConfig {
min_candidates: 20,
..Default::default()
};
let reranker = SpectralReranker::new(config);
let results: Vec<SearchResult> =
(0..5).map(|i| make_result(0.9 - i as f32 * 0.1)).collect();
let embeddings: Vec<Vec<f32>> = (0..5).map(|i| vec![i as f32, 0.0]).collect();
let result = reranker.rerank(&results, &embeddings, 3);
assert_eq!(result, vec![0, 1, 2]);
}
#[test]
fn test_spectral_prefers_diverse() {
let mut results = Vec::new();
let mut embeddings = Vec::new();
for i in 0..10 {
results.push(make_result(0.95));
let mut emb = vec![1.0, 0.0, 0.0, 0.0, 0.0];
emb[0] += i as f32 * 0.01; embeddings.push(emb);
}
let diverse_dirs = [
vec![0.0, 1.0, 0.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 0.0, 1.0],
vec![0.5, 0.5, 0.5, 0.0, 0.0],
];
for dir in &diverse_dirs {
results.push(make_result(0.85));
embeddings.push(dir.clone());
}
let reranker = SpectralReranker::new(SpectralSelectConfig {
min_candidates: 5,
lambda: 0.3, ..Default::default()
});
let selected = reranker.rerank(&results, &embeddings, 5);
assert_eq!(selected.len(), 5);
let diverse_count = selected.iter().filter(|&&idx| idx >= 10).count();
assert!(
diverse_count >= 3,
"Expected at least 3 diverse items, got {}. Selected: {:?}",
diverse_count,
selected
);
}
#[test]
fn test_lambda_one_approximates_topk() {
let mut results = Vec::new();
let mut embeddings = Vec::new();
for i in 0..15 {
let score = 1.0 - i as f32 * 0.05;
results.push(make_result(score));
let mut emb = vec![0.0; 10];
emb[i % 10] = 1.0;
embeddings.push(emb);
}
let reranker = SpectralReranker::new(SpectralSelectConfig {
min_candidates: 5,
lambda: 1.0,
..Default::default()
});
let selected = reranker.rerank(&results, &embeddings, 5);
assert_eq!(selected.len(), 5);
for &idx in &selected {
assert!(
idx < 7,
"Expected top items, got index {}. Selected: {:?}",
idx,
selected
);
}
}
#[test]
fn test_k_equals_one() {
let results = vec![make_result(0.5), make_result(0.9), make_result(0.7)];
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let reranker = SpectralReranker::new(SpectralSelectConfig {
min_candidates: 2,
..Default::default()
});
let selected = reranker.rerank(&results, &embeddings, 1);
assert_eq!(selected.len(), 1);
assert_eq!(selected[0], 1);
}
#[test]
fn test_greedy_determinism() {
let results: Vec<SearchResult> = (0..12)
.map(|i| make_result(0.9 - i as f32 * 0.05))
.collect();
let embeddings: Vec<Vec<f32>> = (0..12)
.map(|i| {
let mut e = vec![0.0; 5];
e[i % 5] = 1.0;
e
})
.collect();
let reranker = SpectralReranker::new(SpectralSelectConfig {
min_candidates: 5,
..Default::default()
});
let r1 = reranker.rerank(&results, &embeddings, 4);
let r2 = reranker.rerank(&results, &embeddings, 4);
assert_eq!(r1, r2);
}
#[test]
fn test_performance_200_candidates() {
let n = 200;
let dim = 384;
let k = 20;
let results: Vec<SearchResult> = (0..n)
.map(|i| make_result(1.0 - i as f32 / n as f32))
.collect();
let embeddings: Vec<Vec<f32>> = (0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * 7 + j * 13) % 100) as f32 / 100.0)
.collect()
})
.collect();
let reranker = SpectralReranker::new(SpectralSelectConfig {
min_candidates: 5,
..Default::default()
});
let start = std::time::Instant::now();
let selected = reranker.rerank(&results, &embeddings, k);
let elapsed = start.elapsed();
assert_eq!(selected.len(), k);
assert!(
elapsed.as_millis() < 500,
"Performance test: took {}ms, expected <500ms",
elapsed.as_millis()
);
}
#[test]
fn test_cross_encoder_empty_input() {
let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
assert!(r.rerank(&[], &[], 5).is_empty());
}
#[test]
fn test_cross_encoder_k_zero() {
let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
let results = vec![make_result(0.9)];
let embeddings = vec![vec![1.0, 0.0]];
assert!(r.rerank(&results, &embeddings, 0).is_empty());
}
#[test]
fn test_cross_encoder_pure_cosine_alpha_zero() {
let query_emb = vec![1.0_f32, 0.0];
let r = CrossEncoderReranker::with_alpha(0.0, query_emb);
let results = vec![make_result(0.5), make_result(0.9)]; let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
let selected = r.rerank(&results, &embeddings, 2);
assert_eq!(selected[0], 0);
}
#[test]
fn test_cross_encoder_pure_original_alpha_one() {
let r = CrossEncoderReranker::with_alpha(1.0, vec![1.0, 0.0]);
let results = vec![make_result(0.3), make_result(0.9), make_result(0.6)];
let embeddings = vec![vec![0.0_f32, 1.0]; 3];
let selected = r.rerank(&results, &embeddings, 2);
assert_eq!(selected[0], 1); assert_eq!(selected[1], 2); }
#[test]
fn test_cross_encoder_blend_changes_ranking() {
let query_emb = vec![1.0_f32, 0.0];
let r = CrossEncoderReranker::with_alpha(0.5, query_emb);
let results = vec![make_result(0.3), make_result(0.9)];
let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
let selected = r.rerank(&results, &embeddings, 2);
assert_eq!(selected[0], 0); }
#[test]
fn test_cross_encoder_empty_query_embedding_falls_back_to_score_order() {
let r = CrossEncoderReranker::with_alpha(0.5, Vec::new());
let results = vec![make_result(0.3), make_result(0.9), make_result(0.6)];
let embeddings = vec![vec![1.0_f32, 0.0]; 3];
let selected = r.rerank(&results, &embeddings, 2);
assert_eq!(selected[0], 1); }
#[test]
fn test_cross_encoder_k_gte_n_returns_all() {
let r = CrossEncoderReranker::with_alpha(0.5, vec![1.0, 0.0]);
let results = vec![make_result(0.8), make_result(0.5)];
let embeddings = vec![vec![1.0_f32, 0.0]; 2];
let selected = r.rerank(&results, &embeddings, 10);
assert_eq!(selected.len(), 2);
}
}