dakera-inference 0.11.74

Embedded inference engine for Dakera - generates embeddings locally via ONNX Runtime
Documentation
//! Model configurations for supported embedding models.
//!
//! Supported models:
//! - **BGE-large** (BAAI/bge-large-en-v1.5): Highest quality, 1024 dimensions (default)
//! - **MiniLM** (all-MiniLM-L6-v2): Fast, 384 dimensions, good for general use
//! - **BGE-small** (BAAI/bge-small-en-v1.5): Balanced, 384 dimensions, high quality
//! - **E5-small** (intfloat/e5-small-v2): Quality-focused, 384 dimensions

use serde::{Deserialize, Serialize};

/// Supported embedding models.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum EmbeddingModel {
    /// BAAI/bge-large-en-v1.5 - Highest quality, 1024 dimensions (default)
    /// - Dimensions: 1024
    /// - Max tokens: 512
    /// - Speed: Slower than small models, but highest quality
    #[default]
    BgeLarge,

    /// all-MiniLM-L6-v2 - Fast and efficient, good for general use
    /// - Dimensions: 384
    /// - Max tokens: 256
    /// - Speed: Fastest
    MiniLM,

    /// BAAI/bge-small-en-v1.5 - Balanced quality and speed
    /// - Dimensions: 384
    /// - Max tokens: 512
    /// - Speed: Medium
    BgeSmall,

    /// intfloat/e5-small-v2 - Higher quality embeddings
    /// - Dimensions: 384
    /// - Max tokens: 512
    /// - Speed: Medium
    E5Small,
}

impl EmbeddingModel {
    /// Get the HuggingFace model ID.
    pub fn model_id(&self) -> &'static str {
        match self {
            EmbeddingModel::BgeLarge => "BAAI/bge-large-en-v1.5",
            EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
            EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
            EmbeddingModel::E5Small => "intfloat/e5-small-v2",
        }
    }

    /// Get the embedding dimension for this model.
    pub fn dimension(&self) -> usize {
        match self {
            EmbeddingModel::BgeLarge => 1024,
            EmbeddingModel::MiniLM => 384,
            EmbeddingModel::BgeSmall => 384,
            EmbeddingModel::E5Small => 384,
        }
    }

    /// Get the maximum sequence length (in tokens).
    pub fn max_seq_length(&self) -> usize {
        match self {
            EmbeddingModel::BgeLarge => 512,
            EmbeddingModel::MiniLM => 256,
            EmbeddingModel::BgeSmall => 512,
            EmbeddingModel::E5Small => 512,
        }
    }

    /// Get the query prefix for models that require it.
    /// Some models like E5 require a prefix for queries vs documents.
    pub fn query_prefix(&self) -> Option<&'static str> {
        match self {
            EmbeddingModel::BgeLarge => None,
            EmbeddingModel::MiniLM => None,
            EmbeddingModel::BgeSmall => None,
            EmbeddingModel::E5Small => Some("query: "),
        }
    }

    /// Get the document/passage prefix for models that require it.
    pub fn document_prefix(&self) -> Option<&'static str> {
        match self {
            EmbeddingModel::BgeLarge => None,
            EmbeddingModel::MiniLM => None,
            EmbeddingModel::BgeSmall => None,
            EmbeddingModel::E5Small => Some("passage: "),
        }
    }

    /// Whether this model uses mean pooling (vs CLS token).
    pub fn use_mean_pooling(&self) -> bool {
        match self {
            EmbeddingModel::BgeLarge => true,
            EmbeddingModel::MiniLM => true,
            EmbeddingModel::BgeSmall => true,
            EmbeddingModel::E5Small => true,
        }
    }

    /// Whether embeddings should be normalized.
    pub fn normalize_embeddings(&self) -> bool {
        true // All supported models use normalized embeddings
    }

    /// Get approximate tokens per second on CPU (for estimation).
    pub fn tokens_per_second_cpu(&self) -> usize {
        match self {
            EmbeddingModel::BgeLarge => 1000,
            EmbeddingModel::MiniLM => 5000,
            EmbeddingModel::BgeSmall => 3000,
            EmbeddingModel::E5Small => 3000,
        }
    }

    /// Get the HuggingFace repository ID hosting the ONNX INT8 model for this embedding model.
    ///
    /// These are Xenova-hosted Optimum ONNX exports — quantized INT8, pre-built, no conversion
    /// needed. BgeLarge: ~130 MB, MiniLM: 23 MB, BGE-small: 35 MB, E5-small: 35 MB.
    pub fn onnx_repo_id(&self) -> &'static str {
        match self {
            EmbeddingModel::BgeLarge => "Xenova/bge-large-en-v1.5",
            EmbeddingModel::MiniLM => "Xenova/all-MiniLM-L6-v2",
            EmbeddingModel::BgeSmall => "Xenova/bge-small-en-v1.5",
            EmbeddingModel::E5Small => "Xenova/e5-small-v2",
        }
    }

    /// Get the ONNX model filename for CPU inference (INT8 quantized).
    pub fn onnx_filename(&self) -> &'static str {
        "onnx/model_quantized.onnx"
    }

    /// Get the ONNX model filename for GPU (CUDA EP) inference.
    ///
    /// Returns the FP32 model (`onnx/model.onnx`) instead of INT8. The INT8 quantized
    /// model has 336 Memcpy CPU↔GPU round-trips caused by ORT falling back to CPU EP
    /// for every unsupported INT8 op — making CUDA 24× slower than pure CPU inference.
    /// The FP32 model contains no unsupported ops and runs entirely on-device.
    pub fn onnx_filename_gpu(&self) -> &'static str {
        "onnx/model.onnx"
    }

    /// List all available models.
    pub fn all() -> &'static [EmbeddingModel] {
        &[
            EmbeddingModel::BgeLarge,
            EmbeddingModel::MiniLM,
            EmbeddingModel::BgeSmall,
            EmbeddingModel::E5Small,
        ]
    }

    /// Parse model from string (case-insensitive).
    pub fn parse(s: &str) -> Option<Self> {
        match s.to_lowercase().as_str() {
            "bge-large" | "bge-large-en" | "bge-large-en-v1.5" => Some(EmbeddingModel::BgeLarge),
            "minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
            "bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
            "e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
            _ => None,
        }
    }
}

