#![allow(unsafe_code)]
use async_trait::async_trait;
use candle_core::{Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::{debug, info, warn};
use super::base::BaseReranker;
use crate::config::reranker::HuggingFaceRerankerConfig;
use crate::core::StoredFact;
use crate::error::{NeomemxError, Result};
pub struct HuggingFaceReranker {
config: HuggingFaceRerankerConfig,
model: BertModel,
tokenizer: Tokenizer,
device: Device,
}
impl HuggingFaceReranker {
pub fn new(config: HuggingFaceRerankerConfig) -> Result<Self> {
let device = match config.device.as_deref() {
Some("cuda") | Some("gpu") => Device::cuda_if_available(0).unwrap_or(Device::Cpu),
_ => Device::Cpu,
};
info!("Loading reranker model: {} on {:?}", config.model, device);
let api = Api::new().map_err(|e| {
NeomemxError::RerankerError(format!("Failed to create HuggingFace API: {}", e))
})?;
let repo = api.repo(Repo::new(config.model.clone(), RepoType::Model));
let tokenizer_path = repo
.get("tokenizer.json")
.map_err(|e| NeomemxError::RerankerError(format!("Failed to get tokenizer: {}", e)))?;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| NeomemxError::RerankerError(format!("Failed to load tokenizer: {}", e)))?;
let config_path = repo.get("config.json").map_err(|e| {
NeomemxError::RerankerError(format!("Failed to get model config: {}", e))
})?;
let config_str = std::fs::read_to_string(config_path)
.map_err(|e| NeomemxError::RerankerError(format!("Failed to read config: {}", e)))?;
let bert_config: BertConfig = serde_json::from_str(&config_str)
.map_err(|e| NeomemxError::RerankerError(format!("Failed to parse config: {}", e)))?;
let weights_path = repo
.get("model.safetensors")
.or_else(|_| repo.get("pytorch_model.bin"))
.map_err(|e| {
NeomemxError::RerankerError(format!("Failed to get model weights: {}", e))
})?;
let vb = if weights_path
.extension()
.map_or(false, |ext| ext == "safetensors")
{
unsafe {
VarBuilder::from_mmaped_safetensors(
&[weights_path],
candle_core::DType::F32,
&device,
)
.map_err(|e| {
NeomemxError::RerankerError(format!("Failed to load safetensors: {}", e))
})?
}
} else {
VarBuilder::from_pth(weights_path, candle_core::DType::F32, &device).map_err(|e| {
NeomemxError::RerankerError(format!("Failed to load pytorch weights: {}", e))
})?
};
let model = BertModel::load(vb, &bert_config).map_err(|e| {
NeomemxError::RerankerError(format!("Failed to load BERT model: {}", e))
})?;
info!("Reranker model loaded successfully");
Ok(Self {
config,
model,
tokenizer,
device,
})
}
fn compute_score(&self, query: &str, document: &str) -> Result<f32> {
let encoding = self
.tokenizer
.encode((query, document), true)
.map_err(|e| NeomemxError::RerankerError(format!("Tokenization failed: {}", e)))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_type_ids = encoding.get_type_ids();
let max_len = self.config.max_length.min(input_ids.len());
let input_ids: Vec<u32> = input_ids[..max_len].to_vec();
let attention_mask: Vec<u32> = attention_mask[..max_len].to_vec();
let token_type_ids: Vec<u32> = token_type_ids[..max_len].to_vec();
let input_ids = Tensor::new(&input_ids[..], &self.device)
.map_err(|e| NeomemxError::RerankerError(format!("Tensor creation failed: {}", e)))?
.unsqueeze(0)
.map_err(|e| NeomemxError::RerankerError(format!("Unsqueeze failed: {}", e)))?;
let attention_mask = Tensor::new(&attention_mask[..], &self.device)
.map_err(|e| NeomemxError::RerankerError(format!("Tensor creation failed: {}", e)))?
.unsqueeze(0)
.map_err(|e| NeomemxError::RerankerError(format!("Unsqueeze failed: {}", e)))?;
let token_type_ids = Tensor::new(&token_type_ids[..], &self.device)
.map_err(|e| NeomemxError::RerankerError(format!("Tensor creation failed: {}", e)))?
.unsqueeze(0)
.map_err(|e| NeomemxError::RerankerError(format!("Unsqueeze failed: {}", e)))?;
let output = self
.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))
.map_err(|e| NeomemxError::RerankerError(format!("Model forward failed: {}", e)))?;
let cls_output = output
.i((0, 0))
.map_err(|e| NeomemxError::RerankerError(format!("Failed to get CLS output: {}", e)))?;
let score: f32 = cls_output
.to_vec1()
.map_err(|e| NeomemxError::RerankerError(format!("Failed to convert to vec: {}", e)))?
.get(0)
.copied()
.unwrap_or(0.0);
Ok(score)
}
fn normalize_scores(scores: &mut [f32]) {
if scores.is_empty() {
return;
}
let min = scores.iter().cloned().fold(f32::INFINITY, f32::min);
let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = max - min + 1e-8;
for score in scores.iter_mut() {
*score = (*score - min) / range;
}
}
}
#[async_trait]
impl BaseReranker for HuggingFaceReranker {
async fn rerank(
&self,
query: &str,
documents: Vec<StoredFact>,
top_k: Option<usize>,
) -> Result<Vec<StoredFact>> {
if documents.is_empty() {
return Ok(documents);
}
debug!("Reranking {} documents", documents.len());
let mut scores: Vec<f32> = Vec::with_capacity(documents.len());
for doc in &documents {
match self.compute_score(query, &doc.content) {
Ok(score) => scores.push(score),
Err(e) => {
warn!("Failed to compute score for document: {}", e);
scores.push(0.0);
}
}
}
if self.config.normalize {
Self::normalize_scores(&mut scores);
}
let mut doc_scores: Vec<(StoredFact, f32)> =
documents.into_iter().zip(scores.into_iter()).collect();
doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k = top_k.or(self.config.top_k);
if let Some(k) = top_k {
doc_scores.truncate(k);
}
let results: Vec<StoredFact> = doc_scores
.into_iter()
.map(|(mut fact, score)| {
fact.relevance_score = Some(score);
fact
})
.collect();
Ok(results)
}
}