1use crate::error::{Result, LetheError};
2use crate::types::EmbeddingVector;
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct EmbeddingConfig {
10 pub provider: EmbeddingProvider,
11 pub model_name: String,
12 pub dimension: usize,
13 pub batch_size: usize,
14 pub timeout_ms: u64,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub enum EmbeddingProvider {
20 TransformersJs { model_id: String },
21 Ollama { base_url: String, model: String },
22 Fallback,
23}
24
25impl Default for EmbeddingConfig {
26 fn default() -> Self {
27 Self {
28 provider: EmbeddingProvider::TransformersJs {
29 model_id: "Xenova/bge-small-en-v1.5".to_string(),
30 },
31 model_name: "bge-small-en-v1.5".to_string(),
32 dimension: 384,
33 batch_size: 32,
34 timeout_ms: 30000,
35 }
36 }
37}
38
39#[async_trait]
41pub trait EmbeddingService: Send + Sync {
42 fn name(&self) -> &str;
44
45 fn dimension(&self) -> usize;
47
48 async fn embed(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>>;
50
51 async fn embed_single(&self, text: &str) -> Result<EmbeddingVector> {
53 let results = self.embed(&[text.to_string()]).await?;
54 results.into_iter().next()
55 .ok_or_else(|| LetheError::embedding("No embedding returned for single text"))
56 }
57}
58
59#[cfg(feature = "ollama")]
61pub struct OllamaEmbeddingService {
62 base_url: String,
63 model: String,
64 dimension: usize,
65 client: reqwest::Client,
66}
67
68#[cfg(feature = "ollama")]
69impl OllamaEmbeddingService {
70 pub fn new(base_url: String, model: String, dimension: usize) -> Self {
72 let client = reqwest::Client::builder()
73 .timeout(std::time::Duration::from_secs(30))
74 .build()
75 .expect("Failed to create HTTP client");
76
77 Self {
78 base_url,
79 model,
80 dimension,
81 client,
82 }
83 }
84
85 pub async fn test_connectivity(&self) -> Result<bool> {
87 let url = format!("{}/api/version", self.base_url);
88
89 match tokio::time::timeout(
90 std::time::Duration::from_millis(500),
91 self.client.get(&url).send()
92 ).await {
93 Ok(Ok(response)) => Ok(response.status().is_success()),
94 _ => Ok(false),
95 }
96 }
97}
98
99#[cfg(feature = "ollama")]
100#[async_trait]
101impl EmbeddingService for OllamaEmbeddingService {
102 fn name(&self) -> &str {
103 "ollama"
104 }
105
106 fn dimension(&self) -> usize {
107 self.dimension
108 }
109
110 async fn embed(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>> {
111 let mut embeddings = Vec::new();
112
113 for text in texts {
114 let request_body = serde_json::json!({
115 "model": self.model,
116 "prompt": text,
117 });
118
119 let url = format!("{}/api/embeddings", self.base_url);
120 let response = self.client
121 .post(&url)
122 .json(&request_body)
123 .send()
124 .await
125 .map_err(|e| LetheError::embedding(format!("Ollama request failed: {}", e)))?;
126
127 if !response.status().is_success() {
128 return Err(LetheError::embedding(format!(
129 "Ollama API error: {}",
130 response.status()
131 )));
132 }
133
134 let response_json: serde_json::Value = response
135 .json()
136 .await
137 .map_err(|e| LetheError::embedding(format!("Failed to parse Ollama response: {}", e)))?;
138
139 let embedding_data = response_json
140 .get("embedding")
141 .and_then(|e| e.as_array())
142 .ok_or_else(|| LetheError::embedding("No embedding data in Ollama response"))?;
143
144 let data: Vec<f32> = embedding_data
145 .iter()
146 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
147 .collect();
148
149 embeddings.push(EmbeddingVector {
150 data,
151 dimension: self.dimension,
152 });
153 }
154
155 Ok(embeddings)
156 }
157}
158
159pub struct FallbackEmbeddingService {
161 dimension: usize,
162}
163
164impl FallbackEmbeddingService {
165 pub fn new(dimension: usize) -> Self {
166 Self { dimension }
167 }
168}
169
170#[async_trait]
171impl EmbeddingService for FallbackEmbeddingService {
172 fn name(&self) -> &str {
173 "fallback"
174 }
175
176 fn dimension(&self) -> usize {
177 self.dimension
178 }
179
180 async fn embed(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>> {
181 tracing::warn!(
182 "Using fallback zero-vector embeddings for {} texts - vector search will be disabled",
183 texts.len()
184 );
185
186 let embeddings = texts
187 .iter()
188 .map(|_| EmbeddingVector {
189 data: vec![0.0; self.dimension],
190 dimension: self.dimension,
191 })
192 .collect();
193
194 Ok(embeddings)
195 }
196}
197
198pub struct EmbeddingServiceFactory;
200
201impl EmbeddingServiceFactory {
202 pub async fn create(config: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingService>> {
204 match &config.provider {
205 EmbeddingProvider::Ollama { base_url, model } => {
206 #[cfg(feature = "ollama")]
207 {
208 let service = OllamaEmbeddingService::new(
209 base_url.clone(),
210 model.clone(),
211 config.dimension,
212 );
213
214 if service.test_connectivity().await? {
216 tracing::info!("Using Ollama embeddings with model: {}", model);
217 Ok(Arc::new(service))
218 } else {
219 tracing::warn!("Ollama not available, falling back to zero vectors");
220 Ok(Arc::new(FallbackEmbeddingService::new(config.dimension)))
221 }
222 }
223 #[cfg(not(feature = "ollama"))]
224 {
225 tracing::warn!("Ollama feature not enabled, falling back to zero vectors");
226 Ok(Arc::new(FallbackEmbeddingService::new(config.dimension)))
227 }
228 }
229 EmbeddingProvider::TransformersJs { model_id: _ } => {
230 tracing::info!("TransformersJS embeddings not implemented in Rust, using fallback");
231 Ok(Arc::new(FallbackEmbeddingService::new(config.dimension)))
232 }
233 EmbeddingProvider::Fallback => {
234 tracing::info!("Using fallback embedding service");
235 Ok(Arc::new(FallbackEmbeddingService::new(config.dimension)))
236 }
237 }
238 }
239
240 pub async fn create_with_preference(
242 preference: Option<&str>,
243 ) -> Result<Arc<dyn EmbeddingService>> {
244 let config = match preference {
245 Some("ollama") => EmbeddingConfig {
246 provider: EmbeddingProvider::Ollama {
247 base_url: "http://localhost:11434".to_string(),
248 model: "nomic-embed-text".to_string(),
249 },
250 model_name: "nomic-embed-text".to_string(),
251 dimension: 768,
252 ..Default::default()
253 },
254 Some("transformersjs") | _ => EmbeddingConfig::default(),
255 };
256
257 Self::create(&config).await
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use std::sync::Arc;
265 use std::time::Duration;
266 use tokio::sync::Barrier;
267 use tokio::time::timeout;
268
269 #[tokio::test]
272 async fn test_fallback_embedding_service() {
273 let service = FallbackEmbeddingService::new(384);
274 let texts = vec!["hello".to_string(), "world".to_string()];
275
276 let embeddings = service.embed(&texts).await.unwrap();
277
278 assert_eq!(embeddings.len(), 2);
279 assert_eq!(embeddings[0].dimension, 384);
280 assert_eq!(embeddings[0].data.len(), 384);
281 assert!(embeddings[0].data.iter().all(|&x| x == 0.0));
282 }
283
284 #[tokio::test]
285 async fn test_embedding_service_factory() {
286 let config = EmbeddingConfig {
287 provider: EmbeddingProvider::Fallback,
288 dimension: 512,
289 ..Default::default()
290 };
291
292 let service = EmbeddingServiceFactory::create(&config).await.unwrap();
293
294 assert_eq!(service.name(), "fallback");
295 assert_eq!(service.dimension(), 512);
296 }
297
298 #[test]
299 fn test_embedding_config_serialization() {
300 let config = EmbeddingConfig::default();
301 let json = serde_json::to_string(&config).unwrap();
302 let deserialized: EmbeddingConfig = serde_json::from_str(&json).unwrap();
303
304 assert_eq!(config.dimension, deserialized.dimension);
305 assert_eq!(config.batch_size, deserialized.batch_size);
306 }
307
308 #[tokio::test]
309 async fn test_single_embedding() {
310 let service = FallbackEmbeddingService::new(128);
311 let embedding = service.embed_single("test text").await.unwrap();
312
313 assert_eq!(embedding.dimension, 128);
314 assert_eq!(embedding.data.len(), 128);
315 }
316
317 #[tokio::test]
318 async fn test_empty_text_embedding() {
319 let service = FallbackEmbeddingService::new(256);
320
321 let embedding = service.embed_single("").await.unwrap();
323 assert_eq!(embedding.dimension, 256);
324 assert_eq!(embedding.data.len(), 256);
325
326 let embedding = service.embed_single(" ").await.unwrap();
328 assert_eq!(embedding.dimension, 256);
329 assert!(embedding.data.iter().all(|&x| x == 0.0));
330 }
331
332 #[tokio::test]
333 async fn test_large_batch_embedding() {
334 let service = FallbackEmbeddingService::new(128);
335
336 let texts: Vec<String> = (0..100).map(|i| format!("text {}", i)).collect();
338
339 let embeddings = service.embed(&texts).await.unwrap();
340
341 assert_eq!(embeddings.len(), 100);
342 for (i, embedding) in embeddings.iter().enumerate() {
343 assert_eq!(embedding.dimension, 128);
344 assert_eq!(embedding.data.len(), 128);
345 assert!(embedding.data.iter().all(|&x| x == 0.0), "Embedding {} should be zero vector", i);
347 }
348 }
349
350 #[tokio::test]
351 async fn test_embedding_vector_properties() {
352 let service = FallbackEmbeddingService::new(512);
353 let embedding = service.embed_single("sample text").await.unwrap();
354
355 assert_eq!(embedding.dimension, 512);
357 assert_eq!(embedding.data.len(), 512);
358
359 assert!(embedding.data.iter().all(|&x| x.is_finite()));
361 assert!(embedding.data.iter().all(|&x| x == 0.0));
362 }
363
364 #[test]
365 fn test_embedding_config_default_values() {
366 let config = EmbeddingConfig::default();
367
368 assert_eq!(config.dimension, 384);
369 assert_eq!(config.batch_size, 32);
370 assert_eq!(config.timeout_ms, 30000);
371 assert_eq!(config.model_name, "bge-small-en-v1.5");
372
373 match config.provider {
374 EmbeddingProvider::TransformersJs { model_id } => {
375 assert_eq!(model_id, "Xenova/bge-small-en-v1.5");
376 }
377 _ => panic!("Expected TransformersJs provider"),
378 }
379 }
380
381 #[test]
382 fn test_embedding_provider_variants() {
383 let transformers_provider = EmbeddingProvider::TransformersJs {
384 model_id: "test-model".to_string(),
385 };
386
387 let ollama_provider = EmbeddingProvider::Ollama {
388 base_url: "http://localhost:11434".to_string(),
389 model: "embeddings".to_string(),
390 };
391
392 let fallback_provider = EmbeddingProvider::Fallback;
393
394 match transformers_provider {
396 EmbeddingProvider::TransformersJs { model_id } => assert_eq!(model_id, "test-model"),
397 _ => panic!("Expected TransformersJs variant"),
398 }
399
400 match ollama_provider {
401 EmbeddingProvider::Ollama { base_url, model } => {
402 assert_eq!(base_url, "http://localhost:11434");
403 assert_eq!(model, "embeddings");
404 }
405 _ => panic!("Expected Ollama variant"),
406 }
407
408 match fallback_provider {
409 EmbeddingProvider::Fallback => {},
410 _ => panic!("Expected Fallback variant"),
411 }
412 }
413
414 #[tokio::test]
415 async fn test_embedding_service_interface() {
416 let service = FallbackEmbeddingService::new(256);
417
418 assert_eq!(service.name(), "fallback");
420
421 assert_eq!(service.dimension(), 256);
423
424 let texts = vec!["text1".to_string(), "text2".to_string()];
426 let embeddings = service.embed(&texts).await.unwrap();
427 assert_eq!(embeddings.len(), 2);
428
429 let single_embedding = service.embed_single("single").await.unwrap();
431 assert_eq!(single_embedding.dimension, 256);
432 }
433
434 #[test]
435 fn test_embedding_config_clone_and_debug() {
436 let config = EmbeddingConfig::default();
437
438 let cloned_config = config.clone();
440 assert_eq!(config.dimension, cloned_config.dimension);
441 assert_eq!(config.batch_size, cloned_config.batch_size);
442
443 let debug_str = format!("{:?}", config);
445 assert!(debug_str.contains("EmbeddingConfig"));
446 assert!(debug_str.contains("dimension"));
447 assert!(debug_str.contains("batch_size"));
448 }
449
450 #[tokio::test]
451 async fn test_embedding_error_scenarios() {
452 let service = FallbackEmbeddingService::new(64);
453
454 let long_text = "a".repeat(10000);
456 let embedding = service.embed_single(&long_text).await.unwrap();
457 assert_eq!(embedding.dimension, 64);
458
459 let special_text = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~";
461 let embedding = service.embed_single(special_text).await.unwrap();
462 assert_eq!(embedding.dimension, 64);
463
464 let unicode_text = "Hello 世界 🌍 тест";
466 let embedding = service.embed_single(unicode_text).await.unwrap();
467 assert_eq!(embedding.dimension, 64);
468 }
469
470 #[test]
477 fn test_ollama_embedding_service_creation() {
478 let service = OllamaEmbeddingService::new(
479 "http://localhost:11434".to_string(),
480 "nomic-embed-text".to_string(),
481 768,
482 );
483
484 assert_eq!(service.name(), "ollama");
485 assert_eq!(service.dimension(), 768);
486 assert_eq!(service.base_url, "http://localhost:11434");
487 assert_eq!(service.model, "nomic-embed-text");
488 }
489
490 #[tokio::test]
491 async fn test_ollama_connectivity_timeout() {
492 let service = OllamaEmbeddingService::new(
493 "http://unreachable-host:11434".to_string(),
494 "test-model".to_string(),
495 768,
496 );
497
498 let start = std::time::Instant::now();
500 let result = service.test_connectivity().await.unwrap();
501 let duration = start.elapsed();
502
503 assert!(!result);
504 assert!(duration < Duration::from_secs(1)); }
506
507 #[tokio::test]
508 async fn test_ollama_connectivity_invalid_url() {
509 let service = OllamaEmbeddingService::new(
510 "invalid-url".to_string(),
511 "test-model".to_string(),
512 768,
513 );
514
515 let result = service.test_connectivity().await.unwrap();
516 assert!(!result);
517 }
518
519 #[tokio::test]
520 async fn test_ollama_embed_network_error() {
521 let service = OllamaEmbeddingService::new(
522 "http://unreachable-host:11434".to_string(),
523 "test-model".to_string(),
524 768,
525 );
526
527 let texts = vec!["test text".to_string()];
528 let result = service.embed(&texts).await;
529
530 assert!(result.is_err());
531 let error_msg = result.unwrap_err().to_string();
532 assert!(error_msg.contains("Ollama request failed"));
533 }
534
535 #[tokio::test]
536 async fn test_ollama_embed_single_delegated() {
537 let service = OllamaEmbeddingService::new(
538 "http://unreachable-host:11434".to_string(),
539 "test-model".to_string(),
540 384,
541 );
542
543 let result = service.embed_single("test").await;
545
546 assert!(result.is_err());
547 let error_msg = result.unwrap_err().to_string();
548 assert!(error_msg.contains("Ollama request failed"));
549 }
550
551 #[tokio::test]
552 async fn test_ollama_embed_empty_response_error() {
553 let service = OllamaEmbeddingService::new(
557 "http://localhost:11434".to_string(),
558 "test-model".to_string(),
559 384,
560 );
561
562 assert_eq!(service.name(), "ollama");
565 assert_eq!(service.model, "test-model");
566 assert_eq!(service.base_url, "http://localhost:11434");
567 }
568
569 #[tokio::test]
574 async fn test_factory_create_ollama_with_connectivity_test() {
575 let config = EmbeddingConfig {
576 provider: EmbeddingProvider::Ollama {
577 base_url: "http://unreachable-host:11434".to_string(),
578 model: "test-model".to_string(),
579 },
580 dimension: 768,
581 ..Default::default()
582 };
583
584 let service = EmbeddingServiceFactory::create(&config).await.unwrap();
586
587 assert_eq!(service.name(), "fallback");
589 assert_eq!(service.dimension(), 768);
590 }
591
592 #[tokio::test]
593 async fn test_factory_create_transformers_js() {
594 let config = EmbeddingConfig {
595 provider: EmbeddingProvider::TransformersJs {
596 model_id: "test-model".to_string(),
597 },
598 dimension: 384,
599 ..Default::default()
600 };
601
602 let service = EmbeddingServiceFactory::create(&config).await.unwrap();
603
604 assert_eq!(service.name(), "fallback");
606 assert_eq!(service.dimension(), 384);
607 }
608
609 #[tokio::test]
610 async fn test_factory_create_explicit_fallback() {
611 let config = EmbeddingConfig {
612 provider: EmbeddingProvider::Fallback,
613 dimension: 1024,
614 ..Default::default()
615 };
616
617 let service = EmbeddingServiceFactory::create(&config).await.unwrap();
618
619 assert_eq!(service.name(), "fallback");
620 assert_eq!(service.dimension(), 1024);
621 }
622
623 #[tokio::test]
624 async fn test_factory_create_with_preference_ollama() {
625 let service = EmbeddingServiceFactory::create_with_preference(Some("ollama")).await.unwrap();
626
627 match service.name() {
630 "ollama" => {
631 assert_eq!(service.dimension(), 768); }
633 "fallback" => {
634 assert_eq!(service.dimension(), 768); }
636 _ => panic!("Unexpected service name: {}", service.name()),
637 }
638 }
639
640 #[tokio::test]
641 async fn test_factory_create_with_preference_transformers() {
642 let service = EmbeddingServiceFactory::create_with_preference(Some("transformersjs")).await.unwrap();
643
644 assert_eq!(service.name(), "fallback");
645 assert_eq!(service.dimension(), 384); }
647
648 #[tokio::test]
649 async fn test_factory_create_with_preference_none() {
650 let service = EmbeddingServiceFactory::create_with_preference(None).await.unwrap();
651
652 assert_eq!(service.name(), "fallback");
653 assert_eq!(service.dimension(), 384); }
655
656 #[tokio::test]
657 async fn test_factory_create_with_preference_unknown() {
658 let service = EmbeddingServiceFactory::create_with_preference(Some("unknown")).await.unwrap();
659
660 assert_eq!(service.name(), "fallback");
662 assert_eq!(service.dimension(), 384);
663 }
664
665 #[tokio::test]
670 async fn test_embed_single_empty_result_error() {
671 struct EmptyEmbeddingService;
673
674 #[async_trait]
675 impl EmbeddingService for EmptyEmbeddingService {
676 fn name(&self) -> &str { "empty" }
677 fn dimension(&self) -> usize { 384 }
678
679 async fn embed(&self, _texts: &[String]) -> Result<Vec<EmbeddingVector>> {
680 Ok(vec![]) }
682 }
683
684 let service = EmptyEmbeddingService;
685 let result = service.embed_single("test").await;
686
687 assert!(result.is_err());
688 let error_msg = result.unwrap_err().to_string();
689 assert!(error_msg.contains("No embedding returned for single text"));
690 }
691
692 #[tokio::test]
693 async fn test_fallback_service_with_maximum_dimensions() {
694 let service = FallbackEmbeddingService::new(4096); let embedding = service.embed_single("test").await.unwrap();
696
697 assert_eq!(embedding.dimension, 4096);
698 assert_eq!(embedding.data.len(), 4096);
699 assert!(embedding.data.iter().all(|&x| x == 0.0));
700 }
701
702 #[tokio::test]
703 async fn test_fallback_service_with_minimum_dimensions() {
704 let service = FallbackEmbeddingService::new(1); let embedding = service.embed_single("test").await.unwrap();
706
707 assert_eq!(embedding.dimension, 1);
708 assert_eq!(embedding.data.len(), 1);
709 assert_eq!(embedding.data[0], 0.0);
710 }
711
712 #[tokio::test]
713 async fn test_batch_processing_edge_cases() {
714 let service = FallbackEmbeddingService::new(256);
715
716 let empty_texts: Vec<String> = vec![];
718 let embeddings = service.embed(&empty_texts).await.unwrap();
719 assert_eq!(embeddings.len(), 0);
720
721 let single_text = vec!["solo".to_string()];
723 let embeddings = service.embed(&single_text).await.unwrap();
724 assert_eq!(embeddings.len(), 1);
725 assert_eq!(embeddings[0].dimension, 256);
726 }
727
728 #[tokio::test]
729 async fn test_concurrent_embedding_operations() {
730 let service = Arc::new(FallbackEmbeddingService::new(128));
731 let barrier = Arc::new(Barrier::new(10));
732
733 let handles: Vec<_> = (0..10).map(|i| {
735 let service = service.clone();
736 let barrier = barrier.clone();
737
738 tokio::spawn(async move {
739 barrier.wait().await;
740 service.embed_single(&format!("concurrent text {}", i)).await
741 })
742 }).collect();
743
744 for handle in handles {
746 let result = handle.await.unwrap().unwrap();
747 assert_eq!(result.dimension, 128);
748 }
749 }
750
751 #[tokio::test]
752 async fn test_embedding_operations_under_timeout() {
753 let service = FallbackEmbeddingService::new(256);
754
755 let result = timeout(Duration::from_millis(100), service.embed_single("test")).await;
757
758 assert!(result.is_ok());
759 let embedding = result.unwrap().unwrap();
760 assert_eq!(embedding.dimension, 256);
761 }
762
763 #[tokio::test]
764 async fn test_massive_text_processing() {
765 let service = FallbackEmbeddingService::new(64);
766
767 let massive_text = "word ".repeat(100_000); let embedding = service.embed_single(&massive_text).await.unwrap();
770
771 assert_eq!(embedding.dimension, 64);
772 assert!(embedding.data.iter().all(|&x| x == 0.0));
773 }
774
775 #[tokio::test]
776 async fn test_mixed_content_batch() {
777 let service = FallbackEmbeddingService::new(128);
778
779 let mixed_texts = vec![
780 "".to_string(), "Normal text".to_string(), "🚀🌟💻".to_string(), "Mixed 🎉 content!".to_string(), "Very long ".repeat(1000), "تجريب العربية".to_string(), "测试中文".to_string(), "Тест кириллицы".to_string(), ];
789
790 let embeddings = service.embed(&mixed_texts).await.unwrap();
791
792 assert_eq!(embeddings.len(), 8);
793 for embedding in &embeddings {
794 assert_eq!(embedding.dimension, 128);
795 assert!(embedding.data.iter().all(|&x| x == 0.0));
796 }
797 }
798
799 #[tokio::test]
800 async fn test_stress_test_rapid_requests() {
801 let service = Arc::new(FallbackEmbeddingService::new(64));
802
803 for i in 0..100 {
805 let embedding = service.embed_single(&format!("stress test {}", i)).await.unwrap();
806 assert_eq!(embedding.dimension, 64);
807 }
808 }
809
810 #[test]
815 fn test_embedding_config_custom_values() {
816 let config = EmbeddingConfig {
817 provider: EmbeddingProvider::Ollama {
818 base_url: "http://custom:8080".to_string(),
819 model: "custom-model".to_string(),
820 },
821 model_name: "custom-model".to_string(),
822 dimension: 1536,
823 batch_size: 64,
824 timeout_ms: 60000,
825 };
826
827 assert_eq!(config.dimension, 1536);
828 assert_eq!(config.batch_size, 64);
829 assert_eq!(config.timeout_ms, 60000);
830 assert_eq!(config.model_name, "custom-model");
831
832 match config.provider {
833 EmbeddingProvider::Ollama { base_url, model } => {
834 assert_eq!(base_url, "http://custom:8080");
835 assert_eq!(model, "custom-model");
836 }
837 _ => panic!("Expected Ollama provider"),
838 }
839 }
840
841 #[test]
842 fn test_embedding_provider_serialization() {
843 let transformers = EmbeddingProvider::TransformersJs {
845 model_id: "test-model".to_string(),
846 };
847 let json = serde_json::to_string(&transformers).unwrap();
848 let deserialized: EmbeddingProvider = serde_json::from_str(&json).unwrap();
849
850 match deserialized {
851 EmbeddingProvider::TransformersJs { model_id } => {
852 assert_eq!(model_id, "test-model");
853 }
854 _ => panic!("Expected TransformersJs provider"),
855 }
856
857 let ollama = EmbeddingProvider::Ollama {
859 base_url: "http://test:11434".to_string(),
860 model: "test-model".to_string(),
861 };
862 let json = serde_json::to_string(&ollama).unwrap();
863 let deserialized: EmbeddingProvider = serde_json::from_str(&json).unwrap();
864
865 match deserialized {
866 EmbeddingProvider::Ollama { base_url, model } => {
867 assert_eq!(base_url, "http://test:11434");
868 assert_eq!(model, "test-model");
869 }
870 _ => panic!("Expected Ollama provider"),
871 }
872
873 let fallback = EmbeddingProvider::Fallback;
875 let json = serde_json::to_string(&fallback).unwrap();
876 let deserialized: EmbeddingProvider = serde_json::from_str(&json).unwrap();
877
878 match deserialized {
879 EmbeddingProvider::Fallback => {},
880 _ => panic!("Expected Fallback provider"),
881 }
882 }
883
884 #[test]
885 fn test_embedding_config_complex_serialization() {
886 let config = EmbeddingConfig {
887 provider: EmbeddingProvider::Ollama {
888 base_url: "http://production:11434".to_string(),
889 model: "production-model".to_string(),
890 },
891 model_name: "production-model".to_string(),
892 dimension: 2048,
893 batch_size: 128,
894 timeout_ms: 45000,
895 };
896
897 let json = serde_json::to_string_pretty(&config).unwrap();
899
900 let deserialized: EmbeddingConfig = serde_json::from_str(&json).unwrap();
902
903 assert_eq!(config.dimension, deserialized.dimension);
905 assert_eq!(config.batch_size, deserialized.batch_size);
906 assert_eq!(config.timeout_ms, deserialized.timeout_ms);
907 assert_eq!(config.model_name, deserialized.model_name);
908
909 match (&config.provider, &deserialized.provider) {
910 (
911 EmbeddingProvider::Ollama { base_url: url1, model: model1 },
912 EmbeddingProvider::Ollama { base_url: url2, model: model2 }
913 ) => {
914 assert_eq!(url1, url2);
915 assert_eq!(model1, model2);
916 }
917 _ => panic!("Provider mismatch during serialization"),
918 }
919 }
920
921 #[tokio::test]
926 async fn test_embedding_performance_characteristics() {
927 let service = FallbackEmbeddingService::new(384);
928
929 let start = std::time::Instant::now();
931 let _embedding = service.embed_single("performance test").await.unwrap();
932 let single_duration = start.elapsed();
933
934 let texts: Vec<String> = (0..100).map(|i| format!("batch text {}", i)).collect();
936 let start = std::time::Instant::now();
937 let embeddings = service.embed(&texts).await.unwrap();
938 let batch_duration = start.elapsed();
939
940 assert_eq!(embeddings.len(), 100);
942
943 assert!(single_duration < Duration::from_millis(10));
945 assert!(batch_duration < Duration::from_millis(100));
946
947 let per_item_batch = batch_duration.as_nanos() / 100;
949 let single_item = single_duration.as_nanos();
950
951 assert!(per_item_batch <= single_item * 2); }
954
955 #[tokio::test]
956 async fn test_memory_efficiency_large_batches() {
957 let service = FallbackEmbeddingService::new(1024); for chunk in 0..10 {
961 let texts: Vec<String> = (0..50)
962 .map(|i| format!("chunk {} item {}", chunk, i))
963 .collect();
964
965 let embeddings = service.embed(&texts).await.unwrap();
966 assert_eq!(embeddings.len(), 50);
967
968 for embedding in embeddings {
970 assert_eq!(embedding.dimension, 1024);
971 assert_eq!(embedding.data.len(), 1024);
972 }
973 }
974 }
975
976 #[tokio::test]
977 async fn test_concurrent_service_usage() {
978 let service = Arc::new(FallbackEmbeddingService::new(256));
979
980 let tasks: Vec<_> = (0..20).map(|task_id| {
982 let service = service.clone();
983 tokio::spawn(async move {
984 let mut results = Vec::new();
985
986 for i in 0..5 {
987 let text = format!("task {} iteration {}", task_id, i);
988 let embedding = service.embed_single(&text).await?;
989 results.push(embedding);
990 }
991
992 Ok::<Vec<EmbeddingVector>, LetheError>(results)
993 })
994 }).collect();
995
996 for task in tasks {
998 let results = task.await.unwrap().unwrap();
999 assert_eq!(results.len(), 5);
1000
1001 for embedding in results {
1002 assert_eq!(embedding.dimension, 256);
1003 }
1004 }
1005 }
1006
1007 #[tokio::test]
1012 async fn test_service_trait_object_usage() {
1013 let services: Vec<Arc<dyn EmbeddingService>> = vec![
1015 Arc::new(FallbackEmbeddingService::new(128)),
1016 Arc::new(FallbackEmbeddingService::new(256)),
1017 Arc::new(FallbackEmbeddingService::new(512)),
1018 ];
1019
1020 for (i, service) in services.iter().enumerate() {
1021 let expected_dim = match i {
1022 0 => 128,
1023 1 => 256,
1024 2 => 512,
1025 _ => unreachable!(),
1026 };
1027
1028 assert_eq!(service.name(), "fallback");
1029 assert_eq!(service.dimension(), expected_dim);
1030
1031 let embedding = service.embed_single("trait object test").await.unwrap();
1032 assert_eq!(embedding.dimension, expected_dim);
1033 }
1034 }
1035
1036 #[tokio::test]
1037 async fn test_factory_pattern_comprehensive() {
1038 let configs = vec![
1040 EmbeddingConfig {
1041 provider: EmbeddingProvider::Fallback,
1042 dimension: 384,
1043 ..Default::default()
1044 },
1045 EmbeddingConfig {
1046 provider: EmbeddingProvider::TransformersJs {
1047 model_id: "test-transformers".to_string(),
1048 },
1049 dimension: 768,
1050 ..Default::default()
1051 },
1052 EmbeddingConfig {
1053 provider: EmbeddingProvider::Ollama {
1054 base_url: "http://test:11434".to_string(),
1055 model: "test-ollama".to_string(),
1056 },
1057 dimension: 1024,
1058 ..Default::default()
1059 },
1060 ];
1061
1062 for config in configs {
1063 let service = EmbeddingServiceFactory::create(&config).await.unwrap();
1064
1065 assert_eq!(service.name(), "fallback");
1067 assert_eq!(service.dimension(), config.dimension);
1068
1069 let embedding = service.embed_single("factory test").await.unwrap();
1071 assert_eq!(embedding.dimension, config.dimension);
1072 }
1073 }
1074
1075 #[tokio::test]
1076 async fn test_preference_based_factory_all_options() {
1077 let preferences = vec![
1078 None,
1079 Some("ollama"),
1080 Some("transformersjs"),
1081 Some("transformers"),
1082 Some("unknown"),
1083 Some(""),
1084 ];
1085
1086 for preference in preferences {
1087 let service = EmbeddingServiceFactory::create_with_preference(preference).await.unwrap();
1088
1089 match (preference, service.name()) {
1091 (Some("ollama"), "ollama") => {
1092 assert_eq!(service.dimension(), 768);
1094 }
1095 (Some("ollama"), "fallback") => {
1096 assert_eq!(service.dimension(), 768);
1098 }
1099 (_, "fallback") => {
1100 assert!(service.dimension() > 0);
1102 }
1103 _ => {
1104 panic!("Unexpected service name '{}' for preference '{:?}'", service.name(), preference);
1106 }
1107 }
1108
1109 let embedding_result = service.embed_single("preference test").await;
1111 match embedding_result {
1112 Ok(embedding) => {
1113 assert!(embedding.dimension > 0);
1114 }
1115 Err(e) => {
1116 if service.name() == "ollama" && e.to_string().contains("Ollama API error") {
1118 let fallback_service = FallbackEmbeddingService::new(service.dimension());
1121 let embedding = fallback_service.embed_single("preference test").await.unwrap();
1122 assert!(embedding.dimension > 0);
1123 } else {
1124 panic!("Unexpected error: {}", e);
1125 }
1126 }
1127 }
1128 }
1129 }
1130
1131 #[test]
1132 fn test_embedding_config_edge_cases() {
1133 let mut config = EmbeddingConfig::default();
1134
1135 config.timeout_ms = 0;
1137 assert_eq!(config.timeout_ms, 0);
1138
1139 config.dimension = 512;
1141 assert_eq!(config.dimension, 512);
1142
1143 config.provider = EmbeddingProvider::Ollama {
1145 base_url: "http://localhost:11434".to_string(),
1146 model: "nomic-embed-text".to_string(),
1147 };
1148 if let EmbeddingProvider::Ollama { base_url, model } = &config.provider {
1149 assert_eq!(base_url, "http://localhost:11434");
1150 assert_eq!(model, "nomic-embed-text");
1151 } else {
1152 panic!("Expected Ollama provider");
1153 }
1154
1155 config.batch_size = 64;
1157 assert_eq!(config.batch_size, 64);
1158 }
1159}