pub mod engine;
use ahash::AHashMap;
use std::sync::{Arc, RwLock};
use once_cell::sync::Lazy;
use engine::EmbeddingEngine;
type CachedEngine = Arc<EmbeddingEngine>;
static ENGINE_CACHE: Lazy<RwLock<AHashMap<String, CachedEngine>>> = Lazy::new(|| RwLock::new(AHashMap::new()));
#[cfg(feature = "tokio-runtime")]
static EMBED_SEMAPHORE: Lazy<Arc<tokio::sync::Semaphore>> = Lazy::new(|| {
let budget = crate::core::config::concurrency::resolve_thread_budget(None);
Arc::new(tokio::sync::Semaphore::new(budget))
});
#[derive(Debug, Clone)]
pub struct EmbeddingPreset {
pub name: &'static str,
pub chunk_size: usize,
pub overlap: usize,
pub model_repo: &'static str,
pub pooling: &'static str,
pub model_file: &'static str,
pub dimensions: usize,
pub description: &'static str,
}
pub const EMBEDDING_PRESETS: &[EmbeddingPreset] = &[
EmbeddingPreset {
name: "fast",
chunk_size: 512,
overlap: 50,
model_repo: "Xenova/all-MiniLM-L6-v2",
pooling: "mean",
model_file: "onnx/model_quantized.onnx",
dimensions: 384,
description: "Fast embedding with quantized model (384 dims, ~22M params). Best for: Quick prototyping, development, resource-constrained environments.",
},
EmbeddingPreset {
name: "balanced",
chunk_size: 1024,
overlap: 100,
model_repo: "Xenova/bge-base-en-v1.5",
pooling: "cls",
model_file: "onnx/model.onnx",
dimensions: 768,
description: "Balanced quality and speed (768 dims, ~109M params). Best for: General-purpose RAG, production deployments, English documents.",
},
EmbeddingPreset {
name: "quality",
chunk_size: 2000,
overlap: 200,
model_repo: "Xenova/bge-large-en-v1.5",
pooling: "cls",
model_file: "onnx/model.onnx",
dimensions: 1024,
description: "High quality with larger context (1024 dims, ~335M params). Best for: Complex documents, maximum accuracy, sufficient compute resources.",
},
EmbeddingPreset {
name: "multilingual",
chunk_size: 1024,
overlap: 100,
model_repo: "intfloat/multilingual-e5-base",
pooling: "mean",
model_file: "onnx/model.onnx",
dimensions: 768,
description: "Multilingual support (768 dims, 100+ languages). Best for: International documents, mixed-language content, global applications.",
},
];
pub fn get_preset(name: &str) -> Option<&'static EmbeddingPreset> {
EMBEDDING_PRESETS.iter().find(|p| p.name == name)
}
pub fn list_presets() -> Vec<&'static str> {
EMBEDDING_PRESETS.iter().map(|p| p.name).collect()
}
fn onnx_runtime_install_message() -> String {
#[cfg(all(windows, target_env = "gnu"))]
{
return "ONNX Runtime embeddings are not supported on Windows MinGW builds. \
ONNX Runtime requires MSVC toolchain. \
Please use Windows MSVC builds or disable embeddings feature."
.to_string();
}
#[cfg(not(all(windows, target_env = "gnu")))]
{
"ONNX Runtime is required for embeddings functionality. \
Install: \
macOS: 'brew install onnxruntime', \
Linux (Ubuntu/Debian): 'apt install libonnxruntime libonnxruntime-dev', \
Linux (Fedora): 'dnf install onnxruntime onnxruntime-devel', \
Linux (Arch): 'pacman -S onnxruntime', \
Windows (MSVC): Download from https://github.com/microsoft/onnxruntime/releases and add to PATH. \
\
Alternatively, set ORT_DYLIB_PATH environment variable to the ONNX Runtime library path. \
\
For Docker/containers: Install via package manager in your base image. \
Verified packages: Ubuntu 22.04+, Fedora 38+, Arch Linux."
.to_string()
}
}
fn resolve_cache_dir(cache_dir: Option<std::path::PathBuf>) -> std::path::PathBuf {
cache_dir.unwrap_or_else(|| crate::cache_dir::resolve_cache_dir("embeddings"))
}
fn resolve_model_info(
model_type: &crate::core::config::EmbeddingModelType,
) -> crate::Result<(&str, &str, engine::Pooling)> {
match model_type {
crate::core::config::EmbeddingModelType::Preset { name } => {
let preset = get_preset(name)
.ok_or_else(|| crate::KreuzbergError::embedding(format!("Unknown embedding preset: {name}")))?;
let pooling = match preset.pooling {
"cls" => engine::Pooling::Cls,
_ => engine::Pooling::Mean,
};
Ok((preset.model_repo, preset.model_file, pooling))
}
crate::core::config::EmbeddingModelType::Custom { model_id, .. } => {
Ok((model_id.as_str(), "onnx/model.onnx", engine::Pooling::Mean))
}
crate::core::config::EmbeddingModelType::Llm { .. } => Err(crate::KreuzbergError::embedding(
"LLM-based embeddings require the 'liter-llm' feature and are handled by a separate code path",
)),
}
}
fn load_tokenizer(
tokenizer_path: &std::path::Path,
config_path: &std::path::Path,
special_tokens_path: &std::path::Path,
tokenizer_config_path: &std::path::Path,
max_length: usize,
) -> crate::Result<tokenizers::Tokenizer> {
use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, TruncationParams};
let config: serde_json::Value = serde_json::from_slice(
&std::fs::read(config_path)
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to read config.json: {e}")))?,
)
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to parse config.json: {e}")))?;
let tokenizer_config: serde_json::Value = serde_json::from_slice(
&std::fs::read(tokenizer_config_path)
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to read tokenizer_config.json: {e}")))?,
)
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to parse tokenizer_config.json: {e}")))?;
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to load tokenizer: {e}")))?;
let model_max_length = tokenizer_config["model_max_length"].as_f64().unwrap_or(512.0) as usize;
let max_length = max_length.min(model_max_length);
let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32;
let pad_token = tokenizer_config["pad_token"].as_str().unwrap_or("[PAD]").to_string();
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_token,
pad_id,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length,
..Default::default()
}))
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to configure tokenizer: {e}")))?;
if let Ok(special_tokens_data) = std::fs::read(special_tokens_path)
&& let Ok(serde_json::Value::Object(map)) = serde_json::from_slice(&special_tokens_data)
{
for (_, value) in &map {
if let Some(content) = value.as_str() {
tokenizer.add_special_tokens(&[AddedToken {
content: content.to_string(),
special: true,
..Default::default()
}]);
} else if value.is_object()
&& let (Some(content), Some(single_word), Some(lstrip), Some(rstrip), Some(normalized)) = (
value["content"].as_str(),
value["single_word"].as_bool(),
value["lstrip"].as_bool(),
value["rstrip"].as_bool(),
value["normalized"].as_bool(),
)
{
tokenizer.add_special_tokens(&[AddedToken {
content: content.to_string(),
special: true,
single_word,
lstrip,
rstrip,
normalized,
}]);
}
}
}
Ok(tokenizer)
}
fn download_model_files(
repo_name: &str,
model_file: &str,
cache_directory: &std::path::Path,
) -> crate::Result<(
std::path::PathBuf,
std::path::PathBuf,
std::path::PathBuf,
std::path::PathBuf,
std::path::PathBuf,
)> {
let api = hf_hub::api::sync::ApiBuilder::from_env()
.with_cache_dir(cache_directory.to_path_buf())
.with_progress(true)
.build()
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to create HF API client: {e}")))?;
let repo = api.model(repo_name.to_string());
let model_path = repo
.get(model_file)
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to download {model_file}: {e}")))?;
let tokenizer_path = repo
.get("tokenizer.json")
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to download tokenizer.json: {e}")))?;
let config_path = repo
.get("config.json")
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to download config.json: {e}")))?;
let special_tokens_path = repo
.get("special_tokens_map.json")
.unwrap_or_else(|_| std::path::PathBuf::new());
let tokenizer_config_path = repo
.get("tokenizer_config.json")
.unwrap_or_else(|_| std::path::PathBuf::new());
Ok((
model_path,
tokenizer_path,
config_path,
special_tokens_path,
tokenizer_config_path,
))
}
fn get_or_init_engine(
repo_name: &str,
model_file: &str,
pooling: engine::Pooling,
cache_dir: Option<std::path::PathBuf>,
) -> crate::Result<Arc<EmbeddingEngine>> {
let cache_directory = resolve_cache_dir(cache_dir);
let engine_key = format!(
"{repo_name}_{model_file}_{cache_directory}",
cache_directory = cache_directory.display()
);
{
match ENGINE_CACHE.read() {
Ok(cache) => {
if let Some(cached) = cache.get(&engine_key) {
return Ok(Arc::clone(cached));
}
}
Err(poison_error) => {
let cache = poison_error.get_ref();
if let Some(cached) = cache.get(&engine_key) {
return Ok(Arc::clone(cached));
}
}
}
}
{
let mut cache = match ENGINE_CACHE.write() {
Ok(guard) => guard,
Err(poison_error) => poison_error.into_inner(),
};
if let Some(cached) = cache.get(&engine_key) {
return Ok(Arc::clone(cached));
}
crate::ort_discovery::ensure_ort_available();
let (model_path, tokenizer_path, config_path, special_tokens_path, tokenizer_config_path) =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
download_model_files(repo_name, model_file, &cache_directory)
}))
.map_err(|panic_payload| {
let panic_msg = panic_to_string(panic_payload);
if looks_like_ort_error(&panic_msg) {
crate::KreuzbergError::MissingDependency(format!(
"ONNX Runtime - {}",
onnx_runtime_install_message()
))
} else {
crate::KreuzbergError::embedding(format!("Model download panicked: {panic_msg}"))
}
})??;
let tokenizer = load_tokenizer(
&tokenizer_path,
&config_path,
&special_tokens_path,
&tokenizer_config_path,
512, )?;
let thread_budget = crate::core::config::concurrency::resolve_thread_budget(None);
let session = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut builder = ort::session::Session::builder()?;
builder = builder
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)
.map_err(|e| ort::Error::new(e.message()))?;
builder = builder
.with_intra_threads(thread_budget)
.map_err(|e| ort::Error::new(e.message()))?;
builder = builder
.with_inter_threads(1)
.map_err(|e| ort::Error::new(e.message()))?;
builder.commit_from_file(&model_path)
}))
.map_err(|panic_payload| {
let panic_msg = panic_to_string(panic_payload);
if looks_like_ort_error(&panic_msg) {
crate::KreuzbergError::MissingDependency(format!("ONNX Runtime - {}", onnx_runtime_install_message()))
} else {
crate::KreuzbergError::embedding(format!("ONNX Runtime initialization panicked: {panic_msg}"))
}
})?
.map_err(|e| {
let error_msg = e.to_string();
if looks_like_ort_error(&error_msg) {
crate::KreuzbergError::MissingDependency(format!("ONNX Runtime - {}", onnx_runtime_install_message()))
} else {
crate::KreuzbergError::embedding(format!("Failed to create ONNX session: {e}"))
}
})?;
let new_engine = Arc::new(EmbeddingEngine::new(tokenizer, session, pooling));
cache.insert(engine_key, Arc::clone(&new_engine));
Ok(new_engine)
}
}
fn looks_like_ort_error(msg: &str) -> bool {
msg.contains("onnxruntime")
|| msg.contains("ORT")
|| msg.contains("libonnxruntime")
|| msg.contains("onnxruntime.dll")
|| msg.contains("Unable to load")
|| msg.contains("library load failed")
|| msg.contains("attempting to load")
|| msg.contains("An error occurred while")
}
fn panic_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
}
}
pub fn warm_model(
model_type: &crate::core::config::EmbeddingModelType,
cache_dir: Option<std::path::PathBuf>,
) -> crate::Result<()> {
let (repo, model_file, pooling) = resolve_model_info(model_type)?;
get_or_init_engine(repo, model_file, pooling, cache_dir).map(|_| ())
}
pub fn download_model(
model_type: &crate::core::config::EmbeddingModelType,
cache_dir: Option<std::path::PathBuf>,
) -> crate::Result<()> {
let (repo_name, model_file, _pooling) = resolve_model_info(model_type)?;
let cache_directory = resolve_cache_dir(cache_dir);
let files = &[
model_file,
"tokenizer.json",
"config.json",
"special_tokens_map.json",
"tokenizer_config.json",
];
tracing::info!(repo = %repo_name, "Downloading embedding model files (no ONNX init)");
let api = hf_hub::api::sync::ApiBuilder::from_env()
.with_cache_dir(cache_directory)
.with_progress(true)
.build()
.map_err(|e| crate::KreuzbergError::embedding(format!("Failed to create HF API client: {e}")))?;
let repo = api.model(repo_name.to_string());
for file in files {
match repo.get(file) {
Ok(path) => tracing::debug!(file = %file, path = ?path, "Downloaded"),
Err(e) => {
if *file == model_file || *file == "tokenizer.json" {
return Err(crate::KreuzbergError::embedding(format!(
"Failed to download {file}: {e}"
)));
}
tracing::debug!(file = %file, error = %e, "Optional file not found, skipping");
}
}
}
tracing::info!(repo = %repo_name, "Embedding model files downloaded successfully");
Ok(())
}
fn normalize_in_place(embedding: &mut [f32]) {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > f32::EPSILON {
let inv_mag = 1.0 / magnitude;
embedding.iter_mut().for_each(|x| *x *= inv_mag);
}
}
fn normalize_embeddings(embeddings: &mut [Vec<f32>]) {
const PARALLEL_THRESHOLD: usize = 64;
if embeddings.len() >= PARALLEL_THRESHOLD {
use rayon::prelude::*;
embeddings.par_iter_mut().for_each(|v| normalize_in_place(v));
} else {
embeddings.iter_mut().for_each(|v| normalize_in_place(v));
}
}
pub fn generate_embeddings_for_chunks(
chunks: &mut [crate::types::Chunk],
config: &crate::core::config::EmbeddingConfig,
) -> crate::Result<()> {
if chunks.is_empty() {
return Ok(());
}
let texts: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
let embeddings_result = embed_texts(&texts, config)?;
for (chunk, embedding) in chunks.iter_mut().zip(embeddings_result) {
chunk.embedding = Some(embedding);
}
Ok(())
}
pub fn embed_texts<T: AsRef<str>>(
texts: &[T],
config: &crate::core::config::EmbeddingConfig,
) -> crate::Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
for (i, t) in texts.iter().enumerate() {
if t.as_ref().is_empty() {
return Err(crate::KreuzbergError::embedding(format!(
"Text at position {pos} is empty. All texts must be non-empty.",
pos = i + 1
)));
}
}
match &config.model {
#[cfg(feature = "liter-llm")]
crate::core::config::EmbeddingModelType::Llm { llm } => {
let normalize = config.normalize;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| {
crate::KreuzbergError::embedding(format!("Failed to create runtime for LLM embeddings: {e}"))
})?;
rt.block_on(crate::llm::vlm_embeddings::embed_via_llm(texts, llm, normalize))
}
#[cfg(not(feature = "liter-llm"))]
crate::core::config::EmbeddingModelType::Llm { .. } => Err(crate::KreuzbergError::MissingDependency(
"LLM embeddings require the 'liter-llm' feature. Rebuild with --features liter-llm".into(),
)),
_ => {
let chunk_count = texts.len();
let (repo, model_file, pooling) = resolve_model_info(&config.model)?;
let engine = get_or_init_engine(repo, model_file, pooling, config.cache_dir.clone())?;
let text_refs: Vec<&str> = texts.iter().map(|t| t.as_ref()).collect();
let mut embeddings = engine.embed(&text_refs, config.batch_size).map_err(|e| {
crate::KreuzbergError::embedding(format!(
"Failed to generate embeddings for {chunk_count} texts (model={:?}, batch_size={}): {e}",
config.model, config.batch_size
))
})?;
if config.normalize {
normalize_embeddings(&mut embeddings);
}
Ok(embeddings)
}
}
}
#[cfg(feature = "tokio-runtime")]
pub async fn embed_texts_async<T: AsRef<str> + Send + 'static>(
texts: Vec<T>,
config: &crate::core::config::EmbeddingConfig,
) -> crate::Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
#[cfg(feature = "liter-llm")]
if let crate::core::config::EmbeddingModelType::Llm { llm } = &config.model {
return crate::llm::vlm_embeddings::embed_via_llm(&texts, llm, config.normalize).await;
}
#[cfg(not(feature = "liter-llm"))]
if let crate::core::config::EmbeddingModelType::Llm { .. } = &config.model {
return Err(crate::KreuzbergError::MissingDependency(
"LLM embeddings require the 'liter-llm' feature. Rebuild with --features liter-llm".into(),
));
}
let _permit = EMBED_SEMAPHORE
.acquire()
.await
.map_err(|_| crate::KreuzbergError::embedding("Embedding semaphore closed".to_string()))?;
let config = Arc::new(config.clone());
tokio::task::spawn_blocking(move || embed_texts(&texts, &config))
.await
.map_err(|e| crate::KreuzbergError::embedding(format!("Embedding task panicked: {e}")))?
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_preset() {
assert!(get_preset("balanced").is_some());
assert!(get_preset("fast").is_some());
assert!(get_preset("quality").is_some());
assert!(get_preset("multilingual").is_some());
assert!(get_preset("nonexistent").is_none());
}
#[test]
fn test_list_presets() {
let presets = list_presets();
assert_eq!(presets.len(), 4);
assert!(presets.contains(&"fast"));
assert!(presets.contains(&"balanced"));
assert!(presets.contains(&"quality"));
assert!(presets.contains(&"multilingual"));
}
#[test]
fn test_preset_dimensions() {
let balanced = get_preset("balanced").unwrap();
assert_eq!(balanced.dimensions, 768);
let fast = get_preset("fast").unwrap();
assert_eq!(fast.dimensions, 384);
let quality = get_preset("quality").unwrap();
assert_eq!(quality.dimensions, 1024);
}
#[test]
fn test_preset_chunk_sizes() {
let fast = get_preset("fast").unwrap();
assert_eq!(fast.chunk_size, 512);
assert_eq!(fast.overlap, 50);
let quality = get_preset("quality").unwrap();
assert_eq!(quality.chunk_size, 2000);
assert_eq!(quality.overlap, 200);
}
#[test]
fn test_preset_model_repos() {
let fast = get_preset("fast").unwrap();
assert_eq!(fast.model_repo, "Xenova/all-MiniLM-L6-v2");
assert_eq!(fast.pooling, "mean");
assert_eq!(fast.model_file, "onnx/model_quantized.onnx");
let balanced = get_preset("balanced").unwrap();
assert_eq!(balanced.model_repo, "Xenova/bge-base-en-v1.5");
assert_eq!(balanced.pooling, "cls");
}
#[test]
fn test_embed_texts_rejects_empty_string() {
let config = crate::core::config::EmbeddingConfig::default();
let texts = vec!["valid", ""];
let err = embed_texts(&texts, &config).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("position 2"),
"Error should identify the empty text position, got: {msg}"
);
assert!(msg.contains("empty"), "Error should mention empty text, got: {msg}");
}
#[test]
fn test_embed_texts_empty_list_returns_empty() {
let config = crate::core::config::EmbeddingConfig::default();
let texts: Vec<&str> = vec![];
let result = embed_texts(&texts, &config).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_embed_texts_rejects_first_empty_string() {
let config = crate::core::config::EmbeddingConfig::default();
let texts = vec![""];
let err = embed_texts(&texts, &config).unwrap_err();
assert!(err.to_string().contains("position 1"));
}
}