dakera-inference 0.11.54

Embedded inference engine for Dakera - generates embeddings locally via ONNX Runtime
Documentation
//! Cross-encoder reranker for improving recall precision.
//!
//! Uses BAAI/bge-reranker-base (Xenova ONNX INT8 quantized) to score
//! (query, passage) pairs for relevance. More accurate than bi-encoder
//! vector similarity but slower — used as a second-stage reranker after
//! ANN candidate retrieval.
//!
//! # Architecture
//!
//! ```text
//! query + passage → [CLS] query [SEP] passage [SEP]
//!                       ↓ BERT forward pass
//!                   logits [batch, 1]
//!                       ↓ sigmoid
//!                   relevance scores ∈ [0, 1]
//! ```

use crate::engine::EmbeddingEngine;
use crate::error::{InferenceError, Result};
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use std::path::PathBuf;
use std::sync::Arc;
use tokenizers::{
    EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
};
use tracing::{info, instrument, warn};

/// The reranker model Xenova HuggingFace repo ID (ONNX INT8).
const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
/// ONNX quantized model filename within the repo.
const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
/// Maximum token length for cross-encoder input (query + passage combined).
const MAX_SEQ_LENGTH: usize = 512;

/// Cross-encoder reranking engine.
///
/// Thread-safe — can be wrapped in `Arc` and shared across tasks.
pub struct CrossEncoderEngine {
    session: Arc<Mutex<Session>>,
    tokenizer: Arc<Tokenizer>,
    /// Whether the loaded ONNX model expects a `token_type_ids` input tensor.
    /// bge-reranker-base only has `input_ids` + `attention_mask`; some other
    /// cross-encoders include `token_type_ids`. Determined at load time.
    has_token_type_ids: bool,
}

impl CrossEncoderEngine {
    /// Load or download the reranker model.
    ///
    /// Downloads `Xenova/bge-reranker-base` ONNX INT8 model from HuggingFace Hub
    /// if not already cached.
    #[instrument(skip_all)]
    pub async fn new(cache_dir: Option<String>) -> Result<Self> {
        info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);

        let (tokenizer_path, onnx_path) =
            tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
                .await
                .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;

        info!("Loading reranker tokenizer from {:?}", tokenizer_path);
        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;

        // Configure padding + truncation for uniform batch shapes
        let padding = PaddingParams {
            strategy: PaddingStrategy::BatchLongest,
            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
            pad_token: tokenizer
                .get_padding()
                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
            ..Default::default()
        };
        tokenizer.with_padding(Some(padding));
        let truncation = TruncationParams {
            max_length: MAX_SEQ_LENGTH,
            ..Default::default()
        };
        let _ = tokenizer.with_truncation(Some(truncation));

        info!("Loading reranker ONNX model from {:?}", onnx_path);
        let session = Session::builder()
            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
            .with_optimization_level(GraphOptimizationLevel::Level3)
            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
            .with_intra_threads(4)
            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
            .commit_from_file(&onnx_path)
            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;

        // Inspect model inputs to determine if token_type_ids is required.
        // bge-reranker-base (Xenova ONNX) only has input_ids + attention_mask.
        let has_token_type_ids = session
            .inputs()
            .iter()
            .any(|i| i.name() == "token_type_ids");
        info!(
            has_token_type_ids,
            "Cross-encoder reranker loaded successfully"
        );
        Ok(Self {
            session: Arc::new(Mutex::new(session)),
            tokenizer: Arc::new(tokenizer),
            has_token_type_ids,
        })
    }

    /// Score a batch of (query, passage) pairs.
    ///
    /// Returns a relevance score in `[0, 1]` for each passage.
    /// Higher scores indicate greater relevance to the query.
    ///
    /// Each pair is tokenized as `[CLS] query [SEP] passage [SEP]`.
    #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
    pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
        if passages.is_empty() {
            return Ok(Vec::new());
        }

        let query = query.to_string();
        let passages = passages.to_vec();
        let tokenizer = Arc::clone(&self.tokenizer);
        let session = Arc::clone(&self.session);
        let has_token_type_ids = self.has_token_type_ids;

        tokio::task::spawn_blocking(move || {
            score_pairs_blocking(&session, &tokenizer, &query, &passages, has_token_type_ids)
        })
        .await
        .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))?
    }
}

