use crate::info_print;
use anyhow::Result;
use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
pub const RERANK_WEIGHT: f32 = 0.575;
pub const RRF_WEIGHT: f32 = 0.425;
#[allow(dead_code)] pub struct NeuralReranker {
reranker: TextRerank,
model_name: String,
}
impl NeuralReranker {
pub fn new() -> Result<Self> {
Self::with_model(RerankerModel::JINARerankerV1TurboEn)
}
pub fn with_model(model: RerankerModel) -> Result<Self> {
let model_name = model.to_string();
info_print!("Loading reranker model: {}", model_name);
let mut options = RerankInitOptions::default();
options.model_name = model;
options.show_download_progress = false;
let reranker = TextRerank::try_new(options)?;
info_print!("Reranker model loaded successfully!");
Ok(Self {
reranker,
model_name,
})
}
#[allow(dead_code)] pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
if documents.is_empty() {
return Ok(vec![]);
}
let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
let results = self.reranker.rerank(
query, doc_refs, false, None, )?;
Ok(results.into_iter().map(|r| (r.index, r.score)).collect())
}
pub fn rerank_and_blend(
&mut self,
query: &str,
documents: &[String],
rrf_scores: &[f32],
) -> Result<Vec<(usize, f32)>> {
if documents.is_empty() {
return Ok(vec![]);
}
assert_eq!(
documents.len(),
rrf_scores.len(),
"Documents and RRF scores must have same length"
);
let rerank_results = self.rerank(query, documents)?;
let normalized: Vec<(usize, f32)> = rerank_results
.iter()
.map(|(idx, score)| (*idx, sigmoid(*score)))
.collect();
let rrf_min = rrf_scores.iter().cloned().fold(f32::INFINITY, f32::min);
let rrf_max = rrf_scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let rrf_range = (rrf_max - rrf_min).max(0.0001);
let mut blended: Vec<(usize, f32)> = normalized
.into_iter()
.map(|(idx, rerank_norm)| {
let rrf_norm = (rrf_scores[idx] - rrf_min) / rrf_range;
let blended_score = RERANK_WEIGHT * rerank_norm + RRF_WEIGHT * rrf_norm;
(idx, blended_score)
})
.collect();
blended.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(blended)
}
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 0.0001);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
#[ignore] fn test_reranker_creation() {
let reranker = NeuralReranker::new();
assert!(reranker.is_ok());
}
#[test]
#[ignore] fn test_rerank_basic() {
let mut reranker = NeuralReranker::new().unwrap();
let query = "How do I authenticate users?";
let documents = vec![
"fn authenticate(user: &str, password: &str) -> bool { ... }".to_string(),
"fn calculate_sum(a: i32, b: i32) -> i32 { a + b }".to_string(),
"impl UserAuth for App { fn login(&self, credentials: Credentials) -> Result<Token> }"
.to_string(),
];
let results = reranker.rerank(query, &documents).unwrap();
assert_eq!(results.len(), 3);
for i in 0..results.len() - 1 {
assert!(results[i].1 >= results[i + 1].1);
}
}
}