neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! HuggingFace reranker implementation using Candle
//!
//! Note: This module uses unsafe code for memory-mapped safetensors loading,
//! which is required by the candle library for efficient model loading.

#![allow(unsafe_code)]

use async_trait::async_trait;
use candle_core::{Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::{debug, info, warn};

use super::base::BaseReranker;
use crate::config::reranker::HuggingFaceRerankerConfig;
use crate::core::StoredFact;
use crate::error::{NeomemxError, Result};

/// HuggingFace reranker using cross-encoder models
pub struct HuggingFaceReranker {
    config: HuggingFaceRerankerConfig,
    model: BertModel,
    tokenizer: Tokenizer,
    device: Device,
}

impl HuggingFaceReranker {
    /// Create a new HuggingFace reranker
    pub fn new(config: HuggingFaceRerankerConfig) -> Result<Self> {
        let device = match config.device.as_deref() {
            Some("cuda") | Some("gpu") => Device::cuda_if_available(0).unwrap_or(Device::Cpu),
            _ => Device::Cpu,
        };

        info!("Loading reranker model: {} on {:?}", config.model, device);

        // Load model from HuggingFace Hub
        let api = Api::new().map_err(|e| {
            NeomemxError::RerankerError(format!("Failed to create HuggingFace API: {}", e))
        })?;

        let repo = api.repo(Repo::new(config.model.clone(), RepoType::Model));

        // Load tokenizer
        let tokenizer_path = repo
            .get("tokenizer.json")
            .map_err(|e| NeomemxError::RerankerError(format!("Failed to get tokenizer: {}", e)))?;

        let tokenizer = Tokenizer::from_file(tokenizer_path)
            .map_err(|e| NeomemxError::RerankerError(format!("Failed to load tokenizer: {}", e)))?;

        // Load model config
        let config_path = repo.get("config.json").map_err(|e| {
            NeomemxError::RerankerError(format!("Failed to get model config: {}", e))
        })?;

        let config_str = std::fs::read_to_string(config_path)
            .map_err(|e| NeomemxError::RerankerError(format!("Failed to read config: {}", e)))?;

        let bert_config: BertConfig = serde_json::from_str(&config_str)
            .map_err(|e| NeomemxError::RerankerError(format!("Failed to parse config: {}", e)))?;

        // Load model weights
        let weights_path = repo
            .get("model.safetensors")
            .or_else(|_| repo.get("pytorch_model.bin"))
            .map_err(|e| {
                NeomemxError::RerankerError(format!("Failed to get model weights: {}", e))
            })?;

        let vb = if weights_path
            .extension()
            .map_or(false, |ext| ext == "safetensors")
        {
            unsafe {
                VarBuilder::from_mmaped_safetensors(
                    &[weights_path],
                    candle_core::DType::F32,
                    &device,
                )
                .map_err(|e| {
                    NeomemxError::RerankerError(format!("Failed to load safetensors: {}", e))
                })?
            }
        } else {
            VarBuilder::from_pth(weights_path, candle_core::DType::F32, &device).map_err(|e| {
                NeomemxError::RerankerError(format!("Failed to load pytorch weights: {}", e))
            })?
        };

        let model = BertModel::load(vb, &bert_config).map_err(|e| {
            NeomemxError::RerankerError(format!("Failed to load BERT model: {}", e))
        })?;

        info!("Reranker model loaded successfully");

        Ok(Self {
            config,
            model,
            tokenizer,
            device,
        })
    }

    /// Compute relevance score for a query-document pair
    fn compute_score(&self, query: &str, document: &str) -> Result<f32> {
        let encoding = self
            .tokenizer
            .encode((query, document), true)
            .map_err(|e| NeomemxError::RerankerError(format!("Tokenization failed: {}", e)))?;

        let input_ids = encoding.get_ids();
        let attention_mask = encoding.get_attention_mask();
        let token_type_ids = encoding.get_type_ids();

        // Truncate to max length
        let max_len = self.config.max_length.min(input_ids.len());
        let input_ids: Vec<u32> = input_ids[..max_len].to_vec();
        let attention_mask: Vec<u32> = attention_mask[..max_len].to_vec();
        let token_type_ids: Vec<u32> = token_type_ids[..max_len].to_vec();

        // Convert to tensors
        let input_ids = Tensor::new(&input_ids[..], &self.device)
            .map_err(|e| NeomemxError::RerankerError(format!("Tensor creation failed: {}", e)))?
            .unsqueeze(0)
            .map_err(|e| NeomemxError::RerankerError(format!("Unsqueeze failed: {}", e)))?;

        let attention_mask = Tensor::new(&attention_mask[..], &self.device)
            .map_err(|e| NeomemxError::RerankerError(format!("Tensor creation failed: {}", e)))?
            .unsqueeze(0)
            .map_err(|e| NeomemxError::RerankerError(format!("Unsqueeze failed: {}", e)))?;

        let token_type_ids = Tensor::new(&token_type_ids[..], &self.device)
            .map_err(|e| NeomemxError::RerankerError(format!("Tensor creation failed: {}", e)))?
            .unsqueeze(0)
            .map_err(|e| NeomemxError::RerankerError(format!("Unsqueeze failed: {}", e)))?;

        // Forward pass
        let output = self
            .model
            .forward(&input_ids, &token_type_ids, Some(&attention_mask))
            .map_err(|e| NeomemxError::RerankerError(format!("Model forward failed: {}", e)))?;

        // Get the CLS token output (first token) and compute score
        // For cross-encoders, we typically use the pooler output or the first token
        let cls_output = output
            .i((0, 0))
            .map_err(|e| NeomemxError::RerankerError(format!("Failed to get CLS output: {}", e)))?;

        // Take the first element as the score (for single-class relevance)
        let score: f32 = cls_output
            .to_vec1()
            .map_err(|e| NeomemxError::RerankerError(format!("Failed to convert to vec: {}", e)))?
            .get(0)
            .copied()
            .unwrap_or(0.0);

        Ok(score)
    }

    /// Normalize scores to [0, 1] range
    fn normalize_scores(scores: &mut [f32]) {
        if scores.is_empty() {
            return;
        }

        let min = scores.iter().cloned().fold(f32::INFINITY, f32::min);
        let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        let range = max - min + 1e-8;

        for score in scores.iter_mut() {
            *score = (*score - min) / range;
        }
    }
}

#[async_trait]
impl BaseReranker for HuggingFaceReranker {
    async fn rerank(
        &self,
        query: &str,
        documents: Vec<StoredFact>,
        top_k: Option<usize>,
    ) -> Result<Vec<StoredFact>> {
        if documents.is_empty() {
            return Ok(documents);
        }

        debug!("Reranking {} documents", documents.len());

        // Compute scores for all documents
        let mut scores: Vec<f32> = Vec::with_capacity(documents.len());

        for doc in &documents {
            match self.compute_score(query, &doc.content) {
                Ok(score) => scores.push(score),
                Err(e) => {
                    warn!("Failed to compute score for document: {}", e);
                    scores.push(0.0);
                }
            }
        }

        // Normalize scores if configured
        if self.config.normalize {
            Self::normalize_scores(&mut scores);
        }

        // Combine documents with scores and sort
        let mut doc_scores: Vec<(StoredFact, f32)> =
            documents.into_iter().zip(scores.into_iter()).collect();

        doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        // Apply top_k limit
        let top_k = top_k.or(self.config.top_k);
        if let Some(k) = top_k {
            doc_scores.truncate(k);
        }

        // Update relevance scores in stored facts
        let results: Vec<StoredFact> = doc_scores
            .into_iter()
            .map(|(mut fact, score)| {
                fact.relevance_score = Some(score);
                fact
            })
            .collect();

        Ok(results)
    }
}