use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::engine::EmbeddingEngine;
use crate::error::EmbeddingResult;
use crate::mock::{MockEmbeddingEngine, MockVectorMode};
use crate::ollama::OllamaEmbeddingEngine;
use crate::openai_compatible::OpenAICompatibleEmbeddingEngine;
use crate::provider::EmbeddingProvider;
#[cfg(feature = "onnx")]
use crate::onnx::OnnxEmbeddingEngine;
#[cfg(feature = "onnx")]
use std::path::PathBuf;
const FALLBACK_DIMENSIONS: usize = 384;
pub fn known_model_dimensions(provider: EmbeddingProvider, model: &str) -> Option<usize> {
let bare = model.rsplit('/').next().unwrap_or(model);
let key = bare.to_ascii_lowercase();
let dim = match key.as_str() {
"text-embedding-3-large" => 3072,
"text-embedding-3-small" => 1536,
"text-embedding-ada-002" => 1536,
"bge-small-v1.5" | "bge-small-en-v1.5" => 384,
"bge-base-en-v1.5" => 768,
"bge-large-en-v1.5" => 1024,
"all-minilm-l6-v2" => 384,
"nomic-embed-text" => 768,
"mxbai-embed-large" => 1024,
_ => return None,
};
let _ = provider; Some(dim)
}
#[cfg(feature = "onnx")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnnxEmbeddingConfig {
pub model_path: PathBuf,
pub tokenizer_path: PathBuf,
pub model_name: String,
pub dimensions: usize,
pub max_sequence_length: usize,
pub batch_size: usize,
}
#[cfg(feature = "onnx")]
impl Default for OnnxEmbeddingConfig {
fn default() -> Self {
Self::bge_small("./target/models")
}
}
#[cfg(feature = "onnx")]
impl OnnxEmbeddingConfig {
pub fn bge_small(model_dir: impl Into<PathBuf>) -> Self {
let base = model_dir.into();
let model_path = base.join("BGE-Small-v1.5-model_quantized.onnx");
let tokenizer_path = base.join("bge-small-tokenizer.json");
Self {
model_path,
tokenizer_path,
model_name: "bge-small-en-v1.5".to_string(),
dimensions: 384,
max_sequence_length: 512,
batch_size: 32,
}
}
pub fn minilm_l6(model_dir: impl Into<PathBuf>) -> Self {
let base = model_dir.into();
let model_path = base.join("all-MiniLM-L6-v2.onnx");
let tokenizer_path = base.join("minilm-l6-tokenizer.json");
Self {
model_path,
tokenizer_path,
model_name: "all-MiniLM-L6-v2".to_string(),
dimensions: 384,
max_sequence_length: 256,
batch_size: 32,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub provider: EmbeddingProvider,
pub model: String,
pub dimensions: usize,
pub endpoint: Option<String>,
pub api_key: Option<String>,
pub api_version: Option<String>,
pub max_completion_tokens: usize,
pub batch_size: usize,
pub mock: bool,
#[serde(default)]
pub mock_mode: MockVectorMode,
#[cfg(feature = "onnx")]
pub onnx: OnnxEmbeddingConfig,
pub huggingface_tokenizer: Option<String>,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
#[cfg(all(feature = "onnx", target_os = "android"))]
let (provider, model, dimensions, endpoint) = {
let onnx_cfg = OnnxEmbeddingConfig::default();
(
EmbeddingProvider::Onnx,
onnx_cfg.model_name.clone(),
onnx_cfg.dimensions,
None,
)
};
#[cfg(all(feature = "onnx", not(target_os = "android")))]
let (provider, model, dimensions, endpoint) = {
let m = "text-embedding-3-small".to_string();
let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
.unwrap_or(FALLBACK_DIMENSIONS);
(
EmbeddingProvider::OpenAi,
m,
d,
Some("https://api.openai.com/v1".to_string()),
)
};
#[cfg(not(feature = "onnx"))]
let (provider, model, dimensions, endpoint) = {
let m = "text-embedding-3-small".to_string();
let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
.unwrap_or(FALLBACK_DIMENSIONS);
(
EmbeddingProvider::OpenAi,
m,
d,
Some("https://api.openai.com/v1".to_string()),
)
};
Self {
provider,
model,
dimensions,
endpoint,
api_key: None,
api_version: None,
max_completion_tokens: 8191,
batch_size: 36,
mock: false,
mock_mode: MockVectorMode::Zero,
#[cfg(feature = "onnx")]
onnx: OnnxEmbeddingConfig::default(),
huggingface_tokenizer: None,
}
}
}
impl EmbeddingConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("MOCK_EMBEDDING") {
let val = val.trim().to_lowercase();
if val == "deterministic" || val == "hash" {
config.mock = true;
config.provider = EmbeddingProvider::Mock;
config.mock_mode = MockVectorMode::Deterministic;
return config;
}
if val == "true" || val == "1" || val == "yes" {
config.mock = true;
config.provider = EmbeddingProvider::Mock;
config.mock_mode = MockVectorMode::Zero;
return config;
}
}
if let Ok(val) = std::env::var("EMBEDDING_PROVIDER") {
let val = val.trim().to_lowercase();
match val.as_str() {
"onnx" => config.provider = EmbeddingProvider::Onnx,
"fastembed" => config.provider = EmbeddingProvider::Fastembed,
"openai" => config.provider = EmbeddingProvider::OpenAi,
"openai_compatible" => config.provider = EmbeddingProvider::OpenAiCompatible,
"ollama" => config.provider = EmbeddingProvider::Ollama,
"mock" => {
config.mock = true;
config.provider = EmbeddingProvider::Mock;
}
_ => {
}
}
}
if config.provider == EmbeddingProvider::Ollama {
config.model = "avr/sfr-embedding-mistral:latest".to_string();
}
if let Ok(val) = std::env::var("EMBEDDING_MODEL") {
let val = val.trim().to_string();
if !val.is_empty() {
config.model = val;
}
}
let explicit_dims = std::env::var("EMBEDDING_DIMENSIONS")
.ok()
.and_then(|v| v.trim().parse::<usize>().ok());
let resolve_from_table = |config: &EmbeddingConfig| match known_model_dimensions(
config.provider.clone(),
&config.model,
) {
Some(d) => d,
None => {
tracing::warn!(
provider = ?config.provider,
model = %config.model,
fallback = FALLBACK_DIMENSIONS,
"Could not auto-derive embedding dimensions; set \
EMBEDDING_DIMENSIONS explicitly if your embedder produces \
a different vector size, otherwise the first vector write \
will fail with a shape mismatch."
);
FALLBACK_DIMENSIONS
}
};
config.dimensions = match explicit_dims {
Some(d) => d,
None => {
#[cfg(feature = "onnx")]
{
if matches!(
config.provider,
EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed
) {
config.onnx.dimensions
} else {
resolve_from_table(&config)
}
}
#[cfg(not(feature = "onnx"))]
{
resolve_from_table(&config)
}
}
};
if let Ok(val) = std::env::var("EMBEDDING_ENDPOINT") {
let val = val.trim().to_string();
if !val.is_empty() {
config.endpoint = Some(val);
}
}
if let Ok(val) = std::env::var("EMBEDDING_API_KEY") {
let val = val.trim().to_string();
if !val.is_empty() {
config.api_key = Some(val);
}
} else if let Ok(val) = std::env::var("LLM_API_KEY") {
let val = val.trim().to_string();
if !val.is_empty() {
config.api_key = Some(val);
}
}
if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
let val = val.trim().to_string();
if !val.is_empty() {
config.api_version = Some(val);
}
}
if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
&& let Ok(n) = val.trim().parse::<usize>()
{
config.max_completion_tokens = n;
}
if let Ok(val) = std::env::var("EMBEDDING_BATCH_SIZE")
&& let Ok(n) = val.trim().parse::<usize>()
{
config.batch_size = n;
}
#[cfg(feature = "onnx")]
if let Ok(val) = std::env::var("EMBEDDING_ONNX_BATCH_SIZE")
&& let Ok(n) = val.trim().parse::<usize>()
&& n > 0
{
config.onnx.batch_size = n;
}
if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
let val = val.trim().to_string();
if !val.is_empty() {
config.huggingface_tokenizer = Some(val);
}
}
config
}
pub fn effective_provider(&self) -> EmbeddingProvider {
if self.mock {
EmbeddingProvider::Mock
} else {
self.provider.clone()
}
}
pub async fn create_engine(&self) -> EmbeddingResult<Arc<dyn EmbeddingEngine>> {
match self.effective_provider() {
#[cfg(feature = "onnx")]
EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
let engine = OnnxEmbeddingEngine::with_auto_download(self.onnx.clone()).await?;
Ok(Arc::new(engine))
}
#[cfg(not(feature = "onnx"))]
EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
Err(crate::error::EmbeddingError::NotImplemented(
"ONNX embedding engine requires the `onnx` crate feature".to_string(),
))
}
EmbeddingProvider::OpenAi | EmbeddingProvider::OpenAiCompatible => {
let engine = OpenAICompatibleEmbeddingEngine::new(self)?;
Ok(Arc::new(engine))
}
EmbeddingProvider::Ollama => {
let engine = OllamaEmbeddingEngine::new(self)?;
Ok(Arc::new(engine))
}
EmbeddingProvider::Mock => Ok(Arc::new(
MockEmbeddingEngine::new(self.dimensions).with_mode(self.mock_mode),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
#[cfg(all(feature = "onnx", target_os = "android"))]
fn test_default_is_onnx_on_android() {
let config = EmbeddingConfig::default();
assert_eq!(config.provider, EmbeddingProvider::Onnx);
assert_eq!(config.dimensions, 384);
assert_eq!(config.batch_size, 36);
assert_eq!(config.max_completion_tokens, 8191);
assert!(!config.mock);
}
#[test]
#[cfg(not(target_os = "android"))]
fn test_default_is_openai_off_android() {
let config = EmbeddingConfig::default();
assert_eq!(config.provider, EmbeddingProvider::OpenAi);
assert_eq!(config.model, "text-embedding-3-small");
assert_eq!(config.dimensions, 1536);
assert_eq!(
config.endpoint.as_deref(),
Some("https://api.openai.com/v1")
);
assert!(!config.mock);
}
#[test]
fn test_effective_provider_mock_override() {
let config = EmbeddingConfig {
mock: true,
..Default::default()
};
assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
}
#[test]
#[cfg(all(feature = "onnx", target_os = "android"))]
fn test_effective_provider_passthrough_onnx() {
let config = EmbeddingConfig::default();
assert_eq!(config.effective_provider(), EmbeddingProvider::Onnx);
}
#[test]
#[cfg(not(target_os = "android"))]
fn test_effective_provider_passthrough_openai() {
let config = EmbeddingConfig::default();
assert_eq!(config.effective_provider(), EmbeddingProvider::OpenAi);
}
#[test]
#[serial]
fn test_from_env_mock_embedding_true() {
unsafe { std::env::set_var("MOCK_EMBEDDING", "true") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("MOCK_EMBEDDING") };
assert!(config.mock);
assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
}
#[test]
#[serial]
fn test_from_env_mock_embedding_numeric() {
unsafe { std::env::set_var("MOCK_EMBEDDING", "1") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("MOCK_EMBEDDING") };
assert!(config.mock);
assert_eq!(config.mock_mode, MockVectorMode::Zero);
}
#[test]
#[ignore = "mutates global env vars; run with --test-threads=1 --ignored"]
fn test_from_env_mock_embedding_deterministic() {
unsafe { std::env::set_var("MOCK_EMBEDDING", "deterministic") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("MOCK_EMBEDDING") };
assert!(config.mock);
assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
assert_eq!(config.mock_mode, MockVectorMode::Deterministic);
}
#[test]
#[serial]
fn test_from_env_provider() {
unsafe { std::env::set_var("EMBEDDING_PROVIDER", "openai") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
assert_eq!(config.provider, EmbeddingProvider::OpenAi);
}
#[test]
#[serial]
fn test_from_env_fastembed_alias() {
unsafe { std::env::set_var("EMBEDDING_PROVIDER", "fastembed") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
assert_eq!(config.provider, EmbeddingProvider::Fastembed);
}
#[test]
#[serial]
fn test_from_env_dimensions() {
unsafe { std::env::set_var("EMBEDDING_DIMENSIONS", "1536") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("EMBEDDING_DIMENSIONS") };
assert_eq!(config.dimensions, 1536);
}
#[test]
#[serial]
fn test_from_env_api_key_fallback() {
unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
unsafe { std::env::set_var("LLM_API_KEY", "my-llm-key") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("LLM_API_KEY") };
assert_eq!(config.api_key, Some("my-llm-key".to_string()));
}
#[test]
#[serial]
fn test_from_env_api_key_prefers_embedding() {
unsafe { std::env::set_var("EMBEDDING_API_KEY", "embed-key") };
unsafe { std::env::set_var("LLM_API_KEY", "llm-key") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
unsafe { std::env::remove_var("LLM_API_KEY") };
assert_eq!(config.api_key, Some("embed-key".to_string()));
}
#[test]
#[cfg(feature = "onnx")]
#[serial]
fn from_env_onnx_batch_size_override() {
unsafe { std::env::set_var("EMBEDDING_ONNX_BATCH_SIZE", "8") };
let config = EmbeddingConfig::from_env();
unsafe { std::env::remove_var("EMBEDDING_ONNX_BATCH_SIZE") };
assert_eq!(config.onnx.batch_size, 8);
}
#[test]
#[cfg(feature = "onnx")]
fn test_onnx_config_bge_small() {
let cfg = OnnxEmbeddingConfig::bge_small("/models");
assert_eq!(cfg.dimensions, 384);
assert_eq!(cfg.max_sequence_length, 512);
assert_eq!(cfg.model_name, "bge-small-en-v1.5");
}
#[test]
#[cfg(feature = "onnx")]
fn test_onnx_config_minilm_l6() {
let cfg = OnnxEmbeddingConfig::minilm_l6("/models");
assert_eq!(cfg.dimensions, 384);
assert_eq!(cfg.max_sequence_length, 256);
assert_eq!(cfg.model_name, "all-MiniLM-L6-v2");
}
#[test]
fn known_dims_openai_large() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-large"),
Some(3072),
);
}
#[test]
fn known_dims_openai_small() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-small"),
Some(1536),
);
}
#[test]
fn known_dims_ada_002() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-ada-002"),
Some(1536),
);
}
#[test]
fn known_dims_prefix_stripped() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::OpenAi, "openai/text-embedding-3-small"),
Some(1536),
);
assert_eq!(
known_model_dimensions(
EmbeddingProvider::OpenAiCompatible,
"azure/text-embedding-3-large"
),
Some(3072),
);
}
#[test]
fn known_dims_bge_small() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::Onnx, "bge-small-en-v1.5"),
Some(384),
);
assert_eq!(
known_model_dimensions(EmbeddingProvider::Onnx, "BGE-Small-v1.5"),
Some(384),
);
assert_eq!(
known_model_dimensions(EmbeddingProvider::Fastembed, "BAAI/bge-small-en-v1.5"),
Some(384),
);
}
#[test]
fn known_dims_bge_large() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::Fastembed, "bge-large-en-v1.5"),
Some(1024),
);
}
#[test]
fn known_dims_unknown_returns_none() {
assert_eq!(
known_model_dimensions(EmbeddingProvider::OpenAi, "some-unknown-model"),
None,
);
}
#[test]
#[serial]
fn from_env_explicit_override_wins() {
unsafe {
std::env::set_var("EMBEDDING_PROVIDER", "openai");
std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
std::env::set_var("EMBEDDING_DIMENSIONS", "999");
}
let config = EmbeddingConfig::from_env();
unsafe {
std::env::remove_var("EMBEDDING_PROVIDER");
std::env::remove_var("EMBEDDING_MODEL");
std::env::remove_var("EMBEDDING_DIMENSIONS");
}
assert_eq!(config.dimensions, 999);
}
#[test]
#[serial]
fn from_env_model_change_resolves() {
unsafe {
std::env::set_var("EMBEDDING_PROVIDER", "openai");
std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
std::env::remove_var("EMBEDDING_DIMENSIONS");
}
let config = EmbeddingConfig::from_env();
unsafe {
std::env::remove_var("EMBEDDING_PROVIDER");
std::env::remove_var("EMBEDDING_MODEL");
}
assert_eq!(config.dimensions, 3072);
}
#[test]
#[serial]
fn from_env_unknown_falls_back() {
unsafe {
std::env::set_var("EMBEDDING_PROVIDER", "openai");
std::env::set_var("EMBEDDING_MODEL", "some-unknown-model-xyz");
std::env::remove_var("EMBEDDING_DIMENSIONS");
}
let config = EmbeddingConfig::from_env();
unsafe {
std::env::remove_var("EMBEDDING_PROVIDER");
std::env::remove_var("EMBEDDING_MODEL");
}
assert_eq!(config.dimensions, FALLBACK_DIMENSIONS);
}
}