blazen_embed_fastembed/
provider.rs1use std::fmt;
4use std::sync::{Arc, Mutex};
5
6use crate::FastEmbedOptions;
7
8#[derive(Debug)]
10pub enum FastEmbedError {
11 UnknownModel(String),
13 Init(String),
15 Embed(String),
17 MutexPoisoned(String),
19 TaskPanicked(String),
21}
22
23impl fmt::Display for FastEmbedError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 Self::UnknownModel(msg) => write!(f, "unknown fastembed model: {msg}"),
27 Self::Init(msg) => write!(f, "fastembed init failed: {msg}"),
28 Self::Embed(msg) => write!(f, "fastembed embed failed: {msg}"),
29 Self::MutexPoisoned(msg) => write!(f, "fastembed mutex poisoned: {msg}"),
30 Self::TaskPanicked(msg) => write!(f, "fastembed blocking task panicked: {msg}"),
31 }
32 }
33}
34
35impl std::error::Error for FastEmbedError {}
36
37#[derive(Debug, Clone)]
39pub struct FastEmbedResponse {
40 pub embeddings: Vec<Vec<f32>>,
42 pub model: String,
44}
45
46pub struct FastEmbedModel {
53 model: Arc<Mutex<fastembed::TextEmbedding>>,
57 model_id: String,
59 dims: usize,
61 batch_size: Option<usize>,
63}
64
65impl FastEmbedModel {
70 pub fn from_options(opts: FastEmbedOptions) -> Result<Self, FastEmbedError> {
82 let fe_model = if let Some(ref name) = opts.model_name {
84 name.parse::<fastembed::EmbeddingModel>()
85 .map_err(|e| FastEmbedError::UnknownModel(format!("\"{name}\": {e}")))?
86 } else {
87 fastembed::EmbeddingModel::default()
88 };
89
90 let model_info =
92 <fastembed::EmbeddingModel as fastembed::ModelTrait>::get_model_info(&fe_model)
93 .ok_or_else(|| {
94 FastEmbedError::Init(format!("no model info found for {fe_model:?}"))
95 })?;
96 let dims = model_info.dim;
97 let model_code = model_info.model_code.clone();
98
99 let mut init_opts = fastembed::TextInitOptions::new(fe_model);
101 if let Some(cache_dir) = opts.cache_dir {
102 init_opts = init_opts.with_cache_dir(cache_dir);
103 }
104 if let Some(show) = opts.show_download_progress {
105 init_opts = init_opts.with_show_download_progress(show);
106 }
107
108 let te = fastembed::TextEmbedding::try_new(init_opts)
109 .map_err(|e| FastEmbedError::Init(e.to_string()))?;
110
111 Ok(Self {
112 model: Arc::new(Mutex::new(te)),
113 model_id: model_code,
114 dims,
115 batch_size: opts.max_batch_size,
116 })
117 }
118
119 #[must_use]
121 pub fn model_id(&self) -> &str {
122 &self.model_id
123 }
124
125 #[must_use]
127 pub fn dimensions(&self) -> usize {
128 self.dims
129 }
130
131 pub async fn embed(&self, texts: &[String]) -> Result<FastEmbedResponse, FastEmbedError> {
142 if texts.is_empty() {
143 return Ok(FastEmbedResponse {
144 embeddings: vec![],
145 model: self.model_id.clone(),
146 });
147 }
148
149 let texts_owned: Vec<String> = texts.to_vec();
152 let batch_size = self.batch_size;
153 let model_id = self.model_id.clone();
154 let model_handle = Arc::clone(&self.model);
155
156 let embeddings = tokio::task::spawn_blocking(move || {
157 let mut model = model_handle
158 .lock()
159 .map_err(|e| FastEmbedError::MutexPoisoned(e.to_string()))?;
160 let result: Vec<Vec<f32>> = model
161 .embed(&texts_owned, batch_size)
162 .map_err(|e| FastEmbedError::Embed(e.to_string()))?;
163 Ok::<Vec<Vec<f32>>, FastEmbedError>(result)
164 })
165 .await
166 .map_err(|e| FastEmbedError::TaskPanicked(e.to_string()))??;
167
168 Ok(FastEmbedResponse {
169 embeddings,
170 model: model_id,
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 #[ignore = "requires model download from HuggingFace"]
181 fn from_options_default_loads_model() {
182 let model = FastEmbedModel::from_options(FastEmbedOptions::default())
183 .expect("should create model with default options");
184 assert!(model.dimensions() > 0);
185 assert!(!model.model_id().is_empty());
186 }
187
188 #[tokio::test]
189 #[ignore = "requires model download from HuggingFace"]
190 async fn embed_returns_correct_count() {
191 let model = FastEmbedModel::from_options(FastEmbedOptions::default())
192 .expect("should create model with default options");
193 let response = model
194 .embed(&["hello".into(), "world".into()])
195 .await
196 .expect("embedding should succeed");
197 assert_eq!(response.embeddings.len(), 2);
198 assert!(!response.embeddings[0].is_empty());
199 assert_eq!(response.embeddings[0].len(), model.dimensions());
200 }
201
202 #[tokio::test]
203 async fn embed_empty_input_returns_empty() {
204 let Ok(model) = FastEmbedModel::from_options(FastEmbedOptions::default()) else {
208 eprintln!("skipping embed_empty_input_returns_empty: model not available");
209 return;
210 };
211 let response = model.embed(&[]).await.expect("empty embed should succeed");
212 assert!(response.embeddings.is_empty());
213 }
214}