ruvector-onnx-embeddings 0.1.0

ONNX-based embedding generation for RuVector - Reimagined embedding pipeline in pure Rust
Documentation
//! Configuration for the ONNX embedder

use crate::PretrainedModel;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;

/// Source of the ONNX model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelSource {
    /// Load from HuggingFace Hub (downloads if not cached)
    HuggingFace {
        model_id: String,
        revision: Option<String>,
    },
    /// Load from a local ONNX file
    Local {
        model_path: PathBuf,
        tokenizer_path: PathBuf,
    },
    /// Use a pre-configured model
    Pretrained(PretrainedModel),
    /// Custom URL for model download
    Url {
        model_url: String,
        tokenizer_url: String,
    },
}

impl Default for ModelSource {
    fn default() -> Self {
        Self::Pretrained(PretrainedModel::default())
    }
}

impl From<PretrainedModel> for ModelSource {
    fn from(model: PretrainedModel) -> Self {
        Self::Pretrained(model)
    }
}

/// Pooling strategy for combining token embeddings
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PoolingStrategy {
    /// Mean pooling over all tokens (most common)
    #[default]
    Mean,
    /// Use [CLS] token embedding
    Cls,
    /// Max pooling over all tokens
    Max,
    /// Mean pooling with sqrt(length) scaling
    MeanSqrtLen,
    /// Last token pooling (for decoder models)
    LastToken,
    /// Weighted mean based on attention mask
    WeightedMean,
}

/// Execution provider for ONNX Runtime
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum ExecutionProvider {
    /// CPU inference (default, always available)
    #[default]
    Cpu,
    /// CUDA GPU acceleration
    Cuda { device_id: i32 },
    /// TensorRT optimization
    TensorRt { device_id: i32 },
    /// CoreML on macOS
    CoreMl,
    /// DirectML on Windows
    DirectMl,
    /// ROCm for AMD GPUs
    Rocm { device_id: i32 },
}

/// Configuration for the embedder
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedderConfig {
    /// Model source
    pub model_source: ModelSource,
    /// Pooling strategy
    pub pooling: PoolingStrategy,
    /// Whether to normalize embeddings to unit length
    pub normalize: bool,
    /// Maximum sequence length (truncation)
    pub max_length: usize,
    /// Batch size for inference
    pub batch_size: usize,
    /// Number of threads for CPU inference
    pub num_threads: usize,
    /// Execution provider
    pub execution_provider: ExecutionProvider,
    /// Cache directory for downloaded models
    pub cache_dir: PathBuf,
    /// Whether to show progress during downloads
    pub show_progress: bool,
    /// Use fp16 inference if available
    pub use_fp16: bool,
    /// Enable graph optimization
    pub optimize_graph: bool,
}

impl Default for EmbedderConfig {
    fn default() -> Self {
        Self {
            model_source: ModelSource::default(),
            pooling: PoolingStrategy::default(),
            normalize: true,
            max_length: 256,
            batch_size: 32,
            num_threads: num_cpus::get(),
            execution_provider: ExecutionProvider::default(),
            cache_dir: default_cache_dir(),
            show_progress: true,
            use_fp16: false,
            optimize_graph: true,
        }
    }
}

impl EmbedderConfig {
    /// Create a new config builder
    pub fn builder() -> EmbedderConfigBuilder {
        EmbedderConfigBuilder::default()
    }

    /// Create config for a pretrained model
    pub fn pretrained(model: PretrainedModel) -> Self {
        Self {
            model_source: ModelSource::Pretrained(model),
            max_length: model.max_seq_length(),
            normalize: model.normalize_output(),
            ..Default::default()
        }
    }

    /// Create config for a local model
    pub fn local(model_path: impl Into<PathBuf>, tokenizer_path: impl Into<PathBuf>) -> Self {
        Self {
            model_source: ModelSource::Local {
                model_path: model_path.into(),
                tokenizer_path: tokenizer_path.into(),
            },
            ..Default::default()
        }
    }

    /// Create config for a HuggingFace model
    pub fn huggingface(model_id: impl Into<String>) -> Self {
        Self {
            model_source: ModelSource::HuggingFace {
                model_id: model_id.into(),
                revision: None,
            },
            ..Default::default()
        }
    }
}

/// Builder for EmbedderConfig
#[derive(Debug, Default)]
pub struct EmbedderConfigBuilder {
    config: EmbedderConfig,
}

impl EmbedderConfigBuilder {
    pub fn model_source(mut self, source: ModelSource) -> Self {
        self.config.model_source = source;
        self
    }

    pub fn pretrained(mut self, model: PretrainedModel) -> Self {
        self.config.model_source = ModelSource::Pretrained(model);
        self.config.max_length = model.max_seq_length();
        self.config.normalize = model.normalize_output();
        self
    }

    pub fn pooling(mut self, strategy: PoolingStrategy) -> Self {
        self.config.pooling = strategy;
        self
    }

    pub fn normalize(mut self, normalize: bool) -> Self {
        self.config.normalize = normalize;
        self
    }

    pub fn max_length(mut self, length: usize) -> Self {
        self.config.max_length = length;
        self
    }

    pub fn batch_size(mut self, size: usize) -> Self {
        self.config.batch_size = size;
        self
    }

    pub fn num_threads(mut self, threads: usize) -> Self {
        self.config.num_threads = threads;
        self
    }

    pub fn execution_provider(mut self, provider: ExecutionProvider) -> Self {
        self.config.execution_provider = provider;
        self
    }

    pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
        self.config.cache_dir = dir.into();
        self
    }

    pub fn show_progress(mut self, show: bool) -> Self {
        self.config.show_progress = show;
        self
    }

    pub fn use_fp16(mut self, use_fp16: bool) -> Self {
        self.config.use_fp16 = use_fp16;
        self
    }

    pub fn optimize_graph(mut self, optimize: bool) -> Self {
        self.config.optimize_graph = optimize;
        self
    }

    pub fn build(self) -> EmbedderConfig {
        self.config
    }
}

fn default_cache_dir() -> PathBuf {
    dirs::cache_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join("ruvector")
        .join("onnx-models")
}

fn num_cpus_get() -> usize {
    std::thread::available_parallelism()
        .map(|p| p.get())
        .unwrap_or(4)
}

mod num_cpus {
    pub fn get() -> usize {
        super::num_cpus_get()
    }
}