use scirs2_core::ndarray::{Array1, Array2, Axis};
use crate::error::TextError;
pub type AlignResult<T> = Result<T, TextError>;
pub fn procrustes_align(x: &Array2<f32>, y: &Array2<f32>) -> AlignResult<Array2<f32>> {
let n_x = x.nrows();
let n_y = y.nrows();
if n_x != n_y {
return Err(TextError::InvalidInput(format!(
"procrustes_align: row counts must match (x={n_x}, y={n_y})"
)));
}
if n_x == 0 {
return Err(TextError::InvalidInput(
"procrustes_align: input matrices must not be empty".to_string(),
));
}
let d_src = x.ncols();
let d_tgt = y.ncols();
let xt = x.t(); let m_f64: Array2<f64> = {
let mut m = Array2::<f64>::zeros((d_src, d_tgt));
for i in 0..d_src {
for j in 0..d_tgt {
let mut acc = 0.0f64;
for k in 0..n_x {
acc += xt[[i, k]] as f64 * y[[k, j]] as f64;
}
m[[i, j]] = acc;
}
}
m
};
let (u, _s, vt) = scirs2_linalg::svd(&m_f64.view(), false, None)
.map_err(|e| TextError::EmbeddingError(format!("procrustes_align: SVD failed: {e}")))?;
let w_f64: Array2<f64> = u.dot(&vt);
let w = w_f64.mapv(|v| v as f32);
Ok(w)
}
#[derive(Debug, Clone)]
pub struct CrossLingualAligner {
pub alignment_matrix: Array2<f32>,
pub d_src: usize,
pub d_tgt: usize,
}
impl CrossLingualAligner {
pub fn fit(src_embeddings: &Array2<f32>, tgt_embeddings: &Array2<f32>) -> AlignResult<Self> {
let w = procrustes_align(src_embeddings, tgt_embeddings)?;
let d_src = src_embeddings.ncols();
let d_tgt = tgt_embeddings.ncols();
Ok(CrossLingualAligner {
alignment_matrix: w,
d_src,
d_tgt,
})
}
pub fn transform(&self, src_embedding: &Array1<f32>) -> Array1<f32> {
let src_2d = src_embedding
.view()
.to_shape((1, self.d_src))
.expect("reshape cannot fail here")
.to_owned();
src_2d
.dot(&self.alignment_matrix)
.index_axis(Axis(0), 0)
.to_owned()
}
pub fn transform_batch(&self, src_embeddings: &Array2<f32>) -> Array2<f32> {
src_embeddings.dot(&self.alignment_matrix)
}
}
pub struct AlignedEncoder<'a, F>
where
F: Fn(&[usize]) -> Array1<f32>,
{
pub base_encoder: &'a F,
pub aligner: &'a CrossLingualAligner,
pub normalize_output: bool,
}
impl<'a, F> AlignedEncoder<'a, F>
where
F: Fn(&[usize]) -> Array1<f32>,
{
pub fn new(
base_encoder: &'a F,
aligner: &'a CrossLingualAligner,
normalize_output: bool,
) -> Self {
AlignedEncoder {
base_encoder,
aligner,
normalize_output,
}
}
pub fn encode(&self, tokens: &[usize]) -> Array1<f32> {
let base = (self.base_encoder)(tokens);
let aligned = self.aligner.transform(&base);
if self.normalize_output {
let norm: f32 = aligned.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 {
aligned.mapv(|x| x / norm)
} else {
aligned
}
} else {
aligned
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn lcg_f32(seed: u64, offset: u64) -> f32 {
const A: u64 = 6_364_136_223_846_793_005;
const C: u64 = 1_442_695_040_888_963_407;
let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
(((state >> 12) as f64) / ((1u64 << 52) as f64)) as f32 * 2.0 - 1.0
}
fn rand_matrix(rows: usize, cols: usize, seed: u64) -> Array2<f32> {
Array2::from_shape_fn((rows, cols), |(i, j)| lcg_f32(seed, (i * cols + j) as u64))
}
fn rotate_2d(x: &Array2<f32>, angle: f32) -> Array2<f32> {
let (cos, sin) = (angle.cos(), angle.sin());
Array2::from_shape_fn((x.nrows(), 2), |(i, j)| {
if j == 0 {
x[[i, 0]] * cos - x[[i, 1]] * sin
} else {
x[[i, 0]] * sin + x[[i, 1]] * cos
}
})
}
#[test]
fn procrustes_aligns_rotated_copies_exactly() {
let x = rand_matrix(6, 2, 1);
let angle = std::f32::consts::FRAC_PI_2; let y = rotate_2d(&x, angle);
let w = procrustes_align(&x, &y).expect("procrustes should succeed");
let xw = x.dot(&w);
let mut max_err = 0.0f32;
for i in 0..6 {
for j in 0..2 {
let err = (xw[[i, j]] - y[[i, j]]).abs();
if err > max_err {
max_err = err;
}
}
}
assert!(
max_err < 1e-4,
"max element-wise error = {max_err}, expected < 1e-4"
);
}
#[test]
fn procrustes_identity_when_src_equals_tgt() {
let x = rand_matrix(5, 3, 99);
let w = procrustes_align(&x, &x).expect("procrustes should succeed");
let xw = x.dot(&w);
for i in 0..5 {
for j in 0..3 {
assert!(
(xw[[i, j]] - x[[i, j]]).abs() < 1e-4,
"xw[{i},{j}] = {} ≠ x[{i},{j}] = {}",
xw[[i, j]],
x[[i, j]]
);
}
}
}
#[test]
fn procrustes_fit_reduces_frobenius_distance() {
let x = rand_matrix(8, 3, 42);
let y = rand_matrix(8, 3, 77);
let frobenius = |a: &Array2<f32>, b: &Array2<f32>| -> f32 {
a.iter()
.zip(b.iter())
.map(|(ai, bi)| (ai - bi).powi(2))
.sum::<f32>()
.sqrt()
};
let before = frobenius(&x, &y);
let aligner = CrossLingualAligner::fit(&x, &y).expect("fit should succeed");
let xw = aligner.transform_batch(&x);
let after = frobenius(&xw, &y);
assert!(
after <= before + 1e-4,
"||X·W - Y||_F = {after} should be ≤ ||X - Y||_F = {before}"
);
}
#[test]
fn aligned_encoder_preserves_approximate_norm() {
let x = rand_matrix(5, 2, 10);
let y = rotate_2d(&x, 0.5);
let aligner = CrossLingualAligner::fit(&x, &y).expect("fit should succeed");
let encoder = |tokens: &[usize]| -> scirs2_core::ndarray::Array1<f32> {
let mut v = scirs2_core::ndarray::Array1::<f32>::zeros(2);
for &t in tokens {
let row = t % 5;
v[0] += x[[row, 0]];
v[1] += x[[row, 1]];
}
v
};
let enc = AlignedEncoder::new(&encoder, &aligner, false);
for seed in 0..4usize {
let tokens: Vec<usize> = vec![seed, seed + 1];
let base = encoder(&tokens);
let aligned_out = enc.encode(&tokens);
let norm_base: f32 = base.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_aligned: f32 = aligned_out.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm_base - norm_aligned).abs() < 1e-4,
"norms differ: base={norm_base}, aligned={norm_aligned}"
);
}
}
#[test]
fn cross_lingual_transform_batch_equals_individual() {
let src = rand_matrix(6, 3, 55);
let tgt = rand_matrix(6, 3, 66);
let aligner = CrossLingualAligner::fit(&src, &tgt).expect("fit");
let batch_out = aligner.transform_batch(&src);
for i in 0..6 {
let row = src.index_axis(Axis(0), i).to_owned();
let individual = aligner.transform(&row);
let batch_row = batch_out.index_axis(Axis(0), i);
for j in 0..3 {
assert!(
(individual[j] - batch_row[j]).abs() < 1e-6,
"row {i} col {j}: individual={} batch={}",
individual[j],
batch_row[j]
);
}
}
}
#[test]
fn procrustes_mismatched_rows_returns_error() {
let x = rand_matrix(4, 2, 1);
let y = rand_matrix(3, 2, 2);
let result = procrustes_align(&x, &y);
assert!(result.is_err());
}
#[test]
fn aligned_encoder_normalise_output_unit_norm() {
let x = rand_matrix(4, 2, 7);
let y = rotate_2d(&x, 0.3);
let aligner = CrossLingualAligner::fit(&x, &y).expect("fit");
let encoder = |tokens: &[usize]| -> scirs2_core::ndarray::Array1<f32> {
let mut v = scirs2_core::ndarray::Array1::<f32>::zeros(2);
for &t in tokens {
let row = t % 4;
v[0] += x[[row, 0]];
v[1] += x[[row, 1]];
}
v
};
let enc = AlignedEncoder::new(&encoder, &aligner, true);
let out = enc.encode(&[0, 1, 2]);
let norm: f32 = out.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "expected unit norm, got {norm}");
}
}