use anyhow::{Result, anyhow};
use reqwest;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::{RwLock, Mutex};
use tracing::{info, warn, debug, error};
use hf_hub::api::tokio::Api;
use tokio::fs;
use std::pin::Pin;
use std::future::Future;
use candle_core::{Device, Tensor, DType};
use candle_transformers::models::bert::{BertModel, Config};
use candle_nn::VarBuilder;
use tokenizers::Tokenizer;
use crate::types::EmbeddingConfig;
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
texts: Vec<String>,
model: String,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
embeddings: Vec<Vec<f32>>,
model: String,
dimensions: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelConfig {
pub dimensions: usize,
pub model_type: String,
pub description: String,
pub hf_model_id: String, }
#[derive(Debug, Clone)]
pub enum ModelStrategy {
LocalFirst, HttpOnly, LocalOnly, }
pub struct LoadedModel {
pub model: BertModel,
pub tokenizer: Tokenizer,
pub device: Device,
pub config: ModelConfig,
}
pub struct EmbeddingService {
config: EmbeddingConfig,
client: reqwest::Client,
dimensions: Arc<RwLock<Option<usize>>>,
models_path: PathBuf,
supported_models: HashMap<String, ModelConfig>,
current_model: Arc<RwLock<Option<String>>>,
strategy: ModelStrategy,
initialized: Arc<RwLock<bool>>,
loaded_model: Arc<RwLock<Option<LoadedModel>>>,
embedding_cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
force_cpu: Arc<RwLock<bool>>,
inference_lock: Arc<Mutex<()>>,
}
impl EmbeddingService {
pub async fn new(config: &EmbeddingConfig, models_path: &Path) -> Result<Self> {
let client = reqwest::Client::new();
let mut supported_models = HashMap::new();
supported_models.insert(
"BAAI/bge-m3".to_string(),
ModelConfig {
dimensions: 1024,
model_type: "bert".to_string(),
description: "BGE-M3 multilingual embedding model (1024 dimensions)".to_string(),
hf_model_id: "BAAI/bge-m3".to_string(),
},
);
supported_models.insert(
"embaas/sentence-transformers-e5-large-v2".to_string(),
ModelConfig {
dimensions: 1024,
model_type: "sentence-transformers".to_string(),
description: "High-quality E5-Large-v2 embedding model (1024 dimensions)".to_string(),
hf_model_id: "embaas/sentence-transformers-e5-large-v2".to_string(),
},
);
supported_models.insert(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
ModelConfig {
dimensions: 384,
model_type: "sentence-transformers".to_string(),
description: "Compact multilingual model (384 dimensions)".to_string(),
hf_model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
},
);
supported_models.insert(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
ModelConfig {
dimensions: 384,
model_type: "sentence-transformers".to_string(),
description: "Multilingual paraphrase model (384 dimensions)".to_string(),
hf_model_id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
},
);
let strategy = if config.service_url.is_some() {
ModelStrategy::LocalFirst } else {
ModelStrategy::LocalOnly };
Ok(Self {
config: config.clone(),
client,
dimensions: Arc::new(RwLock::new(Some(config.dimensions))),
models_path: models_path.to_path_buf(),
supported_models,
current_model: Arc::new(RwLock::new(None)),
strategy,
initialized: Arc::new(RwLock::new(false)),
loaded_model: Arc::new(RwLock::new(None)),
embedding_cache: Arc::new(RwLock::new(HashMap::new())),
force_cpu: Arc::new(RwLock::new(false)),
inference_lock: Arc::new(Mutex::new(())),
})
}
fn detect_best_device(force_cpu: bool) -> Device {
if force_cpu {
info!("💻 Force CPU mode enabled - using CPU for inference");
return Device::Cpu;
}
#[cfg(feature = "cuda")]
{
match Device::new_cuda(0) {
Ok(cuda_device) => {
info!("✅ CUDA GPU detected and available");
return cuda_device;
}
Err(e) => {
warn!("⚠️ CUDA GPU not available: {}", e);
}
}
}
#[cfg(feature = "metal")]
{
match Device::new_metal(0) {
Ok(metal_device) => {
info!("🔄 Testing Metal GPU availability...");
match candle_core::Tensor::zeros((2, 2), candle_core::DType::F32, &metal_device) {
Ok(_test) => {
info!("✅ Metal GPU verified and enabled");
info!(" 🚀 GPU acceleration active for embedding generation");
return metal_device;
}
Err(e) => {
warn!("⚠️ Metal GPU test failed: {}", e);
warn!(" Falling back to CPU for stability");
}
}
}
Err(e) => {
warn!("⚠️ Metal GPU not available: {}", e);
}
}
}
info!("💻 Using CPU for inference");
Device::Cpu
}
pub async fn initialize(&self) -> Result<()> {
let mut initialized = self.initialized.write().await;
if *initialized {
return Ok(());
}
info!("Initializing Embedding Service");
tokio::fs::create_dir_all(&self.models_path).await?;
debug!("Models directory created: {:?}", self.models_path);
use crate::utils::model_setup::ModelSetup;
let model_setup = ModelSetup::new(self.models_path.clone());
let mut model_ready = false;
if self.config.model == "embaas/sentence-transformers-e5-large-v2" {
match model_setup.check_model_exists() {
Ok(true) => {
info!("✅ embaas/sentence-transformers-e5-large-v2 model is ready");
model_ready = true;
}
Ok(false) => {
info!("📦 embaas/sentence-transformers-e5-large-v2 model not found in models directory");
info!(" The application should provide model files in: {:?}", self.models_path);
warn!("⚠️ Model files not found. RAG will use fallback embeddings.");
}
Err(e) => {
warn!("⚠️ Model check failed: {}. Using fallback embeddings.", e);
}
}
}
let default_model = if self.config.model == "embaas/sentence-transformers-e5-large-v2" {
"embaas/sentence-transformers-e5-large-v2".to_string()
} else {
self.config.model.clone()
};
if model_ready {
info!("Loading embaas/sentence-transformers-e5-large-v2 model into memory...");
match self.load_model(&default_model).await {
Ok(_) => {
info!("✅ Model loaded successfully and ready for embeddings");
}
Err(e) => {
warn!("⚠️ Model file loading failed: {}. Using fallback embeddings.", e);
let mut current_model = self.current_model.write().await;
*current_model = Some(default_model);
}
}
} else {
info!("Using fallback embeddings (model not available)");
let mut current_model = self.current_model.write().await;
*current_model = Some(default_model);
}
*initialized = true;
Ok(())
}
pub async fn load_model(&self, model_name: &str) -> Result<()> {
info!("🔄 Loading embedding model: {}", model_name);
if !self.supported_models.contains_key(model_name) {
let supported: Vec<_> = self.supported_models.keys().collect();
return Err(anyhow!(
"Unsupported model: {}. Supported models: {:?}",
model_name,
supported
));
}
if model_name == "embaas/sentence-transformers-e5-large-v2" {
info!(" 🔄 Loading embaas/sentence-transformers-e5-large-v2 model");
let model_variants = vec![
"embaas/sentence-transformers-e5-large-v2", "BAAI/bge-m3", "sentence-transformers/all-MiniLM-L6-v2" ];
let mut load_success = false;
let mut actual_model_name = model_name.to_string();
for variant in &model_variants {
info!(" 🔄 Attempting to load: {}", variant);
match self.try_load_single_model(variant).await {
Ok(_) => {
if *variant == "embaas/sentence-transformers-e5-large-v2" {
info!(" ✅ Successfully loaded primary embaas e5-large-v2 model: {}", variant);
} else {
warn!(" ⚠️ BGE-M3 not available, using fallback: {}", variant);
warn!(" 📏 Keeping BGE-M3 dimensions (1024D) for compatibility");
}
actual_model_name = variant.to_string();
load_success = true;
break;
}
Err(e) => {
warn!(" ❌ Failed to load {}: {}", variant, e);
continue;
}
}
}
if !load_success {
return Err(anyhow!("All embaas e5-large-v2 fallback models failed to load"));
}
let mut current_model = self.current_model.write().await;
*current_model = Some(model_name.to_string());
info!("Model loaded successfully: {} -> {} (1024D)", model_name, actual_model_name);
return Ok(());
} else {
return self.try_load_single_model(model_name).await;
}
}
async fn try_load_single_model(&self, model_name: &str) -> Result<()> {
if !self.is_model_downloaded(model_name).await? {
info!(" 📦 Model {} not found locally", model_name);
info!(" ⚠️ Auto-download is disabled. The application must provide model files.");
return Err(anyhow!("Model {} not found in models directory. Please ensure the application provides model files in the models directory.", model_name));
} else {
info!(" 📁 Model {} found locally", model_name);
}
self.load_model_for_inference(model_name).await?;
info!(" 🧠 Model {} loaded for inference", model_name);
Ok(())
}
async fn load_model_for_inference(&self, model_name: &str) -> Result<()> {
let model_config = self.supported_models.get(model_name)
.ok_or_else(|| anyhow!("Model config not found for {}", model_name))?;
let model_cache_dir = self.models_path.join(format!("models--{}", model_name.replace('/', "--")));
let force_cpu = *self.force_cpu.read().await;
let device = Self::detect_best_device(force_cpu);
info!("🖥️ Using device: {:?}", device);
let tokenizer_path = model_cache_dir.join("tokenizer.json");
if !tokenizer_path.exists() {
error!(" ❌ Tokenizer file missing: {:?}", tokenizer_path);
if let Ok(mut entries) = fs::read_dir(&model_cache_dir).await {
let mut files = Vec::new();
while let Some(entry) = entries.next_entry().await? {
files.push(entry.file_name().to_string_lossy().to_string());
}
error!(" 📁 Model directory contents: {:?}", files);
}
return Err(anyhow!("Tokenizer not found at {:?}", tokenizer_path));
}
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
let config_path = model_cache_dir.join("config.json");
let config_content = fs::read_to_string(&config_path).await?;
let bert_config: Config = serde_json::from_str(&config_content)?;
let safetensors_path = model_cache_dir.join("model.safetensors");
let pytorch_path = model_cache_dir.join("pytorch_model.bin");
let model_format = if safetensors_path.exists() {
"safetensors"
} else if pytorch_path.exists() {
"pytorch"
} else {
return Err(anyhow!("No model weights found. Checked for: {:?}, {:?}",
safetensors_path, pytorch_path));
};
info!(" 🔍 Using {} format model weights", model_format);
let vb = match model_format {
"pytorch" => {
info!(" 📥 Loading PyTorch weights from: {:?}", pytorch_path);
unsafe { VarBuilder::from_pth(&pytorch_path, DType::F32, &device) }
.map_err(|e| anyhow!("Failed to load PyTorch weights: {}", e))?
},
"safetensors" => {
info!(" 📥 Loading SafeTensors from: {:?}", safetensors_path);
unsafe { VarBuilder::from_pth(&safetensors_path, DType::F32, &device) }
.map_err(|e| anyhow!("Failed to load SafeTensors: {}", e))?
},
_ => {
return Err(anyhow!("Unsupported model format: {}", model_format));
}
};
let bert_model = BertModel::load(vb, &bert_config)
.map_err(|e| anyhow!("Failed to load BERT model: {}", e))?;
let loaded_model = LoadedModel {
model: bert_model,
tokenizer,
device,
config: model_config.clone(),
};
let mut model_lock = self.loaded_model.write().await;
*model_lock = Some(loaded_model);
Ok(())
}
async fn download_model_from_hf(&self, model_name: &str) -> Result<()> {
let model_config = self.supported_models.get(model_name)
.ok_or_else(|| anyhow!("Model config not found for {}", model_name))?;
let model_cache_dir = self.models_path.join(format!("models--{}", model_name.replace('/', "--")));
if let Some(parent) = model_cache_dir.parent() {
if !parent.exists() {
info!(" 📁 Creating parent directory: {:?}", parent);
}
}
match fs::create_dir_all(&model_cache_dir).await {
Ok(()) => {
info!(" 📁 Model cache directory ready: {:?}", model_cache_dir);
}
Err(e) => {
error!(" ❌ Failed to create model cache directory {:?}: {:?}", model_cache_dir, e);
error!(" 🔍 Check permissions for directory creation");
return Err(anyhow!("Failed to create model cache directory: {}", e));
}
}
let essential_files = vec![
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"model.safetensors", "pytorch_model.bin", ];
info!(" 📥 Downloading {} essential files...", essential_files.len());
let api = Api::new().map_err(|e| anyhow!("Failed to initialize HF API: {}", e))?;
let repo = api.model(model_config.hf_model_id.clone());
for file_name in essential_files {
let file_path = model_cache_dir.join(file_name);
if file_path.exists() {
debug!(" ✓ {} already exists", file_name);
continue;
}
info!(" ⬇️ Downloading {}...", file_name);
match repo.get(file_name).await {
Ok(file_path_from_hf) => {
match tokio::fs::copy(&file_path_from_hf, &file_path).await {
Ok(bytes_copied) => {
debug!(" ✅ Downloaded {} ({} bytes)", file_name, bytes_copied);
if !file_path.exists() {
error!(" ❌ File copy appeared to succeed but file doesn't exist: {:?}", file_path);
return Err(anyhow!("File copy verification failed for {}", file_name));
}
}
Err(e) => {
error!(" ❌ Failed to copy {} from HF cache to local directory: {:?}", file_name, e);
error!(" 🔍 Source: {:?}", file_path_from_hf);
error!(" 🔍 Target: {:?}", file_path);
error!(" 📁 Target directory: {:?}", model_cache_dir);
error!(" 🔍 Check permissions and disk space");
return Err(anyhow!("Failed to copy {} from HF cache to local directory: {}", file_name, e));
}
}
}
Err(e) => {
warn!(" ⚠️ Failed to download {}: {}", file_name, e);
}
}
}
let config_path = model_cache_dir.join("config.json");
let tokenizer_path = model_cache_dir.join("tokenizer.json");
let has_weights = model_cache_dir.join("model.safetensors").exists()
|| model_cache_dir.join("pytorch_model.bin").exists();
if !config_path.exists() {
return Err(anyhow!("Critical file missing: config.json was not downloaded"));
}
if !tokenizer_path.exists() {
return Err(anyhow!("Critical file missing: tokenizer.json was not downloaded"));
}
if !has_weights {
return Err(anyhow!("Critical file missing: no model weight files (model.safetensors or pytorch_model.bin) were downloaded"));
}
let model_info = serde_json::json!({
"model_name": model_name,
"hf_model_id": model_config.hf_model_id,
"dimensions": model_config.dimensions,
"model_type": model_config.model_type,
"downloaded_at": chrono::Utc::now().to_rfc3339(),
"cached_by": "rag-module-rust"
});
let info_file = model_cache_dir.join("model_info.json");
let info_content = serde_json::to_string_pretty(&model_info)?;
fs::write(info_file, info_content).await?;
let tokenizer_path = model_cache_dir.join("tokenizer.json");
let config_path = model_cache_dir.join("config.json");
if !tokenizer_path.exists() {
return Err(anyhow!("Download failed: tokenizer.json not found at {:?}", tokenizer_path));
}
if !config_path.exists() {
return Err(anyhow!("Download failed: config.json not found at {:?}", config_path));
}
info!(" 🎉 Model {} successfully downloaded and cached", model_name);
debug!(" 📁 Model files saved to: {:?}", model_cache_dir);
Ok(())
}
pub async fn is_model_downloaded(&self, model_name: &str) -> Result<bool> {
let model_path = self.models_path.join(format!("models--{}", model_name.replace('/', "--")));
if !model_path.exists() {
return Ok(false);
}
let required_files = vec![
"config.json",
"tokenizer.json",
];
let has_weights = model_path.join("model.safetensors").exists()
|| model_path.join("pytorch_model.bin").exists();
if !has_weights {
return Ok(false);
}
for file_name in required_files {
if !model_path.join(file_name).exists() {
return Ok(false);
}
}
Ok(true)
}
pub async fn get_model_info(&self) -> Result<serde_json::Value> {
let current_model = self.current_model.read().await;
if let Some(model_name) = current_model.as_ref() {
if let Some(model_config) = self.supported_models.get(model_name) {
let dimensions = self.dimensions.read().await;
return Ok(serde_json::json!({
"name": model_name,
"dimensions": dimensions.unwrap_or(model_config.dimensions),
"type": model_config.model_type,
"description": model_config.description,
"loaded": true,
"strategy": format!("{:?}", self.strategy)
}));
}
}
Ok(serde_json::json!({
"name": null,
"loaded": false,
"strategy": format!("{:?}", self.strategy)
}))
}
pub fn get_supported_models(&self) -> &HashMap<String, ModelConfig> {
&self.supported_models
}
pub async fn get_storage_info(&self) -> Result<serde_json::Value> {
let mut models = Vec::new();
let mut total_size = 0u64;
for (model_name, model_config) in &self.supported_models {
let is_downloaded = self.is_model_downloaded(model_name).await.unwrap_or(false);
let model_path = self.models_path.join(format!("models--{}", model_name.replace('/', "--")));
let mut size = 0u64;
if is_downloaded {
size = Self::calculate_dir_size(&model_path).await.unwrap_or(0);
total_size += size;
}
models.push(serde_json::json!({
"name": model_name,
"dimensions": model_config.dimensions,
"type": model_config.model_type,
"description": model_config.description,
"downloaded": is_downloaded,
"size": size,
"sizeFormatted": Self::format_file_size(size)
}));
}
Ok(serde_json::json!({
"models": models,
"totalSize": total_size,
"totalSizeFormatted": Self::format_file_size(total_size),
"modelsPath": self.models_path
}))
}
fn calculate_dir_size(dir: &Path) -> Pin<Box<dyn Future<Output = Result<u64>> + Send>> {
let dir = dir.to_path_buf();
Box::pin(async move {
let mut total_size = 0u64;
if !dir.is_dir() {
return Ok(0);
}
let mut entries = tokio::fs::read_dir(&dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.is_dir() {
total_size += Self::calculate_dir_size(&path).await.unwrap_or(0);
} else {
if let Ok(metadata) = tokio::fs::metadata(&path).await {
total_size += metadata.len();
}
}
}
Ok(total_size)
})
}
fn format_file_size(bytes: u64) -> String {
if bytes == 0 {
return "0 B".to_string();
}
let sizes = ["B", "KB", "MB", "GB"];
let i = ((bytes as f64).log2() / 10.0) as usize;
let i = i.min(sizes.len() - 1);
let size = bytes as f64 / (1024_u64.pow(i as u32) as f64);
format!("{:.2} {}", size, sizes[i])
}
pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.generate_embeddings(&[text]).await?;
embeddings.into_iter().next()
.ok_or_else(|| anyhow!("No embedding returned"))
}
pub async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if self.config.dimensions == 1 {
return Ok(texts.iter().map(|_| vec![0.1]).collect());
}
match self.strategy {
ModelStrategy::LocalFirst => {
match self.generate_local_embeddings(texts).await {
Ok(embeddings) => {
debug!("Generated embeddings using local model");
return Ok(embeddings);
}
Err(e) => {
warn!("Local model failed: {}, trying HTTP service", e);
if self.config.service_url.is_some() {
return self.generate_http_embeddings(texts).await;
} else {
return Err(anyhow!("Both local model and HTTP service failed"));
}
}
}
}
ModelStrategy::LocalOnly => {
return self.generate_local_embeddings(texts).await;
}
ModelStrategy::HttpOnly => {
return self.generate_http_embeddings(texts).await;
}
}
}
pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
self.generate_embeddings(&text_refs).await
}
fn generate_local_embeddings<'a>(&'a self, texts: &'a [&'a str]) -> Pin<Box<dyn Future<Output = Result<Vec<Vec<f32>>>> + Send + 'a>> {
Box::pin(async move {
self.generate_local_embeddings_impl(texts).await
})
}
async fn generate_local_embeddings_impl(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let loaded_model_lock = self.loaded_model.read().await;
if let Some(loaded_model) = loaded_model_lock.as_ref() {
let mut cached_embeddings = Vec::new();
let mut uncached_texts = Vec::new();
let mut uncached_indices = Vec::new();
for (i, &text) in texts.iter().enumerate() {
if let Some(cached_embedding) = self.get_cached_embedding(text).await {
debug!("📋 Cache hit for text {} ({} chars)", i + 1, text.len());
cached_embeddings.push((i, cached_embedding));
} else {
debug!("🔍 Cache miss for text {} ({} chars)", i + 1, text.len());
uncached_texts.push(text);
uncached_indices.push(i);
}
}
info!("📊 Cache stats: {}/{} hits, {} texts need generation",
cached_embeddings.len(), texts.len(), uncached_texts.len());
let mut final_embeddings = vec![Vec::new(); texts.len()];
for (idx, embedding) in cached_embeddings {
final_embeddings[idx] = embedding;
}
if !uncached_texts.is_empty() {
match self.run_inference_with_model(loaded_model, &uncached_texts).await {
Ok(new_embeddings) => {
for (i, embedding) in new_embeddings.into_iter().enumerate() {
let original_idx = uncached_indices[i];
let text = uncached_texts[i];
self.cache_embedding(text, &embedding).await;
final_embeddings[original_idx] = embedding;
}
}
Err(e) => {
if format!("{:?}", e).contains("Metal") || format!("{:?}", e).contains("device mismatch") {
warn!("🔄 Metal error detected, reloading model on CPU...");
drop(loaded_model_lock);
{
let mut force_cpu = self.force_cpu.write().await;
*force_cpu = true;
}
let current_model_name = {
let current_model = self.current_model.read().await;
current_model.clone().ok_or_else(|| anyhow!("No model loaded"))?
};
info!("🔄 Reloading model on CPU...");
self.load_model_for_inference(¤t_model_name).await?;
info!("🔄 Retrying embedding generation with CPU model...");
return self.generate_local_embeddings(texts).await;
} else {
return Err(e);
}
}
}
}
return Ok(final_embeddings);
}
info!("Model not loaded yet, using embaas/sentence-transformers-e5-large-v2 compatible fallback embeddings");
let embeddings = texts
.iter()
.map(|text| {
self.generate_content_aware_embedding(text, 1024) })
.collect();
debug!("Generated {} embaas e5-large-v2 compatible fallback embeddings with 1024 dimensions", texts.len());
Ok(embeddings)
}
fn generate_content_aware_embedding(&self, text: &str, target_dimensions: usize) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut embedding = Vec::with_capacity(target_dimensions);
for i in 0..target_dimensions {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
i.hash(&mut hasher);
let hash_value = hasher.finish();
let normalized_value = (hash_value % 10000) as f32 / 10000.0 - 0.5; embedding.push(normalized_value);
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in embedding.iter_mut() {
*val /= norm;
}
}
embedding
}
async fn run_inference_with_model(&self, loaded_model: &LoadedModel, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let _inference_guard = self.inference_lock.lock().await;
info!("🔒 Acquired inference lock for {} texts", texts.len());
info!("🚀 Starting TRUE batch embaas e5-large-v2 inference for {} texts", texts.len());
let mut tokenized_texts = Vec::new();
let mut max_length = 0;
for (i, &text) in texts.iter().enumerate() {
info!("🔤 Tokenizing text {}/{}: {} chars", i + 1, texts.len(), text.len());
let encoding = loaded_model.tokenizer
.encode(text, true)
.map_err(|e| anyhow!("Tokenization failed for text {}: {}", i, e))?;
let tokens = encoding.get_ids().to_vec();
let token_type_ids = encoding.get_type_ids().to_vec();
let attention_mask = encoding.get_attention_mask().to_vec();
max_length = max_length.max(tokens.len());
tokenized_texts.push((tokens, token_type_ids, attention_mask));
}
info!("🔢 Max sequence length: {}", max_length);
let batch_size = texts.len();
let mut input_ids_batch = Vec::new();
let mut token_type_ids_batch = Vec::new();
let mut attention_mask_batch = Vec::new();
for (tokens, type_ids, attention_mask) in tokenized_texts {
let mut padded_tokens = tokens.to_vec();
let mut padded_type_ids = type_ids.to_vec();
let mut padded_attention = attention_mask.to_vec();
while padded_tokens.len() < max_length {
padded_tokens.push(0); padded_type_ids.push(0);
padded_attention.push(0);
}
input_ids_batch.extend(padded_tokens.iter().map(|&x| x as i64));
token_type_ids_batch.extend(padded_type_ids.iter().map(|&x| x as i64));
attention_mask_batch.extend(padded_attention.iter().map(|&x| x as f32));
}
let input_ids = Tensor::new(input_ids_batch.as_slice(), &loaded_model.device)?
.reshape(&[batch_size, max_length])?;
let token_type_ids = Tensor::new(token_type_ids_batch.as_slice(), &loaded_model.device)?
.reshape(&[batch_size, max_length])?;
let attention_mask = Tensor::new(attention_mask_batch.as_slice(), &loaded_model.device)?
.reshape(&[batch_size, max_length])?;
info!("🎛️ Created batch tensors - input_ids: {:?}, token_type_ids: {:?}, attention_mask: {:?}",
input_ids.shape(), token_type_ids.shape(), attention_mask.shape());
info!("🧠 MODEL FORWARD PASS INITIATED");
info!(" 📊 Batch Details:");
info!(" - Batch Size: {}", batch_size);
info!(" - Max Sequence Length: {}", max_length);
info!(" - Device: {:?}", loaded_model.device);
info!(" - Model Type: embaas/sentence-transformers-e5-large-v2");
info!(" 🎛️ Tensor Shapes:");
info!(" - input_ids: {:?}", input_ids.shape());
info!(" - token_type_ids: {:?}", token_type_ids.shape());
info!(" - attention_mask: {:?}", attention_mask.shape());
info!("🚀 Executing model.forward()...");
let start_time = std::time::Instant::now();
let sequence_output = loaded_model.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))
.map_err(|e| {
error!("❌ Model forward pass failed: {}", e);
error!("🔧 Debugging info:");
error!(" - Input tensor shapes: input_ids={:?}, token_type_ids={:?}, attention_mask={:?}",
input_ids.shape(), token_type_ids.shape(), attention_mask.shape());
error!(" - Device: {:?}", loaded_model.device);
error!(" - Error details: {:?}", e);
if format!("{:?}", e).contains("Metal") || format!("{:?}", e).contains("pipeline") {
error!("🔧 Metal GPU error detected! Model will be reloaded on CPU.");
}
anyhow!("Model inference failed: {}", e)
})?;
info!("✅ Model forward pass successful");
let forward_duration = start_time.elapsed();
info!("✅ Model forward pass completed in {:?}", forward_duration);
info!("📊 Model Output Analysis:");
info!(" - Output Shape: {:?}", sequence_output.shape());
info!(" - Expected Shape: [batch_size={}, seq_len={}, hidden_dim=1024]", batch_size, max_length);
info!("🔄 Starting post-processing (pooling & normalization)...");
let mut embeddings = Vec::new();
let pooling_start = std::time::Instant::now();
for i in 0..batch_size {
debug!(" 🔧 Processing sequence {}/{}", i + 1, batch_size);
let sequence_i = sequence_output.narrow(0, i, 1)?; let attention_mask_i = attention_mask.narrow(0, i, 1)?;
let pooled_output = self.mean_pool(&sequence_i, &attention_mask_i)?;
debug!(" Mean pooled shape: {:?}", pooled_output.shape());
let normalized_embedding = self.l2_normalize(&pooled_output)?;
debug!(" L2 normalized shape: {:?}", normalized_embedding.shape());
let embedding_vec = normalized_embedding.squeeze(0)?.to_vec1::<f32>()?;
debug!(" Final embedding dimensions: {}", embedding_vec.len());
embeddings.push(embedding_vec);
}
let pooling_duration = pooling_start.elapsed();
info!("✅ Post-processing completed in {:?}", pooling_duration);
let total_duration = start_time.elapsed();
info!("🎯 EMBEDDING GENERATION SUMMARY:");
info!(" 📊 Results: {} embeddings generated", embeddings.len());
info!(" 📐 Dimensions: {} per embedding", embeddings.first().map(|e| e.len()).unwrap_or(0));
info!(" ⏱️ Total Time: {:?}", total_duration);
info!(" ⚡ Forward Pass: {:?}", forward_duration);
info!(" 🔄 Post-processing: {:?}", pooling_duration);
info!(" 🖥️ Device Used: {:?}", loaded_model.device);
info!("🔓 Releasing inference lock");
Ok(embeddings)
}
fn mean_pool(&self, sequence_output: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let expanded_mask = attention_mask
.to_dtype(candle_core::DType::F32)? .unsqueeze(2)? .expand(sequence_output.shape())?;
let masked_embeddings = (sequence_output * &expanded_mask)?;
let summed_embeddings = masked_embeddings.sum_keepdim(1)?;
let summed_mask = expanded_mask.sum_keepdim(1)?;
let clamp_mask = summed_mask.clamp(1e-9, f64::INFINITY)?;
let mean_pooled = (summed_embeddings / clamp_mask)?;
Ok(mean_pooled.squeeze(1)?) }
fn l2_normalize(&self, embeddings: &Tensor) -> Result<Tensor> {
let shape = embeddings.shape();
info!("🔧 L2 normalize input shape: {:?}", shape);
let norm = if shape.dims().len() == 2 {
embeddings.sqr()?.sum_keepdim(1)?.sqrt()?
} else if shape.dims().len() == 1 {
embeddings.sqr()?.sum_all()?.sqrt()?
} else {
return Err(anyhow!("Unexpected tensor shape for L2 normalization: {:?}", shape));
};
info!("🔧 Norm shape: {:?}", norm.shape());
let clamp_norm = norm.clamp(1e-12, f64::INFINITY)?;
let normalized = embeddings.broadcast_div(&clamp_norm)?;
info!("🔧 L2 normalized shape: {:?}", normalized.shape());
Ok(normalized)
}
async fn generate_http_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let request = EmbeddingRequest {
texts: texts.iter().map(|s| s.to_string()).collect(),
model: self.config.model.clone(),
};
let service_url = self.config.service_url
.as_ref()
.ok_or_else(|| anyhow!("No embedding service URL configured"))?;
let mut request_builder = self.client
.post(format!("{}/embeddings", service_url))
.json(&request);
if let Some(api_key) = &self.config.api_key {
request_builder = request_builder.bearer_auth(api_key);
}
let response = request_builder
.send()
.await
.map_err(|e| anyhow!("Failed to call embedding service: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow!("Embedding service error {}: {}", status, error_text));
}
let embedding_response: EmbeddingResponse = response
.json()
.await
.map_err(|e| anyhow!("Failed to parse embedding response: {}", e))?;
let mut dimensions = self.dimensions.write().await;
if *dimensions != Some(embedding_response.dimensions) {
*dimensions = Some(embedding_response.dimensions);
}
debug!("Generated {} HTTP embeddings with {} dimensions", texts.len(), embedding_response.dimensions);
Ok(embedding_response.embeddings)
}
pub async fn get_dimensions(&self) -> Result<usize> {
let dimensions = self.dimensions.read().await;
dimensions.ok_or_else(|| anyhow!("Dimensions not set"))
}
pub async fn set_dimensions(&self, dims: usize) -> Result<()> {
let mut dimensions = self.dimensions.write().await;
*dimensions = Some(dims);
Ok(())
}
pub async fn health_check(&self) -> Result<bool> {
let initialized = self.initialized.read().await;
if !*initialized {
return Ok(false);
}
match self.strategy {
ModelStrategy::LocalOnly | ModelStrategy::LocalFirst => {
let current_model = self.current_model.read().await;
Ok(current_model.is_some())
}
ModelStrategy::HttpOnly => {
if let Some(service_url) = &self.config.service_url {
let response = self.client
.get(format!("{}/health", service_url))
.send()
.await;
match response {
Ok(resp) => Ok(resp.status().is_success()),
Err(_) => Ok(false),
}
} else {
Ok(false)
}
}
}
}
pub async fn shutdown(&self) -> Result<()> {
info!("🔄 Shutting down EmbeddingService...");
self.clear_cache().await;
{
let mut loaded_model = self.loaded_model.write().await;
if loaded_model.is_some() {
info!("🗑️ Unloading model from memory");
*loaded_model = None;
}
}
{
let mut force_cpu = self.force_cpu.write().await;
if *force_cpu {
info!("🔄 Resetting force_cpu flag");
*force_cpu = false;
}
}
{
let mut initialized = self.initialized.write().await;
*initialized = false;
}
info!("✅ EmbeddingService shutdown complete");
Ok(())
}
pub async fn reset_state(&self) -> Result<()> {
info!("🔄 Resetting EmbeddingService state...");
self.clear_cache().await;
{
let mut force_cpu = self.force_cpu.write().await;
if *force_cpu {
info!("🔄 Resetting force_cpu flag to allow GPU retry");
*force_cpu = false;
}
}
let current_model_name = {
let current_model = self.current_model.read().await;
current_model.clone()
};
if let Some(model_name) = current_model_name {
info!("🔄 Reloading model to clear internal state: {}", model_name);
{
let mut loaded_model = self.loaded_model.write().await;
*loaded_model = None;
}
match self.load_model_for_inference(&model_name).await {
Ok(_) => {
info!("✅ Model reloaded successfully");
}
Err(e) => {
warn!("⚠️ Failed to reload model (will retry on next embedding request): {}", e);
}
}
}
info!("✅ EmbeddingService state reset complete");
Ok(())
}
pub fn get_model(&self) -> &str {
&self.config.model
}
pub fn get_batch_size(&self) -> usize {
self.config.batch_size.unwrap_or(32)
}
pub async fn download_model(&self, model_name: &str) -> Result<()> {
info!("🔄 Downloading model without loading: {}", model_name);
if !self.supported_models.contains_key(model_name) {
let supported: Vec<_> = self.supported_models.keys().collect();
return Err(anyhow!(
"Unsupported model: {}. Supported models: {:?}",
model_name,
supported
));
}
if self.is_model_downloaded(model_name).await? {
info!("Model {} already downloaded", model_name);
return Ok(());
}
self.download_model_from_hf(model_name).await?;
info!("Model {} downloaded successfully", model_name);
Ok(())
}
pub async fn embed_batch<F>(&self, texts: Vec<&str>, batch_size: Option<usize>, mut on_progress: Option<F>) -> Result<Vec<Vec<f32>>>
where
F: FnMut(usize, usize) + Send + Sync,
{
let batch_size = batch_size.unwrap_or_else(|| self.get_batch_size());
let total = texts.len();
let mut all_embeddings = Vec::new();
info!("🔄 Processing {} texts in batches of {}", total, batch_size);
for (batch_idx, chunk) in texts.chunks(batch_size).enumerate() {
let batch_embeddings = self.generate_embeddings(chunk).await?;
all_embeddings.extend(batch_embeddings);
let processed = (batch_idx + 1) * batch_size.min(chunk.len());
if let Some(ref mut callback) = on_progress {
callback(processed, total);
}
debug!("Processed batch {} of {} ({}/{} texts)",
batch_idx + 1,
(total + batch_size - 1) / batch_size,
processed,
total);
}
info!("✅ Batch processing completed: {} embeddings generated", all_embeddings.len());
Ok(all_embeddings)
}
pub fn calculate_similarity(&self, embedding1: &[f32], embedding2: &[f32]) -> Result<f32> {
if embedding1.len() != embedding2.len() {
return Err(anyhow!("Embedding dimensions don't match: {} vs {}",
embedding1.len(), embedding2.len()));
}
if embedding1.is_empty() {
return Err(anyhow!("Embeddings cannot be empty"));
}
let dot_product: f32 = embedding1.iter()
.zip(embedding2.iter())
.map(|(a, b)| a * b)
.sum();
let magnitude1: f32 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude2: f32 = embedding2.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude1 == 0.0 || magnitude2 == 0.0 {
return Ok(0.0);
}
let similarity = dot_product / (magnitude1 * magnitude2);
Ok(similarity.clamp(-1.0, 1.0)) }
pub fn adjust_embedding_dimensions(&self, mut embedding: Vec<f32>, expected_dimensions: usize) -> Vec<f32> {
match embedding.len().cmp(&expected_dimensions) {
std::cmp::Ordering::Less => {
let padding_size = expected_dimensions - embedding.len();
embedding.extend(std::iter::repeat(0.0).take(padding_size));
debug!("Padded embedding from {} to {} dimensions",
expected_dimensions - padding_size, expected_dimensions);
},
std::cmp::Ordering::Greater => {
embedding.truncate(expected_dimensions);
debug!("Truncated embedding to {} dimensions", expected_dimensions);
},
std::cmp::Ordering::Equal => {
}
}
embedding
}
pub async fn get_detailed_model_state(&self) -> Result<serde_json::Value> {
let current_model = self.current_model.read().await;
let loaded_model = self.loaded_model.read().await;
let initialized = self.initialized.read().await;
let mut state = serde_json::json!({
"initialized": *initialized,
"strategy": format!("{:?}", self.strategy),
"models_path": self.models_path,
"supported_models": self.supported_models.len()
});
if let Some(model_name) = current_model.as_ref() {
let model_config = self.supported_models.get(model_name);
let is_loaded_for_inference = loaded_model.is_some();
let is_downloaded = self.is_model_downloaded(model_name).await.unwrap_or(false);
state["current_model"] = serde_json::json!({
"name": model_name,
"actual_model": model_name, "loaded": is_loaded_for_inference,
"downloaded": is_downloaded,
"is_fallback": false, "config": model_config
});
} else {
state["current_model"] = serde_json::Value::Null;
}
Ok(state)
}
pub async fn list_downloaded_models(&self) -> Result<Vec<String>> {
let mut downloaded = Vec::new();
for model_name in self.supported_models.keys() {
if self.is_model_downloaded(model_name).await? {
downloaded.push(model_name.clone());
}
}
Ok(downloaded)
}
pub async fn remove_model(&self, model_name: &str) -> Result<()> {
if !self.supported_models.contains_key(model_name) {
return Err(anyhow!("Unknown model: {}", model_name));
}
let model_path = self.models_path.join(format!("models--{}", model_name.replace('/', "--")));
if model_path.exists() {
tokio::fs::remove_dir_all(&model_path).await?;
info!("Removed model: {} from {}", model_name, model_path.display());
} else {
info!("Model {} was not downloaded", model_name);
}
let mut current_model = self.current_model.write().await;
if current_model.as_ref() == Some(&model_name.to_string()) {
*current_model = None;
let mut loaded_model = self.loaded_model.write().await;
*loaded_model = None;
}
Ok(())
}
pub async fn test_embedding(&self, test_text: &str) -> Result<serde_json::Value> {
let start_time = std::time::Instant::now();
let embedding = self.generate_embedding(test_text).await?;
let duration = start_time.elapsed();
let dimensions = embedding.len();
let sum: f32 = embedding.iter().sum();
let mean = sum / dimensions as f32;
let variance: f32 = embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / dimensions as f32;
let std_dev = variance.sqrt();
Ok(serde_json::json!({
"test_text": test_text,
"embedding_length": dimensions,
"generation_time_ms": duration.as_millis(),
"statistics": {
"mean": mean,
"std_dev": std_dev,
"min": embedding.iter().cloned().fold(f32::INFINITY, f32::min),
"max": embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
},
"sample_values": embedding.iter().take(5).cloned().collect::<Vec<_>>()
}))
}
async fn get_cached_embedding(&self, text: &str) -> Option<Vec<f32>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let cache_key = hasher.finish().to_string();
let cache = self.embedding_cache.read().await;
cache.get(&cache_key).cloned()
}
async fn cache_embedding(&self, text: &str, embedding: &[f32]) {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let cache_key = hasher.finish().to_string();
let mut cache = self.embedding_cache.write().await;
cache.insert(cache_key, embedding.to_vec());
debug!("🗃️ Cached embedding for text ({} chars) - cache size: {}",
text.len(), cache.len());
}
pub async fn clear_cache(&self) {
let mut cache = self.embedding_cache.write().await;
let cache_size = cache.len();
cache.clear();
info!("🧹 Cleared embedding cache ({} entries)", cache_size);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_dummy_embeddings() {
let temp_dir = TempDir::new().unwrap();
let models_path = temp_dir.path().join("models");
let config = EmbeddingConfig {
model: "dummy".to_string(),
dimensions: 1,
service_url: None,
api_key: None,
batch_size: None,
};
let service = EmbeddingService::new(&config, &models_path).await.unwrap();
let embedding = service.generate_embedding("test text").await.unwrap();
assert_eq!(embedding, vec![0.1]);
assert_eq!(service.get_dimensions().await.unwrap(), 1);
}
#[tokio::test]
async fn test_multiple_dummy_embeddings() {
let temp_dir = TempDir::new().unwrap();
let models_path = temp_dir.path().join("models");
let config = EmbeddingConfig {
model: "dummy".to_string(),
dimensions: 1,
service_url: None,
api_key: None,
batch_size: None,
};
let service = EmbeddingService::new(&config, &models_path).await.unwrap();
let texts = vec!["text1", "text2", "text3"];
let embeddings = service.generate_embeddings(&texts).await.unwrap();
assert_eq!(embeddings.len(), 3);
for embedding in embeddings {
assert_eq!(embedding, vec![0.1]);
}
}
#[tokio::test]
async fn test_bge_m3_model_loading() {
let temp_dir = TempDir::new().unwrap();
let models_path = temp_dir.path().join("models");
let config = EmbeddingConfig {
model: "embaas/sentence-transformers-e5-large-v2".to_string(),
dimensions: 1024,
service_url: None,
api_key: None,
batch_size: None,
};
let service = EmbeddingService::new(&config, &models_path).await.unwrap();
service.initialize().await.unwrap();
let model_info = service.get_model_info().await.unwrap();
if model_info["loaded"].as_bool().unwrap_or(false) {
assert_eq!(model_info["name"], "embaas/sentence-transformers-e5-large-v2");
assert_eq!(model_info["dimensions"], 1024);
} else {
assert_eq!(model_info["name"], serde_json::Value::Null);
assert_eq!(model_info["loaded"], false);
}
}
#[tokio::test]
async fn test_supported_models() {
let temp_dir = TempDir::new().unwrap();
let models_path = temp_dir.path().join("models");
let config = EmbeddingConfig::default();
let service = EmbeddingService::new(&config, &models_path).await.unwrap();
let supported = service.get_supported_models();
assert!(supported.contains_key("BAAI/bge-m3"));
assert!(supported.contains_key("sentence-transformers/all-MiniLM-L6-v2"));
assert!(supported.contains_key("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"));
let e5_config = &supported["embaas/sentence-transformers-e5-large-v2"];
assert_eq!(e5_config.dimensions, 1024);
}
}