Skip to main content

oxirs_embed/alignment/
mod.rs

1//! Embedding Alignment: Aligning embeddings across different knowledge graph spaces.
2//!
3//! Supports:
4//! - Orthogonal Procrustes (SVD-based rotation)
5//! - Linear Transformation (general affine mapping)
6//! - Bidirectional Matching (mutual nearest neighbor)
7//! - Cross-lingual alignment via pivot language
8
9use std::collections::HashMap;
10
11// ─────────────────────────────────────────────
12// Core types
13// ─────────────────────────────────────────────
14
15/// A seed pair linking a source entity index to a target entity index.
16#[derive(Debug, Clone)]
17pub struct AlignmentPair {
18    pub source_idx: usize,
19    pub target_idx: usize,
20    pub confidence: f64,
21}
22
23/// Available alignment strategies.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum AlignmentMethod {
26    /// SVD-based orthogonal (rotation) mapping.
27    OrthogonalProcrustes,
28    /// Unconstrained linear transformation.
29    LinearTransformation,
30    /// Bidirectional (mutual) nearest-neighbor matching.
31    BidirectionalMatching,
32}
33
34/// Transformation applied to source embeddings to align them to target space.
35#[derive(Debug, Clone)]
36pub enum AlignmentTransform {
37    /// Orthogonal rotation matrix (dim × dim).
38    Orthogonal(Vec<Vec<f32>>),
39    /// General linear transformation matrix (dim × dim).
40    Linear(Vec<Vec<f32>>),
41    /// No-op identity transform.
42    Identity,
43}
44
45impl AlignmentTransform {
46    /// Apply this transform to a single embedding vector.
47    pub fn apply(&self, embedding: &[f32]) -> Vec<f32> {
48        match self {
49            AlignmentTransform::Identity => embedding.to_vec(),
50            AlignmentTransform::Orthogonal(mat) | AlignmentTransform::Linear(mat) => {
51                let dim = embedding.len();
52                (0..dim)
53                    .map(|i| {
54                        (0..dim.min(mat[i].len()))
55                            .map(|j| mat[i][j] * embedding[j])
56                            .sum()
57                    })
58                    .collect()
59            }
60        }
61    }
62
63    /// Construct an identity transform for the given dimension.
64    pub fn identity(dim: usize) -> Self {
65        let mat: Vec<Vec<f32>> = (0..dim)
66            .map(|i| (0..dim).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
67            .collect();
68        AlignmentTransform::Orthogonal(mat)
69    }
70
71    /// Return the underlying matrix if any.
72    pub fn matrix(&self) -> Option<&Vec<Vec<f32>>> {
73        match self {
74            AlignmentTransform::Orthogonal(m) | AlignmentTransform::Linear(m) => Some(m),
75            AlignmentTransform::Identity => None,
76        }
77    }
78}
79
80/// Result of an alignment operation.
81#[derive(Debug)]
82pub struct AlignmentResult {
83    /// The learned transformation.
84    pub transform: AlignmentTransform,
85    /// New pairs discovered beyond the seeds.
86    pub new_pairs: Vec<AlignmentPair>,
87    /// Mean cosine similarity of aligned seed pairs.
88    pub alignment_score: f64,
89}
90
91// ─────────────────────────────────────────────
92// EmbeddingAlignment
93// ─────────────────────────────────────────────
94
95/// Aligns embeddings from two different KG spaces.
96pub struct EmbeddingAlignment {
97    pub source_embeddings: Vec<Vec<f32>>,
98    pub target_embeddings: Vec<Vec<f32>>,
99    pub dim: usize,
100}
101
102impl EmbeddingAlignment {
103    /// Create a new alignment helper.
104    ///
105    /// Panics if source and target have different embedding dimensions.
106    pub fn new(source: Vec<Vec<f32>>, target: Vec<Vec<f32>>) -> Self {
107        let dim = source.first().map_or(0, |v| v.len());
108        Self {
109            source_embeddings: source,
110            target_embeddings: target,
111            dim,
112        }
113    }
114
115    /// Find an alignment between source and target using the given method and seed pairs.
116    pub fn find_alignment(
117        &self,
118        method: AlignmentMethod,
119        seed_pairs: &[AlignmentPair],
120    ) -> AlignmentResult {
121        let transform = match method {
122            AlignmentMethod::OrthogonalProcrustes => self.orthogonal_procrustes(seed_pairs),
123            AlignmentMethod::LinearTransformation => self.linear_transform(seed_pairs),
124            AlignmentMethod::BidirectionalMatching => {
125                // No transform, just find pairs
126                AlignmentTransform::Identity
127            }
128        };
129
130        // Apply transform to source embeddings and find new aligned pairs
131        let transformed_source = self.apply_transform(&transform);
132        let new_pairs =
133            self.bidirectional_nn(&transformed_source, &self.target_embeddings, seed_pairs);
134        let alignment_score = self.mean_cosine_similarity(seed_pairs, &transform);
135
136        AlignmentResult {
137            transform,
138            new_pairs,
139            alignment_score,
140        }
141    }
142
143    /// Apply the transform to all source embeddings, returning the transformed set.
144    pub fn apply_transform(&self, transform: &AlignmentTransform) -> Vec<Vec<f32>> {
145        self.source_embeddings
146            .iter()
147            .map(|e| transform.apply(e))
148            .collect()
149    }
150
151    // ── Private helpers ───────────────────────
152
153    /// Compute orthogonal Procrustes: W = V * U^T from SVD of Y^T * X.
154    /// Uses a power-iteration / simplified SVD for pure-Rust.
155    fn orthogonal_procrustes(&self, seed_pairs: &[AlignmentPair]) -> AlignmentTransform {
156        if seed_pairs.is_empty() || self.dim == 0 {
157            return AlignmentTransform::identity(self.dim);
158        }
159
160        // Build cross-covariance matrix M = Y^T * X  (dim × dim)
161        let dim = self.dim;
162        let mut m = vec![vec![0.0_f32; dim]; dim];
163
164        for sp in seed_pairs {
165            let src = &self.source_embeddings[sp.source_idx];
166            let tgt = &self.target_embeddings[sp.target_idx];
167            for i in 0..dim {
168                for j in 0..dim {
169                    m[i][j] += tgt[i] * src[j];
170                }
171            }
172        }
173
174        // Approximate orthogonal map via iterative polar decomposition (5 Newton steps)
175        let mat = polar_decomposition(&m, dim);
176        AlignmentTransform::Orthogonal(mat)
177    }
178
179    /// Compute a linear transformation W via least-squares: minimize ||X*W - Y||_F.
180    /// Closed form: W = (X^T X)^{-1} X^T Y.
181    fn linear_transform(&self, seed_pairs: &[AlignmentPair]) -> AlignmentTransform {
182        if seed_pairs.is_empty() || self.dim == 0 {
183            return AlignmentTransform::identity(self.dim);
184        }
185        let dim = self.dim;
186        let n = seed_pairs.len();
187
188        // Build X (n × dim) and Y (n × dim)
189        let mut xt_x = vec![vec![0.0_f32; dim]; dim]; // X^T X
190        let mut xt_y = vec![vec![0.0_f32; dim]; dim]; // X^T Y
191
192        for sp in seed_pairs {
193            let x = &self.source_embeddings[sp.source_idx];
194            let y = &self.target_embeddings[sp.target_idx];
195            for i in 0..dim {
196                for j in 0..dim {
197                    xt_x[i][j] += x[i] * x[j];
198                    xt_y[i][j] += x[i] * y[j];
199                }
200            }
201        }
202
203        // Regularize: (X^T X + λI)
204        let lambda = 1e-4_f32 * (n as f32);
205        for (i, row) in xt_x.iter_mut().enumerate() {
206            row[i] += lambda;
207        }
208
209        // Solve via Gauss-Jordan for each output column
210        let w = solve_linear_system(&xt_x, &xt_y, dim);
211        AlignmentTransform::Linear(w)
212    }
213
214    /// Bidirectional nearest-neighbor matching in the transformed source space.
215    fn bidirectional_nn(
216        &self,
217        transformed_source: &[Vec<f32>],
218        target: &[Vec<f32>],
219        seed_pairs: &[AlignmentPair],
220    ) -> Vec<AlignmentPair> {
221        // Build set of already-used indices
222        let used_src: std::collections::HashSet<usize> =
223            seed_pairs.iter().map(|p| p.source_idx).collect();
224        let used_tgt: std::collections::HashSet<usize> =
225            seed_pairs.iter().map(|p| p.target_idx).collect();
226
227        let mut pairs = Vec::new();
228
229        // For each non-seed source, find nearest target and check mutual
230        for (s_idx, s_emb) in transformed_source.iter().enumerate() {
231            if used_src.contains(&s_idx) {
232                continue;
233            }
234            // Find nearest target
235            let Some((best_t, best_sim)) = nearest_neighbor(s_emb, target, &used_tgt) else {
236                continue;
237            };
238            // Check mutual: from best_t, find nearest source
239            if let Some((mutual_s, _)) =
240                nearest_neighbor(&target[best_t], transformed_source, &used_src)
241            {
242                if mutual_s == s_idx {
243                    pairs.push(AlignmentPair {
244                        source_idx: s_idx,
245                        target_idx: best_t,
246                        confidence: best_sim as f64,
247                    });
248                }
249            }
250        }
251        pairs
252    }
253
254    /// Compute mean cosine similarity of seed pairs under a given transform.
255    fn mean_cosine_similarity(
256        &self,
257        seed_pairs: &[AlignmentPair],
258        transform: &AlignmentTransform,
259    ) -> f64 {
260        if seed_pairs.is_empty() {
261            return 0.0;
262        }
263        let total: f64 = seed_pairs
264            .iter()
265            .map(|sp| {
266                let src_t = transform.apply(&self.source_embeddings[sp.source_idx]);
267                let tgt = &self.target_embeddings[sp.target_idx];
268                cosine_similarity(&src_t, tgt) as f64
269            })
270            .sum();
271        total / seed_pairs.len() as f64
272    }
273}
274
275// ─────────────────────────────────────────────
276// CrossLingualAligner
277// ─────────────────────────────────────────────
278
279/// Aligns multiple language embedding spaces via a pivot language.
280pub struct CrossLingualAligner {
281    language_spaces: HashMap<String, Vec<Vec<f32>>>,
282    pivot_language: String,
283}
284
285impl CrossLingualAligner {
286    /// Create a new aligner with the given pivot language code.
287    pub fn new(pivot: &str) -> Self {
288        Self {
289            language_spaces: HashMap::new(),
290            pivot_language: pivot.to_string(),
291        }
292    }
293
294    /// Register an embedding space for a language.
295    pub fn add_language(&mut self, lang: &str, embeddings: Vec<Vec<f32>>) {
296        self.language_spaces.insert(lang.to_string(), embeddings);
297    }
298
299    /// Align the given language to the pivot using seed pairs.
300    pub fn align_to_pivot(
301        &self,
302        lang: &str,
303        seed_pairs: &[AlignmentPair],
304    ) -> Option<AlignmentResult> {
305        let source = self.language_spaces.get(lang)?.clone();
306        let target = self.language_spaces.get(&self.pivot_language)?.clone();
307        let aligner = EmbeddingAlignment::new(source, target);
308        Some(aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, seed_pairs))
309    }
310
311    /// Translate an embedding from one language space to another via the pivot.
312    pub fn translate(&self, embedding: &[f32], from_lang: &str, to_lang: &str) -> Option<Vec<f32>> {
313        // Build trivially: if either endpoint is the pivot, do a direct transform.
314        // For simplicity we compute the orthogonal map from first-available seed pair
315        // (or identity if no data) and compose from→pivot→to.
316
317        if from_lang == to_lang {
318            return Some(embedding.to_vec());
319        }
320
321        let _from_space = self.language_spaces.get(from_lang)?;
322        let _to_space = self.language_spaces.get(to_lang)?;
323
324        // Without explicit seed pairs here, return identity-transformed embedding.
325        // In real usage the caller would provide seed pairs per language pair.
326        Some(embedding.to_vec())
327    }
328
329    /// List registered languages.
330    pub fn languages(&self) -> Vec<&str> {
331        self.language_spaces.keys().map(|s| s.as_str()).collect()
332    }
333
334    /// Pivot language accessor.
335    pub fn pivot_language(&self) -> &str {
336        &self.pivot_language
337    }
338}
339
340// ─────────────────────────────────────────────
341// Math utilities
342// ─────────────────────────────────────────────
343
344/// Cosine similarity between two vectors.
345fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
346    let dot: f32 = a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum();
347    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
348    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
349    if na < 1e-10 || nb < 1e-10 {
350        return 0.0;
351    }
352    (dot / (na * nb)).clamp(-1.0, 1.0)
353}
354
355/// Find the nearest neighbor of `query` in `candidates`, skipping `excluded` indices.
356/// Returns (index, cosine_similarity).
357fn nearest_neighbor(
358    query: &[f32],
359    candidates: &[Vec<f32>],
360    excluded: &std::collections::HashSet<usize>,
361) -> Option<(usize, f32)> {
362    let mut best_idx = None;
363    let mut best_sim = f32::NEG_INFINITY;
364    for (idx, cand) in candidates.iter().enumerate() {
365        if excluded.contains(&idx) {
366            continue;
367        }
368        let sim = cosine_similarity(query, cand);
369        if sim > best_sim {
370            best_sim = sim;
371            best_idx = Some(idx);
372        }
373    }
374    best_idx.map(|idx| (idx, best_sim))
375}
376
377/// Approximate the orthogonal factor of M via iterative polar decomposition.
378/// U = lim_{t→∞} (3/2)U_{t-1} - (1/2)U_{t-1}(U_{t-1}^T U_{t-1})
379fn polar_decomposition(m: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
380    // Start with M / ||M||_F
381    let frob: f32 = m
382        .iter()
383        .flat_map(|r| r.iter())
384        .map(|v| v * v)
385        .sum::<f32>()
386        .sqrt();
387    if frob < 1e-10 {
388        return AlignmentTransform::identity(dim)
389            .matrix()
390            .cloned()
391            .unwrap_or_else(|| {
392                (0..dim)
393                    .map(|i| (0..dim).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
394                    .collect()
395            });
396    }
397
398    let mut u: Vec<Vec<f32>> = m
399        .iter()
400        .map(|r| r.iter().map(|v| v / frob).collect())
401        .collect();
402
403    // Newton-Schulz iterations: U_{k+1} = 1.5 * U_k - 0.5 * U_k * U_k^T * U_k
404    for _ in 0..10 {
405        let utu = mat_mul_transposed(&u, &u, dim); // U * U^T
406        let utu_u = mat_mul(&utu, &u, dim); // (U * U^T) * U
407        let mut new_u = vec![vec![0.0_f32; dim]; dim];
408        for i in 0..dim {
409            for j in 0..dim {
410                new_u[i][j] = 1.5 * u[i][j] - 0.5 * utu_u[i][j];
411            }
412        }
413        u = new_u;
414    }
415    u
416}
417
418/// Matrix multiplication A * B (dim × dim).
419fn mat_mul(a: &[Vec<f32>], b: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
420    let mut c = vec![vec![0.0_f32; dim]; dim];
421    for i in 0..dim {
422        for k in 0..dim {
423            for j in 0..dim {
424                c[i][j] += a[i][k] * b[k][j];
425            }
426        }
427    }
428    c
429}
430
431/// Compute A * A^T (dim × dim).
432fn mat_mul_transposed(a: &[Vec<f32>], _b: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
433    let mut c = vec![vec![0.0_f32; dim]; dim];
434    for i in 0..dim {
435        for j in 0..dim {
436            for (k, a_ik) in a[i].iter().enumerate() {
437                c[i][j] += a_ik * a[j][k];
438            }
439        }
440    }
441    c
442}
443
444/// Solve A * W = B for W using Gauss-Jordan elimination.
445/// A is (dim × dim), B is (dim × dim), returns W (dim × dim).
446fn solve_linear_system(a: &[Vec<f32>], b: &[Vec<f32>], dim: usize) -> Vec<Vec<f32>> {
447    // Build augmented [A | B]
448    let mut aug: Vec<Vec<f32>> = (0..dim)
449        .map(|i| {
450            let mut row = a[i].clone();
451            row.extend_from_slice(&b[i]);
452            row
453        })
454        .collect();
455
456    let total_cols = 2 * dim;
457
458    // Forward elimination with partial pivoting
459    for col in 0..dim {
460        // Find pivot
461        let mut max_row = col;
462        let mut max_val = aug[col][col].abs();
463        for (row, aug_row) in aug.iter().enumerate().skip(col + 1) {
464            if aug_row[col].abs() > max_val {
465                max_val = aug_row[col].abs();
466                max_row = row;
467            }
468        }
469        aug.swap(col, max_row);
470
471        let pivot = aug[col][col];
472        if pivot.abs() < 1e-10 {
473            continue;
474        }
475        for val in &mut aug[col][..total_cols] {
476            *val /= pivot;
477        }
478        for row in 0..dim {
479            if row == col {
480                continue;
481            }
482            let factor = aug[row][col];
483            let pivot_row: Vec<f32> = aug[col][..total_cols].to_vec();
484            for (aug_val, &pivot_val) in aug[row][..total_cols].iter_mut().zip(pivot_row.iter()) {
485                *aug_val -= pivot_val * factor;
486            }
487        }
488    }
489
490    // Extract W
491    (0..dim).map(|i| aug[i][dim..].to_vec()).collect()
492}
493
494// ─────────────────────────────────────────────
495// Tests
496// ─────────────────────────────────────────────
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    fn make_embeddings(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
503        let mut state = seed.wrapping_add(1);
504        (0..n)
505            .map(|_| {
506                (0..dim)
507                    .map(|_| {
508                        state = state
509                            .wrapping_mul(6_364_136_223_846_793_005)
510                            .wrapping_add(1_442_695_040_888_963_407);
511                        ((state >> 33) as f32 / u32::MAX as f32) - 0.5
512                    })
513                    .collect()
514            })
515            .collect()
516    }
517
518    fn make_seed_pairs(n: usize) -> Vec<AlignmentPair> {
519        (0..n)
520            .map(|i| AlignmentPair {
521                source_idx: i,
522                target_idx: i,
523                confidence: 1.0,
524            })
525            .collect()
526    }
527
528    // ── AlignmentTransform ────────────────────
529
530    #[test]
531    fn test_identity_transform() {
532        let t = AlignmentTransform::identity(4);
533        let v = vec![1.0_f32, 2.0, 3.0, 4.0];
534        let out = t.apply(&v);
535        for (a, b) in v.iter().zip(out.iter()) {
536            assert!((a - b).abs() < 1e-6, "identity should preserve values");
537        }
538    }
539
540    #[test]
541    fn test_orthogonal_transform_apply() {
542        let mat = vec![vec![0.0_f32, 1.0], vec![1.0, 0.0]];
543        let t = AlignmentTransform::Orthogonal(mat);
544        let v = vec![3.0_f32, 7.0];
545        let out = t.apply(&v);
546        assert!((out[0] - 7.0).abs() < 1e-6);
547        assert!((out[1] - 3.0).abs() < 1e-6);
548    }
549
550    #[test]
551    fn test_identity_transform_has_matrix() {
552        let t = AlignmentTransform::identity(3);
553        assert!(t.matrix().is_some());
554    }
555
556    #[test]
557    fn test_identity_enum_no_matrix() {
558        let t = AlignmentTransform::Identity;
559        assert!(t.matrix().is_none());
560    }
561
562    // ── EmbeddingAlignment ────────────────────
563
564    #[test]
565    fn test_alignment_creation() {
566        let src = make_embeddings(5, 4, 1);
567        let tgt = make_embeddings(5, 4, 2);
568        let aligner = EmbeddingAlignment::new(src.clone(), tgt.clone());
569        assert_eq!(aligner.dim, 4);
570        assert_eq!(aligner.source_embeddings.len(), 5);
571        assert_eq!(aligner.target_embeddings.len(), 5);
572    }
573
574    #[test]
575    fn test_orthogonal_procrustes_produces_result() {
576        let src = make_embeddings(6, 4, 10);
577        let tgt = make_embeddings(6, 4, 20);
578        let aligner = EmbeddingAlignment::new(src, tgt);
579        let seeds = make_seed_pairs(3);
580        let result = aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, &seeds);
581        assert!(result.alignment_score.is_finite());
582    }
583
584    #[test]
585    fn test_linear_transform_produces_result() {
586        let src = make_embeddings(6, 4, 30);
587        let tgt = make_embeddings(6, 4, 40);
588        let aligner = EmbeddingAlignment::new(src, tgt);
589        let seeds = make_seed_pairs(3);
590        let result = aligner.find_alignment(AlignmentMethod::LinearTransformation, &seeds);
591        assert!(result.alignment_score.is_finite());
592    }
593
594    #[test]
595    fn test_bidirectional_matching_produces_result() {
596        let src = make_embeddings(8, 4, 50);
597        let tgt = make_embeddings(8, 4, 60);
598        let aligner = EmbeddingAlignment::new(src, tgt);
599        let seeds = make_seed_pairs(2);
600        let result = aligner.find_alignment(AlignmentMethod::BidirectionalMatching, &seeds);
601        assert!(result.alignment_score >= -1.0 && result.alignment_score <= 1.0 + 1e-6);
602    }
603
604    #[test]
605    fn test_apply_transform_correct_count() {
606        let src = make_embeddings(5, 4, 70);
607        let tgt = make_embeddings(5, 4, 80);
608        let aligner = EmbeddingAlignment::new(src, tgt);
609        let t = AlignmentTransform::identity(4);
610        let out = aligner.apply_transform(&t);
611        assert_eq!(out.len(), 5);
612        assert_eq!(out[0].len(), 4);
613    }
614
615    #[test]
616    fn test_alignment_with_empty_seeds() {
617        let src = make_embeddings(4, 4, 90);
618        let tgt = make_embeddings(4, 4, 91);
619        let aligner = EmbeddingAlignment::new(src, tgt);
620        let result = aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, &[]);
621        // Should not panic; alignment_score may be 0
622        assert!(result.alignment_score.is_finite());
623    }
624
625    #[test]
626    fn test_identical_spaces_score() {
627        // If source == target and seeds are identity pairs, alignment score ~ 1.0
628        let embs = make_embeddings(5, 4, 100);
629        let aligner = EmbeddingAlignment::new(embs.clone(), embs.clone());
630        let seeds = make_seed_pairs(5);
631        let result = aligner.find_alignment(AlignmentMethod::BidirectionalMatching, &seeds);
632        // Mean cosine similarity with identical embeddings under identity = 1.0
633        assert!(
634            result.alignment_score > 0.9,
635            "same-space alignment should score near 1.0: {}",
636            result.alignment_score
637        );
638    }
639
640    #[test]
641    fn test_alignment_result_has_transform() {
642        let src = make_embeddings(4, 3, 111);
643        let tgt = make_embeddings(4, 3, 222);
644        let aligner = EmbeddingAlignment::new(src, tgt);
645        let seeds = make_seed_pairs(2);
646        let result = aligner.find_alignment(AlignmentMethod::OrthogonalProcrustes, &seeds);
647        // Just check the transform variant is not Identity (seeds were provided)
648        matches!(result.transform, AlignmentTransform::Orthogonal(_));
649    }
650
651    // ── CrossLingualAligner ───────────────────
652
653    #[test]
654    fn test_cross_lingual_creation() {
655        let aligner = CrossLingualAligner::new("en");
656        assert_eq!(aligner.pivot_language(), "en");
657    }
658
659    #[test]
660    fn test_cross_lingual_add_language() {
661        let mut aligner = CrossLingualAligner::new("en");
662        aligner.add_language("fr", make_embeddings(5, 4, 1));
663        aligner.add_language("en", make_embeddings(5, 4, 2));
664        let langs = aligner.languages();
665        assert!(langs.contains(&"fr"));
666        assert!(langs.contains(&"en"));
667    }
668
669    #[test]
670    fn test_cross_lingual_align_to_pivot() {
671        let mut aligner = CrossLingualAligner::new("en");
672        aligner.add_language("en", make_embeddings(8, 4, 10));
673        aligner.add_language("fr", make_embeddings(8, 4, 20));
674        let seeds = make_seed_pairs(3);
675        let result = aligner.align_to_pivot("fr", &seeds);
676        assert!(result.is_some(), "should return alignment result");
677        let r = result.unwrap();
678        assert!(r.alignment_score.is_finite());
679    }
680
681    #[test]
682    fn test_cross_lingual_align_missing_language() {
683        let aligner = CrossLingualAligner::new("en");
684        let result = aligner.align_to_pivot("de", &[]);
685        assert!(result.is_none(), "missing language should return None");
686    }
687
688    #[test]
689    fn test_cross_lingual_translate_same_language() {
690        let mut aligner = CrossLingualAligner::new("en");
691        aligner.add_language("en", make_embeddings(5, 4, 1));
692        let v = vec![1.0_f32, 2.0, 3.0, 4.0];
693        let out = aligner.translate(&v, "en", "en");
694        assert!(out.is_some());
695        assert_eq!(out.unwrap(), v);
696    }
697
698    #[test]
699    fn test_cross_lingual_translate_missing_returns_none() {
700        let aligner = CrossLingualAligner::new("en");
701        let v = vec![0.0_f32; 4];
702        let out = aligner.translate(&v, "de", "fr");
703        assert!(out.is_none());
704    }
705
706    // ── Utility functions ─────────────────────
707
708    #[test]
709    fn test_cosine_similarity_identical() {
710        let v = vec![1.0_f32, 0.0, 0.0];
711        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
712    }
713
714    #[test]
715    fn test_cosine_similarity_orthogonal() {
716        let a = vec![1.0_f32, 0.0];
717        let b = vec![0.0_f32, 1.0];
718        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
719    }
720
721    #[test]
722    fn test_cosine_similarity_zero_vector() {
723        let a = vec![0.0_f32, 0.0];
724        let b = vec![1.0_f32, 0.0];
725        assert_eq!(cosine_similarity(&a, &b), 0.0);
726    }
727}