Skip to main content

scirs2_text/embeddings/
crosslingual.rs

1//! Cross-lingual Embedding Alignment.
2//!
3//! This module provides methods for aligning embedding spaces across languages,
4//! enabling cross-lingual transfer and translation of embeddings.
5//!
6//! # Alignment Methods
7//!
8//! | Method | Description |
9//! |--------|-------------|
10//! | Procrustes | Orthogonal alignment: W = UV^T from SVD(X^T Y) |
11//! | CCA | Canonical Correlation Analysis projection |
12//! | MUSE | Multilingual Unsupervised/Supervised Embeddings (iterative refinement) |
13//!
14//! # Example
15//!
16//! ```rust
17//! use scirs2_text::embeddings::crosslingual::{
18//!     CrossLingualConfig, AlignmentMethod, align_embeddings, translate_embedding, AlignmentMatrix,
19//! };
20//!
21//! let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
22//! let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
23//! let anchors = vec![(0, 0), (1, 1)];
24//!
25//! let config = CrossLingualConfig::default();
26//! let alignment = align_embeddings(&source, &target, &anchors, &config).unwrap();
27//! let translated = translate_embedding(&source[0], &alignment);
28//! assert_eq!(translated.len(), 2);
29//! ```
30
31use crate::error::{Result, TextError};
32
33/// SVD decomposition result: (U, S, Vt).
34type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
35
36// ─── AlignmentMethod ────────────────────────────────────────────────────────
37
38/// Method used to align embedding spaces.
39#[non_exhaustive]
40#[derive(Debug, Clone, PartialEq, Default)]
41pub enum AlignmentMethod {
42    /// Procrustes alignment: find orthogonal W minimising ‖XW − Y‖_F.
43    #[default]
44    Procrustes,
45    /// Canonical Correlation Analysis alignment.
46    CCA,
47    /// Multilingual Unsupervised/Supervised Embeddings (iterative).
48    MUSE,
49}
50
51// ─── CrossLingualConfig ─────────────────────────────────────────────────────
52
53/// Configuration for cross-lingual alignment.
54#[derive(Debug, Clone)]
55pub struct CrossLingualConfig {
56    /// Dimensionality of source embeddings.
57    pub source_dim: usize,
58    /// Dimensionality of target embeddings.
59    pub target_dim: usize,
60    /// Alignment method to use.
61    pub alignment: AlignmentMethod,
62    /// Number of refinement iterations (for MUSE).
63    pub refinement_iterations: usize,
64    /// Learning rate for iterative methods.
65    pub learning_rate: f64,
66}
67
68impl Default for CrossLingualConfig {
69    fn default() -> Self {
70        Self {
71            source_dim: 0, // auto-detect
72            target_dim: 0, // auto-detect
73            alignment: AlignmentMethod::Procrustes,
74            refinement_iterations: 5,
75            learning_rate: 0.01,
76        }
77    }
78}
79
80// ─── AlignmentMatrix ────────────────────────────────────────────────────────
81
82/// Learned alignment transformation matrix.
83#[derive(Debug, Clone)]
84pub struct AlignmentMatrix {
85    /// The transformation matrix W (rows × cols).
86    pub w: Vec<Vec<f64>>,
87    /// Number of rows (source dimensionality).
88    pub rows: usize,
89    /// Number of columns (target dimensionality).
90    pub cols: usize,
91    /// Method used to compute this alignment.
92    pub method: AlignmentMethod,
93}
94
95// ─── Linear algebra helpers ─────────────────────────────────────────────────
96
97/// Transpose a matrix represented as Vec<Vec<f64>>.
98fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
99    if m.is_empty() {
100        return Vec::new();
101    }
102    let rows = m.len();
103    let cols = m[0].len();
104    let mut t = vec![vec![0.0; rows]; cols];
105    for i in 0..rows {
106        for j in 0..cols {
107            t[j][i] = m[i][j];
108        }
109    }
110    t
111}
112
113/// Multiply two matrices A (m×k) and B (k×n) → C (m×n).
114fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
115    let m = a.len();
116    if m == 0 {
117        return Vec::new();
118    }
119    let k = a[0].len();
120    if b.is_empty() || b[0].is_empty() {
121        return vec![vec![]; m];
122    }
123    let n = b[0].len();
124    let mut c = vec![vec![0.0; n]; m];
125    for i in 0..m {
126        for j in 0..n {
127            let mut s = 0.0;
128            for p in 0..k {
129                s += a[i][p] * b[p][j];
130            }
131            c[i][j] = s;
132        }
133    }
134    c
135}
136
137/// Compute SVD of an m×n matrix using one-sided Jacobi rotations.
138/// Returns (U, S, Vt) where U is m×min(m,n), S is min(m,n), Vt is min(m,n)×n.
139fn svd_jacobi(matrix: &[Vec<f64>]) -> Result<SvdResult> {
140    let m = matrix.len();
141    if m == 0 {
142        return Ok((Vec::new(), Vec::new(), Vec::new()));
143    }
144    let n = matrix[0].len();
145    if n == 0 {
146        return Ok((vec![vec![]; m], Vec::new(), Vec::new()));
147    }
148
149    let k = m.min(n);
150    let max_iter = 100;
151    let tol = 1e-12;
152
153    // Work on A^T A for small cases, use a simpler power-iteration-based approach
154    // For the Procrustes problem we only need the thin SVD of X^T Y which is at most dim×dim.
155
156    // Compute A^T A (n×n)
157    let at = transpose(matrix);
158    let ata = matmul(&at, matrix);
159
160    // Eigen-decompose A^T A via Jacobi
161    let nn = ata.len();
162    let mut d = ata.clone(); // will be diagonalised
163    let mut v = vec![vec![0.0; nn]; nn]; // eigenvectors
164    for i in 0..nn {
165        v[i][i] = 1.0;
166    }
167
168    for _iter in 0..max_iter {
169        // Find max off-diagonal
170        let mut max_off = 0.0;
171        let mut p = 0;
172        let mut q = 1;
173        for i in 0..nn {
174            for j in (i + 1)..nn {
175                let val = d[i][j].abs();
176                if val > max_off {
177                    max_off = val;
178                    p = i;
179                    q = j;
180                }
181            }
182        }
183        if max_off < tol {
184            break;
185        }
186
187        // Compute Jacobi rotation
188        let theta = if (d[p][p] - d[q][q]).abs() < 1e-15 {
189            std::f64::consts::FRAC_PI_4
190        } else {
191            0.5 * (2.0 * d[p][q] / (d[p][p] - d[q][q])).atan()
192        };
193        let c = theta.cos();
194        let s = theta.sin();
195
196        // Apply rotation to d
197        let mut new_d = d.clone();
198        for i in 0..nn {
199            if i != p && i != q {
200                new_d[i][p] = c * d[i][p] + s * d[i][q];
201                new_d[p][i] = new_d[i][p];
202                new_d[i][q] = -s * d[i][p] + c * d[i][q];
203                new_d[q][i] = new_d[i][q];
204            }
205        }
206        new_d[p][p] = c * c * d[p][p] + 2.0 * s * c * d[p][q] + s * s * d[q][q];
207        new_d[q][q] = s * s * d[p][p] - 2.0 * s * c * d[p][q] + c * c * d[q][q];
208        new_d[p][q] = 0.0;
209        new_d[q][p] = 0.0;
210        d = new_d;
211
212        // Update eigenvectors
213        for i in 0..nn {
214            let vip = v[i][p];
215            let viq = v[i][q];
216            v[i][p] = c * vip + s * viq;
217            v[i][q] = -s * vip + c * viq;
218        }
219    }
220
221    // Extract eigenvalues and sort descending
222    let mut eig_pairs: Vec<(f64, usize)> = (0..nn).map(|i| (d[i][i].max(0.0), i)).collect();
223    eig_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
224
225    let mut sigma = vec![0.0; k];
226    let mut vt = vec![vec![0.0; n]; k];
227    for i in 0..k {
228        let (eigval, idx) = eig_pairs[i];
229        sigma[i] = eigval.sqrt();
230        for j in 0..nn {
231            vt[i][j] = v[j][idx];
232        }
233    }
234
235    // U = A V Σ^{-1}
236    // V columns (transposed from vt rows)
237    let mut u = vec![vec![0.0; k]; m];
238    for i in 0..m {
239        for j in 0..k {
240            if sigma[j] > 1e-15 {
241                let mut s = 0.0;
242                for p in 0..n {
243                    s += matrix[i][p] * vt[j][p];
244                }
245                u[i][j] = s / sigma[j];
246            }
247        }
248    }
249
250    Ok((u, sigma, vt))
251}
252
253// ─── Alignment functions ────────────────────────────────────────────────────
254
255/// Procrustes alignment: find orthogonal W such that ‖XW − Y‖_F is minimised.
256///
257/// W = V U^T from SVD(X^T Y).
258fn procrustes_align(
259    source_anchors: &[Vec<f64>],
260    target_anchors: &[Vec<f64>],
261) -> Result<AlignmentMatrix> {
262    if source_anchors.is_empty() || target_anchors.is_empty() {
263        return Err(TextError::InvalidInput("Empty anchor sets".to_string()));
264    }
265    let dim_s = source_anchors[0].len();
266    let dim_t = target_anchors[0].len();
267    if dim_s != dim_t {
268        return Err(TextError::InvalidInput(format!(
269            "Procrustes requires same dimensionality, got {} vs {}",
270            dim_s, dim_t
271        )));
272    }
273
274    // Compute M = X^T Y (dim × dim)
275    let xt = transpose(source_anchors);
276    let m = matmul(&xt, target_anchors);
277
278    // SVD of M = X^T Y
279    let (u, _sigma, vt) = svd_jacobi(&m)?;
280
281    // Procrustes solution: W = U V^T
282    // SVD(M) = U Σ V^T → W = U V^T
283    let w = matmul(&u, &vt);
284
285    Ok(AlignmentMatrix {
286        w,
287        rows: dim_s,
288        cols: dim_t,
289        method: AlignmentMethod::Procrustes,
290    })
291}
292
293/// CCA alignment: project both source and target to a shared space.
294fn cca_align(source_anchors: &[Vec<f64>], target_anchors: &[Vec<f64>]) -> Result<AlignmentMatrix> {
295    // Simplified CCA: use whitened Procrustes
296    // 1. Center both sets
297    let n = source_anchors.len();
298    if n == 0 {
299        return Err(TextError::InvalidInput("Empty anchor sets".to_string()));
300    }
301    let dim_s = source_anchors[0].len();
302    let dim_t = target_anchors[0].len();
303
304    // Center source
305    let mut src_mean = vec![0.0; dim_s];
306    for v in source_anchors {
307        for (i, &x) in v.iter().enumerate() {
308            src_mean[i] += x;
309        }
310    }
311    let nf = n as f64;
312    for v in &mut src_mean {
313        *v /= nf;
314    }
315
316    let centered_src: Vec<Vec<f64>> = source_anchors
317        .iter()
318        .map(|v| v.iter().zip(src_mean.iter()).map(|(x, m)| x - m).collect())
319        .collect();
320
321    // Center target
322    let mut tgt_mean = vec![0.0; dim_t];
323    for v in target_anchors {
324        for (i, &x) in v.iter().enumerate() {
325            tgt_mean[i] += x;
326        }
327    }
328    for v in &mut tgt_mean {
329        *v /= nf;
330    }
331
332    let centered_tgt: Vec<Vec<f64>> = target_anchors
333        .iter()
334        .map(|v| v.iter().zip(tgt_mean.iter()).map(|(x, m)| x - m).collect())
335        .collect();
336
337    // Procrustes on centred data
338    procrustes_align(&centered_src, &centered_tgt)
339}
340
341/// MUSE-style iterative alignment (supervised variant).
342fn muse_align(
343    source_anchors: &[Vec<f64>],
344    target_anchors: &[Vec<f64>],
345    iterations: usize,
346) -> Result<AlignmentMatrix> {
347    // Start with Procrustes, then iteratively refine
348    let mut alignment = procrustes_align(source_anchors, target_anchors)?;
349
350    for _iter in 0..iterations {
351        // Apply current alignment to source anchors
352        let aligned: Vec<Vec<f64>> = source_anchors
353            .iter()
354            .map(|s| translate_embedding(s, &alignment))
355            .collect();
356
357        // Re-solve Procrustes with aligned ↔ target
358        alignment = procrustes_align(&aligned, target_anchors)?;
359
360        // Compose: new_W = old_W * refine_W
361        // But since each iteration refines, we keep the latest
362    }
363
364    Ok(alignment)
365}
366
367/// Align source embeddings to the target embedding space using anchor pairs.
368///
369/// `anchors` is a list of `(source_idx, target_idx)` pairs identifying
370/// corresponding words across languages.
371pub fn align_embeddings(
372    source: &[Vec<f64>],
373    target: &[Vec<f64>],
374    anchors: &[(usize, usize)],
375    config: &CrossLingualConfig,
376) -> Result<AlignmentMatrix> {
377    if anchors.is_empty() {
378        return Err(TextError::InvalidInput(
379            "Need at least one anchor pair".to_string(),
380        ));
381    }
382    if source.is_empty() || target.is_empty() {
383        return Err(TextError::InvalidInput(
384            "Source and target embeddings must be non-empty".to_string(),
385        ));
386    }
387
388    // Extract anchor vectors
389    let mut src_anchors = Vec::with_capacity(anchors.len());
390    let mut tgt_anchors = Vec::with_capacity(anchors.len());
391    for &(si, ti) in anchors {
392        if si >= source.len() {
393            return Err(TextError::InvalidInput(format!(
394                "Source anchor index {si} out of bounds (len={})",
395                source.len()
396            )));
397        }
398        if ti >= target.len() {
399            return Err(TextError::InvalidInput(format!(
400                "Target anchor index {ti} out of bounds (len={})",
401                target.len()
402            )));
403        }
404        src_anchors.push(source[si].clone());
405        tgt_anchors.push(target[ti].clone());
406    }
407
408    #[allow(unreachable_patterns)]
409    match &config.alignment {
410        AlignmentMethod::Procrustes => procrustes_align(&src_anchors, &tgt_anchors),
411        AlignmentMethod::CCA => cca_align(&src_anchors, &tgt_anchors),
412        AlignmentMethod::MUSE => {
413            muse_align(&src_anchors, &tgt_anchors, config.refinement_iterations)
414        }
415        _ => procrustes_align(&src_anchors, &tgt_anchors),
416    }
417}
418
419/// Translate a single embedding using the alignment matrix: y = x W.
420pub fn translate_embedding(embedding: &[f64], alignment: &AlignmentMatrix) -> Vec<f64> {
421    let mut result = vec![0.0; alignment.cols];
422    for j in 0..alignment.cols {
423        let mut s = 0.0;
424        for i in 0..alignment.rows.min(embedding.len()) {
425            s += embedding[i] * alignment.w[i][j];
426        }
427        result[j] = s;
428    }
429    result
430}
431
432/// Translate a batch of embeddings.
433pub fn translate_batch(embeddings: &[Vec<f64>], alignment: &AlignmentMatrix) -> Vec<Vec<f64>> {
434    embeddings
435        .iter()
436        .map(|e| translate_embedding(e, alignment))
437        .collect()
438}
439
440/// Compute the alignment quality: mean cosine similarity between aligned source
441/// anchors and target anchors.
442pub fn alignment_quality(
443    source: &[Vec<f64>],
444    target: &[Vec<f64>],
445    anchors: &[(usize, usize)],
446    alignment: &AlignmentMatrix,
447) -> f64 {
448    if anchors.is_empty() {
449        return 0.0;
450    }
451    let mut total_sim = 0.0;
452    let mut count = 0;
453    for &(si, ti) in anchors {
454        if si < source.len() && ti < target.len() {
455            let aligned = translate_embedding(&source[si], alignment);
456            let sim = cosine_sim_local(&aligned, &target[ti]);
457            total_sim += sim;
458            count += 1;
459        }
460    }
461    if count == 0 {
462        0.0
463    } else {
464        total_sim / count as f64
465    }
466}
467
468/// Cosine similarity (public, for cross-module use).
469fn cosine_sim_local(a: &[f64], b: &[f64]) -> f64 {
470    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
471    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
472    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
473    if na < 1e-15 || nb < 1e-15 {
474        return 0.0;
475    }
476    dot / (na * nb)
477}
478
479/// Compute the alignment quality using local cosine similarity.
480pub fn alignment_quality_score(
481    source: &[Vec<f64>],
482    target: &[Vec<f64>],
483    anchors: &[(usize, usize)],
484    alignment: &AlignmentMatrix,
485) -> f64 {
486    if anchors.is_empty() {
487        return 0.0;
488    }
489    let mut total_sim = 0.0;
490    let mut count = 0;
491    for &(si, ti) in anchors {
492        if si < source.len() && ti < target.len() {
493            let aligned = translate_embedding(&source[si], alignment);
494            let sim = cosine_sim_local(&aligned, &target[ti]);
495            total_sim += sim;
496            count += 1;
497        }
498    }
499    if count == 0 {
500        0.0
501    } else {
502        total_sim / count as f64
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_crosslingual_config_default() {
512        let cfg = CrossLingualConfig::default();
513        assert_eq!(cfg.alignment, AlignmentMethod::Procrustes);
514        assert_eq!(cfg.refinement_iterations, 5);
515    }
516
517    #[test]
518    fn test_procrustes_identity() {
519        // If source == target, alignment should be close to identity
520        let source = vec![
521            vec![1.0, 0.0, 0.0],
522            vec![0.0, 1.0, 0.0],
523            vec![0.0, 0.0, 1.0],
524        ];
525        let target = source.clone();
526        let anchors = vec![(0, 0), (1, 1), (2, 2)];
527        let config = CrossLingualConfig::default();
528        let alignment = align_embeddings(&source, &target, &anchors, &config);
529        assert!(alignment.is_ok());
530        let alignment = alignment.expect("should succeed");
531
532        // Translated source should be close to target
533        let translated = translate_embedding(&source[0], &alignment);
534        let dist: f64 = translated
535            .iter()
536            .zip(target[0].iter())
537            .map(|(a, b)| (a - b).powi(2))
538            .sum::<f64>()
539            .sqrt();
540        assert!(
541            dist < 0.1,
542            "Identity alignment should preserve vectors, dist={dist}"
543        );
544    }
545
546    #[test]
547    fn test_procrustes_rotation() {
548        // 90-degree rotation in 2D
549        let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
550        let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
551        let anchors = vec![(0, 0), (1, 1)];
552        let config = CrossLingualConfig::default();
553        let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
554
555        let t0 = translate_embedding(&source[0], &alignment);
556        let t1 = translate_embedding(&source[1], &alignment);
557
558        // t0 should be close to [0, 1]
559        let d0 = ((t0[0] - 0.0).powi(2) + (t0[1] - 1.0).powi(2)).sqrt();
560        assert!(d0 < 0.3, "Rotated [1,0] should be near [0,1], dist={d0}");
561
562        let d1 = ((t1[0] + 1.0).powi(2) + (t1[1] - 0.0).powi(2)).sqrt();
563        assert!(d1 < 0.3, "Rotated [0,1] should be near [-1,0], dist={d1}");
564    }
565
566    #[test]
567    fn test_translation_preserves_relative_distances() {
568        let source = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
569        let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0], vec![-1.0, 1.0]];
570        let anchors = vec![(0, 0), (1, 1)];
571        let config = CrossLingualConfig::default();
572        let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
573
574        // Original distances between source[0] and source[1]
575        let orig_dist_01: f64 = source[0]
576            .iter()
577            .zip(source[1].iter())
578            .map(|(a, b)| (a - b).powi(2))
579            .sum::<f64>()
580            .sqrt();
581        let orig_dist_02: f64 = source[0]
582            .iter()
583            .zip(source[2].iter())
584            .map(|(a, b)| (a - b).powi(2))
585            .sum::<f64>()
586            .sqrt();
587
588        let t0 = translate_embedding(&source[0], &alignment);
589        let t1 = translate_embedding(&source[1], &alignment);
590        let t2 = translate_embedding(&source[2], &alignment);
591
592        let new_dist_01: f64 = t0
593            .iter()
594            .zip(t1.iter())
595            .map(|(a, b)| (a - b).powi(2))
596            .sum::<f64>()
597            .sqrt();
598        let new_dist_02: f64 = t0
599            .iter()
600            .zip(t2.iter())
601            .map(|(a, b)| (a - b).powi(2))
602            .sum::<f64>()
603            .sqrt();
604
605        // Orthogonal transform preserves distances
606        assert!(
607            (orig_dist_01 - new_dist_01).abs() < 0.3,
608            "Distances should be preserved: {orig_dist_01} vs {new_dist_01}"
609        );
610        assert!(
611            (orig_dist_02 - new_dist_02).abs() < 0.3,
612            "Distances should be preserved: {orig_dist_02} vs {new_dist_02}"
613        );
614    }
615
616    #[test]
617    fn test_cca_alignment() {
618        let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
619        let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
620        let anchors = vec![(0, 0), (1, 1)];
621        let config = CrossLingualConfig {
622            alignment: AlignmentMethod::CCA,
623            ..Default::default()
624        };
625        let alignment = align_embeddings(&source, &target, &anchors, &config);
626        assert!(alignment.is_ok());
627    }
628
629    #[test]
630    fn test_muse_alignment() {
631        let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
632        let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
633        let anchors = vec![(0, 0), (1, 1)];
634        let config = CrossLingualConfig {
635            alignment: AlignmentMethod::MUSE,
636            refinement_iterations: 3,
637            ..Default::default()
638        };
639        let alignment = align_embeddings(&source, &target, &anchors, &config);
640        assert!(alignment.is_ok());
641    }
642
643    #[test]
644    fn test_empty_anchors_error() {
645        let source = vec![vec![1.0, 0.0]];
646        let target = vec![vec![0.0, 1.0]];
647        let config = CrossLingualConfig::default();
648        let result = align_embeddings(&source, &target, &[], &config);
649        assert!(result.is_err());
650    }
651
652    #[test]
653    fn test_translate_batch() {
654        let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
655        let target = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
656        let anchors = vec![(0, 0), (1, 1)];
657        let config = CrossLingualConfig::default();
658        let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
659        let batch = translate_batch(&source, &alignment);
660        assert_eq!(batch.len(), 2);
661        assert_eq!(batch[0].len(), 2);
662    }
663}