use super::config::LocalConfig;
use super::provider::EmbeddingProvider;
use anyhow::{Context, Result};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct LocalEmbeddingProvider {
config: LocalConfig,
model: Arc<RwLock<Option<Box<dyn LocalEmbeddingModel>>>>,
cache_dir: std::path::PathBuf,
}
impl LocalEmbeddingProvider {
pub async fn new(config: LocalConfig) -> Result<Self> {
let cache_dir = Self::get_cache_dir()?;
let provider = Self {
config,
model: Arc::new(RwLock::new(None)),
cache_dir,
};
provider.load_model().await?;
Ok(provider)
}
async fn load_model(&self) -> Result<()> {
tracing::info!("Loading local embedding model: {}", self.config.model_name);
#[cfg(feature = "local-embeddings")]
{
match self.try_load_real_model().await {
Ok(real_model) => {
let fallback_model = Box::new(RealEmbeddingModelWithFallback::new(
self.config.model_name.clone(),
self.config.embedding_dimension,
Some(real_model),
));
let mut model_guard = self.model.write().await;
*model_guard = Some(fallback_model);
tracing::info!("Local embedding model loaded with real ONNX backend");
}
Err(e) => {
tracing::warn!("Failed to load real embedding model: {}", e);
tracing::warn!(
"Falling back to mock embeddings - semantic search will not work correctly"
);
let mock_fallback = Box::new(RealEmbeddingModelWithFallback::new(
self.config.model_name.clone(),
self.config.embedding_dimension,
None,
));
let mut model_guard = self.model.write().await;
*model_guard = Some(mock_fallback);
tracing::info!("Local embedding model loaded with mock fallback");
}
}
}
#[cfg(not(feature = "local-embeddings"))]
{
tracing::warn!(
"PRODUCTION WARNING: Using mock embeddings - semantic search will not work correctly"
);
tracing::warn!(
"To enable real embeddings, add 'local-embeddings' feature and ensure ONNX models are available"
);
let mock_fallback = Box::new(super::mock_model::MockLocalModel::new(
self.config.model_name.clone(),
self.config.embedding_dimension,
));
let mut model_guard = self.model.write().await;
*model_guard = Some(mock_fallback);
tracing::info!("Local embedding model loaded with mock implementation");
}
Ok(())
}
#[cfg(feature = "local-embeddings")]
async fn try_load_real_model(&self) -> Result<RealEmbeddingModel> {
RealEmbeddingModel::try_load_from_cache(&self.config, &self.cache_dir).await
}
fn get_cache_dir() -> Result<std::path::PathBuf> {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.context("Could not determine home directory")?;
let cache_dir = std::path::Path::new(&home)
.join(".cache")
.join("memory-core")
.join("embeddings");
std::fs::create_dir_all(&cache_dir).context("Failed to create cache directory")?;
Ok(cache_dir)
}
pub async fn is_loaded(&self) -> bool {
let model_guard = self.model.read().await;
model_guard.is_some()
}
#[must_use]
pub fn model_info(&self) -> serde_json::Value {
serde_json::json!({
"name": self.config.model_name,
"dimension": self.config.embedding_dimension,
"type": "local",
"cache_dir": self.cache_dir,
})
}
}
#[async_trait]
impl EmbeddingProvider for LocalEmbeddingProvider {
async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
let model_guard = self.model.read().await;
let model = model_guard.as_ref().context("Model not loaded")?;
model.embed(text).await
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let model_guard = self.model.read().await;
let model = model_guard.as_ref().context("Model not loaded")?;
model.embed_batch(texts).await
}
fn embedding_dimension(&self) -> usize {
self.config.embedding_dimension
}
fn model_name(&self) -> &str {
&self.config.model_name
}
async fn is_available(&self) -> bool {
self.is_loaded().await
}
async fn warmup(&self) -> Result<()> {
let _embedding = self.embed_text("warmup test").await?;
Ok(())
}
fn metadata(&self) -> serde_json::Value {
serde_json::json!({
"model": self.model_name(),
"dimension": self.embedding_dimension(),
"type": "local",
"provider": "sentence-transformers",
"cache_dir": self.cache_dir
})
}
}
#[async_trait]
pub trait LocalEmbeddingModel: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
#[allow(dead_code)]
fn name(&self) -> &str;
#[allow(dead_code)]
fn dimension(&self) -> usize;
}
#[cfg(feature = "local-embeddings")]
#[allow(unused)]
pub use crate::embeddings::real_model::RealEmbeddingModel;
#[cfg(feature = "local-embeddings")]
#[allow(unused)]
pub use crate::embeddings::mock_model::{MockLocalModel, RealEmbeddingModelWithFallback};
#[allow(unused)]
pub use crate::embeddings::utils::{
LocalModelUseCase, get_recommended_model, list_available_models,
};
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_local_provider_creation() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
assert!(provider.is_loaded().await);
assert_eq!(provider.embedding_dimension(), 384);
assert_eq!(provider.model_name(), "test-model");
}
#[tokio::test]
async fn test_embed_text() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let embedding = provider.embed_text("Hello world").await.unwrap();
assert_eq!(embedding.len(), 384);
let embedding2 = provider.embed_text("Hello world").await.unwrap();
assert_eq!(embedding, embedding2);
let embedding3 = provider.embed_text("Different text").await.unwrap();
assert_ne!(embedding, embedding3);
}
#[tokio::test]
async fn test_embed_batch() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let texts = vec![
"First text".to_string(),
"Second text".to_string(),
"Third text".to_string(),
];
let embeddings = provider.embed_batch(&texts).await.unwrap();
assert_eq!(embeddings.len(), 3);
for embedding in embeddings {
assert_eq!(embedding.len(), 384);
}
}
#[tokio::test]
async fn test_similarity_calculation() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let similarity = provider
.similarity("Hello world", "Hello world")
.await
.unwrap();
assert!((similarity - 1.0).abs() < 0.001);
let similarity = provider
.similarity("Hello world", "Goodbye universe")
.await
.unwrap();
assert!(similarity < 1.0);
}
#[tokio::test]
#[ignore = "Requires local-embeddings feature with ONNX models - blocked by ort crate Send trait issue"]
#[cfg(feature = "local-embeddings")]
async fn test_real_embedding_generation() {
let temp_dir = tempfile::TempDir::new().unwrap();
let cache_path = temp_dir.path().join("models");
if cache_path.exists() || std::env::var("CI").is_ok() {
tracing::info!("Skipping real embedding test - no model files available");
return;
}
let config = LocalConfig::new("sentence-transformers/all-MiniLM-L6-v2", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let embedding1 = provider
.embed_text("machine learning algorithms")
.await
.unwrap();
let embedding2 = provider
.embed_text("artificial intelligence models")
.await
.unwrap();
let embedding3 = provider
.embed_text("cooking recipes for pasta")
.await
.unwrap();
assert_eq!(embedding1.len(), 384);
assert_eq!(embedding2.len(), 384);
assert_eq!(embedding3.len(), 384);
let similarity_ai_ml = provider
.similarity("machine learning", "artificial intelligence")
.await
.unwrap();
let similarity_cooking = provider
.similarity("machine learning", "cooking recipes")
.await
.unwrap();
assert!(
similarity_ai_ml > similarity_cooking,
"AI/ML similarity ({similarity_ai_ml}) should be higher than ML/cooking ({similarity_cooking})"
);
assert!(similarity_ai_ml > 0.0);
assert!(similarity_cooking > 0.0);
}
#[tokio::test]
async fn test_embedding_vector_properties() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let embedding = provider.embed_text("test text").await.unwrap();
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001, "Embedding should be normalized");
for &value in &embedding {
assert!(
(-1.0..=1.0).contains(&value),
"Embedding values should be in [-1, 1]"
);
}
}
#[tokio::test]
async fn test_model_metadata() {
let config = LocalConfig::new("sentence-transformers/test-model", 768);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let metadata = provider.metadata();
assert_eq!(metadata["model"], "sentence-transformers/test-model");
assert_eq!(metadata["dimension"], 768);
assert_eq!(metadata["type"], "local");
let model_info = provider.model_info();
assert_eq!(model_info["name"], "sentence-transformers/test-model");
assert_eq!(model_info["dimension"], 768);
assert_eq!(model_info["type"], "local");
}
#[tokio::test]
async fn test_error_handling() {
let config = LocalConfig::new("nonexistent-model", 384);
let result = LocalEmbeddingProvider::new(config).await;
match result {
Ok(provider) => {
assert!(provider.is_loaded().await);
let embedding = provider.embed_text("test").await.unwrap();
assert_eq!(embedding.len(), 384);
}
Err(e) => {
assert!(e.to_string().contains("model") || e.to_string().contains("load"));
}
}
}
#[tokio::test]
async fn test_warmup_functionality() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let result = provider.warmup().await;
assert!(result.is_ok(), "Warmup should succeed");
}
#[test]
fn test_utils_list_models() {
let models = list_available_models();
assert!(!models.is_empty());
for model in models {
assert!(!model.model_name.is_empty());
assert!(model.embedding_dimension > 0);
}
}
#[test]
fn test_utils_recommended_models() {
let fast_model = get_recommended_model(LocalModelUseCase::Fast);
assert_eq!(fast_model.embedding_dimension, 384);
let quality_model = get_recommended_model(LocalModelUseCase::Quality);
assert_eq!(quality_model.embedding_dimension, 768);
let multilingual_model = get_recommended_model(LocalModelUseCase::Multilingual);
assert_eq!(multilingual_model.embedding_dimension, 384);
}
#[tokio::test]
async fn test_production_warning_behavior() {
let config = LocalConfig::new("test-model", 384);
let provider = LocalEmbeddingProvider::new(config).await.unwrap();
let embedding1 = provider.embed_text("test").await.unwrap();
let embedding2 = provider.embed_text("test").await.unwrap();
assert_eq!(embedding1, embedding2);
assert_eq!(embedding1.len(), 384);
}
}