use anyhow::{Context, Result};
use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
use tracing::info;
use super::model_cache_dir;
pub struct CrossEncoderEngine {
model: TextRerank,
}
impl CrossEncoderEngine {
pub fn load() -> Result<Self> {
if super::is_reranker_model_cached() {
info!("Loading reranker model...");
} else {
info!("Downloading reranker model (~1.1GB, first time only)...");
}
let model = TextRerank::try_new(
RerankInitOptions::new(RerankerModel::BGERerankerBase)
.with_cache_dir(model_cache_dir())
.with_show_download_progress(true),
)
.context("Failed to initialize cross-encoder model")?;
Ok(Self { model })
}
pub fn score_batch(&mut self, query: &str, documents: &[&str]) -> Result<Vec<f32>> {
if documents.is_empty() {
return Ok(Vec::new());
}
let results = self
.model
.rerank(query, documents, false, None)
.context("Cross-encoder batch scoring failed")?;
let mut scores = vec![0.0f32; documents.len()];
for r in &results {
scores[r.index] = r.score;
}
Ok(scores)
}
}