do_memory_core/embeddings/
local.rs1use super::config::LocalConfig;
7use super::provider::EmbeddingProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub struct LocalEmbeddingProvider {
41 config: LocalConfig,
43 model: Arc<RwLock<Option<Box<dyn LocalEmbeddingModel>>>>,
45 cache_dir: std::path::PathBuf,
47}
48
49impl LocalEmbeddingProvider {
50 pub async fn new(config: LocalConfig) -> Result<Self> {
58 let cache_dir = Self::get_cache_dir()?;
59
60 let provider = Self {
61 config,
62 model: Arc::new(RwLock::new(None)),
63 cache_dir,
64 };
65
66 provider.load_model().await?;
68
69 Ok(provider)
70 }
71
72 async fn load_model(&self) -> Result<()> {
74 tracing::info!("Loading local embedding model: {}", self.config.model_name);
75
76 #[cfg(feature = "local-embeddings")]
77 {
78 match self.try_load_real_model().await {
80 Ok(real_model) => {
81 let fallback_model = Box::new(RealEmbeddingModelWithFallback::new(
82 self.config.model_name.clone(),
83 self.config.embedding_dimension,
84 Some(real_model),
85 ));
86
87 let mut model_guard = self.model.write().await;
88 *model_guard = Some(fallback_model);
89
90 tracing::info!("Local embedding model loaded with real ONNX backend");
91 }
92 Err(e) => {
93 tracing::warn!("Failed to load real embedding model: {}", e);
94 tracing::warn!(
95 "Falling back to mock embeddings - semantic search will not work correctly"
96 );
97
98 let mock_fallback = Box::new(RealEmbeddingModelWithFallback::new(
99 self.config.model_name.clone(),
100 self.config.embedding_dimension,
101 None,
102 ));
103
104 let mut model_guard = self.model.write().await;
105 *model_guard = Some(mock_fallback);
106
107 tracing::info!("Local embedding model loaded with mock fallback");
108 }
109 }
110 }
111
112 #[cfg(not(feature = "local-embeddings"))]
113 {
114 tracing::warn!(
115 "PRODUCTION WARNING: Using mock embeddings - semantic search will not work correctly"
116 );
117 tracing::warn!(
118 "To enable real embeddings, add 'local-embeddings' feature and ensure ONNX models are available"
119 );
120
121 let mock_fallback = Box::new(super::mock_model::MockLocalModel::new(
122 self.config.model_name.clone(),
123 self.config.embedding_dimension,
124 ));
125
126 let mut model_guard = self.model.write().await;
127 *model_guard = Some(mock_fallback);
128
129 tracing::info!("Local embedding model loaded with mock implementation");
130 }
131
132 Ok(())
133 }
134
135 #[cfg(feature = "local-embeddings")]
137 async fn try_load_real_model(&self) -> Result<RealEmbeddingModel> {
138 RealEmbeddingModel::try_load_from_cache(&self.config, &self.cache_dir).await
139 }
140
141 fn get_cache_dir() -> Result<std::path::PathBuf> {
143 let home = std::env::var("HOME")
144 .or_else(|_| std::env::var("USERPROFILE"))
145 .context("Could not determine home directory")?;
146
147 let cache_dir = std::path::Path::new(&home)
148 .join(".cache")
149 .join("memory-core")
150 .join("embeddings");
151
152 std::fs::create_dir_all(&cache_dir).context("Failed to create cache directory")?;
153
154 Ok(cache_dir)
155 }
156
157 pub async fn is_loaded(&self) -> bool {
159 let model_guard = self.model.read().await;
160 model_guard.is_some()
161 }
162
163 #[must_use]
165 pub fn model_info(&self) -> serde_json::Value {
166 serde_json::json!({
167 "name": self.config.model_name,
168 "dimension": self.config.embedding_dimension,
169 "type": "local",
170 "cache_dir": self.cache_dir,
171 })
172 }
173}
174
175#[async_trait]
176impl EmbeddingProvider for LocalEmbeddingProvider {
177 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
178 let model_guard = self.model.read().await;
179 let model = model_guard.as_ref().context("Model not loaded")?;
180
181 model.embed(text).await
182 }
183
184 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
185 let model_guard = self.model.read().await;
186 let model = model_guard.as_ref().context("Model not loaded")?;
187
188 model.embed_batch(texts).await
189 }
190
191 fn embedding_dimension(&self) -> usize {
192 self.config.embedding_dimension
193 }
194
195 fn model_name(&self) -> &str {
196 &self.config.model_name
197 }
198
199 async fn is_available(&self) -> bool {
200 self.is_loaded().await
201 }
202
203 async fn warmup(&self) -> Result<()> {
204 let _embedding = self.embed_text("warmup test").await?;
206 Ok(())
207 }
208
209 fn metadata(&self) -> serde_json::Value {
210 serde_json::json!({
211 "model": self.model_name(),
212 "dimension": self.embedding_dimension(),
213 "type": "local",
214 "provider": "sentence-transformers",
215 "cache_dir": self.cache_dir
216 })
217 }
218}
219
220#[async_trait]
222#[allow(dead_code)] pub trait LocalEmbeddingModel: Send + Sync {
224 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
226
227 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
229
230 fn name(&self) -> &str;
232
233 fn dimension(&self) -> usize;
235}
236
237#[cfg(feature = "local-embeddings")]
239#[allow(unused)]
240pub use crate::embeddings::real_model::RealEmbeddingModel;
241
242#[cfg(feature = "local-embeddings")]
244#[allow(unused)]
245pub use crate::embeddings::mock_model::{MockLocalModel, RealEmbeddingModelWithFallback};
246
247#[allow(unused)]
249pub use crate::embeddings::utils::{
250 LocalModelUseCase, get_recommended_model, list_available_models,
251};
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[tokio::test]
258 async fn test_local_provider_creation() {
259 let config = LocalConfig::new("test-model", 384);
260
261 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
262 assert!(provider.is_loaded().await);
263 assert_eq!(provider.embedding_dimension(), 384);
264 assert_eq!(provider.model_name(), "test-model");
265 }
266
267 #[tokio::test]
268 async fn test_embed_text() {
269 let config = LocalConfig::new("test-model", 384);
270
271 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
272
273 let embedding = provider.embed_text("Hello world").await.unwrap();
274 assert_eq!(embedding.len(), 384);
275
276 let embedding2 = provider.embed_text("Hello world").await.unwrap();
278 assert_eq!(embedding, embedding2);
279
280 let embedding3 = provider.embed_text("Different text").await.unwrap();
282 assert_ne!(embedding, embedding3);
283 }
284
285 #[tokio::test]
286 async fn test_embed_batch() {
287 let config = LocalConfig::new("test-model", 384);
288
289 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
290
291 let texts = vec![
292 "First text".to_string(),
293 "Second text".to_string(),
294 "Third text".to_string(),
295 ];
296
297 let embeddings = provider.embed_batch(&texts).await.unwrap();
298 assert_eq!(embeddings.len(), 3);
299
300 for embedding in embeddings {
301 assert_eq!(embedding.len(), 384);
302 }
303 }
304
305 #[tokio::test]
306 async fn test_similarity_calculation() {
307 let config = LocalConfig::new("test-model", 384);
308
309 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
310
311 let similarity = provider
313 .similarity("Hello world", "Hello world")
314 .await
315 .unwrap();
316 assert!((similarity - 1.0).abs() < 0.001);
317
318 let similarity = provider
320 .similarity("Hello world", "Goodbye universe")
321 .await
322 .unwrap();
323 assert!(similarity < 1.0);
324 }
325
326 #[tokio::test]
327 #[ignore = "Requires local-embeddings feature with ONNX models - blocked by ort crate Send trait issue"]
328 #[cfg(feature = "local-embeddings")]
329 async fn test_real_embedding_generation() {
330 let temp_dir = tempfile::TempDir::new().unwrap();
335 let cache_path = temp_dir.path().join("models");
336
337 if cache_path.exists() || std::env::var("CI").is_ok() {
340 tracing::info!("Skipping real embedding test - no model files available");
341 return;
342 }
343
344 let config = LocalConfig::new("sentence-transformers/all-MiniLM-L6-v2", 384);
345
346 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
347
348 let embedding1 = provider
350 .embed_text("machine learning algorithms")
351 .await
352 .unwrap();
353 let embedding2 = provider
354 .embed_text("artificial intelligence models")
355 .await
356 .unwrap();
357 let embedding3 = provider
358 .embed_text("cooking recipes for pasta")
359 .await
360 .unwrap();
361
362 assert_eq!(embedding1.len(), 384);
363 assert_eq!(embedding2.len(), 384);
364 assert_eq!(embedding3.len(), 384);
365
366 let similarity_ai_ml = provider
368 .similarity("machine learning", "artificial intelligence")
369 .await
370 .unwrap();
371 let similarity_cooking = provider
372 .similarity("machine learning", "cooking recipes")
373 .await
374 .unwrap();
375
376 assert!(
378 similarity_ai_ml > similarity_cooking,
379 "AI/ML similarity ({similarity_ai_ml}) should be higher than ML/cooking ({similarity_cooking})"
380 );
381
382 assert!(similarity_ai_ml > 0.0);
384 assert!(similarity_cooking > 0.0);
385 }
386
387 #[tokio::test]
388 async fn test_embedding_vector_properties() {
389 let config = LocalConfig::new("test-model", 384);
390
391 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
392
393 let embedding = provider.embed_text("test text").await.unwrap();
394
395 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
397 assert!((norm - 1.0).abs() < 0.001, "Embedding should be normalized");
398
399 for &value in &embedding {
401 assert!(
402 (-1.0..=1.0).contains(&value),
403 "Embedding values should be in [-1, 1]"
404 );
405 }
406 }
407
408 #[tokio::test]
409 async fn test_model_metadata() {
410 let config = LocalConfig::new("sentence-transformers/test-model", 768);
411
412 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
413
414 let metadata = provider.metadata();
415 assert_eq!(metadata["model"], "sentence-transformers/test-model");
416 assert_eq!(metadata["dimension"], 768);
417 assert_eq!(metadata["type"], "local");
418
419 let model_info = provider.model_info();
420 assert_eq!(model_info["name"], "sentence-transformers/test-model");
421 assert_eq!(model_info["dimension"], 768);
422 assert_eq!(model_info["type"], "local");
423 }
424
425 #[tokio::test]
426 async fn test_error_handling() {
427 let config = LocalConfig::new("nonexistent-model", 384);
428
429 let result = LocalEmbeddingProvider::new(config).await;
431
432 match result {
433 Ok(provider) => {
434 assert!(provider.is_loaded().await);
436 let embedding = provider.embed_text("test").await.unwrap();
437 assert_eq!(embedding.len(), 384);
438 }
439 Err(e) => {
440 assert!(e.to_string().contains("model") || e.to_string().contains("load"));
442 }
443 }
444 }
445
446 #[tokio::test]
447 async fn test_warmup_functionality() {
448 let config = LocalConfig::new("test-model", 384);
449
450 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
451
452 let result = provider.warmup().await;
454 assert!(result.is_ok(), "Warmup should succeed");
455 }
456
457 #[test]
458 fn test_utils_list_models() {
459 let models = list_available_models();
460 assert!(!models.is_empty());
461
462 for model in models {
463 assert!(!model.model_name.is_empty());
464 assert!(model.embedding_dimension > 0);
465 }
466 }
467
468 #[test]
469 fn test_utils_recommended_models() {
470 let fast_model = get_recommended_model(LocalModelUseCase::Fast);
471 assert_eq!(fast_model.embedding_dimension, 384);
472
473 let quality_model = get_recommended_model(LocalModelUseCase::Quality);
474 assert_eq!(quality_model.embedding_dimension, 768);
475
476 let multilingual_model = get_recommended_model(LocalModelUseCase::Multilingual);
477 assert_eq!(multilingual_model.embedding_dimension, 384);
478 }
479
480 #[tokio::test]
481 async fn test_production_warning_behavior() {
482 let config = LocalConfig::new("test-model", 384);
483
484 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
486
487 let embedding1 = provider.embed_text("test").await.unwrap();
489 let embedding2 = provider.embed_text("test").await.unwrap();
490
491 assert_eq!(embedding1, embedding2);
493 assert_eq!(embedding1.len(), 384);
494 }
495}