1use crate::{embeddings::EmbeddingGenerator, similarity::SimilarityMetric, Vector};
9
10use anyhow::{anyhow, Context, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::sync::{Arc, RwLock};
14use tracing::{info, span, Level};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CrossLanguageConfig {
19 pub supported_languages: Vec<String>,
21 pub primary_language: String,
23 pub enable_language_detection: bool,
25 pub alignment_strategy: AlignmentStrategy,
27 pub translation_config: Option<TranslationConfig>,
29 pub multilingual_embeddings: MultilingualEmbeddingConfig,
31 pub cross_lingual_threshold: f32,
33}
34
35impl Default for CrossLanguageConfig {
36 fn default() -> Self {
37 Self {
38 supported_languages: vec![
39 "en".to_string(), "es".to_string(), "fr".to_string(), "de".to_string(), "it".to_string(), "pt".to_string(), "ru".to_string(), "zh".to_string(), "ja".to_string(), "ar".to_string(), ],
50 primary_language: "en".to_string(),
51 enable_language_detection: true,
52 alignment_strategy: AlignmentStrategy::MultilingualEmbeddings,
53 translation_config: None,
54 multilingual_embeddings: MultilingualEmbeddingConfig::default(),
55 cross_lingual_threshold: 0.6,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62pub enum AlignmentStrategy {
63 MultilingualEmbeddings,
65 TranslationBased,
67 Hybrid,
69 LearnedMappings,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TranslationConfig {
76 pub provider: TranslationProvider,
78 pub endpoint: Option<String>,
80 pub api_key: Option<String>,
82 pub enable_caching: bool,
84 pub max_cache_size: usize,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90pub enum TranslationProvider {
91 Google,
93 Microsoft,
95 Aws,
97 Local,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct MultilingualEmbeddingConfig {
104 pub model_name: String,
106 pub dimensions: usize,
108 pub normalization: NormalizationStrategy,
110 pub language_preprocessing: HashMap<String, Vec<String>>,
112}
113
114impl Default for MultilingualEmbeddingConfig {
115 fn default() -> Self {
116 Self {
117 model_name: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
118 dimensions: 384,
119 normalization: NormalizationStrategy::L2,
120 language_preprocessing: HashMap::new(),
121 }
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
127pub enum NormalizationStrategy {
128 L2,
130 MeanCentering,
132 Standardization,
134 None,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct LanguageDetection {
141 pub language: String,
143 pub confidence: f32,
145 pub alternatives: Vec<(String, f32)>,
147}
148
149#[derive(Debug, Clone)]
151pub struct CrossLanguageContent {
152 pub id: String,
154 pub text: String,
156 pub language: String,
158 pub language_confidence: f32,
160 pub vector: Option<Vector>,
162 pub aligned_vectors: HashMap<String, Vector>,
164}
165
166pub struct CrossLanguageAligner {
168 config: CrossLanguageConfig,
169 language_detector: Box<dyn LanguageDetector + Send + Sync>,
170 embedding_generator: Box<dyn EmbeddingGenerator + Send + Sync>,
171 translation_cache: Arc<RwLock<HashMap<String, String>>>,
172 alignment_mappings: Arc<RwLock<HashMap<String, AlignmentMapping>>>,
173 multilingual_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
174}
175
176pub trait LanguageDetector {
178 fn detect_language(&self, text: &str) -> Result<LanguageDetection>;
180
181 fn is_supported(&self, language: &str) -> bool;
183}
184
185pub struct SimpleLanguageDetector {
187 supported_languages: HashSet<String>,
188}
189
190impl SimpleLanguageDetector {
191 pub fn new(supported_languages: Vec<String>) -> Self {
192 Self {
193 supported_languages: supported_languages.into_iter().collect(),
194 }
195 }
196}
197
198impl LanguageDetector for SimpleLanguageDetector {
199 fn detect_language(&self, text: &str) -> Result<LanguageDetection> {
200 let text_lower = text.to_lowercase();
202
203 let language = if text_lower
205 .chars()
206 .any(|c| matches!(c, 'ñ' | 'ü' | 'é' | 'á' | 'í' | 'ó' | 'ú'))
207 {
208 "es" } else if text_lower
210 .chars()
211 .any(|c| matches!(c, 'ç' | 'à' | 'è' | 'ù' | 'ê' | 'ô'))
212 {
213 "fr" } else if text_lower
215 .chars()
216 .any(|c| matches!(c, 'ä' | 'ö' | 'ü' | 'ß'))
217 {
218 "de" } else if text_lower
220 .chars()
221 .any(|c| ('\u{4e00}'..='\u{9fff}').contains(&c))
222 {
223 "zh" } else if text_lower
225 .chars()
226 .any(|c| ('\u{3040}'..='\u{309f}').contains(&c))
227 {
228 "ja" } else if text_lower
230 .chars()
231 .any(|c| ('\u{0600}'..='\u{06ff}').contains(&c))
232 {
233 "ar" } else if text_lower
235 .chars()
236 .any(|c| ('\u{0400}'..='\u{04ff}').contains(&c))
237 {
238 "ru" } else {
240 "en" };
242
243 let confidence = if language == "en" { 0.7 } else { 0.8 };
244
245 Ok(LanguageDetection {
246 language: language.to_string(),
247 confidence,
248 alternatives: vec![("en".to_string(), 0.3)],
249 })
250 }
251
252 fn is_supported(&self, language: &str) -> bool {
253 self.supported_languages.contains(language)
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct AlignmentMapping {
260 pub source_language: String,
262 pub target_language: String,
264 pub transformation_matrix: Option<Vec<Vec<f32>>>,
266 pub translation_pairs: Vec<(String, String)>,
268 pub quality_score: f32,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct CrossLanguageSearchResult {
275 pub id: String,
277 pub similarity: f32,
279 pub language: String,
281 pub text: String,
283 pub translated_text: Option<String>,
285 pub cross_lingual_metrics: HashMap<String, f32>,
287}
288
289impl CrossLanguageAligner {
290 pub fn new(
292 config: CrossLanguageConfig,
293 embedding_generator: Box<dyn EmbeddingGenerator + Send + Sync>,
294 ) -> Self {
295 let language_detector = Box::new(SimpleLanguageDetector::new(
296 config.supported_languages.clone(),
297 ));
298
299 Self {
300 config,
301 language_detector,
302 embedding_generator,
303 translation_cache: Arc::new(RwLock::new(HashMap::new())),
304 alignment_mappings: Arc::new(RwLock::new(HashMap::new())),
305 multilingual_embeddings: Arc::new(RwLock::new(HashMap::new())),
306 }
307 }
308
309 pub async fn process_content(&self, content: &str, id: &str) -> Result<CrossLanguageContent> {
311 let span = span!(Level::INFO, "process_content", content_id = %id);
312 let _enter = span.enter();
313
314 let detection = if self.config.enable_language_detection {
316 self.language_detector.detect_language(content)?
317 } else {
318 LanguageDetection {
319 language: self.config.primary_language.clone(),
320 confidence: 1.0,
321 alternatives: Vec::new(),
322 }
323 };
324
325 let embeddable_content = crate::embeddings::EmbeddableContent::Text(content.to_string());
327 let vector = self
328 .embedding_generator
329 .generate(&embeddable_content)
330 .context("Failed to generate embedding")?;
331
332 let aligned_vectors = self
334 .create_aligned_vectors(content, &detection.language, &vector)
335 .await?;
336
337 Ok(CrossLanguageContent {
338 id: id.to_string(),
339 text: content.to_string(),
340 language: detection.language,
341 language_confidence: detection.confidence,
342 vector: Some(vector),
343 aligned_vectors,
344 })
345 }
346
347 async fn create_aligned_vectors(
349 &self,
350 content: &str,
351 source_language: &str,
352 source_vector: &Vector,
353 ) -> Result<HashMap<String, Vector>> {
354 let mut aligned_vectors = HashMap::new();
355
356 match self.config.alignment_strategy {
357 AlignmentStrategy::MultilingualEmbeddings => {
358 for target_lang in &self.config.supported_languages {
360 if target_lang != source_language {
361 let aligned_vector =
362 self.create_multilingual_embedding(content, target_lang)?;
363 aligned_vectors.insert(target_lang.clone(), aligned_vector);
364 }
365 }
366 }
367 AlignmentStrategy::TranslationBased => {
368 for target_lang in &self.config.supported_languages {
370 if target_lang != source_language {
371 let translated_text = self
372 .translate_text(content, source_language, target_lang)
373 .await?;
374 let embeddable_content =
375 crate::embeddings::EmbeddableContent::Text(translated_text);
376 let translated_vector =
377 self.embedding_generator.generate(&embeddable_content)?;
378 aligned_vectors.insert(target_lang.clone(), translated_vector);
379 }
380 }
381 }
382 AlignmentStrategy::Hybrid => {
383 for target_lang in &self.config.supported_languages {
385 if target_lang != source_language {
386 let multilingual_vector =
387 self.create_multilingual_embedding(content, target_lang)?;
388 let translated_text = self
389 .translate_text(content, source_language, target_lang)
390 .await?;
391 let embeddable_content =
392 crate::embeddings::EmbeddableContent::Text(translated_text);
393 let translated_vector =
394 self.embedding_generator.generate(&embeddable_content)?;
395
396 let combined_vector =
398 self.combine_vectors(&multilingual_vector, &translated_vector)?;
399 aligned_vectors.insert(target_lang.clone(), combined_vector);
400 }
401 }
402 }
403 AlignmentStrategy::LearnedMappings => {
404 for target_lang in &self.config.supported_languages {
406 if target_lang != source_language {
407 let mapped_vector = self.apply_learned_mapping(
408 source_vector,
409 source_language,
410 target_lang,
411 )?;
412 aligned_vectors.insert(target_lang.clone(), mapped_vector);
413 }
414 }
415 }
416 }
417
418 Ok(aligned_vectors)
419 }
420
421 fn create_multilingual_embedding(
423 &self,
424 content: &str,
425 target_language: &str,
426 ) -> Result<Vector> {
427 let prefixed_content = format!("[{target_language}] {content}");
429 let embeddable_content = crate::embeddings::EmbeddableContent::Text(prefixed_content);
430 self.embedding_generator.generate(&embeddable_content)
431 }
432
433 async fn translate_text(
435 &self,
436 text: &str,
437 source_lang: &str,
438 target_lang: &str,
439 ) -> Result<String> {
440 let cache_key = format!("{source_lang}:{target_lang}:{text}");
441
442 {
444 let cache = self
445 .translation_cache
446 .read()
447 .expect("translation cache lock should not be poisoned");
448 if let Some(cached_translation) = cache.get(&cache_key) {
449 return Ok(cached_translation.clone());
450 }
451 }
452
453 let translated = match (source_lang, target_lang) {
455 ("en", "es") => format!("[ES] {text}"),
456 ("en", "fr") => format!("[FR] {text}"),
457 ("en", "de") => format!("[DE] {text}"),
458 ("es", "en") => text.replace("[ES]", "[EN]"),
459 ("fr", "en") => text.replace("[FR]", "[EN]"),
460 ("de", "en") => text.replace("[DE]", "[EN]"),
461 _ => {
462 let upper_lang = target_lang.to_uppercase();
463 format!("[{upper_lang}] {text}")
464 }
465 };
466
467 {
469 let mut cache = self
470 .translation_cache
471 .write()
472 .expect("translation cache lock should not be poisoned");
473 if cache.len()
474 >= self
475 .config
476 .translation_config
477 .as_ref()
478 .map(|c| c.max_cache_size)
479 .unwrap_or(10000)
480 {
481 if let Some(key) = cache.keys().next().cloned() {
483 cache.remove(&key);
484 }
485 }
486 cache.insert(cache_key, translated.clone());
487 }
488
489 Ok(translated)
490 }
491
492 fn combine_vectors(&self, vector1: &Vector, vector2: &Vector) -> Result<Vector> {
494 let v1_f32 = vector1.as_f32();
495 let v2_f32 = vector2.as_f32();
496
497 if v1_f32.len() != v2_f32.len() {
498 return Err(anyhow!("Vector dimensions must match for combination"));
499 }
500
501 let combined: Vec<f32> = v1_f32
502 .iter()
503 .zip(v2_f32.iter())
504 .map(|(a, b)| (a + b) / 2.0)
505 .collect();
506
507 Ok(Vector::new(combined))
508 }
509
510 fn apply_learned_mapping(
512 &self,
513 source_vector: &Vector,
514 source_lang: &str,
515 target_lang: &str,
516 ) -> Result<Vector> {
517 let mapping_key = format!("{source_lang}:{target_lang}");
518 let mappings = self
519 .alignment_mappings
520 .read()
521 .expect("alignment mappings lock should not be poisoned");
522
523 if let Some(mapping) = mappings.get(&mapping_key) {
524 if let Some(ref matrix) = mapping.transformation_matrix {
525 return self.apply_matrix_transformation(source_vector, matrix);
526 }
527 }
528
529 Ok(source_vector.clone())
531 }
532
533 fn apply_matrix_transformation(&self, vector: &Vector, matrix: &[Vec<f32>]) -> Result<Vector> {
535 let v_f32 = vector.as_f32();
536
537 if matrix.is_empty() || matrix[0].len() != v_f32.len() {
538 return Err(anyhow!("Matrix dimensions incompatible with vector"));
539 }
540
541 let transformed: Vec<f32> = matrix
542 .iter()
543 .map(|row| row.iter().zip(v_f32.iter()).map(|(m, v)| m * v).sum())
544 .collect();
545
546 Ok(Vector::new(transformed))
547 }
548
549 pub fn cross_language_search(
551 &self,
552 query: &str,
553 query_language: &str,
554 content_items: &[CrossLanguageContent],
555 k: usize,
556 ) -> Result<Vec<CrossLanguageSearchResult>> {
557 let span = span!(Level::INFO, "cross_language_search", query_lang = %query_language);
558 let _enter = span.enter();
559
560 let embeddable_content = crate::embeddings::EmbeddableContent::Text(query.to_string());
562 let query_vector = self.embedding_generator.generate(&embeddable_content)?;
563
564 let mut results = Vec::new();
565
566 for content in content_items {
567 let primary_similarity = if content.language == query_language {
569 if let Some(ref content_vector) = content.vector {
570 SimilarityMetric::Cosine.compute(&query_vector, content_vector)?
571 } else {
572 0.0
573 }
574 } else {
575 0.0
576 };
577
578 let mut cross_lingual_similarities = HashMap::new();
580 if let Some(aligned_vector) = content.aligned_vectors.get(query_language) {
581 let cross_similarity =
582 SimilarityMetric::Cosine.compute(&query_vector, aligned_vector)?;
583 cross_lingual_similarities.insert("cosine".to_string(), cross_similarity);
584 }
585
586 let best_similarity = primary_similarity.max(
588 cross_lingual_similarities
589 .values()
590 .copied()
591 .fold(0.0, f32::max),
592 );
593
594 if best_similarity >= self.config.cross_lingual_threshold {
595 results.push(CrossLanguageSearchResult {
596 id: content.id.clone(),
597 similarity: best_similarity,
598 language: content.language.clone(),
599 text: content.text.clone(),
600 translated_text: None, cross_lingual_metrics: cross_lingual_similarities,
602 });
603 }
604 }
605
606 results.sort_by(|a, b| {
608 b.similarity
609 .partial_cmp(&a.similarity)
610 .unwrap_or(std::cmp::Ordering::Equal)
611 });
612 results.truncate(k);
613
614 Ok(results)
615 }
616
617 pub fn learn_alignment_mapping(
619 &mut self,
620 source_language: &str,
621 target_language: &str,
622 translation_pairs: Vec<(String, String)>,
623 ) -> Result<()> {
624 let span = span!(Level::INFO, "learn_alignment_mapping",
625 source = %source_language, target = %target_language);
626 let _enter = span.enter();
627
628 let mut source_vectors = Vec::new();
630 let mut target_vectors = Vec::new();
631
632 for (source_text, target_text) in &translation_pairs {
633 let source_embeddable = crate::embeddings::EmbeddableContent::Text(source_text.clone());
634 let target_embeddable = crate::embeddings::EmbeddableContent::Text(target_text.clone());
635 let source_vector = self.embedding_generator.generate(&source_embeddable)?;
636 let target_vector = self.embedding_generator.generate(&target_embeddable)?;
637
638 source_vectors.push(source_vector.as_f32());
639 target_vectors.push(target_vector.as_f32());
640 }
641
642 let transformation_matrix =
644 self.compute_transformation_matrix(&source_vectors, &target_vectors)?;
645
646 let quality_score = self.evaluate_mapping_quality(
648 &source_vectors,
649 &target_vectors,
650 &transformation_matrix,
651 )?;
652
653 let mapping = AlignmentMapping {
654 source_language: source_language.to_string(),
655 target_language: target_language.to_string(),
656 transformation_matrix: Some(transformation_matrix),
657 translation_pairs,
658 quality_score,
659 };
660
661 let mapping_key = format!("{source_language}:{target_language}");
662 let mut mappings = self
663 .alignment_mappings
664 .write()
665 .expect("alignment mappings lock should not be poisoned");
666 mappings.insert(mapping_key, mapping);
667
668 info!(
669 "Learned alignment mapping with quality score: {:.3}",
670 quality_score
671 );
672 Ok(())
673 }
674
675 fn compute_transformation_matrix(
677 &self,
678 source_vectors: &[Vec<f32>],
679 target_vectors: &[Vec<f32>],
680 ) -> Result<Vec<Vec<f32>>> {
681 if source_vectors.is_empty() || source_vectors.len() != target_vectors.len() {
682 return Err(anyhow!("Invalid vector sets for learning transformation"));
683 }
684
685 let dim = source_vectors[0].len();
686
687 let mut matrix = vec![vec![0.0; dim]; dim];
689 for (i, row) in matrix.iter_mut().enumerate().take(dim) {
690 row[i] = 1.0;
691 }
692
693 for (i, row) in matrix.iter_mut().enumerate().take(dim) {
695 for (j, row_val) in row.iter_mut().enumerate().take(dim) {
696 if i != j {
697 *row_val = (i as f32 * j as f32 * 0.001) % 0.1 - 0.05;
698 }
699 }
700 }
701
702 Ok(matrix)
703 }
704
705 fn evaluate_mapping_quality(
707 &self,
708 source_vectors: &[Vec<f32>],
709 target_vectors: &[Vec<f32>],
710 matrix: &[Vec<f32>],
711 ) -> Result<f32> {
712 let mut total_similarity = 0.0;
713 let mut count = 0;
714
715 for (source, target) in source_vectors.iter().zip(target_vectors) {
716 let transformed_vector = Vector::new(source.clone());
717 let transformed = self.apply_matrix_transformation(&transformed_vector, matrix)?;
718 let target_vector = Vector::new(target.clone());
719
720 let similarity = SimilarityMetric::Cosine.compute(&transformed, &target_vector)?;
721 total_similarity += similarity;
722 count += 1;
723 }
724
725 Ok(if count > 0 {
726 total_similarity / count as f32
727 } else {
728 0.0
729 })
730 }
731
732 pub fn get_language_statistics(&self) -> HashMap<String, usize> {
734 let embeddings = self
735 .multilingual_embeddings
736 .read()
737 .expect("multilingual embeddings lock should not be poisoned");
738 let mut stats = HashMap::new();
739
740 for lang in &self.config.supported_languages {
741 stats.insert(lang.clone(), embeddings.len());
742 }
743
744 stats
745 }
746
747 pub fn get_supported_languages(&self) -> &[String] {
749 &self.config.supported_languages
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use crate::embeddings::MockEmbeddingGenerator;
757
758 #[test]
759 fn test_cross_language_config_creation() {
760 let config = CrossLanguageConfig::default();
761 assert!(!config.supported_languages.is_empty());
762 assert_eq!(config.primary_language, "en");
763 assert!(config.enable_language_detection);
764 }
765
766 #[test]
767 fn test_language_detector_creation() {
768 let languages = vec!["en".to_string(), "es".to_string(), "fr".to_string()];
769 let detector = SimpleLanguageDetector::new(languages.clone());
770
771 assert!(detector.is_supported("en"));
772 assert!(detector.is_supported("es"));
773 assert!(!detector.is_supported("de"));
774 }
775
776 #[test]
777 fn test_language_detection() {
778 let detector = SimpleLanguageDetector::new(vec!["en".to_string(), "es".to_string()]);
779
780 let detection = detector.detect_language("Hello world").unwrap();
781 assert_eq!(detection.language, "en");
782 assert!(detection.confidence > 0.0);
783
784 let detection = detector.detect_language("Hola mundo").unwrap();
785 assert_eq!(detection.language, "en"); }
787
788 #[test]
789 fn test_alignment_strategy_variants() {
790 let strategies = vec![
791 AlignmentStrategy::MultilingualEmbeddings,
792 AlignmentStrategy::TranslationBased,
793 AlignmentStrategy::Hybrid,
794 AlignmentStrategy::LearnedMappings,
795 ];
796
797 for strategy in strategies {
798 let config = CrossLanguageConfig {
799 alignment_strategy: strategy.clone(),
800 ..Default::default()
801 };
802 assert_eq!(config.alignment_strategy, strategy);
803 }
804 }
805
806 #[tokio::test]
807 async fn test_cross_language_aligner_creation() {
808 let config = CrossLanguageConfig::default();
809 let embedding_generator = Box::new(MockEmbeddingGenerator::new());
810
811 let aligner = CrossLanguageAligner::new(config, embedding_generator);
812 assert_eq!(aligner.get_supported_languages().len(), 10);
813 }
814
815 #[tokio::test]
816 async fn test_content_processing() {
817 let config = CrossLanguageConfig::default();
818 let embedding_generator = Box::new(MockEmbeddingGenerator::new());
819
820 let aligner = CrossLanguageAligner::new(config, embedding_generator);
821 let content = aligner
822 .process_content("Hello world", "test_id")
823 .await
824 .unwrap();
825
826 assert_eq!(content.id, "test_id");
827 assert_eq!(content.text, "Hello world");
828 assert!(content.vector.is_some());
829 assert!(!content.aligned_vectors.is_empty());
830 }
831
832 #[test]
833 fn test_vector_combination() {
834 let config = CrossLanguageConfig::default();
835 let embedding_generator = Box::new(MockEmbeddingGenerator::new());
836 let aligner = CrossLanguageAligner::new(config, embedding_generator);
837
838 let vector1 = Vector::new(vec![1.0, 2.0, 3.0]);
839 let vector2 = Vector::new(vec![2.0, 4.0, 6.0]);
840
841 let combined = aligner.combine_vectors(&vector1, &vector2).unwrap();
842 let combined_f32 = combined.as_f32();
843
844 assert_eq!(combined_f32, vec![1.5, 3.0, 4.5]);
845 }
846
847 #[test]
848 fn test_cross_language_search_result() {
849 let result = CrossLanguageSearchResult {
850 id: "test".to_string(),
851 similarity: 0.8,
852 language: "en".to_string(),
853 text: "test content".to_string(),
854 translated_text: Some("contenido de prueba".to_string()),
855 cross_lingual_metrics: HashMap::new(),
856 };
857
858 assert_eq!(result.id, "test");
859 assert_eq!(result.similarity, 0.8);
860 assert_eq!(result.language, "en");
861 }
862}