dakera-inference 0.11.80

Embedded inference engine for Dakera - generates embeddings locally via ONNX Runtime
Documentation
//! Model2Vec static embedding backend.
//!
//! Embeds text via a pre-distilled vocabulary matrix — no neural network
//! forward pass at all.  Each token maps directly to a pre-computed vector;
//! the final embedding is the mean pool of the token vectors followed by L2
//! normalisation.
//!
//! **Performance**: <0.1 ms per text on CPU, >50,000 embeddings/second.
//! **Memory**: ~30 MB vocab matrix vs. ~520 MB ONNX session pool.
//!
//! # Trade-off
//!
//! Quality is ~8–15% lower on MTEB relative to the full BGE-Large transformer.
//! This backend is intended for the *write path* (fast ingest) only.  The
//! [`TieredEngine`](crate::tiered::TieredEngine) uses the full transformer for
//! the *recall* query path, preserving recall quality.
//!
//! # Model artifact
//!
//! The distilled vocabulary matrix is downloaded at runtime from
//! `dakera-ai/bge-large-model2vec-256d` on HuggingFace Hub.  File:
//! `vocab_matrix.bin` (flat `f32` array, `vocab_size × dimension`).
//! File integrity is validated (length must be divisible by 4 bytes).

use crate::backend::{BackendKind, EmbeddingBackend};
use crate::error::{InferenceError, Result};
use crate::models::ModelConfig;
use async_trait::async_trait;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{debug, info, instrument};

/// Model2Vec static embedding backend.
///
/// Holds the vocabulary matrix in memory (optionally memory-mapped for cold-start
/// speed).  The tokenizer used is the *same* tokenizer.json as the source model
/// (BGE-Large / ModernBERT), so token IDs are consistent with the ONNX/Candle paths.
pub struct StaticBackend {
    /// Flat row-major f32 array: `[vocab_size × dimension]`
    vocab_matrix: Arc<Vec<f32>>,
    tokenizer: Arc<Tokenizer>,
    dimension: usize,
    vocab_size: usize,
}

impl StaticBackend {
    /// Build a new `StaticBackend`.
    ///
    /// Downloads `vocab_matrix.bin` and the model tokenizer from HuggingFace
    /// on first run; subsequent calls use the local cache.
    #[instrument(skip_all)]
    pub async fn new(config: &ModelConfig) -> Result<Self> {
        let config = config.clone();
        info!("Initialising StaticBackend (Model2Vec)");

        let dim = Self::model2vec_dimension();

        // Download tokenizer.json from the source model repo
        let model_id = config.model.model_id();
        let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model_id)?;

        // Download tokenizer if not cached
        if !cache_dir.join("tokenizer.json").exists() {
            let model_id_owned = model_id.to_string();
            let cache_dir_clone = cache_dir.clone();
            tokio::task::spawn_blocking(move || {
                crate::backend::onnx::OnnxBackend::download_hf_file(
                    &model_id_owned,
                    "tokenizer.json",
                    &cache_dir_clone,
                )
                .map_err(InferenceError::HubError)
            })
            .await
            .map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
        }

        let tokenizer_path = cache_dir.join("tokenizer.json");
        let tokenizer = Tokenizer::from_file(&tokenizer_path)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;

        // Download vocabulary matrix
        let vocab_matrix = Self::load_vocab_matrix(&config, dim).await?;
        let vocab_size = vocab_matrix.len() / dim;

        info!(
            "StaticBackend ready: vocab_size={}, dimension={}",
            vocab_size, dim
        );

