use std::sync::Mutex;
use fastembed::{EmbeddingModel, TextEmbedding};
use super::Vectorizer;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct HuggingFaceConfig {
pub model: EmbeddingModel,
pub show_download_progress: bool,
}
impl Default for HuggingFaceConfig {
fn default() -> Self {
Self {
model: EmbeddingModel::AllMiniLML6V2,
show_download_progress: false,
}
}
}
impl HuggingFaceConfig {
#[must_use]
pub fn new(model: EmbeddingModel) -> Self {
Self {
model,
show_download_progress: false,
}
}
#[must_use]
pub fn with_show_download_progress(mut self, show: bool) -> Self {
self.show_download_progress = show;
self
}
}
pub struct HuggingFaceTextVectorizer {
model: Mutex<TextEmbedding>,
}
impl HuggingFaceTextVectorizer {
pub fn new(config: HuggingFaceConfig) -> Result<Self> {
let init_options = fastembed::InitOptions::new(config.model)
.with_show_download_progress(config.show_download_progress);
let model = TextEmbedding::try_new(init_options)
.map_err(|e| Error::InvalidInput(format!("failed to load HF model: {e}")))?;
Ok(Self {
model: Mutex::new(model),
})
}
}
impl Vectorizer for HuggingFaceTextVectorizer {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut model = self
.model
.lock()
.map_err(|e| Error::InvalidInput(format!("lock poisoned: {e}")))?;
let mut embeddings = model
.embed(vec![text], None)
.map_err(|e| Error::InvalidInput(format!("embedding failed: {e}")))?;
Ok(embeddings.pop().unwrap_or_default())
}
fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut model = self
.model
.lock()
.map_err(|e| Error::InvalidInput(format!("lock poisoned: {e}")))?;
model
.embed(texts.to_vec(), None)
.map_err(|e| Error::InvalidInput(format!("embedding failed: {e}")))
}
}
unsafe impl Send for HuggingFaceTextVectorizer {}
unsafe impl Sync for HuggingFaceTextVectorizer {}
impl std::fmt::Debug for HuggingFaceTextVectorizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HuggingFaceTextVectorizer")
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_uses_all_mini_lm() {
let cfg = HuggingFaceConfig::default();
assert!(!cfg.show_download_progress);
assert!(format!("{:?}", cfg.model).contains("AllMiniLML6V2"));
}
#[test]
fn config_builder_chain() {
let cfg =
HuggingFaceConfig::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true);
assert!(cfg.show_download_progress);
}
#[test]
fn vectorizer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<HuggingFaceTextVectorizer>();
}
#[test]
fn debug_impl_does_not_panic() {
let cfg = HuggingFaceConfig::default();
let dbg = format!("{cfg:?}");
assert!(dbg.contains("HuggingFaceConfig"));
}
}