1use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
17pub struct AlignmentPair {
18 pub source_idx: usize,
19 pub target_idx: usize,
20 pub confidence: f64,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum AlignmentMethod {
26 OrthogonalProcrustes,
28 LinearTransformation,
30 BidirectionalMatching,
32}
33
34#[derive(Debug, Clone)]
36pub enum AlignmentTransform {
37 Orthogonal(Vec<Vec<f32>>),
39 Linear(Vec<Vec<f32>>),
41 Identity,
43}
44
45impl AlignmentTransform {
46 pub fn apply(&self, embedding: &[f32]) -> Vec<f32> {
48 match self {
49 AlignmentTransform::Identity => embedding.to_vec(),
50 AlignmentTransform::Orthogonal(mat) | AlignmentTransform::Linear(mat) => {
51 let dim = embedding.len();
52 (0..dim)
53 .map(|i| {
54 (0..dim.min(mat[i].len()))
55 .map(|j| mat[i][j] * embedding[j])
56 .sum()
57 })
58 .collect()
59 }
60 }
61 }
62
63 pub fn identity(dim: usize) -> Self {
65 let mat: Vec<Vec<f32>> = (0..dim)
66 .map(|i| (0..dim).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
67 .collect();
68 AlignmentTransform::Orthogonal(mat)
69 }
70
71 pub fn matrix(&self) -> Option<&Vec<Vec<f32>>> {
73 match self {
74 AlignmentTransform::Orthogonal(m) | AlignmentTransform::Linear(m) => Some(m),
75 AlignmentTransform::Identity => None,
76 }
77 }
78}
79
80#[derive(Debug)]
82pub struct AlignmentResult {
83 pub transform: AlignmentTransform,
85 pub new_pairs: Vec<AlignmentPair>,
87 pub alignment_score: f64,
89}
90
91pub struct EmbeddingAlignment {
97 pub source_embeddings: Vec<Vec<f32>>,
98 pub target_embeddings: Vec<Vec<f32>>,
99 pub dim: usize,
100}
101
102impl EmbeddingAlignment {
103 pub fn new(source: Vec<Vec<f32>>, target: Vec<Vec<f32>>) -> Self {
107 let dim = source.first().map_or(0, |v| v.len());
108 Self {
109 source_embeddings: source,
110 target_embeddings: target,
111 dim,
112 }
113 }
114
115 pub fn find_alignment(
117 &self,
118 method: AlignmentMethod,
119 seed_pairs: &[AlignmentPair],
120 ) -> AlignmentResult {
121 let transform = match method {
122 AlignmentMethod::OrthogonalProcrustes => self.orthogonal_procrustes(seed_pairs),
123 AlignmentMethod::LinearTransformation => self.linear_transform(seed_pairs),
124 AlignmentMethod::BidirectionalMatching => {
125 AlignmentTransform::Identity
127 }
128 };
129
130 let transformed_source = self.apply_transform(&transform);
132 let new_pairs =
133 self.bidirectional_nn(&transformed_source, &self.target_embeddings, seed_pairs);
134 let alignment_score = self.mean_cosine_similarity(seed_pairs, &transform);
135
136 AlignmentResult {
137 transform,
138 new_pairs,
139 alignment_score,
140 }
141 }
142
143 pub fn apply_transform(&self, transform: &AlignmentTransform) -> Vec<Vec<f32>> {
145 self.source_embeddings
146 .iter()
147 .map(|e| transform.apply(e))
148 .collect()
149 }
150
151 fn orthogonal_procrustes(&self, seed_pairs: &[AlignmentPair]) -> AlignmentTransform {
156 if seed_pairs.is_empty() || self.dim == 0 {
157 return AlignmentTransform::identity(self.dim);
158 }
159
160 let dim = self.dim;
162 let mut m = vec![vec![0.0_f32; dim]; dim];
163
164 for sp in seed_pairs {
165 let src = &self.source_embeddings[sp.source_idx];
166 let tgt = &self.target_embeddings[sp.target_idx];
167 for i in 0..dim {
168 for j in 0..dim {
169 m[i][j] += tgt[i] * src[j];
170 }
171 }
172 }
173
174 let mat = polar_decomposition(&m, dim);
176 AlignmentTransform::Orthogonal(mat)
177 }
178
179 fn linear_transform(&self, seed_pairs: &[AlignmentPair]) -> AlignmentTransform {
182 if seed_pairs.is_empty() || self.dim == 0 {
183 return AlignmentTransform::identity(self.dim);
184 }
185 let dim = self.dim;
186 let n = seed_pairs.len();
187
188 let mut xt_x = vec![vec![0.0_f32; dim]; dim]; let mut xt_y = vec![vec![0.0_f32; dim]; dim]; for sp in seed_pairs {
193 let x = &self.source_embeddings[sp.source_idx];
194 let y = &self.target_embeddings[sp.target_idx];
195 for i in 0..dim {
196 for j in 0..dim {
197 xt_x[i][j] += x[i] * x[j];
198 xt_y[i][j] += x[i] * y[j];
199 }
200 }
201 }
202
203 let lambda = 1e-4_f32 * (n as f32);
205 for (i, row) in xt_x.iter_mut().enumerate() {
206 row[i] += lambda;
207 }
208
209 let w = solve_linear_system(&xt_x, &xt_y, dim);
211 AlignmentTransform::Linear(w)
212 }
213
214 fn bidirectional_nn(
216 &self,
217 transformed_source: &[Vec<f32>],
218 target: &[Vec<f32>],
219 seed_pairs: &[AlignmentPair],
220 ) -> Vec<AlignmentPair> {
221 let used_src: std::collections::HashSet<usize> =
223 seed_pairs.iter().map(|p| p.source_idx).collect();
224 let used_tgt: std::collections::HashSet<usize> =
225 seed_pairs.iter().map(|p| p.target_idx).collect();
226
227 let mut pairs = Vec::new();
228
229 for (s_idx, s_emb) in transformed_source.iter().enumerate() {
231 if used_src.contains(&s_idx) {
232 continue;
233 }
234 let Some((best_t, best_sim)) = nearest_neighbor(s_emb, target, &used_tgt) else {
236 continue;
237 };
238 if let Some((mutual_s, _)) =
240 nearest_neighbor(&target[best_t], transformed_source, &used_src)
241 {
242 if mutual_s == s_idx {
243 pairs.push(AlignmentPair {
244 source_idx: s_idx,
245 target_idx: best_t,
246 confidence: best_sim as f64,
247 });
248 }
249 }
250 }
251 pairs
252 }
253
254 fn mean_cosine_similarity(
256 &self,
257 seed_pairs: &[AlignmentPair],
258 transform: &AlignmentTransform,
259 ) -> f64 {
260 if seed_pairs.is_empty() {
261 return 0.0;
262 }
263 let total: f64 = seed_pairs
264 .iter()
265 .map(|sp| {
266 let src_t = transform.apply(&self.source_embeddings[sp.source_idx]);
267 let tgt = &self.target_embeddings[sp.target_idx];
268 cosine_similarity(&src_t, tgt) as f64
269 })
270 .sum();
271 total / seed_pairs.len() as f64
272 }
273}
274
275pub struct CrossLingualAligner {
281 language_spaces: HashMap<String, Vec<Vec<f32>>>,
282 pivot_language: String,
283}
284
285impl CrossLingualAligner {
286 pub fn new(pivot: &str) -> Self {
288 Self {
289 language_spaces: HashMap::new(),
290 pivot_language: pivot.to_string(),
291 }
292 }
293
294 pub fn add_language(&mut self, lang: &str, embeddings: Vec<Vec<f32>>) {
296 self.language_spaces.insert(lang.to_string(), embeddings);
297 }
298
299 pub fn align_to_pivot(
301 &self,
302 lang: &str,
303 seed_pairs: &[AlignmentPair],
304 ) -> Option<AlignmentResult> {
305 let source = self.language_spaces.get(lang)?.clone();
306 let target = self.language_spaces.get(&self.pivot_language)?.clone();
307 let aligner = EmbeddingAlignment::new(source, target);
308 Some(aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, seed_pairs))
309 }
310
311 pub fn translate(&self, embedding: &[f32], from_lang: &str, to_lang: &str) -> Option<Vec<f32>> {
313 if from_lang == to_lang {
318 return Some(embedding.to_vec());
319 }
320
321 let _from_space = self.language_spaces.get(from_lang)?;
322 let _to_space = self.language_spaces.get(to_lang)?;
323
324 Some(embedding.to_vec())
327 }
328
329 pub fn languages(&self) -> Vec<&str> {
331 self.language_spaces.keys().map(|s| s.as_str()).collect()
332 }
333
334 pub fn pivot_language(&self) -> &str {
336 &self.pivot_language
337 }
338}
339
340fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
346 let dot: f32 = a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum();
347 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
348 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
349 if na < 1e-10 || nb < 1e-10 {
350 return 0.0;
351 }
352 (dot / (na * nb)).clamp(-1.0, 1.0)
353}
354
355fn nearest_neighbor(
358 query: &[f32],
359 candidates: &[Vec<f32>],
360 excluded: &std::collections::HashSet<usize>,
361) -> Option<(usize, f32)> {
362 let mut best_idx = None;
363 let mut best_sim = f32::NEG_INFINITY;
364 for (idx, cand) in candidates.iter().enumerate() {
365 if excluded.contains(&idx) {
366 continue;
367 }
368 let sim = cosine_similarity(query, cand);
369 if sim > best_sim {
370 best_sim = sim;
371 best_idx = Some(idx);
372 }
373 }
374 best_idx.map(|idx| (idx, best_sim))
375}
376
377fn polar_decomposition(m: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
380 let frob: f32 = m
382 .iter()
383 .flat_map(|r| r.iter())
384 .map(|v| v * v)
385 .sum::<f32>()
386 .sqrt();
387 if frob < 1e-10 {
388 return AlignmentTransform::identity(dim)
389 .matrix()
390 .cloned()
391 .unwrap_or_else(|| {
392 (0..dim)
393 .map(|i| (0..dim).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
394 .collect()
395 });
396 }
397
398 let mut u: Vec<Vec<f32>> = m
399 .iter()
400 .map(|r| r.iter().map(|v| v / frob).collect())
401 .collect();
402
403 for _ in 0..10 {
405 let utu = mat_mul_transposed(&u, &u, dim); let utu_u = mat_mul(&utu, &u, dim); let mut new_u = vec![vec![0.0_f32; dim]; dim];
408 for i in 0..dim {
409 for j in 0..dim {
410 new_u[i][j] = 1.5 * u[i][j] - 0.5 * utu_u[i][j];
411 }
412 }
413 u = new_u;
414 }
415 u
416}
417
418fn mat_mul(a: &[Vec<f32>], b: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
420 let mut c = vec![vec![0.0_f32; dim]; dim];
421 for i in 0..dim {
422 for k in 0..dim {
423 for j in 0..dim {
424 c[i][j] += a[i][k] * b[k][j];
425 }
426 }
427 }
428 c
429}
430
431fn mat_mul_transposed(a: &[Vec<f32>], _b: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
433 let mut c = vec![vec![0.0_f32; dim]; dim];
434 for i in 0..dim {
435 for j in 0..dim {
436 for (k, a_ik) in a[i].iter().enumerate() {
437 c[i][j] += a_ik * a[j][k];
438 }
439 }
440 }
441 c
442}
443
444fn solve_linear_system(a: &[Vec<f32>], b: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
447 let mut aug: Vec<Vec<f32>> = (0..dim)
449 .map(|i| {
450 let mut row = a[i].clone();
451 row.extend_from_slice(&b[i]);
452 row
453 })
454 .collect();
455
456 let total_cols = 2 * dim;
457
458 for col in 0..dim {
460 let mut max_row = col;
462 let mut max_val = aug[col][col].abs();
463 for (row, aug_row) in aug.iter().enumerate().skip(col + 1) {
464 if aug_row[col].abs() > max_val {
465 max_val = aug_row[col].abs();
466 max_row = row;
467 }
468 }
469 aug.swap(col, max_row);
470
471 let pivot = aug[col][col];
472 if pivot.abs() < 1e-10 {
473 continue;
474 }
475 for val in &mut aug[col][..total_cols] {
476 *val /= pivot;
477 }
478 for row in 0..dim {
479 if row == col {
480 continue;
481 }
482 let factor = aug[row][col];
483 let pivot_row: Vec<f32> = aug[col][..total_cols].to_vec();
484 for (aug_val, &pivot_val) in aug[row][..total_cols].iter_mut().zip(pivot_row.iter()) {
485 *aug_val -= pivot_val * factor;
486 }
487 }
488 }
489
490 (0..dim).map(|i| aug[i][dim..].to_vec()).collect()
492}
493
494#[cfg(test)]
499mod tests {
500 use super::*;
501
502 fn make_embeddings(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
503 let mut state = seed.wrapping_add(1);
504 (0..n)
505 .map(|_| {
506 (0..dim)
507 .map(|_| {
508 state = state
509 .wrapping_mul(6_364_136_223_846_793_005)
510 .wrapping_add(1_442_695_040_888_963_407);
511 ((state >> 33) as f32 / u32::MAX as f32) - 0.5
512 })
513 .collect()
514 })
515 .collect()
516 }
517
518 fn make_seed_pairs(n: usize) -> Vec<AlignmentPair> {
519 (0..n)
520 .map(|i| AlignmentPair {
521 source_idx: i,
522 target_idx: i,
523 confidence: 1.0,
524 })
525 .collect()
526 }
527
528 #[test]
531 fn test_identity_transform() {
532 let t = AlignmentTransform::identity(4);
533 let v = vec![1.0_f32, 2.0, 3.0, 4.0];
534 let out = t.apply(&v);
535 for (a, b) in v.iter().zip(out.iter()) {
536 assert!((a - b).abs() < 1e-6, "identity should preserve values");
537 }
538 }
539
540 #[test]
541 fn test_orthogonal_transform_apply() {
542 let mat = vec![vec![0.0_f32, 1.0], vec![1.0, 0.0]];
543 let t = AlignmentTransform::Orthogonal(mat);
544 let v = vec![3.0_f32, 7.0];
545 let out = t.apply(&v);
546 assert!((out[0] - 7.0).abs() < 1e-6);
547 assert!((out[1] - 3.0).abs() < 1e-6);
548 }
549
550 #[test]
551 fn test_identity_transform_has_matrix() {
552 let t = AlignmentTransform::identity(3);
553 assert!(t.matrix().is_some());
554 }
555
556 #[test]
557 fn test_identity_enum_no_matrix() {
558 let t = AlignmentTransform::Identity;
559 assert!(t.matrix().is_none());
560 }
561
562 #[test]
565 fn test_alignment_creation() {
566 let src = make_embeddings(5, 4, 1);
567 let tgt = make_embeddings(5, 4, 2);
568 let aligner = EmbeddingAlignment::new(src.clone(), tgt.clone());
569 assert_eq!(aligner.dim, 4);
570 assert_eq!(aligner.source_embeddings.len(), 5);
571 assert_eq!(aligner.target_embeddings.len(), 5);
572 }
573
574 #[test]
575 fn test_orthogonal_procrustes_produces_result() {
576 let src = make_embeddings(6, 4, 10);
577 let tgt = make_embeddings(6, 4, 20);
578 let aligner = EmbeddingAlignment::new(src, tgt);
579 let seeds = make_seed_pairs(3);
580 let result = aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, &seeds);
581 assert!(result.alignment_score.is_finite());
582 }
583
584 #[test]
585 fn test_linear_transform_produces_result() {
586 let src = make_embeddings(6, 4, 30);
587 let tgt = make_embeddings(6, 4, 40);
588 let aligner = EmbeddingAlignment::new(src, tgt);
589 let seeds = make_seed_pairs(3);
590 let result = aligner.find_alignment(AlignmentMethod::LinearTransformation, &seeds);
591 assert!(result.alignment_score.is_finite());
592 }
593
594 #[test]
595 fn test_bidirectional_matching_produces_result() {
596 let src = make_embeddings(8, 4, 50);
597 let tgt = make_embeddings(8, 4, 60);
598 let aligner = EmbeddingAlignment::new(src, tgt);
599 let seeds = make_seed_pairs(2);
600 let result = aligner.find_alignment(AlignmentMethod::BidirectionalMatching, &seeds);
601 assert!(result.alignment_score >= -1.0 && result.alignment_score <= 1.0 + 1e-6);
602 }
603
604 #[test]
605 fn test_apply_transform_correct_count() {
606 let src = make_embeddings(5, 4, 70);
607 let tgt = make_embeddings(5, 4, 80);
608 let aligner = EmbeddingAlignment::new(src, tgt);
609 let t = AlignmentTransform::identity(4);
610 let out = aligner.apply_transform(&t);
611 assert_eq!(out.len(), 5);
612 assert_eq!(out[0].len(), 4);
613 }
614
615 #[test]
616 fn test_alignment_with_empty_seeds() {
617 let src = make_embeddings(4, 4, 90);
618 let tgt = make_embeddings(4, 4, 91);
619 let aligner = EmbeddingAlignment::new(src, tgt);
620 let result = aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, &[]);
621 assert!(result.alignment_score.is_finite());
623 }
624
625 #[test]
626 fn test_identical_spaces_score() {
627 let embs = make_embeddings(5, 4, 100);
629 let aligner = EmbeddingAlignment::new(embs.clone(), embs.clone());
630 let seeds = make_seed_pairs(5);
631 let result = aligner.find_alignment(AlignmentMethod::BidirectionalMatching, &seeds);
632 assert!(
634 result.alignment_score > 0.9,
635 "same-space alignment should score near 1.0: {}",
636 result.alignment_score
637 );
638 }
639
640 #[test]
641 fn test_alignment_result_has_transform() {
642 let src = make_embeddings(4, 3, 111);
643 let tgt = make_embeddings(4, 3, 222);
644 let aligner = EmbeddingAlignment::new(src, tgt);
645 let seeds = make_seed_pairs(2);
646 let result = aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, &seeds);
647 matches!(result.transform, AlignmentTransform::Orthogonal(_));
649 }
650
651 #[test]
654 fn test_cross_lingual_creation() {
655 let aligner = CrossLingualAligner::new("en");
656 assert_eq!(aligner.pivot_language(), "en");
657 }
658
659 #[test]
660 fn test_cross_lingual_add_language() {
661 let mut aligner = CrossLingualAligner::new("en");
662 aligner.add_language("fr", make_embeddings(5, 4, 1));
663 aligner.add_language("en", make_embeddings(5, 4, 2));
664 let langs = aligner.languages();
665 assert!(langs.contains(&"fr"));
666 assert!(langs.contains(&"en"));
667 }
668
669 #[test]
670 fn test_cross_lingual_align_to_pivot() {
671 let mut aligner = CrossLingualAligner::new("en");
672 aligner.add_language("en", make_embeddings(8, 4, 10));
673 aligner.add_language("fr", make_embeddings(8, 4, 20));
674 let seeds = make_seed_pairs(3);
675 let result = aligner.align_to_pivot("fr", &seeds);
676 assert!(result.is_some(), "should return alignment result");
677 let r = result.unwrap();
678 assert!(r.alignment_score.is_finite());
679 }
680
681 #[test]
682 fn test_cross_lingual_align_missing_language() {
683 let aligner = CrossLingualAligner::new("en");
684 let result = aligner.align_to_pivot("de", &[]);
685 assert!(result.is_none(), "missing language should return None");
686 }
687
688 #[test]
689 fn test_cross_lingual_translate_same_language() {
690 let mut aligner = CrossLingualAligner::new("en");
691 aligner.add_language("en", make_embeddings(5, 4, 1));
692 let v = vec![1.0_f32, 2.0, 3.0, 4.0];
693 let out = aligner.translate(&v, "en", "en");
694 assert!(out.is_some());
695 assert_eq!(out.unwrap(), v);
696 }
697
698 #[test]
699 fn test_cross_lingual_translate_missing_returns_none() {
700 let aligner = CrossLingualAligner::new("en");
701 let v = vec![0.0_f32; 4];
702 let out = aligner.translate(&v, "de", "fr");
703 assert!(out.is_none());
704 }
705
706 #[test]
709 fn test_cosine_similarity_identical() {
710 let v = vec![1.0_f32, 0.0, 0.0];
711 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
712 }
713
714 #[test]
715 fn test_cosine_similarity_orthogonal() {
716 let a = vec![1.0_f32, 0.0];
717 let b = vec![0.0_f32, 1.0];
718 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
719 }
720
721 #[test]
722 fn test_cosine_similarity_zero_vector() {
723 let a = vec![0.0_f32, 0.0];
724 let b = vec![1.0_f32, 0.0];
725 assert_eq!(cosine_similarity(&a, &b), 0.0);
726 }
727}