Skip to main content

do_memory_core/embeddings/
local.rs

1//! Local embedding provider using sentence transformers
2//!
3//! This provider runs embedding models locally using candle-transformers,
4//! providing offline capability with no external API dependencies.
5
6use super::config::LocalConfig;
7use super::provider::EmbeddingProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// Local embedding provider using sentence transformers
14///
15/// Runs embedding models locally using candle-transformers or similar.
16/// Provides offline embedding generation with no external dependencies.
17///
18/// # Models Supported
19/// - sentence-transformers/all-MiniLM-L6-v2 (384 dims, default)
20/// - sentence-transformers/all-mpnet-base-v2 (768 dims, higher quality)
21/// - sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 (384 dims, multilingual)
22///
23/// # Example
24/// ```no_run
25/// use do_memory_core::embeddings::{EmbeddingProvider, LocalEmbeddingProvider, LocalConfig};
26///
27/// #[tokio::main]
28/// async fn main() -> anyhow::Result<()> {
29///     let config = LocalConfig::new(
30///         "sentence-transformers/all-MiniLM-L6-v2",
31///         384
32///     );
33///     let provider = LocalEmbeddingProvider::new(config).await?;
34///
35///     let embedding = provider.embed_text("Hello world").await?;
36///     println!("Generated embedding with {} dimensions", embedding.len());
37///     Ok(())
38/// }
39/// ```
40pub struct LocalEmbeddingProvider {
41    /// Model configuration
42    config: LocalConfig,
43    /// Embedding model (placeholder for actual model implementation)
44    model: Arc<RwLock<Option<Box<dyn LocalEmbeddingModel>>>>,
45    /// Model cache directory
46    cache_dir: std::path::PathBuf,
47}
48
49impl LocalEmbeddingProvider {
50    /// Create a new local embedding provider
51    ///
52    /// # Arguments
53    /// * `config` - Model configuration specifying which model to use
54    ///
55    /// # Returns
56    /// Configured local embedding provider
57    pub async fn new(config: LocalConfig) -> Result<Self> {
58        let cache_dir = Self::get_cache_dir()?;
59
60        let provider = Self {
61            config,
62            model: Arc::new(RwLock::new(None)),
63            cache_dir,
64        };
65
66        // Initialize/load the model
67        provider.load_model().await?;
68
69        Ok(provider)
70    }
71
72    /// Load the embedding model
73    async fn load_model(&self) -> Result<()> {
74        tracing::info!("Loading local embedding model: {}", self.config.model_name);
75
76        #[cfg(feature = "local-embeddings")]
77        {
78            // Try to load real ONNX model, fallback to mock if fails
79            match self.try_load_real_model().await {
80                Ok(real_model) => {
81                    let fallback_model = Box::new(RealEmbeddingModelWithFallback::new(
82                        self.config.model_name.clone(),
83                        self.config.embedding_dimension,
84                        Some(real_model),
85                    ));
86
87                    let mut model_guard = self.model.write().await;
88                    *model_guard = Some(fallback_model);
89
90                    tracing::info!("Local embedding model loaded with real ONNX backend");
91                }
92                Err(e) => {
93                    tracing::warn!("Failed to load real embedding model: {}", e);
94                    tracing::warn!(
95                        "Falling back to mock embeddings - semantic search will not work correctly"
96                    );
97
98                    let mock_fallback = Box::new(RealEmbeddingModelWithFallback::new(
99                        self.config.model_name.clone(),
100                        self.config.embedding_dimension,
101                        None,
102                    ));
103
104                    let mut model_guard = self.model.write().await;
105                    *model_guard = Some(mock_fallback);
106
107                    tracing::info!("Local embedding model loaded with mock fallback");
108                }
109            }
110        }
111
112        #[cfg(not(feature = "local-embeddings"))]
113        {
114            tracing::warn!(
115                "PRODUCTION WARNING: Using mock embeddings - semantic search will not work correctly"
116            );
117            tracing::warn!(
118                "To enable real embeddings, add 'local-embeddings' feature and ensure ONNX models are available"
119            );
120
121            let mock_fallback = Box::new(super::mock_model::MockLocalModel::new(
122                self.config.model_name.clone(),
123                self.config.embedding_dimension,
124            ));
125
126            let mut model_guard = self.model.write().await;
127            *model_guard = Some(mock_fallback);
128
129            tracing::info!("Local embedding model loaded with mock implementation");
130        }
131
132        Ok(())
133    }
134
135    /// Try to load real ONNX model
136    #[cfg(feature = "local-embeddings")]
137    async fn try_load_real_model(&self) -> Result<RealEmbeddingModel> {
138        RealEmbeddingModel::try_load_from_cache(&self.config, &self.cache_dir).await
139    }
140
141    /// Get the cache directory for models
142    fn get_cache_dir() -> Result<std::path::PathBuf> {
143        let home = std::env::var("HOME")
144            .or_else(|_| std::env::var("USERPROFILE"))
145            .context("Could not determine home directory")?;
146
147        let cache_dir = std::path::Path::new(&home)
148            .join(".cache")
149            .join("memory-core")
150            .join("embeddings");
151
152        std::fs::create_dir_all(&cache_dir).context("Failed to create cache directory")?;
153
154        Ok(cache_dir)
155    }
156
157    /// Check if model is loaded
158    pub async fn is_loaded(&self) -> bool {
159        let model_guard = self.model.read().await;
160        model_guard.is_some()
161    }
162
163    /// Get model information
164    #[must_use]
165    pub fn model_info(&self) -> serde_json::Value {
166        serde_json::json!({
167            "name": self.config.model_name,
168            "dimension": self.config.embedding_dimension,
169            "type": "local",
170            "cache_dir": self.cache_dir,
171        })
172    }
173}
174
175#[async_trait]
176impl EmbeddingProvider for LocalEmbeddingProvider {
177    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
178        let model_guard = self.model.read().await;
179        let model = model_guard.as_ref().context("Model not loaded")?;
180
181        model.embed(text).await
182    }
183
184    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
185        let model_guard = self.model.read().await;
186        let model = model_guard.as_ref().context("Model not loaded")?;
187
188        model.embed_batch(texts).await
189    }
190
191    fn embedding_dimension(&self) -> usize {
192        self.config.embedding_dimension
193    }
194
195    fn model_name(&self) -> &str {
196        &self.config.model_name
197    }
198
199    async fn is_available(&self) -> bool {
200        self.is_loaded().await
201    }
202
203    async fn warmup(&self) -> Result<()> {
204        // Test embedding generation
205        let _embedding = self.embed_text("warmup test").await?;
206        Ok(())
207    }
208
209    fn metadata(&self) -> serde_json::Value {
210        serde_json::json!({
211            "model": self.model_name(),
212            "dimension": self.embedding_dimension(),
213            "type": "local",
214            "provider": "sentence-transformers",
215            "cache_dir": self.cache_dir
216        })
217    }
218}
219
220/// Trait for local embedding models
221#[async_trait]
222#[allow(dead_code)] // Trait methods used by implementations, not called directly in this crate
223pub trait LocalEmbeddingModel: Send + Sync {
224    /// Generate embedding for single text
225    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
226
227    /// Generate embeddings for batch of texts
228    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
229
230    /// Get model name
231    fn name(&self) -> &str;
232
233    /// Get embedding dimension
234    fn dimension(&self) -> usize;
235}
236
237/// Import real model implementation
238#[cfg(feature = "local-embeddings")]
239#[allow(unused)]
240pub use crate::embeddings::real_model::RealEmbeddingModel;
241
242/// Import mock model implementations
243#[cfg(feature = "local-embeddings")]
244#[allow(unused)]
245pub use crate::embeddings::mock_model::{MockLocalModel, RealEmbeddingModelWithFallback};
246
247/// Re-export utilities from the utils module
248#[allow(unused)]
249pub use crate::embeddings::utils::{
250    LocalModelUseCase, get_recommended_model, list_available_models,
251};
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[tokio::test]
258    async fn test_local_provider_creation() {
259        let config = LocalConfig::new("test-model", 384);
260
261        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
262        assert!(provider.is_loaded().await);
263        assert_eq!(provider.embedding_dimension(), 384);
264        assert_eq!(provider.model_name(), "test-model");
265    }
266
267    #[tokio::test]
268    async fn test_embed_text() {
269        let config = LocalConfig::new("test-model", 384);
270
271        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
272
273        let embedding = provider.embed_text("Hello world").await.unwrap();
274        assert_eq!(embedding.len(), 384);
275
276        // Test deterministic behavior
277        let embedding2 = provider.embed_text("Hello world").await.unwrap();
278        assert_eq!(embedding, embedding2);
279
280        // Different text should produce different embedding
281        let embedding3 = provider.embed_text("Different text").await.unwrap();
282        assert_ne!(embedding, embedding3);
283    }
284
285    #[tokio::test]
286    async fn test_embed_batch() {
287        let config = LocalConfig::new("test-model", 384);
288
289        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
290
291        let texts = vec![
292            "First text".to_string(),
293            "Second text".to_string(),
294            "Third text".to_string(),
295        ];
296
297        let embeddings = provider.embed_batch(&texts).await.unwrap();
298        assert_eq!(embeddings.len(), 3);
299
300        for embedding in embeddings {
301            assert_eq!(embedding.len(), 384);
302        }
303    }
304
305    #[tokio::test]
306    async fn test_similarity_calculation() {
307        let config = LocalConfig::new("test-model", 384);
308
309        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
310
311        // Identical texts should have high similarity
312        let similarity = provider
313            .similarity("Hello world", "Hello world")
314            .await
315            .unwrap();
316        assert!((similarity - 1.0).abs() < 0.001);
317
318        // Different texts should have lower similarity
319        let similarity = provider
320            .similarity("Hello world", "Goodbye universe")
321            .await
322            .unwrap();
323        assert!(similarity < 1.0);
324    }
325
326    #[tokio::test]
327    #[ignore = "Requires local-embeddings feature with ONNX models - blocked by ort crate Send trait issue"]
328    #[cfg(feature = "local-embeddings")]
329    async fn test_real_embedding_generation() {
330        // This test only runs when local-embeddings feature is enabled
331        // and real ONNX models are available
332
333        // Create a temporary directory for model cache
334        let temp_dir = tempfile::TempDir::new().unwrap();
335        let cache_path = temp_dir.path().join("models");
336
337        // Try to load a real model if available
338        // In CI, this might not have actual model files
339        if cache_path.exists() || std::env::var("CI").is_ok() {
340            tracing::info!("Skipping real embedding test - no model files available");
341            return;
342        }
343
344        let config = LocalConfig::new("sentence-transformers/all-MiniLM-L6-v2", 384);
345
346        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
347
348        // Generate embeddings for semantically similar texts
349        let embedding1 = provider
350            .embed_text("machine learning algorithms")
351            .await
352            .unwrap();
353        let embedding2 = provider
354            .embed_text("artificial intelligence models")
355            .await
356            .unwrap();
357        let embedding3 = provider
358            .embed_text("cooking recipes for pasta")
359            .await
360            .unwrap();
361
362        assert_eq!(embedding1.len(), 384);
363        assert_eq!(embedding2.len(), 384);
364        assert_eq!(embedding3.len(), 384);
365
366        // Calculate similarities
367        let similarity_ai_ml = provider
368            .similarity("machine learning", "artificial intelligence")
369            .await
370            .unwrap();
371        let similarity_cooking = provider
372            .similarity("machine learning", "cooking recipes")
373            .await
374            .unwrap();
375
376        // Semantically similar texts should have higher similarity
377        assert!(
378            similarity_ai_ml > similarity_cooking,
379            "AI/ML similarity ({similarity_ai_ml}) should be higher than ML/cooking ({similarity_cooking})"
380        );
381
382        // Both should be positive (cosine similarity range)
383        assert!(similarity_ai_ml > 0.0);
384        assert!(similarity_cooking > 0.0);
385    }
386
387    #[tokio::test]
388    async fn test_embedding_vector_properties() {
389        let config = LocalConfig::new("test-model", 384);
390
391        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
392
393        let embedding = provider.embed_text("test text").await.unwrap();
394
395        // Check that embedding is properly normalized (unit vector)
396        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
397        assert!((norm - 1.0).abs() < 0.001, "Embedding should be normalized");
398
399        // Check that values are in reasonable range
400        for &value in &embedding {
401            assert!(
402                (-1.0..=1.0).contains(&value),
403                "Embedding values should be in [-1, 1]"
404            );
405        }
406    }
407
408    #[tokio::test]
409    async fn test_model_metadata() {
410        let config = LocalConfig::new("sentence-transformers/test-model", 768);
411
412        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
413
414        let metadata = provider.metadata();
415        assert_eq!(metadata["model"], "sentence-transformers/test-model");
416        assert_eq!(metadata["dimension"], 768);
417        assert_eq!(metadata["type"], "local");
418
419        let model_info = provider.model_info();
420        assert_eq!(model_info["name"], "sentence-transformers/test-model");
421        assert_eq!(model_info["dimension"], 768);
422        assert_eq!(model_info["type"], "local");
423    }
424
425    #[tokio::test]
426    async fn test_error_handling() {
427        let config = LocalConfig::new("nonexistent-model", 384);
428
429        // Test with non-existent model - should fall back to mock or fail gracefully
430        let result = LocalEmbeddingProvider::new(config).await;
431
432        match result {
433            Ok(provider) => {
434                // If successful, it should be a mock implementation
435                assert!(provider.is_loaded().await);
436                let embedding = provider.embed_text("test").await.unwrap();
437                assert_eq!(embedding.len(), 384);
438            }
439            Err(e) => {
440                // Should provide meaningful error message
441                assert!(e.to_string().contains("model") || e.to_string().contains("load"));
442            }
443        }
444    }
445
446    #[tokio::test]
447    async fn test_warmup_functionality() {
448        let config = LocalConfig::new("test-model", 384);
449
450        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
451
452        // Warmup should succeed
453        let result = provider.warmup().await;
454        assert!(result.is_ok(), "Warmup should succeed");
455    }
456
457    #[test]
458    fn test_utils_list_models() {
459        let models = list_available_models();
460        assert!(!models.is_empty());
461
462        for model in models {
463            assert!(!model.model_name.is_empty());
464            assert!(model.embedding_dimension > 0);
465        }
466    }
467
468    #[test]
469    fn test_utils_recommended_models() {
470        let fast_model = get_recommended_model(LocalModelUseCase::Fast);
471        assert_eq!(fast_model.embedding_dimension, 384);
472
473        let quality_model = get_recommended_model(LocalModelUseCase::Quality);
474        assert_eq!(quality_model.embedding_dimension, 768);
475
476        let multilingual_model = get_recommended_model(LocalModelUseCase::Multilingual);
477        assert_eq!(multilingual_model.embedding_dimension, 384);
478    }
479
480    #[tokio::test]
481    async fn test_production_warning_behavior() {
482        let config = LocalConfig::new("test-model", 384);
483
484        // This should emit a warning if not in test mode
485        let provider = LocalEmbeddingProvider::new(config).await.unwrap();
486
487        // Verify the provider works but may be using mock embeddings
488        let embedding1 = provider.embed_text("test").await.unwrap();
489        let embedding2 = provider.embed_text("test").await.unwrap();
490
491        // In test mode, embeddings should be deterministic (same)
492        assert_eq!(embedding1, embedding2);
493        assert_eq!(embedding1.len(), 384);
494    }
495}