project_rag/embedding/
fastembed_manager.rs

1use super::EmbeddingProvider;
2use anyhow::{Context, Result};
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::sync::RwLock;
5
6/// FastEmbed-based embedding provider using all-MiniLM-L6-v2
7///
8/// Uses RwLock for safe interior mutability since fastembed's embed() requires &mut self.
9pub struct FastEmbedManager {
10    model: RwLock<TextEmbedding>,
11    dimension: usize,
12}
13
14impl FastEmbedManager {
15    /// Create a new FastEmbedManager with the default model (all-MiniLM-L6-v2)
16    pub fn new() -> Result<Self> {
17        Self::with_model(EmbeddingModel::AllMiniLML6V2)
18    }
19
20    /// Create a new FastEmbedManager from a model name string
21    pub fn from_model_name(model_name: &str) -> Result<Self> {
22        let model = match model_name {
23            "all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
24            "all-MiniLM-L12-v2" => EmbeddingModel::AllMiniLML12V2,
25            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
26            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
27            _ => {
28                tracing::warn!(
29                    "Unknown model '{}', falling back to all-MiniLM-L6-v2",
30                    model_name
31                );
32                EmbeddingModel::AllMiniLML6V2
33            }
34        };
35        Self::with_model(model)
36    }
37
38    /// Create a new FastEmbedManager with a specific model
39    pub fn with_model(model: EmbeddingModel) -> Result<Self> {
40        tracing::info!("Initializing FastEmbed model: {:?}", model);
41
42        // all-MiniLM-L6-v2 has 384 dimensions
43        let dimension = match model {
44            EmbeddingModel::AllMiniLML6V2 => 384,
45            EmbeddingModel::AllMiniLML12V2 => 384,
46            EmbeddingModel::BGEBaseENV15 => 768,
47            EmbeddingModel::BGESmallENV15 => 384,
48            _ => 384, // Default to 384 for unknown models
49        };
50
51        let mut options = InitOptions::default();
52        options.model_name = model;
53        options.show_download_progress = true;
54
55        let embedding_model =
56            TextEmbedding::try_new(options).context("Failed to initialize FastEmbed model")?;
57
58        Ok(Self {
59            model: RwLock::new(embedding_model),
60            dimension,
61        })
62    }
63}
64
65impl EmbeddingProvider for FastEmbedManager {
66    fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
67        if texts.is_empty() {
68            return Ok(vec![]);
69        }
70
71        tracing::debug!("Generating embeddings for {} texts", texts.len());
72
73        // Acquire write lock safely. If the lock is poisoned (due to a panic while holding
74        // the lock), we recover by taking ownership of the inner value.
75        let mut model = self.model.write().unwrap_or_else(|poisoned| {
76            tracing::warn!("FastEmbed model lock was poisoned, recovering...");
77            poisoned.into_inner()
78        });
79
80        // Generate embeddings using the mutable reference
81        // Note: For timeout protection, wrap calls to this method in tokio::time::timeout
82        // at the async call site (e.g., in mcp_server/indexing.rs)
83        let embeddings = model
84            .embed(texts, None)
85            .context("Failed to generate embeddings")?;
86
87        Ok(embeddings)
88    }
89
90    fn dimension(&self) -> usize {
91        self.dimension
92    }
93
94    fn model_name(&self) -> &str {
95        "all-MiniLM-L6-v2"
96    }
97}
98
99impl Default for FastEmbedManager {
100    fn default() -> Self {
101        Self::new().expect("Failed to initialize default FastEmbed model")
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_embedding_generation() {
111        let manager = FastEmbedManager::new().unwrap();
112        let texts = vec![
113            "fn main() { println!(\"Hello, world!\"); }".to_string(),
114            "pub struct Vector { x: f32, y: f32 }".to_string(),
115        ];
116
117        let embeddings = manager.embed_batch(texts).unwrap();
118        assert_eq!(embeddings.len(), 2);
119        assert_eq!(embeddings[0].len(), 384);
120        assert_eq!(embeddings[1].len(), 384);
121    }
122
123    #[test]
124    fn test_empty_batch() {
125        let manager = FastEmbedManager::new().unwrap();
126        let embeddings = manager.embed_batch(vec![]).unwrap();
127        assert_eq!(embeddings.len(), 0);
128    }
129
130    #[test]
131    fn test_dimension() {
132        let manager = FastEmbedManager::new().unwrap();
133        assert_eq!(manager.dimension(), 384);
134    }
135
136    #[test]
137    fn test_model_name() {
138        let manager = FastEmbedManager::new().unwrap();
139        assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
140    }
141
142    #[test]
143    fn test_default() {
144        let manager = FastEmbedManager::default();
145        assert_eq!(manager.dimension(), 384);
146        assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
147    }
148
149    #[test]
150    fn test_single_text() {
151        let manager = FastEmbedManager::new().unwrap();
152        let texts = vec!["Hello world".to_string()];
153        let embeddings = manager.embed_batch(texts).unwrap();
154        assert_eq!(embeddings.len(), 1);
155        assert_eq!(embeddings[0].len(), 384);
156    }
157
158    #[test]
159    fn test_large_batch() {
160        let manager = FastEmbedManager::new().unwrap();
161        let texts: Vec<String> = (0..10).map(|i| format!("Test text {}", i)).collect();
162        let embeddings = manager.embed_batch(texts).unwrap();
163        assert_eq!(embeddings.len(), 10);
164        for embedding in embeddings {
165            assert_eq!(embedding.len(), 384);
166        }
167    }
168
169    #[test]
170    fn test_with_model_allminilm_l12() {
171        let manager = FastEmbedManager::with_model(EmbeddingModel::AllMiniLML12V2).unwrap();
172        assert_eq!(manager.dimension(), 384);
173    }
174
175    #[test]
176    fn test_with_model_bge_base() {
177        let manager = FastEmbedManager::with_model(EmbeddingModel::BGEBaseENV15).unwrap();
178        assert_eq!(manager.dimension(), 768);
179    }
180
181    #[test]
182    fn test_with_model_bge_small() {
183        let manager = FastEmbedManager::with_model(EmbeddingModel::BGESmallENV15).unwrap();
184        assert_eq!(manager.dimension(), 384);
185    }
186}