rlm_rs/embedding/
fastembed_impl.rs1use crate::Result;
7use crate::embedding::{DEFAULT_DIMENSIONS, Embedder};
8use crate::error::StorageError;
9use std::panic::{AssertUnwindSafe, catch_unwind};
10use std::sync::OnceLock;
11
12static EMBEDDING_MODEL: OnceLock<std::sync::Mutex<fastembed::TextEmbedding>> = OnceLock::new();
15
16pub struct FastEmbedEmbedder {
36 model_name: &'static str,
38}
39
40impl FastEmbedEmbedder {
41 #[allow(clippy::missing_const_for_fn)]
49 pub fn new() -> Result<Self> {
50 Ok(Self {
51 model_name: "BGE-M3",
52 })
53 }
54
55 fn get_model() -> Result<&'static std::sync::Mutex<fastembed::TextEmbedding>> {
60 if let Some(model) = EMBEDDING_MODEL.get() {
62 return Ok(model);
63 }
64
65 let options = fastembed::InitOptions::new(fastembed::EmbeddingModel::BGEM3)
67 .with_show_download_progress(false);
68
69 let model = fastembed::TextEmbedding::try_new(options)
70 .map_err(|e| StorageError::Embedding(format!("Failed to load embedding model: {e}")))?;
71
72 let _ = EMBEDDING_MODEL.set(std::sync::Mutex::new(model));
74
75 EMBEDDING_MODEL.get().ok_or_else(|| {
77 StorageError::Embedding("Model initialization race condition".to_string()).into()
78 })
79 }
80
81 #[must_use]
83 pub const fn model_name(&self) -> &'static str {
84 self.model_name
85 }
86}
87
88impl Embedder for FastEmbedEmbedder {
89 fn dimensions(&self) -> usize {
90 DEFAULT_DIMENSIONS
91 }
92
93 fn model_name(&self) -> &'static str {
94 self.model_name
95 }
96
97 fn embed(&self, text: &str) -> Result<Vec<f32>> {
98 if text.is_empty() {
99 return Err(crate::Error::Chunking(
100 crate::error::ChunkingError::InvalidConfig {
101 reason: "Cannot embed empty text".to_string(),
102 },
103 ));
104 }
105
106 let model = Self::get_model()?;
107 let mut model = model
108 .lock()
109 .map_err(|e| StorageError::Embedding(format!("Failed to lock embedding model: {e}")))?;
110
111 let texts = [text];
112
113 let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
116
117 let embeddings = result
118 .map_err(|panic_info| {
119 let panic_msg = panic_info
120 .downcast_ref::<&str>()
121 .map(|s| (*s).to_string())
122 .or_else(|| panic_info.downcast_ref::<String>().cloned())
123 .unwrap_or_else(|| "unknown panic".to_string());
124 StorageError::Embedding(format!("ONNX runtime panic: {panic_msg}"))
125 })?
126 .map_err(|e| StorageError::Embedding(format!("Embedding failed: {e}")))?;
127
128 embeddings.into_iter().next().ok_or_else(|| {
129 StorageError::Embedding("No embedding returned from model".to_string()).into()
130 })
131 }
132
133 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
134 if texts.is_empty() {
135 return Ok(Vec::new());
136 }
137
138 if texts.iter().any(|t| t.is_empty()) {
139 return Err(crate::Error::Chunking(
140 crate::error::ChunkingError::InvalidConfig {
141 reason: "Cannot embed empty text".to_string(),
142 },
143 ));
144 }
145
146 let model = Self::get_model()?;
147 let mut model = model
148 .lock()
149 .map_err(|e| StorageError::Embedding(format!("Failed to lock embedding model: {e}")))?;
150
151 let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
153
154 result
155 .map_err(|panic_info| {
156 let panic_msg = panic_info
157 .downcast_ref::<&str>()
158 .map(|s| (*s).to_string())
159 .or_else(|| panic_info.downcast_ref::<String>().cloned())
160 .unwrap_or_else(|| "unknown panic".to_string());
161 crate::Error::Storage(StorageError::Embedding(format!(
162 "ONNX runtime panic: {panic_msg}"
163 )))
164 })?
165 .map_err(|e| {
166 crate::Error::Storage(StorageError::Embedding(format!(
167 "Batch embedding failed: {e}"
168 )))
169 })
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_embedder_creation() {
179 let embedder = FastEmbedEmbedder::new();
180 assert!(embedder.is_ok());
181 assert_eq!(embedder.unwrap().dimensions(), DEFAULT_DIMENSIONS);
182 }
183
184 #[test]
185 fn test_model_name() {
186 let embedder = FastEmbedEmbedder::new().unwrap();
187 assert_eq!(embedder.model_name(), "BGE-M3");
188 }
189
190 #[test]
194 #[ignore = "requires fastembed model download"]
195 fn test_embed_success() {
196 let embedder = FastEmbedEmbedder::new().unwrap();
197 let result = embedder.embed("Hello, world!");
198 assert!(result.is_ok());
199 assert_eq!(result.unwrap().len(), DEFAULT_DIMENSIONS);
200 }
201
202 #[test]
203 #[ignore = "requires fastembed model download"]
204 fn test_embed_batch_success() {
205 let embedder = FastEmbedEmbedder::new().unwrap();
206 let texts = vec!["Hello", "World"];
207 let result = embedder.embed_batch(&texts);
208 assert!(result.is_ok());
209 let embeddings = result.unwrap();
210 assert_eq!(embeddings.len(), 2);
211 assert_eq!(embeddings[0].len(), DEFAULT_DIMENSIONS);
212 }
213
214 #[test]
215 fn test_embed_empty_fails() {
216 let embedder = FastEmbedEmbedder::new().unwrap();
217 let result = embedder.embed("");
218 assert!(result.is_err());
219 }
220
221 #[test]
222 fn test_embed_batch_empty_list() {
223 let embedder = FastEmbedEmbedder::new().unwrap();
224 let result = embedder.embed_batch(&[]);
225 assert!(result.is_ok());
226 assert!(result.unwrap().is_empty());
227 }
228
229 #[test]
230 fn test_embed_batch_with_empty_fails() {
231 let embedder = FastEmbedEmbedder::new().unwrap();
232 let texts = vec!["Valid", "", "Also valid"];
233 let result = embedder.embed_batch(&texts);
234 assert!(result.is_err());
235 }
236}