do_memory_core/embeddings/
local.rs1use super::config::LocalConfig;
7use super::provider::EmbeddingProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub struct LocalEmbeddingProvider {
41 config: LocalConfig,
43 model: Arc<RwLock<Option<Box<dyn LocalEmbeddingModel>>>>,
45 cache_dir: std::path::PathBuf,
47}
48
49impl LocalEmbeddingProvider {
50 pub async fn new(config: LocalConfig) -> Result<Self> {
58 let cache_dir = Self::get_cache_dir()?;
59
60 let provider = Self {
61 config,
62 model: Arc::new(RwLock::new(None)),
63 cache_dir,
64 };
65
66 provider.load_model().await?;
68
69 Ok(provider)
70 }
71
72 async fn load_model(&self) -> Result<()> {
74 tracing::info!("Loading local embedding model: {}", self.config.model_name);
75
76 #[cfg(feature = "local-embeddings")]
77 {
78 match self.try_load_real_model().await {
80 Ok(real_model) => {
81 let fallback_model = Box::new(RealEmbeddingModelWithFallback::new(
82 self.config.model_name.clone(),
83 self.config.embedding_dimension,
84 Some(real_model),
85 ));
86
87 let mut model_guard = self.model.write().await;
88 *model_guard = Some(fallback_model);
89
90 tracing::info!("Local embedding model loaded with real ONNX backend");
91 }
92 Err(e) => {
93 tracing::warn!("Failed to load real embedding model: {}", e);
94 tracing::warn!(
95 "Falling back to mock embeddings - semantic search will not work correctly"
96 );
97
98 let mock_fallback = Box::new(RealEmbeddingModelWithFallback::new(
99 self.config.model_name.clone(),
100 self.config.embedding_dimension,
101 None,
102 ));
103
104 let mut model_guard = self.model.write().await;
105 *model_guard = Some(mock_fallback);
106
107 tracing::info!("Local embedding model loaded with mock fallback");
108 }
109 }
110 }
111
112 #[cfg(not(feature = "local-embeddings"))]
113 {
114 tracing::warn!(
115 "PRODUCTION WARNING: Using mock embeddings - semantic search will not work correctly"
116 );
117 tracing::warn!(
118 "To enable real embeddings, add 'local-embeddings' feature and ensure ONNX models are available"
119 );
120
121 let mock_fallback = Box::new(super::mock_model::MockLocalModel::new(
122 self.config.model_name.clone(),
123 self.config.embedding_dimension,
124 ));
125
126 let mut model_guard = self.model.write().await;
127 *model_guard = Some(mock_fallback);
128
129 tracing::info!("Local embedding model loaded with mock implementation");
130 }
131
132 Ok(())
133 }
134
135 #[cfg(feature = "local-embeddings")]
137 async fn try_load_real_model(&self) -> Result<RealEmbeddingModel> {
138 RealEmbeddingModel::try_load_from_cache(&self.config, &self.cache_dir).await
139 }
140
141 fn get_cache_dir() -> Result<std::path::PathBuf> {
143 let home = std::env::var("HOME")
144 .or_else(|_| std::env::var("USERPROFILE"))
145 .context("Could not determine home directory")?;
146
147 let cache_dir = std::path::Path::new(&home)
148 .join(".cache")
149 .join("memory-core")
150 .join("embeddings");
151
152 std::fs::create_dir_all(&cache_dir).context("Failed to create cache directory")?;
153
154 Ok(cache_dir)
155 }
156
157 pub async fn is_loaded(&self) -> bool {
159 let model_guard = self.model.read().await;
160 model_guard.is_some()
161 }
162
163 #[must_use]
165 pub fn model_info(&self) -> serde_json::Value {
166 serde_json::json!({
167 "name": self.config.model_name,
168 "dimension": self.config.embedding_dimension,
169 "type": "local",
170 "cache_dir": self.cache_dir,
171 })
172 }
173}
174
175#[async_trait]
176impl EmbeddingProvider for LocalEmbeddingProvider {
177 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
178 let model_guard = self.model.read().await;
179 let model = model_guard.as_ref().context("Model not loaded")?;
180
181 model.embed(text).await
182 }
183
184 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
185 let model_guard = self.model.read().await;
186 let model = model_guard.as_ref().context("Model not loaded")?;
187
188 model.embed_batch(texts).await
189 }
190
191 fn embedding_dimension(&self) -> usize {
192 self.config.embedding_dimension
193 }
194
195 fn model_name(&self) -> &str {
196 &self.config.model_name
197 }
198
199 async fn is_available(&self) -> bool {
200 self.is_loaded().await
201 }
202
203 async fn warmup(&self) -> Result<()> {
204 let _embedding = self.embed_text("warmup test").await?;
206 Ok(())
207 }
208
209 fn metadata(&self) -> serde_json::Value {
210 serde_json::json!({
211 "model": self.model_name(),
212 "dimension": self.embedding_dimension(),
213 "type": "local",
214 "provider": "sentence-transformers",
215 "cache_dir": self.cache_dir
216 })
217 }
218}
219
220#[async_trait]
222pub trait LocalEmbeddingModel: Send + Sync {
223 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
225
226 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
228
229 #[allow(dead_code)]
231 fn name(&self) -> &str;
232
233 #[allow(dead_code)]
235 fn dimension(&self) -> usize;
236}
237
238#[cfg(feature = "local-embeddings")]
240#[allow(unused)]
241pub use crate::embeddings::real_model::RealEmbeddingModel;
242
243#[cfg(feature = "local-embeddings")]
245#[allow(unused)]
246pub use crate::embeddings::mock_model::{MockLocalModel, RealEmbeddingModelWithFallback};
247
248#[allow(unused)]
250pub use crate::embeddings::utils::{
251 LocalModelUseCase, get_recommended_model, list_available_models,
252};
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[tokio::test]
259 async fn test_local_provider_creation() {
260 let config = LocalConfig::new("test-model", 384);
261
262 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
263 assert!(provider.is_loaded().await);
264 assert_eq!(provider.embedding_dimension(), 384);
265 assert_eq!(provider.model_name(), "test-model");
266 }
267
268 #[tokio::test]
269 async fn test_embed_text() {
270 let config = LocalConfig::new("test-model", 384);
271
272 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
273
274 let embedding = provider.embed_text("Hello world").await.unwrap();
275 assert_eq!(embedding.len(), 384);
276
277 let embedding2 = provider.embed_text("Hello world").await.unwrap();
279 assert_eq!(embedding, embedding2);
280
281 let embedding3 = provider.embed_text("Different text").await.unwrap();
283 assert_ne!(embedding, embedding3);
284 }
285
286 #[tokio::test]
287 async fn test_embed_batch() {
288 let config = LocalConfig::new("test-model", 384);
289
290 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
291
292 let texts = vec![
293 "First text".to_string(),
294 "Second text".to_string(),
295 "Third text".to_string(),
296 ];
297
298 let embeddings = provider.embed_batch(&texts).await.unwrap();
299 assert_eq!(embeddings.len(), 3);
300
301 for embedding in embeddings {
302 assert_eq!(embedding.len(), 384);
303 }
304 }
305
306 #[tokio::test]
307 async fn test_similarity_calculation() {
308 let config = LocalConfig::new("test-model", 384);
309
310 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
311
312 let similarity = provider
314 .similarity("Hello world", "Hello world")
315 .await
316 .unwrap();
317 assert!((similarity - 1.0).abs() < 0.001);
318
319 let similarity = provider
321 .similarity("Hello world", "Goodbye universe")
322 .await
323 .unwrap();
324 assert!(similarity < 1.0);
325 }
326
327 #[tokio::test]
328 #[ignore = "Requires local-embeddings feature with ONNX models - blocked by ort crate Send trait issue"]
329 #[cfg(feature = "local-embeddings")]
330 async fn test_real_embedding_generation() {
331 let temp_dir = tempfile::TempDir::new().unwrap();
336 let cache_path = temp_dir.path().join("models");
337
338 if cache_path.exists() || std::env::var("CI").is_ok() {
341 tracing::info!("Skipping real embedding test - no model files available");
342 return;
343 }
344
345 let config = LocalConfig::new("sentence-transformers/all-MiniLM-L6-v2", 384);
346
347 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
348
349 let embedding1 = provider
351 .embed_text("machine learning algorithms")
352 .await
353 .unwrap();
354 let embedding2 = provider
355 .embed_text("artificial intelligence models")
356 .await
357 .unwrap();
358 let embedding3 = provider
359 .embed_text("cooking recipes for pasta")
360 .await
361 .unwrap();
362
363 assert_eq!(embedding1.len(), 384);
364 assert_eq!(embedding2.len(), 384);
365 assert_eq!(embedding3.len(), 384);
366
367 let similarity_ai_ml = provider
369 .similarity("machine learning", "artificial intelligence")
370 .await
371 .unwrap();
372 let similarity_cooking = provider
373 .similarity("machine learning", "cooking recipes")
374 .await
375 .unwrap();
376
377 assert!(
379 similarity_ai_ml > similarity_cooking,
380 "AI/ML similarity ({similarity_ai_ml}) should be higher than ML/cooking ({similarity_cooking})"
381 );
382
383 assert!(similarity_ai_ml > 0.0);
385 assert!(similarity_cooking > 0.0);
386 }
387
388 #[tokio::test]
389 async fn test_embedding_vector_properties() {
390 let config = LocalConfig::new("test-model", 384);
391
392 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
393
394 let embedding = provider.embed_text("test text").await.unwrap();
395
396 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
398 assert!((norm - 1.0).abs() < 0.001, "Embedding should be normalized");
399
400 for &value in &embedding {
402 assert!(
403 (-1.0..=1.0).contains(&value),
404 "Embedding values should be in [-1, 1]"
405 );
406 }
407 }
408
409 #[tokio::test]
410 async fn test_model_metadata() {
411 let config = LocalConfig::new("sentence-transformers/test-model", 768);
412
413 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
414
415 let metadata = provider.metadata();
416 assert_eq!(metadata["model"], "sentence-transformers/test-model");
417 assert_eq!(metadata["dimension"], 768);
418 assert_eq!(metadata["type"], "local");
419
420 let model_info = provider.model_info();
421 assert_eq!(model_info["name"], "sentence-transformers/test-model");
422 assert_eq!(model_info["dimension"], 768);
423 assert_eq!(model_info["type"], "local");
424 }
425
426 #[tokio::test]
427 async fn test_error_handling() {
428 let config = LocalConfig::new("nonexistent-model", 384);
429
430 let result = LocalEmbeddingProvider::new(config).await;
432
433 match result {
434 Ok(provider) => {
435 assert!(provider.is_loaded().await);
437 let embedding = provider.embed_text("test").await.unwrap();
438 assert_eq!(embedding.len(), 384);
439 }
440 Err(e) => {
441 assert!(e.to_string().contains("model") || e.to_string().contains("load"));
443 }
444 }
445 }
446
447 #[tokio::test]
448 async fn test_warmup_functionality() {
449 let config = LocalConfig::new("test-model", 384);
450
451 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
452
453 let result = provider.warmup().await;
455 assert!(result.is_ok(), "Warmup should succeed");
456 }
457
458 #[test]
459 fn test_utils_list_models() {
460 let models = list_available_models();
461 assert!(!models.is_empty());
462
463 for model in models {
464 assert!(!model.model_name.is_empty());
465 assert!(model.embedding_dimension > 0);
466 }
467 }
468
469 #[test]
470 fn test_utils_recommended_models() {
471 let fast_model = get_recommended_model(LocalModelUseCase::Fast);
472 assert_eq!(fast_model.embedding_dimension, 384);
473
474 let quality_model = get_recommended_model(LocalModelUseCase::Quality);
475 assert_eq!(quality_model.embedding_dimension, 768);
476
477 let multilingual_model = get_recommended_model(LocalModelUseCase::Multilingual);
478 assert_eq!(multilingual_model.embedding_dimension, 384);
479 }
480
481 #[tokio::test]
482 async fn test_production_warning_behavior() {
483 let config = LocalConfig::new("test-model", 384);
484
485 let provider = LocalEmbeddingProvider::new(config).await.unwrap();
487
488 let embedding1 = provider.embed_text("test").await.unwrap();
490 let embedding2 = provider.embed_text("test").await.unwrap();
491
492 assert_eq!(embedding1, embedding2);
494 assert_eq!(embedding1.len(), 384);
495 }
496}