#[cfg(feature = "embeddings")]
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
#[cfg(feature = "embeddings")]
use std::sync::{Arc, Mutex, RwLock};
#[cfg(feature = "embeddings")]
use std::collections::HashMap;
#[cfg(feature = "embeddings")]
use std::mem::ManuallyDrop;
#[cfg(feature = "embeddings")]
use once_cell::sync::Lazy;
#[cfg(feature = "embeddings")]
pub(crate) struct LeakedModel {
ptr: *mut TextEmbedding,
}
#[cfg(feature = "embeddings")]
impl LeakedModel {
fn new(model: TextEmbedding) -> Self {
Self {
ptr: Box::into_raw(Box::new(model)),
}
}
#[allow(unsafe_code, clippy::mut_from_ref)]
unsafe fn get_mut(&self) -> &mut TextEmbedding {
unsafe { &mut *self.ptr }
}
}
#[cfg(feature = "embeddings")]
#[allow(unsafe_code)]
unsafe impl Send for LeakedModel {}
#[cfg(feature = "embeddings")]
#[allow(unsafe_code)]
unsafe impl Sync for LeakedModel {}
#[cfg(feature = "embeddings")]
type CachedEmbedding = Arc<Mutex<LeakedModel>>;
#[cfg(feature = "embeddings")]
static MODEL_CACHE: Lazy<ManuallyDrop<RwLock<HashMap<String, CachedEmbedding>>>> =
Lazy::new(|| ManuallyDrop::new(RwLock::new(HashMap::new())));
#[cfg(feature = "embeddings")]
fn onnx_runtime_install_message() -> String {
#[cfg(all(windows, target_env = "gnu"))]
{
return "ONNX Runtime embeddings are not supported on Windows MinGW builds. \
ONNX Runtime requires MSVC toolchain. \
Please use Windows MSVC builds or disable embeddings feature."
.to_string();
}
#[cfg(not(all(windows, target_env = "gnu")))]
{
"ONNX Runtime is required for embeddings functionality. \
Install: \
macOS: 'brew install onnxruntime', \
Linux (Ubuntu/Debian): 'apt install libonnxruntime libonnxruntime-dev', \
Linux (Fedora): 'dnf install onnxruntime onnxruntime-devel', \
Linux (Arch): 'pacman -S onnxruntime', \
Windows (MSVC): Download from https://github.com/microsoft/onnxruntime/releases and add to PATH. \
\
Alternatively, set ORT_DYLIB_PATH environment variable to the ONNX Runtime library path. \
\
For Docker/containers: Install via package manager in your base image. \
Verified packages: Ubuntu 22.04+, Fedora 38+, Arch Linux."
.to_string()
}
}
#[cfg(feature = "embeddings")]
#[allow(private_interfaces)]
pub fn get_or_init_model(
model: EmbeddingModel,
cache_dir: Option<std::path::PathBuf>,
) -> crate::Result<CachedEmbedding> {
let cache_directory = cache_dir.unwrap_or_else(|| {
let mut path = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
path.push(".kreuzberg");
path.push("embeddings");
path
});
let model_key = format!("{:?}_{}", model, cache_directory.display());
{
match MODEL_CACHE.read() {
Ok(cache) => {
if let Some(cached_model) = cache.get(&model_key) {
return Ok(Arc::clone(cached_model));
}
}
Err(poison_error) => {
let cache = poison_error.get_ref();
if let Some(cached_model) = cache.get(&model_key) {
return Ok(Arc::clone(cached_model));
}
}
}
}
{
let mut cache = match MODEL_CACHE.write() {
Ok(guard) => guard,
Err(poison_error) => poison_error.into_inner(),
};
if let Some(cached_model) = cache.get(&model_key) {
return Ok(Arc::clone(cached_model));
}
crate::ort_discovery::ensure_ort_available();
let embedding_model = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut init_options = InitOptions::new(model);
init_options = init_options.with_cache_dir(cache_directory);
TextEmbedding::try_new(init_options)
}))
.map_err(|panic_payload| {
let panic_msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic during ONNX Runtime initialization".to_string()
};
if panic_msg.contains("onnxruntime")
|| panic_msg.contains("ORT")
|| panic_msg.contains("libonnxruntime")
|| panic_msg.contains("onnxruntime.dll")
|| panic_msg.contains("Unable to load")
|| panic_msg.contains("library load failed")
|| panic_msg.contains("attempting to load")
|| panic_msg.contains("An error occurred while")
{
crate::KreuzbergError::MissingDependency(format!("ONNX Runtime - {}", onnx_runtime_install_message()))
} else {
crate::KreuzbergError::Plugin {
message: format!("ONNX Runtime initialization panicked: {}", panic_msg),
plugin_name: "embeddings".to_string(),
}
}
})
.and_then(|result| {
result.map_err(|e| {
let error_msg = e.to_string();
if error_msg.contains("onnxruntime")
|| error_msg.contains("ORT")
|| error_msg.contains("libonnxruntime")
|| error_msg.contains("onnxruntime.dll")
|| error_msg.contains("Unable to load")
|| error_msg.contains("library load failed")
|| error_msg.contains("attempting to load")
|| error_msg.contains("An error occurred while")
{
crate::KreuzbergError::MissingDependency(format!(
"ONNX Runtime - {}",
onnx_runtime_install_message()
))
} else {
crate::KreuzbergError::Plugin {
message: format!("Failed to initialize embedding model: {}", e),
plugin_name: "embeddings".to_string(),
}
}
})
})?;
let leaked_model = LeakedModel::new(embedding_model);
let arc_model = Arc::new(Mutex::new(leaked_model));
cache.insert(model_key, Arc::clone(&arc_model));
Ok(arc_model)
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingPreset {
pub name: &'static str,
pub chunk_size: usize,
pub overlap: usize,
#[cfg(feature = "embeddings")]
pub model: EmbeddingModel,
#[cfg(not(feature = "embeddings"))]
pub model_name: &'static str,
pub dimensions: usize,
pub description: &'static str,
}
pub const EMBEDDING_PRESETS: &[EmbeddingPreset] = &[
EmbeddingPreset {
name: "fast",
chunk_size: 512,
overlap: 50,
#[cfg(feature = "embeddings")]
model: EmbeddingModel::AllMiniLML6V2Q,
#[cfg(not(feature = "embeddings"))]
model_name: "AllMiniLML6V2Q",
dimensions: 384,
description: "Fast embedding with quantized model (384 dims, ~22M params). Best for: Quick prototyping, development, resource-constrained environments.",
},
EmbeddingPreset {
name: "balanced",
chunk_size: 1024,
overlap: 100,
#[cfg(feature = "embeddings")]
model: EmbeddingModel::BGEBaseENV15,
#[cfg(not(feature = "embeddings"))]
model_name: "BGEBaseENV15",
dimensions: 768,
description: "Balanced quality and speed (768 dims, ~109M params). Best for: General-purpose RAG, production deployments, English documents.",
},
EmbeddingPreset {
name: "quality",
chunk_size: 2000,
overlap: 200,
#[cfg(feature = "embeddings")]
model: EmbeddingModel::BGELargeENV15,
#[cfg(not(feature = "embeddings"))]
model_name: "BGELargeENV15",
dimensions: 1024,
description: "High quality with larger context (1024 dims, ~335M params). Best for: Complex documents, maximum accuracy, sufficient compute resources.",
},
EmbeddingPreset {
name: "multilingual",
chunk_size: 1024,
overlap: 100,
#[cfg(feature = "embeddings")]
model: EmbeddingModel::MultilingualE5Base,
#[cfg(not(feature = "embeddings"))]
model_name: "MultilingualE5Base",
dimensions: 768,
description: "Multilingual support (768 dims, 100+ languages). Best for: International documents, mixed-language content, global applications.",
},
];
pub fn get_preset(name: &str) -> Option<&'static EmbeddingPreset> {
EMBEDDING_PRESETS.iter().find(|p| p.name == name)
}
pub fn list_presets() -> Vec<&'static str> {
EMBEDDING_PRESETS.iter().map(|p| p.name).collect()
}
#[cfg(feature = "embeddings")]
pub fn generate_embeddings_for_chunks(
chunks: &mut [crate::types::Chunk],
config: &crate::core::config::EmbeddingConfig,
) -> crate::Result<()> {
if chunks.is_empty() {
return Ok(());
}
let fastembed_model = match &config.model {
crate::core::config::EmbeddingModelType::Preset { name } => {
let preset = get_preset(name).ok_or_else(|| crate::KreuzbergError::Plugin {
message: format!("Unknown embedding preset: {}", name),
plugin_name: "embeddings".to_string(),
})?;
preset.model.clone()
}
#[cfg(feature = "embeddings")]
crate::core::config::EmbeddingModelType::FastEmbed { model, .. } => match model.as_str() {
"AllMiniLML6V2Q" => fastembed::EmbeddingModel::AllMiniLML6V2Q,
"BGEBaseENV15" => fastembed::EmbeddingModel::BGEBaseENV15,
"BGELargeENV15" => fastembed::EmbeddingModel::BGELargeENV15,
"MultilingualE5Base" => fastembed::EmbeddingModel::MultilingualE5Base,
_ => {
return Err(crate::KreuzbergError::Plugin {
message: format!("Unknown fastembed model: {}", model),
plugin_name: "embeddings".to_string(),
});
}
},
crate::core::config::EmbeddingModelType::Custom { .. } => {
return Err(crate::KreuzbergError::Plugin {
message: "Custom ONNX models are not yet supported for embedding generation".to_string(),
plugin_name: "embeddings".to_string(),
});
}
};
let model = get_or_init_model(fastembed_model, config.cache_dir.clone())?;
let texts: Vec<String> = chunks.iter().map(|chunk| chunk.content.clone()).collect();
let embeddings_result = {
let locked_model = model.lock().map_err(|e| crate::KreuzbergError::Plugin {
message: format!("Failed to acquire model lock: {}", e),
plugin_name: "embeddings".to_string(),
})?;
#[allow(unsafe_code)]
let model_mut = unsafe { locked_model.get_mut() };
model_mut
.embed(texts, Some(config.batch_size))
.map_err(|e| crate::KreuzbergError::Plugin {
message: format!("Failed to generate embeddings: {}", e),
plugin_name: "embeddings".to_string(),
})?
};
for (chunk, mut embedding) in chunks.iter_mut().zip(embeddings_result.into_iter()) {
if config.normalize {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
embedding.iter_mut().for_each(|x| *x /= magnitude);
}
}
chunk.embedding = Some(embedding);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_preset() {
assert!(get_preset("balanced").is_some());
assert!(get_preset("fast").is_some());
assert!(get_preset("quality").is_some());
assert!(get_preset("multilingual").is_some());
assert!(get_preset("nonexistent").is_none());
}
#[test]
fn test_list_presets() {
let presets = list_presets();
assert_eq!(presets.len(), 4);
assert!(presets.contains(&"fast"));
assert!(presets.contains(&"balanced"));
assert!(presets.contains(&"quality"));
assert!(presets.contains(&"multilingual"));
}
#[test]
fn test_preset_dimensions() {
let balanced = get_preset("balanced").unwrap();
assert_eq!(balanced.dimensions, 768);
let fast = get_preset("fast").unwrap();
assert_eq!(fast.dimensions, 384);
let quality = get_preset("quality").unwrap();
assert_eq!(quality.dimensions, 1024);
}
#[test]
fn test_preset_chunk_sizes() {
let fast = get_preset("fast").unwrap();
assert_eq!(fast.chunk_size, 512);
assert_eq!(fast.overlap, 50);
let quality = get_preset("quality").unwrap();
assert_eq!(quality.chunk_size, 2000);
assert_eq!(quality.overlap, 200);
}
#[cfg(feature = "embeddings")]
#[test]
fn test_lock_poisoning_recovery_semantics() {}
}