use crate::Result;
use crate::embedding::{DEFAULT_DIMENSIONS, Embedder};
use crate::error::StorageError;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::OnceLock;
static EMBEDDING_MODEL: OnceLock<std::sync::Mutex<fastembed::TextEmbedding>> = OnceLock::new();
pub struct FastEmbedEmbedder {
model_name: &'static str,
}
impl FastEmbedEmbedder {
#[allow(clippy::missing_const_for_fn)]
pub fn new() -> Result<Self> {
Ok(Self {
model_name: "BGE-M3",
})
}
fn get_model() -> Result<&'static std::sync::Mutex<fastembed::TextEmbedding>> {
if let Some(model) = EMBEDDING_MODEL.get() {
return Ok(model);
}
let options = fastembed::InitOptions::new(fastembed::EmbeddingModel::BGEM3)
.with_show_download_progress(false);
let model = fastembed::TextEmbedding::try_new(options)
.map_err(|e| StorageError::Embedding(format!("Failed to load embedding model: {e}")))?;
let _ = EMBEDDING_MODEL.set(std::sync::Mutex::new(model));
EMBEDDING_MODEL.get().ok_or_else(|| {
StorageError::Embedding("Model initialization race condition".to_string()).into()
})
}
#[must_use]
pub const fn model_name(&self) -> &'static str {
self.model_name
}
}
impl Embedder for FastEmbedEmbedder {
fn dimensions(&self) -> usize {
DEFAULT_DIMENSIONS
}
fn model_name(&self) -> &'static str {
self.model_name
}
fn embed(&self, text: &str) -> Result<Vec<f32>> {
if text.is_empty() {
return Err(crate::Error::Chunking(
crate::error::ChunkingError::InvalidConfig {
reason: "Cannot embed empty text".to_string(),
},
));
}
let model = Self::get_model()?;
let mut model = model
.lock()
.map_err(|e| StorageError::Embedding(format!("Failed to lock embedding model: {e}")))?;
let texts = [text];
let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
let embeddings = result
.map_err(|panic_info| {
let panic_msg = panic_info
.downcast_ref::<&str>()
.map(|s| (*s).to_string())
.or_else(|| panic_info.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "unknown panic".to_string());
StorageError::Embedding(format!("ONNX runtime panic: {panic_msg}"))
})?
.map_err(|e| StorageError::Embedding(format!("Embedding failed: {e}")))?;
embeddings.into_iter().next().ok_or_else(|| {
StorageError::Embedding("No embedding returned from model".to_string()).into()
})
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if texts.iter().any(|t| t.is_empty()) {
return Err(crate::Error::Chunking(
crate::error::ChunkingError::InvalidConfig {
reason: "Cannot embed empty text".to_string(),
},
));
}
let model = Self::get_model()?;
let mut model = model
.lock()
.map_err(|e| StorageError::Embedding(format!("Failed to lock embedding model: {e}")))?;
let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
result
.map_err(|panic_info| {
let panic_msg = panic_info
.downcast_ref::<&str>()
.map(|s| (*s).to_string())
.or_else(|| panic_info.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "unknown panic".to_string());
crate::Error::Storage(StorageError::Embedding(format!(
"ONNX runtime panic: {panic_msg}"
)))
})?
.map_err(|e| {
crate::Error::Storage(StorageError::Embedding(format!(
"Batch embedding failed: {e}"
)))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedder_creation() {
let embedder = FastEmbedEmbedder::new();
assert!(embedder.is_ok());
assert_eq!(embedder.unwrap().dimensions(), DEFAULT_DIMENSIONS);
}
#[test]
fn test_model_name() {
let embedder = FastEmbedEmbedder::new().unwrap();
assert_eq!(embedder.model_name(), "BGE-M3");
}
#[test]
#[ignore = "requires fastembed model download"]
fn test_embed_success() {
let embedder = FastEmbedEmbedder::new().unwrap();
let result = embedder.embed("Hello, world!");
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), DEFAULT_DIMENSIONS);
}
#[test]
#[ignore = "requires fastembed model download"]
fn test_embed_batch_success() {
let embedder = FastEmbedEmbedder::new().unwrap();
let texts = vec!["Hello", "World"];
let result = embedder.embed_batch(&texts);
assert!(result.is_ok());
let embeddings = result.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), DEFAULT_DIMENSIONS);
}
#[test]
fn test_embed_empty_fails() {
let embedder = FastEmbedEmbedder::new().unwrap();
let result = embedder.embed("");
assert!(result.is_err());
}
#[test]
fn test_embed_batch_empty_list() {
let embedder = FastEmbedEmbedder::new().unwrap();
let result = embedder.embed_batch(&[]);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_embed_batch_with_empty_fails() {
let embedder = FastEmbedEmbedder::new().unwrap();
let texts = vec!["Valid", "", "Also valid"];
let result = embedder.embed_batch(&texts);
assert!(result.is_err());
}
}