Skip to main content

redis_vl/vectorizers/
hf_local.rs

1//! HuggingFace local embedding adapter using ONNX Runtime via `fastembed`.
2//!
3//! Enabled by the `hf-local` feature flag. This vectorizer runs embedding
4//! models locally without requiring an external API. Models are automatically
5//! downloaded from the HuggingFace Hub on first use.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use redis_vl::vectorizers::{HuggingFaceTextVectorizer, Vectorizer};
11//!
12//! // Uses the default model (AllMiniLML6V2)
13//! let vectorizer = HuggingFaceTextVectorizer::new(Default::default()).unwrap();
14//! let embedding = vectorizer.embed("Hello, world!").unwrap();
15//! ```
16
17use std::sync::Mutex;
18
19use fastembed::{EmbeddingModel, TextEmbedding};
20
21use super::Vectorizer;
22use crate::error::{Error, Result};
23
24/// Configuration for the HuggingFace local embedding provider.
25#[derive(Debug, Clone)]
26pub struct HuggingFaceConfig {
27    /// The embedding model to use.
28    ///
29    /// Defaults to [`EmbeddingModel::AllMiniLML6V2`].
30    pub model: EmbeddingModel,
31    /// Whether to show download progress when fetching the model.
32    ///
33    /// Defaults to `false`.
34    pub show_download_progress: bool,
35}
36
37impl Default for HuggingFaceConfig {
38    fn default() -> Self {
39        Self {
40            model: EmbeddingModel::AllMiniLML6V2,
41            show_download_progress: false,
42        }
43    }
44}
45
46impl HuggingFaceConfig {
47    /// Creates a new config with the given model.
48    #[must_use]
49    pub fn new(model: EmbeddingModel) -> Self {
50        Self {
51            model,
52            show_download_progress: false,
53        }
54    }
55
56    /// Enables download progress output.
57    #[must_use]
58    pub fn with_show_download_progress(mut self, show: bool) -> Self {
59        self.show_download_progress = show;
60        self
61    }
62}
63
64/// HuggingFace local embedding adapter backed by ONNX Runtime.
65///
66/// Uses the [`fastembed`] crate to run embedding models locally. Models are
67/// automatically downloaded from the HuggingFace Hub on first use and cached
68/// on disk.
69///
70/// This vectorizer implements [`Vectorizer`] for synchronous embedding
71/// generation. For async use cases, wrap it with
72/// [`tokio::task::spawn_blocking`] or use it with the synchronous semantic
73/// extension APIs.
74pub struct HuggingFaceTextVectorizer {
75    model: Mutex<TextEmbedding>,
76}
77
78impl HuggingFaceTextVectorizer {
79    /// Creates a new HuggingFace local vectorizer.
80    ///
81    /// This may download the model from HuggingFace Hub on first invocation.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the model cannot be loaded.
86    pub fn new(config: HuggingFaceConfig) -> Result<Self> {
87        let init_options = fastembed::InitOptions::new(config.model)
88            .with_show_download_progress(config.show_download_progress);
89
90        let model = TextEmbedding::try_new(init_options)
91            .map_err(|e| Error::InvalidInput(format!("failed to load HF model: {e}")))?;
92
93        Ok(Self {
94            model: Mutex::new(model),
95        })
96    }
97}
98
99impl Vectorizer for HuggingFaceTextVectorizer {
100    fn embed(&self, text: &str) -> Result<Vec<f32>> {
101        let mut model = self
102            .model
103            .lock()
104            .map_err(|e| Error::InvalidInput(format!("lock poisoned: {e}")))?;
105        let mut embeddings = model
106            .embed(vec![text], None)
107            .map_err(|e| Error::InvalidInput(format!("embedding failed: {e}")))?;
108        Ok(embeddings.pop().unwrap_or_default())
109    }
110
111    fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
112        let mut model = self
113            .model
114            .lock()
115            .map_err(|e| Error::InvalidInput(format!("lock poisoned: {e}")))?;
116        model
117            .embed(texts.to_vec(), None)
118            .map_err(|e| Error::InvalidInput(format!("embedding failed: {e}")))
119    }
120}
121
122// Safety: Mutex<TextEmbedding> provides thread-safe access.
123unsafe impl Send for HuggingFaceTextVectorizer {}
124unsafe impl Sync for HuggingFaceTextVectorizer {}
125
126impl std::fmt::Debug for HuggingFaceTextVectorizer {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("HuggingFaceTextVectorizer")
129            .finish_non_exhaustive()
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn default_config_uses_all_mini_lm() {
139        let cfg = HuggingFaceConfig::default();
140        assert!(!cfg.show_download_progress);
141        // EmbeddingModel doesn't implement PartialEq, so verify Debug output.
142        assert!(format!("{:?}", cfg.model).contains("AllMiniLML6V2"));
143    }
144
145    #[test]
146    fn config_builder_chain() {
147        let cfg =
148            HuggingFaceConfig::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true);
149        assert!(cfg.show_download_progress);
150    }
151
152    #[test]
153    fn vectorizer_is_send_sync() {
154        fn assert_send_sync<T: Send + Sync>() {}
155        assert_send_sync::<HuggingFaceTextVectorizer>();
156    }
157
158    #[test]
159    fn debug_impl_does_not_panic() {
160        // We can't easily construct a vectorizer in unit tests without downloading
161        // a model, but we can verify the Debug impl compiles and the config Debug works.
162        let cfg = HuggingFaceConfig::default();
163        let dbg = format!("{cfg:?}");
164        assert!(dbg.contains("HuggingFaceConfig"));
165    }
166}