Skip to main content

rlm_rs/embedding/
fastembed_impl.rs

1//! `FastEmbed`-based semantic embedder.
2//!
3//! Provides real semantic embeddings using the BGE-M3 model via fastembed-rs.
4//! Only available when the `fastembed-embeddings` feature is enabled.
5
6use crate::Result;
7use crate::embedding::{DEFAULT_DIMENSIONS, Embedder};
8use crate::error::StorageError;
9use std::panic::{AssertUnwindSafe, catch_unwind};
10use std::sync::OnceLock;
11
12/// Thread-safe singleton for the embedding model.
13/// Uses `OnceLock` for lazy initialization on first use.
14static EMBEDDING_MODEL: OnceLock<std::sync::Mutex<fastembed::TextEmbedding>> = OnceLock::new();
15
16/// `FastEmbed` embedder using BGE-M3.
17///
18/// Uses the fastembed-rs library for real semantic embeddings.
19/// The model is lazily loaded on first embed call to preserve cold start time.
20///
21/// BGE-M3 provides:
22/// - 1024 dimensions (vs 384 for `MiniLM`)
23/// - 8192 token context (vs ~512 for `MiniLM`)
24/// - Better multilingual support
25///
26/// # Examples
27///
28/// ```ignore
29/// use rlm_rs::embedding::FastEmbedEmbedder;
30///
31/// let embedder = FastEmbedEmbedder::new()?;
32/// let embedding = embedder.embed("Hello, world!")?;
33/// assert_eq!(embedding.len(), 1024);
34/// ```
35pub struct FastEmbedEmbedder {
36    /// Model name for debugging.
37    model_name: &'static str,
38}
39
40impl FastEmbedEmbedder {
41    /// Creates a new `FastEmbed` embedder.
42    ///
43    /// Note: Model is lazily loaded on first `embed()` call.
44    ///
45    /// # Errors
46    ///
47    /// Returns an error if model initialization fails.
48    #[allow(clippy::missing_const_for_fn)]
49    pub fn new() -> Result<Self> {
50        Ok(Self {
51            model_name: "BGE-M3",
52        })
53    }
54
55    /// Gets or initializes the embedding model (thread-safe).
56    ///
57    /// The model is loaded lazily on first use to preserve cold start time.
58    /// Subsequent calls return the cached instance.
59    fn get_model() -> Result<&'static std::sync::Mutex<fastembed::TextEmbedding>> {
60        // Check if already initialized
61        if let Some(model) = EMBEDDING_MODEL.get() {
62            return Ok(model);
63        }
64
65        // Initialize the model
66        let options = fastembed::InitOptions::new(fastembed::EmbeddingModel::BGEM3)
67            .with_show_download_progress(false);
68
69        let model = fastembed::TextEmbedding::try_new(options)
70            .map_err(|e| StorageError::Embedding(format!("Failed to load embedding model: {e}")))?;
71
72        // Store the model, ignoring if another thread beat us to it
73        let _ = EMBEDDING_MODEL.set(std::sync::Mutex::new(model));
74
75        // Return the (possibly other thread's) model
76        EMBEDDING_MODEL.get().ok_or_else(|| {
77            StorageError::Embedding("Model initialization race condition".to_string()).into()
78        })
79    }
80
81    /// Returns the model name.
82    #[must_use]
83    pub const fn model_name(&self) -> &'static str {
84        self.model_name
85    }
86}
87
88impl Embedder for FastEmbedEmbedder {
89    fn dimensions(&self) -> usize {
90        DEFAULT_DIMENSIONS
91    }
92
93    fn model_name(&self) -> &'static str {
94        self.model_name
95    }
96
97    fn embed(&self, text: &str) -> Result<Vec<f32>> {
98        if text.is_empty() {
99            return Err(crate::Error::Chunking(
100                crate::error::ChunkingError::InvalidConfig {
101                    reason: "Cannot embed empty text".to_string(),
102                },
103            ));
104        }
105
106        let model = Self::get_model()?;
107        let mut model = model
108            .lock()
109            .map_err(|e| StorageError::Embedding(format!("Failed to lock embedding model: {e}")))?;
110
111        let texts = [text];
112
113        // Wrap ONNX runtime call in catch_unwind for graceful degradation.
114        // ONNX runtime can panic on malformed inputs or internal errors.
115        let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
116
117        let embeddings = result
118            .map_err(|panic_info| {
119                let panic_msg = panic_info
120                    .downcast_ref::<&str>()
121                    .map(|s| (*s).to_string())
122                    .or_else(|| panic_info.downcast_ref::<String>().cloned())
123                    .unwrap_or_else(|| "unknown panic".to_string());
124                StorageError::Embedding(format!("ONNX runtime panic: {panic_msg}"))
125            })?
126            .map_err(|e| StorageError::Embedding(format!("Embedding failed: {e}")))?;
127
128        embeddings.into_iter().next().ok_or_else(|| {
129            StorageError::Embedding("No embedding returned from model".to_string()).into()
130        })
131    }
132
133    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
134        if texts.is_empty() {
135            return Ok(Vec::new());
136        }
137
138        if texts.iter().any(|t| t.is_empty()) {
139            return Err(crate::Error::Chunking(
140                crate::error::ChunkingError::InvalidConfig {
141                    reason: "Cannot embed empty text".to_string(),
142                },
143            ));
144        }
145
146        let model = Self::get_model()?;
147        let mut model = model
148            .lock()
149            .map_err(|e| StorageError::Embedding(format!("Failed to lock embedding model: {e}")))?;
150
151        // Wrap ONNX runtime call in catch_unwind for graceful degradation.
152        let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
153
154        result
155            .map_err(|panic_info| {
156                let panic_msg = panic_info
157                    .downcast_ref::<&str>()
158                    .map(|s| (*s).to_string())
159                    .or_else(|| panic_info.downcast_ref::<String>().cloned())
160                    .unwrap_or_else(|| "unknown panic".to_string());
161                crate::Error::Storage(StorageError::Embedding(format!(
162                    "ONNX runtime panic: {panic_msg}"
163                )))
164            })?
165            .map_err(|e| {
166                crate::Error::Storage(StorageError::Embedding(format!(
167                    "Batch embedding failed: {e}"
168                )))
169            })
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_embedder_creation() {
179        let embedder = FastEmbedEmbedder::new();
180        assert!(embedder.is_ok());
181        assert_eq!(embedder.unwrap().dimensions(), DEFAULT_DIMENSIONS);
182    }
183
184    #[test]
185    fn test_model_name() {
186        let embedder = FastEmbedEmbedder::new().unwrap();
187        assert_eq!(embedder.model_name(), "BGE-M3");
188    }
189
190    // Integration tests that require model download are marked #[ignore]
191    // Run with: cargo test --features fastembed-embeddings -- --ignored
192
193    #[test]
194    #[ignore = "requires fastembed model download"]
195    fn test_embed_success() {
196        let embedder = FastEmbedEmbedder::new().unwrap();
197        let result = embedder.embed("Hello, world!");
198        assert!(result.is_ok());
199        assert_eq!(result.unwrap().len(), DEFAULT_DIMENSIONS);
200    }
201
202    #[test]
203    #[ignore = "requires fastembed model download"]
204    fn test_embed_batch_success() {
205        let embedder = FastEmbedEmbedder::new().unwrap();
206        let texts = vec!["Hello", "World"];
207        let result = embedder.embed_batch(&texts);
208        assert!(result.is_ok());
209        let embeddings = result.unwrap();
210        assert_eq!(embeddings.len(), 2);
211        assert_eq!(embeddings[0].len(), DEFAULT_DIMENSIONS);
212    }
213
214    #[test]
215    fn test_embed_empty_fails() {
216        let embedder = FastEmbedEmbedder::new().unwrap();
217        let result = embedder.embed("");
218        assert!(result.is_err());
219    }
220
221    #[test]
222    fn test_embed_batch_empty_list() {
223        let embedder = FastEmbedEmbedder::new().unwrap();
224        let result = embedder.embed_batch(&[]);
225        assert!(result.is_ok());
226        assert!(result.unwrap().is_empty());
227    }
228
229    #[test]
230    fn test_embed_batch_with_empty_fails() {
231        let embedder = FastEmbedEmbedder::new().unwrap();
232        let texts = vec!["Valid", "", "Also valid"];
233        let result = embedder.embed_batch(&texts);
234        assert!(result.is_err());
235    }
236}