impl std::fmt::Display for EmbeddingModel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            EmbeddingModel::BgeLarge => write!(f, "bge-large-en-v1.5"),
            EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
            EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
            EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
        }
    }
}

/// Configuration for model loading and inference.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
    /// The embedding model to use.
    pub model: EmbeddingModel,

    /// Custom cache directory for model files.
    /// If None, uses HuggingFace default cache.
    pub cache_dir: Option<String>,

    /// Maximum batch size for inference.
    pub max_batch_size: usize,

    /// Whether to use GPU acceleration if available.
    pub use_gpu: bool,

    /// Number of threads for CPU inference.
    pub num_threads: Option<usize>,

    /// Number of parallel ONNX sessions in the session pool.
    ///
    /// Each session holds its own ORT context. Pool members serve batches
    /// concurrently via `spawn_blocking`, eliminating Mutex head-of-line
    /// blocking when multiple callers embed text simultaneously.
    /// Defaults to 4; override with `DAKERA_ONNX_POOL_SIZE` env var at startup.
    pub session_pool_size: usize,
}

impl Default for ModelConfig {
    fn default() -> Self {
        // DAK-5746: pool=4 restored. PR#488 regressed LME ingest: pool=1 serializes all
        // ONNX calls onto session[0]. With 4 concurrent HTTP requests × 7 sub-batches each,
        // pool=1 produces ~28 serial ONNX calls vs pool=4's 7 parallel chains — ~4× throughput
        // regression measured at 2761ms/50-text batch on prod. OOM root causes (unbounded HNSW,
        // RocksDB cache) fixed by PR#488 other changes; pool=4 × BGE-Large INT8 ≈ 1.6GB fits
        // comfortably on the 8GB server. pool_size 4→1 downgrade was the wrong OOM fix.
        let pool_size = std::env::var("DAKERA_ONNX_POOL_SIZE")
            .ok()
            .and_then(|v| v.parse::<usize>().ok())
            .filter(|&n| n >= 1)
            .unwrap_or(4);
        // DAK-5716: no length-sorting (PR#476 proved sorted batching regresses INT8
        // quantization quality). DAK-5953: default raised 8→32 — amortises per-call ONNX
        // overhead 4× with no quality impact (size-only change, not order). Bench sets
        // DAKERA_ONNX_BATCH_SIZE=128; 32 is a safe default for CPU-only deployments.
        let max_batch_size = std::env::var("DAKERA_ONNX_BATCH_SIZE")
            .ok()
            .and_then(|v| v.parse::<usize>().ok())
            .filter(|&n| n >= 1)
            .unwrap_or(32);
        Self {
            model: EmbeddingModel::default(),
            cache_dir: None,
            max_batch_size,
            use_gpu: false,
            num_threads: None,
            session_pool_size: pool_size,
        }
    }
}