/// Blocking cross-encoder inference — runs inside `spawn_blocking`.
fn score_pairs_blocking(
    session: &Arc<Mutex<Session>>,
    tokenizer: &Tokenizer,
    query: &str,
    passages: &[String],
    has_token_type_ids: bool,
) -> Result<Vec<f32>> {
    let batch_size = passages.len();

    // Build EncodeInput pairs: [CLS] query [SEP] passage [SEP]
    let inputs: Vec<EncodeInput> = passages
        .iter()
        .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
        .collect();

    let encodings = tokenizer
        .encode_batch(inputs, true)
        .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;

    let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
    if seq_len == 0 {
        return Ok(vec![0.5; batch_size]);
    }

    // Flatten to i64 arrays (ORT BERT models expect int64)
    let mut input_ids = Vec::with_capacity(batch_size * seq_len);
    let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
    let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);

    for enc in &encodings {
        input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
        attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
        let type_ids = enc.get_type_ids();
        if type_ids.is_empty() {
            token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
        } else {
            token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
        }
    }

    // Build ORT tensors
    let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
    let attention_mask_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
    let token_type_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
        .map_err(|e| InferenceError::InferenceError(e.to_string()))?;

    // Run inference and extract scores in one scoped block so `sess` and `outputs`
    // are dropped before we return (avoids session borrow escaping the mutex guard).
    let scores: Vec<f32> = {
        let mut sess = session.lock();
        let outputs = if has_token_type_ids {
            sess.run(inputs![
                "input_ids" => input_ids_tensor,
                "attention_mask" => attention_mask_tensor,
                "token_type_ids" => token_type_ids_tensor
            ])
            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
        } else {
            sess.run(inputs![
                "input_ids" => input_ids_tensor,
                "attention_mask" => attention_mask_tensor
            ])
            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
        };

        // Extract logits — bge-reranker-base output shape is [batch_size, 1]
        let (out_shape, logits_slice) = outputs[0]
            .try_extract_tensor::<f32>()
            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;

        if out_shape.is_empty() || out_shape[0] as usize != batch_size {
            warn!(
                "Reranker output shape mismatch: expected [{}, 1], got {:?}",
                batch_size, out_shape
            );
        }

        // Apply sigmoid → owned Vec<f32> so the borrow on outputs/sess ends here
        logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
        // outputs and sess drop here in the correct order
    };

    if scores.len() != batch_size {
        warn!(
            "Reranker score count mismatch: expected {}, got {}",
            batch_size,
            scores.len()
        );
        let mut padded = scores;
        padded.resize(batch_size, 0.5);
        return Ok(padded);
    }

    Ok(scores)
}

/// Sigmoid activation: 1 / (1 + exp(-x))
#[inline]
fn sigmoid(x: f32) -> f32 {
    1.0 / (1.0 + (-x).exp())
}

/// Download tokenizer and ONNX model files for the reranker.
/// Reuses `EmbeddingEngine::download_hf_file_pub` for redirect-aware caching.
fn download_reranker_files(
    cache_dir: Option<String>,
) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
    let cache = match cache_dir {
        Some(dir) => {
            let p = PathBuf::from(dir);
            std::fs::create_dir_all(&p)
                .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
            p
        }
        None => {
            let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
            PathBuf::from(home)
                .join(".cache")
                .join("huggingface")
                .join("dakera")
                .join(RERANKER_REPO_ID.replace('/', "--"))
        }
    };

    std::fs::create_dir_all(&cache)
        .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;

    let files = [
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        RERANKER_ONNX_FILE,
    ];

    for filename in &files {
        EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
            .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
    }

    let tokenizer_path = cache.join("tokenizer.json");
    let onnx_path = cache.join(RERANKER_ONNX_FILE);
    Ok((tokenizer_path, onnx_path))
}

impl std::fmt::Debug for CrossEncoderEngine {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CrossEncoderEngine")
            .field("model", &RERANKER_REPO_ID)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_sigmoid() {
        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
        assert!(sigmoid(10.0) > 0.99);
        assert!(sigmoid(-10.0) < 0.01);
    }
}