use crate::error::{Result, TextError};
type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Default)]
pub enum AlignmentMethod {
#[default]
Procrustes,
CCA,
MUSE,
}
#[derive(Debug, Clone)]
pub struct CrossLingualConfig {
pub source_dim: usize,
pub target_dim: usize,
pub alignment: AlignmentMethod,
pub refinement_iterations: usize,
pub learning_rate: f64,
}
impl Default for CrossLingualConfig {
fn default() -> Self {
Self {
source_dim: 0, target_dim: 0, alignment: AlignmentMethod::Procrustes,
refinement_iterations: 5,
learning_rate: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct AlignmentMatrix {
pub w: Vec<Vec<f64>>,
pub rows: usize,
pub cols: usize,
pub method: AlignmentMethod,
}
fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
if m.is_empty() {
return Vec::new();
}
let rows = m.len();
let cols = m[0].len();
let mut t = vec![vec![0.0; rows]; cols];
for i in 0..rows {
for j in 0..cols {
t[j][i] = m[i][j];
}
}
t
}
fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
let m = a.len();
if m == 0 {
return Vec::new();
}
let k = a[0].len();
if b.is_empty() || b[0].is_empty() {
return vec![vec![]; m];
}
let n = b[0].len();
let mut c = vec![vec![0.0; n]; m];
for i in 0..m {
for j in 0..n {
let mut s = 0.0;
for p in 0..k {
s += a[i][p] * b[p][j];
}
c[i][j] = s;
}
}
c
}
fn svd_jacobi(matrix: &[Vec<f64>]) -> Result<SvdResult> {
let m = matrix.len();
if m == 0 {
return Ok((Vec::new(), Vec::new(), Vec::new()));
}
let n = matrix[0].len();
if n == 0 {
return Ok((vec![vec![]; m], Vec::new(), Vec::new()));
}
let k = m.min(n);
let max_iter = 100;
let tol = 1e-12;
let at = transpose(matrix);
let ata = matmul(&at, matrix);
let nn = ata.len();
let mut d = ata.clone(); let mut v = vec![vec![0.0; nn]; nn]; for i in 0..nn {
v[i][i] = 1.0;
}
for _iter in 0..max_iter {
let mut max_off = 0.0;
let mut p = 0;
let mut q = 1;
for i in 0..nn {
for j in (i + 1)..nn {
let val = d[i][j].abs();
if val > max_off {
max_off = val;
p = i;
q = j;
}
}
}
if max_off < tol {
break;
}
let theta = if (d[p][p] - d[q][q]).abs() < 1e-15 {
std::f64::consts::FRAC_PI_4
} else {
0.5 * (2.0 * d[p][q] / (d[p][p] - d[q][q])).atan()
};
let c = theta.cos();
let s = theta.sin();
let mut new_d = d.clone();
for i in 0..nn {
if i != p && i != q {
new_d[i][p] = c * d[i][p] + s * d[i][q];
new_d[p][i] = new_d[i][p];
new_d[i][q] = -s * d[i][p] + c * d[i][q];
new_d[q][i] = new_d[i][q];
}
}
new_d[p][p] = c * c * d[p][p] + 2.0 * s * c * d[p][q] + s * s * d[q][q];
new_d[q][q] = s * s * d[p][p] - 2.0 * s * c * d[p][q] + c * c * d[q][q];
new_d[p][q] = 0.0;
new_d[q][p] = 0.0;
d = new_d;
for i in 0..nn {
let vip = v[i][p];
let viq = v[i][q];
v[i][p] = c * vip + s * viq;
v[i][q] = -s * vip + c * viq;
}
}
let mut eig_pairs: Vec<(f64, usize)> = (0..nn).map(|i| (d[i][i].max(0.0), i)).collect();
eig_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut sigma = vec![0.0; k];
let mut vt = vec![vec![0.0; n]; k];
for i in 0..k {
let (eigval, idx) = eig_pairs[i];
sigma[i] = eigval.sqrt();
for j in 0..nn {
vt[i][j] = v[j][idx];
}
}
let mut u = vec![vec![0.0; k]; m];
for i in 0..m {
for j in 0..k {
if sigma[j] > 1e-15 {
let mut s = 0.0;
for p in 0..n {
s += matrix[i][p] * vt[j][p];
}
u[i][j] = s / sigma[j];
}
}
}
Ok((u, sigma, vt))
}
fn procrustes_align(
source_anchors: &[Vec<f64>],
target_anchors: &[Vec<f64>],
) -> Result<AlignmentMatrix> {
if source_anchors.is_empty() || target_anchors.is_empty() {
return Err(TextError::InvalidInput("Empty anchor sets".to_string()));
}
let dim_s = source_anchors[0].len();
let dim_t = target_anchors[0].len();
if dim_s != dim_t {
return Err(TextError::InvalidInput(format!(
"Procrustes requires same dimensionality, got {} vs {}",
dim_s, dim_t
)));
}
let xt = transpose(source_anchors);
let m = matmul(&xt, target_anchors);
let (u, _sigma, vt) = svd_jacobi(&m)?;
let w = matmul(&u, &vt);
Ok(AlignmentMatrix {
w,
rows: dim_s,
cols: dim_t,
method: AlignmentMethod::Procrustes,
})
}
fn cca_align(source_anchors: &[Vec<f64>], target_anchors: &[Vec<f64>]) -> Result<AlignmentMatrix> {
let n = source_anchors.len();
if n == 0 {
return Err(TextError::InvalidInput("Empty anchor sets".to_string()));
}
let dim_s = source_anchors[0].len();
let dim_t = target_anchors[0].len();
let mut src_mean = vec![0.0; dim_s];
for v in source_anchors {
for (i, &x) in v.iter().enumerate() {
src_mean[i] += x;
}
}
let nf = n as f64;
for v in &mut src_mean {
*v /= nf;
}
let centered_src: Vec<Vec<f64>> = source_anchors
.iter()
.map(|v| v.iter().zip(src_mean.iter()).map(|(x, m)| x - m).collect())
.collect();
let mut tgt_mean = vec![0.0; dim_t];
for v in target_anchors {
for (i, &x) in v.iter().enumerate() {
tgt_mean[i] += x;
}
}
for v in &mut tgt_mean {
*v /= nf;
}
let centered_tgt: Vec<Vec<f64>> = target_anchors
.iter()
.map(|v| v.iter().zip(tgt_mean.iter()).map(|(x, m)| x - m).collect())
.collect();
procrustes_align(¢ered_src, ¢ered_tgt)
}
fn muse_align(
source_anchors: &[Vec<f64>],
target_anchors: &[Vec<f64>],
iterations: usize,
) -> Result<AlignmentMatrix> {
let mut alignment = procrustes_align(source_anchors, target_anchors)?;
for _iter in 0..iterations {
let aligned: Vec<Vec<f64>> = source_anchors
.iter()
.map(|s| translate_embedding(s, &alignment))
.collect();
alignment = procrustes_align(&aligned, target_anchors)?;
}
Ok(alignment)
}
pub fn align_embeddings(
source: &[Vec<f64>],
target: &[Vec<f64>],
anchors: &[(usize, usize)],
config: &CrossLingualConfig,
) -> Result<AlignmentMatrix> {
if anchors.is_empty() {
return Err(TextError::InvalidInput(
"Need at least one anchor pair".to_string(),
));
}
if source.is_empty() || target.is_empty() {
return Err(TextError::InvalidInput(
"Source and target embeddings must be non-empty".to_string(),
));
}
let mut src_anchors = Vec::with_capacity(anchors.len());
let mut tgt_anchors = Vec::with_capacity(anchors.len());
for &(si, ti) in anchors {
if si >= source.len() {
return Err(TextError::InvalidInput(format!(
"Source anchor index {si} out of bounds (len={})",
source.len()
)));
}
if ti >= target.len() {
return Err(TextError::InvalidInput(format!(
"Target anchor index {ti} out of bounds (len={})",
target.len()
)));
}
src_anchors.push(source[si].clone());
tgt_anchors.push(target[ti].clone());
}
#[allow(unreachable_patterns)]
match &config.alignment {
AlignmentMethod::Procrustes => procrustes_align(&src_anchors, &tgt_anchors),
AlignmentMethod::CCA => cca_align(&src_anchors, &tgt_anchors),
AlignmentMethod::MUSE => {
muse_align(&src_anchors, &tgt_anchors, config.refinement_iterations)
}
_ => procrustes_align(&src_anchors, &tgt_anchors),
}
}
pub fn translate_embedding(embedding: &[f64], alignment: &AlignmentMatrix) -> Vec<f64> {
let mut result = vec![0.0; alignment.cols];
for j in 0..alignment.cols {
let mut s = 0.0;
for i in 0..alignment.rows.min(embedding.len()) {
s += embedding[i] * alignment.w[i][j];
}
result[j] = s;
}
result
}
pub fn translate_batch(embeddings: &[Vec<f64>], alignment: &AlignmentMatrix) -> Vec<Vec<f64>> {
embeddings
.iter()
.map(|e| translate_embedding(e, alignment))
.collect()
}
pub fn alignment_quality(
source: &[Vec<f64>],
target: &[Vec<f64>],
anchors: &[(usize, usize)],
alignment: &AlignmentMatrix,
) -> f64 {
if anchors.is_empty() {
return 0.0;
}
let mut total_sim = 0.0;
let mut count = 0;
for &(si, ti) in anchors {
if si < source.len() && ti < target.len() {
let aligned = translate_embedding(&source[si], alignment);
let sim = cosine_sim_local(&aligned, &target[ti]);
total_sim += sim;
count += 1;
}
}
if count == 0 {
0.0
} else {
total_sim / count as f64
}
}
fn cosine_sim_local(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if na < 1e-15 || nb < 1e-15 {
return 0.0;
}
dot / (na * nb)
}
pub fn alignment_quality_score(
source: &[Vec<f64>],
target: &[Vec<f64>],
anchors: &[(usize, usize)],
alignment: &AlignmentMatrix,
) -> f64 {
if anchors.is_empty() {
return 0.0;
}
let mut total_sim = 0.0;
let mut count = 0;
for &(si, ti) in anchors {
if si < source.len() && ti < target.len() {
let aligned = translate_embedding(&source[si], alignment);
let sim = cosine_sim_local(&aligned, &target[ti]);
total_sim += sim;
count += 1;
}
}
if count == 0 {
0.0
} else {
total_sim / count as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crosslingual_config_default() {
let cfg = CrossLingualConfig::default();
assert_eq!(cfg.alignment, AlignmentMethod::Procrustes);
assert_eq!(cfg.refinement_iterations, 5);
}
#[test]
fn test_procrustes_identity() {
let source = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let target = source.clone();
let anchors = vec![(0, 0), (1, 1), (2, 2)];
let config = CrossLingualConfig::default();
let alignment = align_embeddings(&source, &target, &anchors, &config);
assert!(alignment.is_ok());
let alignment = alignment.expect("should succeed");
let translated = translate_embedding(&source[0], &alignment);
let dist: f64 = translated
.iter()
.zip(target[0].iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
assert!(
dist < 0.1,
"Identity alignment should preserve vectors, dist={dist}"
);
}
#[test]
fn test_procrustes_rotation() {
let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
let anchors = vec![(0, 0), (1, 1)];
let config = CrossLingualConfig::default();
let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
let t0 = translate_embedding(&source[0], &alignment);
let t1 = translate_embedding(&source[1], &alignment);
let d0 = ((t0[0] - 0.0).powi(2) + (t0[1] - 1.0).powi(2)).sqrt();
assert!(d0 < 0.3, "Rotated [1,0] should be near [0,1], dist={d0}");
let d1 = ((t1[0] + 1.0).powi(2) + (t1[1] - 0.0).powi(2)).sqrt();
assert!(d1 < 0.3, "Rotated [0,1] should be near [-1,0], dist={d1}");
}
#[test]
fn test_translation_preserves_relative_distances() {
let source = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0], vec![-1.0, 1.0]];
let anchors = vec![(0, 0), (1, 1)];
let config = CrossLingualConfig::default();
let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
let orig_dist_01: f64 = source[0]
.iter()
.zip(source[1].iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let orig_dist_02: f64 = source[0]
.iter()
.zip(source[2].iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let t0 = translate_embedding(&source[0], &alignment);
let t1 = translate_embedding(&source[1], &alignment);
let t2 = translate_embedding(&source[2], &alignment);
let new_dist_01: f64 = t0
.iter()
.zip(t1.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let new_dist_02: f64 = t0
.iter()
.zip(t2.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
assert!(
(orig_dist_01 - new_dist_01).abs() < 0.3,
"Distances should be preserved: {orig_dist_01} vs {new_dist_01}"
);
assert!(
(orig_dist_02 - new_dist_02).abs() < 0.3,
"Distances should be preserved: {orig_dist_02} vs {new_dist_02}"
);
}
#[test]
fn test_cca_alignment() {
let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
let anchors = vec![(0, 0), (1, 1)];
let config = CrossLingualConfig {
alignment: AlignmentMethod::CCA,
..Default::default()
};
let alignment = align_embeddings(&source, &target, &anchors, &config);
assert!(alignment.is_ok());
}
#[test]
fn test_muse_alignment() {
let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
let anchors = vec![(0, 0), (1, 1)];
let config = CrossLingualConfig {
alignment: AlignmentMethod::MUSE,
refinement_iterations: 3,
..Default::default()
};
let alignment = align_embeddings(&source, &target, &anchors, &config);
assert!(alignment.is_ok());
}
#[test]
fn test_empty_anchors_error() {
let source = vec![vec![1.0, 0.0]];
let target = vec![vec![0.0, 1.0]];
let config = CrossLingualConfig::default();
let result = align_embeddings(&source, &target, &[], &config);
assert!(result.is_err());
}
#[test]
fn test_translate_batch() {
let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let target = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let anchors = vec![(0, 0), (1, 1)];
let config = CrossLingualConfig::default();
let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
let batch = translate_batch(&source, &alignment);
assert_eq!(batch.len(), 2);
assert_eq!(batch[0].len(), 2);
}
}