impl ModelConfig {
    /// Create a new config with the specified model.
    pub fn new(model: EmbeddingModel) -> Self {
        Self {
            model,
            ..Default::default()
        }
    }

    /// Set the cache directory.
    pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
        self.cache_dir = Some(dir.into());
        self
    }

    /// Set the maximum batch size.
    pub fn with_max_batch_size(mut self, size: usize) -> Self {
        self.max_batch_size = size;
        self
    }

    /// Enable GPU acceleration.
    pub fn with_gpu(mut self, use_gpu: bool) -> Self {
        self.use_gpu = use_gpu;
        self
    }

    /// Set the number of CPU threads.
    pub fn with_num_threads(mut self, threads: usize) -> Self {
        self.num_threads = Some(threads);
        self
    }

    /// Set the number of parallel ONNX sessions in the pool.
    pub fn with_session_pool_size(mut self, size: usize) -> Self {
        self.session_pool_size = size.max(1);
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_ids() {
        assert_eq!(
            EmbeddingModel::BgeLarge.model_id(),
            "BAAI/bge-large-en-v1.5"
        );
        assert_eq!(
            EmbeddingModel::MiniLM.model_id(),
            "sentence-transformers/all-MiniLM-L6-v2"
        );
        assert_eq!(
            EmbeddingModel::BgeSmall.model_id(),
            "BAAI/bge-small-en-v1.5"
        );
        assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
    }

    #[test]
    fn test_dimensions() {
        assert_eq!(EmbeddingModel::BgeLarge.dimension(), 1024);
        assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
        assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
        assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
        // Verify each model reports its own dimension
        for model in EmbeddingModel::all() {
            assert!(model.dimension() > 0);
        }
    }

    #[test]
    fn test_from_str() {
        assert_eq!(
            EmbeddingModel::parse("bge-large"),
            Some(EmbeddingModel::BgeLarge)
        );
        assert_eq!(
            EmbeddingModel::parse("minilm"),
            Some(EmbeddingModel::MiniLM)
        );
        assert_eq!(
            EmbeddingModel::parse("BGE-SMALL"),
            Some(EmbeddingModel::BgeSmall)
        );
        assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
        assert_eq!(EmbeddingModel::parse("unknown"), None);
    }

    #[test]
    fn test_e5_prefixes() {
        assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
        assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
        assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
    }

    #[test]
    fn test_onnx_filenames() {
        // INT8 model for CPU — all models use the same quantized file
        for model in EmbeddingModel::all() {
            assert_eq!(model.onnx_filename(), "onnx/model_quantized.onnx");
        }
        // FP32 model for GPU — no Memcpy fallback ops
        for model in EmbeddingModel::all() {
            assert_eq!(model.onnx_filename_gpu(), "onnx/model.onnx");
        }
        // Sanity: GPU and CPU filenames are distinct
        assert_ne!(
            EmbeddingModel::BgeLarge.onnx_filename(),
            EmbeddingModel::BgeLarge.onnx_filename_gpu()
        );
    }
}