redis_vl/vectorizers/
hf_local.rs1use std::sync::Mutex;
18
19use fastembed::{EmbeddingModel, TextEmbedding};
20
21use super::Vectorizer;
22use crate::error::{Error, Result};
23
24#[derive(Debug, Clone)]
26pub struct HuggingFaceConfig {
27 pub model: EmbeddingModel,
31 pub show_download_progress: bool,
35}
36
37impl Default for HuggingFaceConfig {
38 fn default() -> Self {
39 Self {
40 model: EmbeddingModel::AllMiniLML6V2,
41 show_download_progress: false,
42 }
43 }
44}
45
46impl HuggingFaceConfig {
47 #[must_use]
49 pub fn new(model: EmbeddingModel) -> Self {
50 Self {
51 model,
52 show_download_progress: false,
53 }
54 }
55
56 #[must_use]
58 pub fn with_show_download_progress(mut self, show: bool) -> Self {
59 self.show_download_progress = show;
60 self
61 }
62}
63
64pub struct HuggingFaceTextVectorizer {
75 model: Mutex<TextEmbedding>,
76}
77
78impl HuggingFaceTextVectorizer {
79 pub fn new(config: HuggingFaceConfig) -> Result<Self> {
87 let init_options = fastembed::InitOptions::new(config.model)
88 .with_show_download_progress(config.show_download_progress);
89
90 let model = TextEmbedding::try_new(init_options)
91 .map_err(|e| Error::InvalidInput(format!("failed to load HF model: {e}")))?;
92
93 Ok(Self {
94 model: Mutex::new(model),
95 })
96 }
97}
98
99impl Vectorizer for HuggingFaceTextVectorizer {
100 fn embed(&self, text: &str) -> Result<Vec<f32>> {
101 let mut model = self
102 .model
103 .lock()
104 .map_err(|e| Error::InvalidInput(format!("lock poisoned: {e}")))?;
105 let mut embeddings = model
106 .embed(vec![text], None)
107 .map_err(|e| Error::InvalidInput(format!("embedding failed: {e}")))?;
108 Ok(embeddings.pop().unwrap_or_default())
109 }
110
111 fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
112 let mut model = self
113 .model
114 .lock()
115 .map_err(|e| Error::InvalidInput(format!("lock poisoned: {e}")))?;
116 model
117 .embed(texts.to_vec(), None)
118 .map_err(|e| Error::InvalidInput(format!("embedding failed: {e}")))
119 }
120}
121
122unsafe impl Send for HuggingFaceTextVectorizer {}
124unsafe impl Sync for HuggingFaceTextVectorizer {}
125
126impl std::fmt::Debug for HuggingFaceTextVectorizer {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("HuggingFaceTextVectorizer")
129 .finish_non_exhaustive()
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn default_config_uses_all_mini_lm() {
139 let cfg = HuggingFaceConfig::default();
140 assert!(!cfg.show_download_progress);
141 assert!(format!("{:?}", cfg.model).contains("AllMiniLML6V2"));
143 }
144
145 #[test]
146 fn config_builder_chain() {
147 let cfg =
148 HuggingFaceConfig::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true);
149 assert!(cfg.show_download_progress);
150 }
151
152 #[test]
153 fn vectorizer_is_send_sync() {
154 fn assert_send_sync<T: Send + Sync>() {}
155 assert_send_sync::<HuggingFaceTextVectorizer>();
156 }
157
158 #[test]
159 fn debug_impl_does_not_panic() {
160 let cfg = HuggingFaceConfig::default();
163 let dbg = format!("{cfg:?}");
164 assert!(dbg.contains("HuggingFaceConfig"));
165 }
166}