        Ok(Self {
            vocab_matrix: Arc::new(vocab_matrix),
            tokenizer: Arc::new(tokenizer),
            dimension: dim,
            vocab_size,
        })
    }

    /// Build from a pre-loaded vocabulary matrix (useful for tests).
    ///
    /// `matrix` must be a flat row-major array of shape `[vocab_size × dimension]`.
    pub fn from_matrix(matrix: Vec<f32>, tokenizer: Tokenizer, dimension: usize) -> Result<Self> {
        if !matrix.len().is_multiple_of(dimension) {
            return Err(InferenceError::InvalidInput(format!(
                "vocab_matrix length {} is not divisible by dimension {}",
                matrix.len(),
                dimension
            )));
        }
        let vocab_size = matrix.len() / dimension;
        Ok(Self {
            vocab_matrix: Arc::new(matrix),
            tokenizer: Arc::new(tokenizer),
            dimension,
            vocab_size,
        })
    }

    /// Configured Model2Vec output dimension (`DAKERA_MRL_DIM`, default 256).
    pub fn model2vec_dimension() -> usize {
        std::env::var("DAKERA_MRL_DIM")
            .ok()
            .and_then(|v| v.parse::<usize>().ok())
            .filter(|&d| d > 0)
            .unwrap_or(256)
    }

    /// Embed a single text via token-lookup + mean pooling.
    #[instrument(skip(self, text), fields(text_len = text.len()))]
    fn embed_single(&self, text: &str) -> Vec<f32> {
        // Tokenize — encode returns token IDs
        let encoding = match self.tokenizer.encode(text, false) {
            Ok(enc) => enc,
            Err(_) => return vec![0.0; self.dimension],
        };

        let ids = encoding.get_ids();
        if ids.is_empty() {
            return vec![0.0; self.dimension];
        }

        // Mean pool token vectors
        let mut result = vec![0.0f32; self.dimension];
        let mut valid_tokens = 0usize;

        for &id in ids {
            let idx = id as usize;
            if idx >= self.vocab_size {
                // OOV: skip (treat as zero vector contribution)
                continue;
            }
            let offset = idx * self.dimension;
            let row = &self.vocab_matrix[offset..offset + self.dimension];
            for (r, v) in result.iter_mut().zip(row.iter()) {
                *r += v;
            }
            valid_tokens += 1;
        }

        if valid_tokens == 0 {
            return vec![0.0; self.dimension];
        }

        let n = valid_tokens as f32;
        for v in result.iter_mut() {
            *v /= n;
        }

        // L2 normalise
        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
        for v in result.iter_mut() {
            *v /= norm;
        }

        result
    }

    /// Load (or download) the Model2Vec vocabulary matrix.
    async fn load_vocab_matrix(config: &ModelConfig, _dim: usize) -> Result<Vec<f32>> {
        // Determine cache path
        let model2vec_repo = config.model.model2vec_repo_id();
        let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model2vec_repo)?;
        let matrix_path = cache_dir.join("vocab_matrix.bin");

        if !matrix_path.exists() {
            info!("Downloading Model2Vec vocab matrix from {}", model2vec_repo);
            let repo = model2vec_repo.to_string();
            let cache = cache_dir.clone();
            tokio::task::spawn_blocking(move || {
                crate::backend::onnx::OnnxBackend::download_hf_file(
                    &repo,
                    "vocab_matrix.bin",
                    &cache,
                )
                .map_err(InferenceError::HubError)
            })
            .await
            .map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
        }

        // Read as raw f32 bytes
        info!("Loading vocab matrix from {:?}", matrix_path);
        let bytes = std::fs::read(&matrix_path)?;
        if bytes.len() % 4 != 0 {
            return Err(InferenceError::ModelLoadError(format!(
                "vocab_matrix.bin size {} is not a multiple of 4 bytes",
                bytes.len()
            )));
        }

        let floats: Vec<f32> = bytes
            .chunks_exact(4)
            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
            .collect();

        debug!("Vocab matrix loaded: {} f32 values", floats.len());
        Ok(floats)
    }
}

