1#![allow(unused_imports)]
7
8use crate::error::{MemvidError, Result};
9use crate::ml::device::DeviceType;
10use crate::ml::models::ModelManager;
11use crate::ml::text::{TextConfig, TextProcessor};
12use candle_core::{Device, Tensor};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use chrono;
15use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone)]
21pub struct EmbeddingConfig {
22 pub model_name: String,
24 pub max_length: usize,
26 pub normalize: bool,
28 pub batch_size: usize,
30 pub device_type: DeviceType,
32}
33
34impl Default for EmbeddingConfig {
35 fn default() -> Self {
36 Self {
37 model_name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
38 max_length: 384,
39 normalize: true,
40 batch_size: 32,
41 device_type: DeviceType::Cpu,
42 }
43 }
44}
45
46pub type Embedding = Vec<f32>;
48
49pub struct EmbeddingModel {
51 config: EmbeddingConfig,
53 text_processor: TextProcessor,
55 cache: HashMap<String, Embedding>,
57 model_manager: ModelManager,
59 is_ready: bool,
61 device: Device,
63 bert_model: Option<BertModel>,
65}
66
67impl EmbeddingModel {
68 pub async fn new(config: EmbeddingConfig) -> Result<Self> {
70 log::info!("Initializing real embedding model: {}", config.model_name);
71
72 log::info!("Using device: {:?}", config.device_type);
73
74 let text_config = TextConfig {
76 max_length: config.max_length,
77 ..Default::default()
78 };
79 let text_processor = TextProcessor::new(text_config);
80
81 let model_manager = ModelManager::new(None)?;
83
84 let mut embedding_model = Self {
85 config,
86 text_processor,
87 cache: HashMap::new(),
88 model_manager,
89 is_ready: false,
90 device: Device::Cpu,
91 bert_model: None,
92 };
93
94 if let Err(e) = embedding_model.load_model().await {
96 log::warn!("Failed to load model, will use fallback: {}", e);
97 }
98
99 Ok(embedding_model)
100 }
101
102 async fn load_model(&mut self) -> Result<()> {
104 log::info!(
105 "Loading BERT model for TRUE semantic inference: {}",
106 self.config.model_name
107 );
108
109 match self.config.device_type {
111 DeviceType::Cuda(_) => {
112 #[cfg(feature = "cuda")]
113 {
114 self.device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
115 if matches!(self.device, Device::Cpu) {
116 log::warn!(
117 "CUDA requested but not available, using CPU for BERT inference"
118 );
119 } else {
120 log::info!("🚀 Using CUDA device for TRUE BERT neural network inference");
121 }
122 }
123 #[cfg(not(feature = "cuda"))]
124 {
125 log::warn!("CUDA requested but not compiled in, using CPU for BERT inference");
126 self.device = Device::Cpu;
127 }
128 }
129 DeviceType::Metal => {
130 #[cfg(feature = "metal")]
131 {
132 self.device = Device::new_metal(0).unwrap_or(Device::Cpu);
133 if matches!(self.device, Device::Cpu) {
134 log::warn!(
135 "Metal requested but not available, using CPU for BERT inference"
136 );
137 } else {
138 log::info!("🚀 Using Metal device for TRUE BERT neural network inference");
139 }
140 }
141 #[cfg(not(feature = "metal"))]
142 {
143 log::warn!("Metal requested but not compiled in, using CPU for BERT inference");
144 self.device = Device::Cpu;
145 }
146 }
147 DeviceType::Cpu => {
148 log::info!("🧠 Using CPU device for TRUE BERT neural network inference");
149 self.device = Device::Cpu;
150 }
151 };
152
153 let model_dir = self
155 .model_manager
156 .download_model(&self.config.model_name)
157 .await?;
158 log::info!("📥 Downloaded BERT model files to: {}", model_dir.display());
159
160 if let Err(e) = self.text_processor.load_tokenizer(&model_dir) {
162 return Err(MemvidError::MachineLearning(format!(
163 "Failed to load BERT tokenizer: {}",
164 e
165 )));
166 }
167 log::info!("📝 Loaded BERT tokenizer successfully");
168
169 let config_path = model_dir.join("config.json");
171 if !config_path.exists() {
172 return Err(MemvidError::MachineLearning(format!(
173 "BERT config file not found: {}",
174 config_path.display()
175 )));
176 }
177
178 let config_content = std::fs::read_to_string(&config_path).map_err(|e| {
179 MemvidError::MachineLearning(format!("Failed to read BERT config: {}", e))
180 })?;
181
182 let bert_config: BertConfig = serde_json::from_str(&config_content).map_err(|e| {
183 MemvidError::MachineLearning(format!("Failed to parse BERT config: {}", e))
184 })?;
185
186 log::info!(
187 "📋 Loaded BERT config: {} layers, {} hidden size, {} attention heads",
188 bert_config.num_hidden_layers,
189 bert_config.hidden_size,
190 bert_config.num_attention_heads
191 );
192
193 let weights_path = model_dir.join("model.safetensors");
195 if !weights_path.exists() {
196 return Err(MemvidError::MachineLearning(format!(
197 "BERT weights file not found: {}",
198 weights_path.display()
199 )));
200 }
201
202 log::info!("🏋️ Loading BERT neural network weights...");
203 let var_builder = unsafe {
204 candle_nn::VarBuilder::from_mmaped_safetensors(
205 &[weights_path],
206 candle_core::DType::F32,
207 &self.device,
208 )
209 .map_err(|e| {
210 MemvidError::MachineLearning(format!("Failed to load BERT safetensors: {}", e))
211 })?
212 };
213
214 log::info!("🧠 Initializing BERT neural network architecture...");
216 let bert_model = BertModel::load(var_builder, &bert_config).map_err(|e| {
217 MemvidError::MachineLearning(format!("Failed to initialize BERT model: {}", e))
218 })?;
219
220 self.bert_model = Some(bert_model);
221 self.is_ready = true;
222
223 log::info!("🎉 TRUE BERT model loaded successfully!");
224 log::info!("🧠 Ready for neural network-based semantic inference");
225 log::info!(
226 "⚡ Using {}-layer transformer with {} hidden dimensions",
227 bert_config.num_hidden_layers,
228 bert_config.hidden_size
229 );
230
231 Ok(())
232 }
233
234 fn generate_bert_embedding(&mut self, text: &str) -> Result<Embedding> {
236 #[cfg(test)]
238 {
239 return Ok(self.generate_test_embedding(text));
240 }
241
242 #[cfg(not(test))]
243 {
244 log::debug!(
245 "🧠 Performing BERT neural network inference for: {}",
246 &text[..std::cmp::min(50, text.len())]
247 );
248
249 let tokenized = self.text_processor.tokenize(text)?;
251 log::trace!(
252 "Tokenized {} chars into {} tokens",
253 text.len(),
254 tokenized.input_ids.len()
255 );
256
257 let bert_model = self
259 .bert_model
260 .as_ref()
261 .ok_or_else(|| MemvidError::MachineLearning("BERT model not loaded".to_string()))?;
262
263 let input_ids = Tensor::new(&tokenized.input_ids[..], &self.device)
265 .map_err(|e| {
266 MemvidError::MachineLearning(format!(
267 "Failed to create input_ids tensor: {}",
268 e
269 ))
270 })?
271 .unsqueeze(0)?; let token_type_ids = Tensor::new(&tokenized.token_type_ids[..], &self.device)
274 .map_err(|e| {
275 MemvidError::MachineLearning(format!(
276 "Failed to create token_type_ids tensor: {}",
277 e
278 ))
279 })?
280 .unsqueeze(0)?; let attention_mask = Tensor::new(&tokenized.attention_mask[..], &self.device)
283 .map_err(|e| {
284 MemvidError::MachineLearning(format!(
285 "Failed to create attention_mask tensor: {}",
286 e
287 ))
288 })?
289 .unsqueeze(0)?; log::trace!(
292 "Created tensors with shapes: input_ids {:?}, token_type_ids {:?}, attention_mask {:?}",
293 input_ids.shape(),
294 token_type_ids.shape(),
295 attention_mask.shape()
296 );
297
298 log::debug!("🔥 Running BERT forward pass through transformer layers...");
300 let bert_output = bert_model
301 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
302 .map_err(|e| {
303 MemvidError::MachineLearning(format!("BERT forward pass failed: {}", e))
304 })?;
305
306 log::trace!("BERT output shape: {:?}", bert_output.shape());
307
308 log::debug!("🎯 Applying mean pooling for sentence representation...");
310 let pooled = self.apply_mean_pooling(&bert_output, &attention_mask)?;
311
312 let pooled_squeezed = pooled.squeeze(0)?;
314 let embedding_vec = pooled_squeezed.to_vec1::<f32>().map_err(|e| {
315 MemvidError::MachineLearning(format!("Failed to convert embedding tensor: {}", e))
316 })?;
317
318 log::debug!(
319 "✅ Generated {}-dimensional BERT embedding",
320 embedding_vec.len()
321 );
322
323 Ok(embedding_vec)
324 }
325 }
326
327 #[cfg(test)]
329 fn generate_test_embedding(&self, text: &str) -> Embedding {
330 use std::collections::hash_map::DefaultHasher;
331 use std::hash::{Hash, Hasher};
332
333 let mut hasher = DefaultHasher::new();
335 text.hash(&mut hasher);
336 let hash = hasher.finish();
337
338 let mut embedding = Vec::with_capacity(384);
340 let mut seed = hash;
341
342 for _ in 0..384 {
343 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
345 let val = ((seed >> 16) as f32) / 32768.0 - 1.0; embedding.push(val * 0.1); }
348
349 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
351 if norm > 0.0 {
352 for val in &mut embedding {
353 *val /= norm;
354 }
355 }
356
357 embedding
358 }
359
360 pub fn encode(&mut self, text: &str) -> Result<Embedding> {
362 if let Some(cached) = self.cache.get(text) {
364 log::trace!("Using cached BERT embedding");
365 return Ok(cached.clone());
366 }
367
368 let embedding = if self.is_ready && self.bert_model.is_some() {
369 log::debug!("🧠 Generating TRUE BERT embedding for: {}", text);
371 self.generate_bert_embedding(text)?
372 } else {
373 return Err(MemvidError::MachineLearning(
375 "BERT model not loaded - true semantic search requires neural network inference"
376 .to_string(),
377 ));
378 };
379
380 self.cache.insert(text.to_string(), embedding.clone());
382
383 Ok(embedding)
384 }
385
386 pub fn encode_batch(&mut self, texts: &[String]) -> Result<Vec<Embedding>> {
388 let mut embeddings = Vec::new();
389
390 for chunk in texts.chunks(self.config.batch_size) {
392 for text in chunk {
393 embeddings.push(self.encode(text)?);
394 }
395 }
396
397 Ok(embeddings)
398 }
399
400 pub fn encode_batch_parallel(
402 &mut self,
403 texts: &[String],
404 ) -> Result<(Vec<Embedding>, Vec<String>)> {
405 use rayon::prelude::*;
406
407 let batch_size = self.config.batch_size.min(texts.len());
408 let mut successful_embeddings = Vec::new();
409 let mut failed_texts = Vec::new();
410
411 for chunk in texts.chunks(batch_size) {
413 let chunk_results: Vec<(usize, Result<Embedding>)> = chunk
414 .par_iter()
415 .enumerate()
416 .map(|(local_idx, text)| {
417 let embedding_result = if self.is_ready {
419 self.generate_enhanced_embedding_standalone(text)
420 } else {
421 self.generate_placeholder_embedding_standalone(text)
422 };
423 (local_idx, embedding_result)
424 })
425 .collect();
426
427 for (local_idx, result) in chunk_results {
429 let text = &chunk[local_idx];
430 match result {
431 Ok(embedding) => {
432 self.cache.insert(text.clone(), embedding.clone());
434 successful_embeddings.push(embedding);
435 }
436 Err(_) => {
437 log::warn!("Failed to generate embedding for text: {}", text);
438 failed_texts.push(text.clone());
439 successful_embeddings.push(vec![0.0; self.dimension()]);
441 }
442 }
443 }
444 }
445
446 Ok((successful_embeddings, failed_texts))
447 }
448
449 pub fn encode_batch_with_retry(
451 &mut self,
452 texts: &[String],
453 max_retries: usize,
454 retry_delay_ms: u64,
455 ) -> Result<(Vec<Embedding>, Vec<String>, usize)> {
456 let mut all_embeddings = Vec::new();
457 let mut failed_texts = Vec::new();
458 let mut total_retries = 0;
459
460 for text in texts {
461 let mut attempts = 0;
462 let mut last_error = None;
463
464 while attempts <= max_retries {
465 match self.encode(text) {
466 Ok(embedding) => {
467 all_embeddings.push(embedding);
468 break;
469 }
470 Err(e) => {
471 attempts += 1;
472 total_retries += 1;
473 last_error = Some(e);
474
475 if attempts <= max_retries {
476 std::thread::sleep(std::time::Duration::from_millis(
477 retry_delay_ms * attempts as u64,
478 ));
479 log::debug!(
480 "Retrying embedding generation for text (attempt {}): {}",
481 attempts,
482 text
483 );
484 }
485 }
486 }
487 }
488
489 if attempts > max_retries {
490 if let Some(e) = last_error {
491 log::error!(
492 "Failed to generate embedding after {} retries: {}",
493 max_retries,
494 e
495 );
496 }
497 failed_texts.push(text.clone());
498 all_embeddings.push(vec![0.0; self.dimension()]);
500 }
501 }
502
503 Ok((all_embeddings, failed_texts, total_retries))
504 }
505
506 fn generate_enhanced_embedding_standalone(&self, text: &str) -> Result<Embedding> {
508 let tokenized = self.text_processor.tokenize(text)?;
510
511 let mut embedding = vec![0.0f32; 384]; let valid_tokens: Vec<u32> = tokenized
516 .input_ids
517 .iter()
518 .zip(tokenized.attention_mask.iter())
519 .filter(|(_, mask)| **mask == 1)
520 .map(|(token_id, _)| *token_id)
521 .collect();
522
523 if !valid_tokens.is_empty() {
524 for (i, &token_id) in valid_tokens.iter().enumerate() {
526 let token_float = token_id as f32;
527
528 for hash_func in 0..5 {
530 let mut hasher = std::collections::hash_map::DefaultHasher::new();
531 use std::hash::{Hash, Hasher};
532
533 (token_id.wrapping_add(hash_func * 1000)).hash(&mut hasher);
534 let hash = hasher.finish();
535
536 for j in 0..20 {
538 let dim_idx = ((hash as usize).wrapping_add(j * 19).wrapping_add(i * 17))
539 % embedding.len();
540 let value = ((hash >> (j * 3)) & 0x7) as f32 / 8.0 - 0.5;
541 embedding[dim_idx] += value * (1.0 / (i as f32 + 1.0).sqrt());
542 }
543 }
544
545 let pos_weight = 1.0 - (i as f32 / valid_tokens.len() as f32) * 0.1;
547 for k in 0..10 {
548 let dim = (token_id as usize * 7 + k * 13) % embedding.len();
549 embedding[dim] += (token_float / 30000.0) * pos_weight;
550 }
551 }
552
553 let seq_norm = 1.0 / (valid_tokens.len() as f32).sqrt();
555 for val in &mut embedding {
556 *val *= seq_norm;
557 }
558 }
559
560 if self.config.normalize {
562 Ok(self.normalize_embedding_standalone(embedding))
563 } else {
564 Ok(embedding)
565 }
566 }
567
568 fn generate_placeholder_embedding_standalone(&self, text: &str) -> Result<Embedding> {
570 let mut embedding = vec![0.0f32; 384]; use std::collections::hash_map::DefaultHasher;
575 use std::hash::{Hash, Hasher};
576
577 for (i, word) in text.split_whitespace().enumerate() {
578 let mut hasher = DefaultHasher::new();
579 word.hash(&mut hasher);
580 let hash = hasher.finish();
581
582 for j in 0..10.min(embedding.len()) {
584 let idx = (i * 10 + j) % embedding.len();
585 embedding[idx] += ((hash >> (j * 6)) & 0x3F) as f32 / 64.0 - 0.5;
586 }
587 }
588
589 if self.config.normalize {
591 Ok(self.normalize_embedding_standalone(embedding))
592 } else {
593 Ok(embedding)
594 }
595 }
596
597 fn normalize_embedding_standalone(&self, mut embedding: Vec<f32>) -> Vec<f32> {
599 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
600 if norm > 0.0 {
601 for val in &mut embedding {
602 *val /= norm;
603 }
604 }
605 embedding
606 }
607
608 pub fn clear_cache(&mut self) {
610 self.cache.clear();
611 }
612
613 pub fn cache_size(&self) -> usize {
615 self.cache.len()
616 }
617
618 pub fn config(&self) -> &EmbeddingConfig {
620 &self.config
621 }
622
623 pub fn has_tokenizer(&self) -> bool {
625 self.text_processor.has_tokenizer()
626 }
627
628 pub fn dimension(&self) -> usize {
630 384 }
632
633 pub fn health_check(&self) -> EmbeddingHealth {
635 EmbeddingHealth {
636 is_ready: self.is_ready,
637 has_tokenizer: self.text_processor.has_tokenizer(),
638 cache_size: self.cache.len(),
639 cache_hit_rate: 0.0, model_name: self.config.model_name.clone(),
641 device_type: format!("{:?}", self.config.device_type),
642 last_inference_time: None, }
644 }
645
646 pub fn clear_cache_selective(&mut self, keep_recent: Option<usize>) {
648 if let Some(keep_count) = keep_recent {
649 if self.cache.len() > keep_count {
650 let excess = self.cache.len() - keep_count;
652 let keys_to_remove: Vec<String> = self.cache.keys().take(excess).cloned().collect();
653 for key in keys_to_remove {
654 self.cache.remove(&key);
655 }
656 }
657 } else {
658 self.cache.clear();
659 }
660 }
661
662 pub fn cache_stats(&self) -> CacheStats {
664 let total_text_length: usize = self.cache.keys().map(|k| k.len()).sum();
665 let avg_text_length = if !self.cache.is_empty() {
666 total_text_length as f32 / self.cache.len() as f32
667 } else {
668 0.0
669 };
670
671 CacheStats {
672 size: self.cache.len(),
673 total_text_length,
674 avg_text_length,
675 estimated_memory_mb: (total_text_length + self.cache.len() * self.dimension() * 4)
676 as f32
677 / 1_048_576.0,
678 }
679 }
680
681 #[cfg(not(test))]
683 fn apply_mean_pooling(
684 &self,
685 hidden_states: &Tensor,
686 attention_mask: &Tensor,
687 ) -> Result<Tensor> {
688 log::trace!("Applying attention-weighted mean pooling");
689
690 let expanded_mask = attention_mask
692 .unsqueeze(2)?
693 .expand(hidden_states.shape())?
694 .to_dtype(hidden_states.dtype())?;
695
696 let masked_hidden = hidden_states.mul(&expanded_mask)?;
698
699 let summed = masked_hidden.sum(1)?;
701
702 let mask_sum = expanded_mask.sum(1)?;
704
705 let mask_sum = mask_sum.clamp(1e-9, f32::INFINITY)?;
707
708 let pooled = summed.div(&mask_sum)?;
710
711 log::trace!("Mean pooling complete, output shape: {:?}", pooled.shape());
712 Ok(pooled)
713 }
714}
715
716#[derive(Debug, Clone, Serialize, Deserialize)]
718pub struct EmbeddingHealth {
719 pub is_ready: bool,
720 pub has_tokenizer: bool,
721 pub cache_size: usize,
722 pub cache_hit_rate: f32,
723 pub model_name: String,
724 pub device_type: String,
725 pub last_inference_time: Option<chrono::DateTime<chrono::Utc>>,
726}
727
728#[derive(Debug, Clone, Serialize, Deserialize)]
730pub struct CacheStats {
731 pub size: usize,
732 pub total_text_length: usize,
733 pub avg_text_length: f32,
734 pub estimated_memory_mb: f32,
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[tokio::test]
742 async fn test_embedding_config_default() {
743 let config = EmbeddingConfig::default();
744 assert_eq!(config.model_name, "sentence-transformers/all-MiniLM-L6-v2");
745 assert_eq!(config.max_length, 384);
746 assert!(config.normalize);
747 }
748
749 #[tokio::test]
750 async fn test_embedding_model_creation() {
751 let config = EmbeddingConfig::default();
752 let model = EmbeddingModel::new(config).await.unwrap();
753 assert_eq!(model.cache_size(), 0);
754 assert_eq!(model.dimension(), 384);
755 }
756
757 #[tokio::test]
758 async fn test_enhanced_embedding_generation() {
759 let config = EmbeddingConfig::default();
760 let mut model = EmbeddingModel::new(config).await.unwrap();
761
762 let text = "This is a test sentence for enhanced embedding";
763 let embedding = model.encode(text).unwrap();
764
765 assert_eq!(embedding.len(), 384); assert_eq!(model.cache_size(), 1);
767
768 let embedding2 = model.encode(text).unwrap();
770 assert_eq!(embedding, embedding2);
771 assert_eq!(model.cache_size(), 1); }
773
774 #[tokio::test]
775 async fn test_embedding_batch() {
776 let config = EmbeddingConfig::default();
777 let mut model = EmbeddingModel::new(config).await.unwrap();
778
779 let texts = vec![
780 "First sentence with enhanced tokenization".to_string(),
781 "Second sentence for comparison".to_string(),
782 "Third sentence with different content".to_string(),
783 ];
784
785 let embeddings = model.encode_batch(&texts).unwrap();
786 assert_eq!(embeddings.len(), 3);
787 assert_eq!(model.cache_size(), 3);
788
789 assert_ne!(embeddings[0], embeddings[1]);
791 assert_ne!(embeddings[1], embeddings[2]);
792 }
793
794 #[tokio::test]
795 async fn test_embedding_normalization() {
796 let mut config = EmbeddingConfig::default();
797 config.normalize = true;
798
799 let mut model = EmbeddingModel::new(config).await.unwrap();
800 let embedding = model.encode("test normalization").unwrap();
801
802 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
804 assert!(
805 (norm - 1.0).abs() < 1e-6,
806 "Embedding should be normalized, got norm: {}",
807 norm
808 );
809 }
810
811 #[tokio::test]
812 async fn test_embedding_deterministic() {
813 let config = EmbeddingConfig::default();
814 let mut model1 = EmbeddingModel::new(config.clone()).await.unwrap();
815 let mut model2 = EmbeddingModel::new(config).await.unwrap();
816
817 let text = "Test deterministic behavior";
818 let embedding1 = model1.encode(text).unwrap();
819 let embedding2 = model2.encode(text).unwrap();
820
821 assert_eq!(embedding1, embedding2);
823 }
824
825 #[tokio::test]
826 async fn test_phase_3d_parallel_embedding() {
827 let config = EmbeddingConfig::default();
828 let mut model = EmbeddingModel::new(config).await.unwrap();
829
830 let texts = vec![
831 "Parallel processing test 1".to_string(),
832 "Parallel processing test 2".to_string(),
833 "Parallel processing test 3".to_string(),
834 "Parallel processing test 4".to_string(),
835 ];
836
837 let (embeddings, failed_texts) = model.encode_batch_parallel(&texts).unwrap();
838
839 assert_eq!(embeddings.len(), texts.len());
840 assert_eq!(failed_texts.len(), 0); assert_eq!(model.cache_size(), texts.len()); for i in 0..embeddings.len() {
845 for j in i + 1..embeddings.len() {
846 assert_ne!(embeddings[i], embeddings[j]);
847 }
848 }
849 }
850
851 #[tokio::test]
852 async fn test_phase_3d_error_recovery() {
853 let config = EmbeddingConfig::default();
854 let mut model = EmbeddingModel::new(config).await.unwrap();
855
856 let texts = vec![
857 "Valid text 1".to_string(),
858 "Valid text 2".to_string(),
859 "Valid text 3".to_string(),
860 ];
861
862 let (embeddings, failed_texts, total_retries) = model
864 .encode_batch_with_retry(
865 &texts, 2, 50, )
868 .unwrap();
869
870 assert_eq!(embeddings.len(), texts.len());
871 assert_eq!(failed_texts.len(), 0); assert_eq!(total_retries, 0); }
874
875 #[tokio::test]
876 async fn test_phase_3d_health_check() {
877 let config = EmbeddingConfig::default();
878 let model = EmbeddingModel::new(config).await.unwrap();
879
880 let health = model.health_check();
881
882 assert_eq!(health.model_name, "sentence-transformers/all-MiniLM-L6-v2");
883 assert_eq!(health.cache_size, 0);
884 assert!(health.device_type.contains("Cpu"));
885 }
887
888 #[tokio::test]
889 async fn test_phase_3d_cache_management() {
890 let config = EmbeddingConfig::default();
891 let mut model = EmbeddingModel::new(config).await.unwrap();
892
893 let texts = vec![
895 "Cache test 1".to_string(),
896 "Cache test 2".to_string(),
897 "Cache test 3".to_string(),
898 "Cache test 4".to_string(),
899 "Cache test 5".to_string(),
900 ];
901
902 for text in &texts {
903 model.encode(text).unwrap();
904 }
905
906 assert_eq!(model.cache_size(), 5);
907
908 let stats = model.cache_stats();
910 assert_eq!(stats.size, 5);
911 assert!(stats.total_text_length > 0);
912 assert!(stats.avg_text_length > 0.0);
913 assert!(stats.estimated_memory_mb > 0.0);
914
915 model.clear_cache_selective(Some(3)); assert_eq!(model.cache_size(), 3);
918
919 model.clear_cache_selective(None);
921 assert_eq!(model.cache_size(), 0);
922 }
923
924 #[tokio::test]
925 async fn test_phase_3d_standalone_methods() {
926 let config = EmbeddingConfig::default();
927 let model = EmbeddingModel::new(config).await.unwrap();
928
929 let text = "Standalone method test";
930
931 let embedding1 = model.generate_enhanced_embedding_standalone(text).unwrap();
933 let embedding2 = model.generate_enhanced_embedding_standalone(text).unwrap();
934
935 assert_eq!(embedding1, embedding2);
937 assert_eq!(embedding1.len(), 384);
938
939 let embedding3 = model
941 .generate_placeholder_embedding_standalone(text)
942 .unwrap();
943 assert_eq!(embedding3.len(), 384);
944
945 assert_ne!(embedding1, embedding3);
947 }
948
949 #[tokio::test]
950 async fn test_phase_3d_normalization_standalone() {
951 let config = EmbeddingConfig::default();
952 let model = EmbeddingModel::new(config).await.unwrap();
953
954 let unnormalized = vec![3.0, 4.0, 0.0]; let normalized = model.normalize_embedding_standalone(unnormalized);
956
957 let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
959 assert!((norm - 1.0).abs() < 1e-6);
960
961 assert!((normalized[0] - 0.6).abs() < 1e-6); assert!((normalized[1] - 0.8).abs() < 1e-6); assert!((normalized[2] - 0.0).abs() < 1e-6); }
966}