Skip to main content

do_memory_core/embeddings/
local.rs

1//! Local embedding provider using sentence transformers
2//!
3//! This provider runs embedding models locally using candle-transformers,
4//! providing offline capability with no external API dependencies.
5
6use super::config::LocalConfig;
7use super::provider::EmbeddingProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// Local embedding provider using sentence transformers
14///
15/// Runs embedding models locally using candle-transformers or similar.
16/// Provides offline embedding generation with no external dependencies.
17///
18/// # Models Supported
19/// - sentence-transformers/all-MiniLM-L6-v2 (384 dims, default)
20/// - sentence-transformers/all-mpnet-base-v2 (768 dims, higher quality)
21/// - sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (384 dims, multilingual)
22///
23/// # Example
24/// ```no_run
25/// use do_memory_core::embeddings::{EmbeddingProvider, LocalEmbeddingProvider, LocalConfig};
26///
27/// #[tokio::main]
28/// async fn main() -> anyhow::Result<()> {
29///     let config = LocalConfig::new(
30///         "sentence-transformers/all-MiniLM-L6-v2",
31///         384
32///     );
33///     let provider = LocalEmbeddingProvider::new(config).await?;
34///
35///     let embedding = provider.embed_text("Hello world").await?;
36///     println!("Generated embedding with {} dimensions", embedding.len());
37///     Ok(())
38/// }
39/// ```
40pub struct LocalEmbeddingProvider {
41    /// Model configuration
42    config: LocalConfig,
43    /// Embedding model (placeholder for actual model implementation)
44    model: Arc<RwLock<Option<Box<dyn LocalEmbeddingModel>>>>,
45    /// Model cache directory
46    cache_dir: std::path::PathBuf,
47}
48
49impl LocalEmbeddingProvider {
50    /// Create a new local embedding provider
51    ///
52    /// # Arguments
53    /// * `config` - Model configuration specifying which model to use
54    ///
55    /// # Returns
56    /// Configured local embedding provider
57    pub async fn new(config: LocalConfig) -> Result<Self> {
58        let cache_dir = Self::get_cache_dir()?;
59
60        let provider = Self {
61            config,
62            model: Arc::new(RwLock::new(None)),
63            cache_dir,
64        };
65
66        // Initialize/load the model
67        provider.load_model().await?;
68
69        Ok(provider)
70    }
71
72    /// Load the embedding model
73    async fn load_model(&self) -> Result<()> {
74        tracing::info!("Loading local embedding model: {}", self.config.model_name);
75
76        #[cfg(feature = "local-embeddings")]
77        {
78            // Try to load real ONNX model, fallback to mock if fails
79            match self.try_load_real_model().await {
80                Ok(real_model) => {
81                    let fallback_model = Box::new(RealEmbeddingModelWithFallback::new(
82                        self.config.model_name.clone(),
83                        self.config.embedding_dimension,
84                        Some(real_model),
85                    ));
86
87                    let mut model_guard = self.model.write().await;
88                    *model_guard = Some(fallback_model);
89
90                    tracing::info!("Local embedding model loaded with real ONNX backend");
91                }
92                Err(e) => {
93                    tracing::warn!("Failed to load real embedding model: {}", e);
94                    tracing::warn!(
95                        "Falling back to mock embeddings - semantic search will not work correctly"
96                    );
97
98                    let mock_fallback = Box::new(RealEmbeddingModelWithFallback::new(
99                        self.config.model_name.clone(),
100                        self.config.embedding_dimension,
101                        None,
102                    ));
103
104                    let mut model_guard = self.model.write().await;
105                    *model_guard = Some(mock_fallback);
106
107                    tracing::info!("Local embedding model loaded with mock fallback");
108                }
109            }
110        }
111
112        #[cfg(not(feature = "local-embeddings"))]
113        {
114            tracing::warn!(
115                "PRODUCTION WARNING: Using mock embeddings - semantic search will not work correctly"
116            );
117            tracing::warn!(
118                "To enable real embeddings, add 'local-embeddings' feature and ensure ONNX models are available"
119            );
120
121            let mock_fallback = Box::new(super::mock_model::MockLocalModel::new(
122                self.config.model_name.clone(),
123                self.config.embedding_dimension,
124            ));
125
126            let mut model_guard = self.model.write().await;
127            *model_guard = Some(mock_fallback);
128
129            tracing::info!("Local embedding model loaded with mock implementation");
130        }
131
132        Ok(())
133    }
134
135    /// Try to load real ONNX model
136    #[cfg(feature = "local-embeddings")]
137    async fn try_load_real_model(&self) -> Result<RealEmbeddingModel> {
138        RealEmbeddingModel::try_load_from_cache(&self.config, &self.cache_dir).await
139    }
140
141    /// Get the cache directory for models
142    fn get_cache_dir() -> Result<std::path::PathBuf> {
143        let home = std::env::var("HOME")
144            .or_else(|_| std::env::var("USERPROFILE"))
145            .context("Could not determine home directory")?;
146
147        let cache_dir = std::path::Path::new(&home)
148            .join(".cache")
149            .join("memory-core")
150            .join("embeddings");
151
152        std::fs::create_dir_all(&cache_dir).context("Failed to create cache directory")?;
153
154        Ok(cache_dir)
155    }
156
157    /// Check if model is loaded
158    pub async fn is_loaded(&self) -> bool {
159        let model_guard = self.model.read().await;
160        model_guard.is_some()
161    }
162
163    /// Get model information
164    #[must_use]
165    pub fn model_info(&self) -> serde_json::Value {
166        serde_json::json!({
167            "name": self.config.model_name,
168            "dimension": self.config.embedding_dimension,
169            "type": "local",
170            "cache_dir": self.cache_dir,
171        })
172    }
173}
174
175#[async_trait]
176impl EmbeddingProvider for LocalEmbeddingProvider {
177    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
178        let model_guard = self.model.read().await;
179        let model = model_guard.as_ref().context("Model not loaded")?;
180
181        model.embed(text).await
182    }
183
184    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
185        let model_guard = self.model.read().await;
186        let model = model_guard.as_ref().context("Model not loaded")?;
187
188        model.embed_batch(texts).await
189    }
190
191    fn embedding_dimension(&self) -> usize {
192        self.config.embedding_dimension
193    }
194
195    fn model_name(&self) -> &str {
196        &self.config.model_name
197    }
198
199    async fn is_available(&self) -> bool {
200        self.is_loaded().await
201    }
202
203    async fn warmup(&self) -> Result<()> {
204        // Test embedding generation
205        let _embedding = self.embed_text("warmup test").await?;
206        Ok(())
207    }
208
209    fn metadata(&self) -> serde_json::Value {
210        serde_json::json!({
211            "model": self.model_name(),
212            "dimension": self.embedding_dimension(),
213            "type": "local",
214            "provider": "sentence-transformers",
215            "cache_dir": self.cache_dir
216        })
217    }
218}
219
220/// Trait for local embedding models
221#[async_trait]
222pub trait LocalEmbeddingModel: Send + Sync {
223    /// Generate embedding for single text
224    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
225
226    /// Generate embeddings for batch of texts
227    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
228
229    /// Get model name
230    #[allow(dead_code)]
231    fn name(&self) -> &str;
232
233    /// Get embedding dimension
234    #[allow(dead_code)]
235    fn dimension(&self) -> usize;
236}
237
238/// Import real model implementation
239#[cfg(feature = "local-embeddings")]
240#[allow(unused)]
241pub use crate::embeddings::real_model::RealEmbeddingModel;
242
243/// Import mock model implementations
244#[cfg(feature = "local-embeddings")]
245#[allow(unused)]
246pub use crate::embeddings::mock_model::{MockLocalModel, RealEmbeddingModelWithFallback};
247
248/// Re-export utilities from the utils module
249#[allow(unused)]
250pub use crate::embeddings::utils::{
251    LocalModelUseCase, get_recommended_model, list_available_models,
252};
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[tokio::test]
259    async fn test_local_provider_creation() {
260        let config = LocalConfig::new("test-model", 384);
261
262        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
263        assert!(provider.is_loaded().await);
264        assert_eq!(provider.embedding_dimension(), 384);
265        assert_eq!(provider.model_name(), "test-model");
266    }
267
268    #[tokio::test]
269    async fn test_embed_text() {
270        let config = LocalConfig::new("test-model", 384);
271
272        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
273
274        let embedding = provider.embed_text("Hello world").await.unwrap();
275        assert_eq!(embedding.len(), 384);
276
277        // Test deterministic behavior
278        let embedding2 = provider.embed_text("Hello world").await.unwrap();
279        assert_eq!(embedding, embedding2);
280
281        // Different text should produce different embedding
282        let embedding3 = provider.embed_text("Different text").await.unwrap();
283        assert_ne!(embedding, embedding3);
284    }
285
286    #[tokio::test]
287    async fn test_embed_batch() {
288        let config = LocalConfig::new("test-model", 384);
289
290        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
291
292        let texts = vec![
293            "First text".to_string(),
294            "Second text".to_string(),
295            "Third text".to_string(),
296        ];
297
298        let embeddings = provider.embed_batch(&texts).await.unwrap();
299        assert_eq!(embeddings.len(), 3);
300
301        for embedding in embeddings {
302            assert_eq!(embedding.len(), 384);
303        }
304    }
305
306    #[tokio::test]
307    async fn test_similarity_calculation() {
308        let config = LocalConfig::new("test-model", 384);
309
310        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
311
312        // Identical texts should have high similarity
313        let similarity = provider
314            .similarity("Hello world", "Hello world")
315            .await
316            .unwrap();
317        assert!((similarity - 1.0).abs() < 0.001);
318
319        // Different texts should have lower similarity
320        let similarity = provider
321            .similarity("Hello world", "Goodbye universe")
322            .await
323            .unwrap();
324        assert!(similarity < 1.0);
325    }
326
327    #[tokio::test]
328    #[ignore = "Requires local-embeddings feature with ONNX models - blocked by ort crate Send trait issue"]
329    #[cfg(feature = "local-embeddings")]
330    async fn test_real_embedding_generation() {
331        // This test only runs when local-embeddings feature is enabled
332        // and real ONNX models are available
333
334        // Create a temporary directory for model cache
335        let temp_dir = tempfile::TempDir::new().unwrap();
336        let cache_path = temp_dir.path().join("models");
337
338        // Try to load a real model if available
339        // In CI, this might not have actual model files
340        if cache_path.exists() || std::env::var("CI").is_ok() {
341            tracing::info!("Skipping real embedding test - no model files available");
342            return;
343        }
344
345        let config = LocalConfig::new("sentence-transformers/all-MiniLM-L6-v2", 384);
346
347        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
348
349        // Generate embeddings for semantically similar texts
350        let embedding1 = provider
351            .embed_text("machine learning algorithms")
352            .await
353            .unwrap();
354        let embedding2 = provider
355            .embed_text("artificial intelligence models")
356            .await
357            .unwrap();
358        let embedding3 = provider
359            .embed_text("cooking recipes for pasta")
360            .await
361            .unwrap();
362
363        assert_eq!(embedding1.len(), 384);
364        assert_eq!(embedding2.len(), 384);
365        assert_eq!(embedding3.len(), 384);
366
367        // Calculate similarities
368        let similarity_ai_ml = provider
369            .similarity("machine learning", "artificial intelligence")
370            .await
371            .unwrap();
372        let similarity_cooking = provider
373            .similarity("machine learning", "cooking recipes")
374            .await
375            .unwrap();
376
377        // Semantically similar texts should have higher similarity
378        assert!(
379            similarity_ai_ml > similarity_cooking,
380            "AI/ML similarity ({similarity_ai_ml}) should be higher than ML/cooking ({similarity_cooking})"
381        );
382
383        // Both should be positive (cosine similarity range)
384        assert!(similarity_ai_ml > 0.0);
385        assert!(similarity_cooking > 0.0);
386    }
387
388    #[tokio::test]
389    async fn test_embedding_vector_properties() {
390        let config = LocalConfig::new("test-model", 384);
391
392        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
393
394        let embedding = provider.embed_text("test text").await.unwrap();
395
396        // Check that embedding is properly normalized (unit vector)
397        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
398        assert!((norm - 1.0).abs() < 0.001, "Embedding should be normalized");
399
400        // Check that values are in reasonable range
401        for &value in &embedding {
402            assert!(
403                (-1.0..=1.0).contains(&value),
404                "Embedding values should be in [-1, 1]"
405            );
406        }
407    }
408
409    #[tokio::test]
410    async fn test_model_metadata() {
411        let config = LocalConfig::new("sentence-transformers/test-model", 768);
412
413        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
414
415        let metadata = provider.metadata();
416        assert_eq!(metadata["model"], "sentence-transformers/test-model");
417        assert_eq!(metadata["dimension"], 768);
418        assert_eq!(metadata["type"], "local");
419
420        let model_info = provider.model_info();
421        assert_eq!(model_info["name"], "sentence-transformers/test-model");
422        assert_eq!(model_info["dimension"], 768);
423        assert_eq!(model_info["type"], "local");
424    }
425
426    #[tokio::test]
427    async fn test_error_handling() {
428        let config = LocalConfig::new("nonexistent-model", 384);
429
430        // Test with non-existent model - should fall back to mock or fail gracefully
431        let result = LocalEmbeddingProvider::new(config).await;
432
433        match result {
434            Ok(provider) => {
435                // If successful, it should be a mock implementation
436                assert!(provider.is_loaded().await);
437                let embedding = provider.embed_text("test").await.unwrap();
438                assert_eq!(embedding.len(), 384);
439            }
440            Err(e) => {
441                // Should provide meaningful error message
442                assert!(e.to_string().contains("model") || e.to_string().contains("load"));
443            }
444        }
445    }
446
447    #[tokio::test]
448    async fn test_warmup_functionality() {
449        let config = LocalConfig::new("test-model", 384);
450
451        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
452
453        // Warmup should succeed
454        let result = provider.warmup().await;
455        assert!(result.is_ok(), "Warmup should succeed");
456    }
457
458    #[test]
459    fn test_utils_list_models() {
460        let models = list_available_models();
461        assert!(!models.is_empty());
462
463        for model in models {
464            assert!(!model.model_name.is_empty());
465            assert!(model.embedding_dimension > 0);
466        }
467    }
468
469    #[test]
470    fn test_utils_recommended_models() {
471        let fast_model = get_recommended_model(LocalModelUseCase::Fast);
472        assert_eq!(fast_model.embedding_dimension, 384);
473
474        let quality_model = get_recommended_model(LocalModelUseCase::Quality);
475        assert_eq!(quality_model.embedding_dimension, 768);
476
477        let multilingual_model = get_recommended_model(LocalModelUseCase::Multilingual);
478        assert_eq!(multilingual_model.embedding_dimension, 384);
479    }
480
481    #[tokio::test]
482    async fn test_production_warning_behavior() {
483        let config = LocalConfig::new("test-model", 384);
484
485        // This should emit a warning if not in test mode
486        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
487
488        // Verify the provider works but may be using mock embeddings
489        let embedding1 = provider.embed_text("test").await.unwrap();
490        let embedding2 = provider.embed_text("test").await.unwrap();
491
492        // In test mode, embeddings should be deterministic (same)
493        assert_eq!(embedding1, embedding2);
494        assert_eq!(embedding1.len(), 384);
495    }
496}