#[async_trait]
impl EmbeddingBackend for StaticBackend {
    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
        if texts.is_empty() {
            return Ok(vec![]);
        }
        // Pure CPU — no async needed; avoid spawn_blocking overhead for small batches
        let results: Vec<Vec<f32>> = texts.iter().map(|t| self.embed_single(t)).collect();
        Ok(results)
    }

    fn dimension(&self) -> usize {
        self.dimension
    }

    fn backend_kind(&self) -> BackendKind {
        BackendKind::Static
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokenizers::models::wordlevel::WordLevel;
    use tokenizers::pre_tokenizers::whitespace::Whitespace;

    fn make_test_tokenizer(words: &[&str]) -> Tokenizer {
        let mut vocab = std::collections::HashMap::new();
        for (i, w) in words.iter().enumerate() {
            vocab.insert(w.to_string(), i as u32);
        }
        let model = WordLevel::builder()
            .vocab(vocab)
            .unk_token("[UNK]".to_string())
            .build()
            .unwrap();
        let mut tok = Tokenizer::new(model);
        tok.with_pre_tokenizer(Some(Whitespace {}));
        tok
    }

    fn make_identity_matrix(vocab_size: usize, dim: usize) -> Vec<f32> {
        // Each token's vector is a one-hot-like: token[i][i % dim] = 1.0
        let mut m = vec![0.0f32; vocab_size * dim];
        for i in 0..vocab_size {
            m[i * dim + (i % dim)] = 1.0;
        }
        m
    }

    #[test]
    fn test_static_backend_from_matrix_dimension() {
        let words = ["[UNK]", "hello", "world", "test", "foo"];
        let tok = make_test_tokenizer(&words);
        let matrix = make_identity_matrix(5, 4);
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        assert_eq!(backend.dimension(), 4);
    }

    #[test]
    fn test_static_backend_from_matrix_vocab_size() {
        let words = ["[UNK]", "a", "b", "c"];
        let tok = make_test_tokenizer(&words);
        let matrix = make_identity_matrix(4, 8);
        let backend = StaticBackend::from_matrix(matrix, tok, 8).unwrap();
        assert_eq!(backend.vocab_size, 4);
    }

    #[test]
    fn test_static_backend_kind() {
        let words = ["[UNK]", "hello"];
        let tok = make_test_tokenizer(&words);
        let matrix = vec![0.0f32; 2 * 4];
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        assert_eq!(backend.backend_kind(), BackendKind::Static);
    }

    #[test]
    fn test_static_embed_empty_text_returns_zeros() {
        let words = ["[UNK]", "hello"];
        let tok = make_test_tokenizer(&words);
        let matrix = vec![1.0f32; 2 * 4]; // all ones
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        let result = backend.embed_single("");
        // Empty text → no tokens → zero vector
        assert_eq!(result.len(), 4);
        assert!(result.iter().all(|&v| v.abs() < 1e-6));
    }

    #[test]
    fn test_static_embed_single_token_normalized() {
        let words = ["[UNK]", "hello", "world"];
        let tok = make_test_tokenizer(&words);
        // hello → token id 1; row 1 = [1, 0, 0, 0]
        let mut matrix = vec![0.0f32; 3 * 4];
        matrix[4] = 1.0; // token 1, dim 0
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        let emb = backend.embed_single("hello");
        assert_eq!(emb.len(), 4);
        // Normalized [1,0,0,0] is still [1,0,0,0]
        assert!((emb[0] - 1.0).abs() < 1e-5);
        assert!(emb[1].abs() < 1e-5);
    }

    #[test]
    fn test_static_embed_invalid_matrix_dimension_error() {
        let words = ["[UNK]", "hello"];
        let tok = make_test_tokenizer(&words);
        // 5 floats not divisible by dim=4
        let matrix = vec![1.0f32; 5];
        let result = StaticBackend::from_matrix(matrix, tok, 4);
        assert!(result.is_err());
    }

    #[test]
    fn test_model2vec_dimension_default() {
        // Should return 256 when env var is unset
        std::env::remove_var("DAKERA_MRL_DIM");
        assert_eq!(StaticBackend::model2vec_dimension(), 256);
    }

    #[tokio::test]
    async fn test_static_embed_batch_empty() {
        let words = ["[UNK]", "hello"];
        let tok = make_test_tokenizer(&words);
        let matrix = vec![0.0f32; 2 * 4];
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        let result = backend.embed_batch(&[]).await.unwrap();
        assert!(result.is_empty());
    }

    #[tokio::test]
    async fn test_static_embed_batch_multiple() {
        let words = ["[UNK]", "hello", "world"];
        let tok = make_test_tokenizer(&words);
        let matrix = make_identity_matrix(3, 4);
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        let texts = vec!["hello".to_string(), "world".to_string()];
        let results = backend.embed_batch(&texts).await.unwrap();
        assert_eq!(results.len(), 2);
        assert_eq!(results[0].len(), 4);
        assert_eq!(results[1].len(), 4);
    }

    #[tokio::test]
    async fn test_static_embed_batch_preserves_order() {
        let words = ["[UNK]", "hello", "world"];
        let tok = make_test_tokenizer(&words);
        // hello → [1, 0, 0, 0], world → [0, 1, 0, 0]
        let mut matrix = vec![0.0f32; 3 * 4];
        matrix[4] = 1.0; // token 1 (hello), dim 0
        matrix[9] = 1.0; // token 2 (world), dim 1
        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
        let texts = vec!["hello".to_string(), "world".to_string()];
        let results = backend.embed_batch(&texts).await.unwrap();
        // hello embedding: dim 0 dominant
        assert!(results[0][0] > results[0][1]);
        // world embedding: dim 1 dominant
        assert!(results[1][1] > results[1][0]);
    }
}