Skip to main content

oxirs_embed/
procrustes_alignment.rs

1//! # Procrustes Alignment for Embedding Spaces
2//!
3//! Implements orthogonal Procrustes analysis for aligning embedding spaces from
4//! different models or languages. Given a set of anchor/seed pairs, this module
5//! finds the optimal rotation matrix that minimizes the sum of squared distances
6//! between the aligned source embeddings and the target embeddings.
7//!
8//! ## Features
9//!
10//! - **Orthogonal Procrustes**: Find the optimal rotation (orthogonal) matrix via SVD
11//! - **Translation alignment**: Optionally center embeddings before rotation
12//! - **Alignment quality metrics**: Mean squared error, cosine similarity
13//! - **Nearest-neighbor evaluation**: Precision@k for alignment quality
14//! - **Batch transformation**: Apply learned alignment to new embeddings
15
16// ─────────────────────────────────────────────
17// Types
18// ─────────────────────────────────────────────
19
20/// An anchor pair mapping a source embedding index to a target embedding index.
21#[derive(Debug, Clone)]
22pub struct AnchorPair {
23    /// Index into the source embedding matrix.
24    pub source_idx: usize,
25    /// Index into the target embedding matrix.
26    pub target_idx: usize,
27    /// Optional label (e.g., entity name).
28    pub label: Option<String>,
29}
30
31impl AnchorPair {
32    pub fn new(source_idx: usize, target_idx: usize) -> Self {
33        Self {
34            source_idx,
35            target_idx,
36            label: None,
37        }
38    }
39
40    pub fn with_label(mut self, label: impl Into<String>) -> Self {
41        self.label = Some(label.into());
42        self
43    }
44}
45
46/// Configuration for Procrustes alignment.
47#[derive(Debug, Clone)]
48pub struct ProcrustesConfig {
49    /// Whether to center (mean-subtract) embeddings before alignment.
50    pub center: bool,
51    /// Whether to normalize embeddings to unit length before alignment.
52    pub normalize: bool,
53    /// Regularization parameter (small positive value for numerical stability).
54    pub regularization: f64,
55}
56
57impl Default for ProcrustesConfig {
58    fn default() -> Self {
59        Self {
60            center: true,
61            normalize: false,
62            regularization: 1e-10,
63        }
64    }
65}
66
67/// The result of a Procrustes alignment.
68#[derive(Debug, Clone)]
69pub struct ProcrustesResult {
70    /// The rotation matrix (dim x dim) that transforms source to target space.
71    pub rotation_matrix: Vec<Vec<f64>>,
72    /// Source centroid (subtracted before rotation if centering is on).
73    pub source_centroid: Vec<f64>,
74    /// Target centroid (added after rotation if centering is on).
75    pub target_centroid: Vec<f64>,
76    /// Mean squared error on the anchor pairs after alignment.
77    pub mse: f64,
78    /// Mean cosine similarity on anchor pairs after alignment.
79    pub mean_cosine_similarity: f64,
80    /// Dimensionality.
81    pub dim: usize,
82}
83
84impl ProcrustesResult {
85    /// Transform a single embedding from source space to target space.
86    pub fn transform(&self, embedding: &[f64]) -> Vec<f64> {
87        let dim = self.dim;
88        // 1. Subtract source centroid
89        let centered: Vec<f64> = (0..dim)
90            .map(|i| embedding.get(i).copied().unwrap_or(0.0) - self.source_centroid[i])
91            .collect();
92
93        // 2. Apply rotation
94        let mut rotated = vec![0.0; dim];
95        for (i, rot_val) in rotated.iter_mut().enumerate().take(dim) {
96            for (j, &c_val) in centered.iter().enumerate().take(dim) {
97                *rot_val += self.rotation_matrix[i][j] * c_val;
98            }
99        }
100
101        // 3. Add target centroid
102        for (i, val) in rotated.iter_mut().enumerate().take(dim) {
103            *val += self.target_centroid[i];
104        }
105
106        rotated
107    }
108
109    /// Transform a batch of embeddings.
110    pub fn transform_batch(&self, embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
111        embeddings.iter().map(|e| self.transform(e)).collect()
112    }
113}
114
115/// Alignment quality metrics.
116#[derive(Debug, Clone)]
117pub struct AlignmentMetrics {
118    /// Mean squared error on evaluation pairs.
119    pub mse: f64,
120    /// Mean cosine similarity on evaluation pairs.
121    pub mean_cosine_similarity: f64,
122    /// Precision@1 (fraction of source embeddings whose nearest neighbor in target
123    /// space is the correct match).
124    pub precision_at_1: f64,
125    /// Precision@5.
126    pub precision_at_5: f64,
127    /// Precision@10.
128    pub precision_at_10: f64,
129    /// Number of evaluation pairs.
130    pub eval_pairs: usize,
131}
132
133// ─────────────────────────────────────────────
134// ProcrustesAligner
135// ─────────────────────────────────────────────
136
137/// Procrustes alignment for embedding spaces.
138pub struct ProcrustesAligner {
139    config: ProcrustesConfig,
140}
141
142impl ProcrustesAligner {
143    /// Create a new aligner with default configuration.
144    pub fn new() -> Self {
145        Self {
146            config: ProcrustesConfig::default(),
147        }
148    }
149
150    /// Create a new aligner with the given configuration.
151    pub fn with_config(config: ProcrustesConfig) -> Self {
152        Self { config }
153    }
154
155    /// Compute the optimal alignment.
156    ///
157    /// `source_embeddings`: rows of source embedding matrix (n x dim)
158    /// `target_embeddings`: rows of target embedding matrix (m x dim)
159    /// `anchors`: pairs mapping source indices to target indices
160    pub fn align(
161        &self,
162        source_embeddings: &[Vec<f64>],
163        target_embeddings: &[Vec<f64>],
164        anchors: &[AnchorPair],
165    ) -> Result<ProcrustesResult, ProcrustesError> {
166        if anchors.is_empty() {
167            return Err(ProcrustesError::NoAnchors);
168        }
169
170        // Validate anchors
171        for anchor in anchors {
172            if anchor.source_idx >= source_embeddings.len() {
173                return Err(ProcrustesError::InvalidIndex {
174                    which: "source",
175                    idx: anchor.source_idx,
176                    len: source_embeddings.len(),
177                });
178            }
179            if anchor.target_idx >= target_embeddings.len() {
180                return Err(ProcrustesError::InvalidIndex {
181                    which: "target",
182                    idx: anchor.target_idx,
183                    len: target_embeddings.len(),
184                });
185            }
186        }
187
188        // Determine dimensionality
189        let dim = source_embeddings.first().map(|v| v.len()).unwrap_or(0);
190        if dim == 0 {
191            return Err(ProcrustesError::EmptyEmbeddings);
192        }
193
194        // Extract anchor subsets
195        let src_anchors: Vec<Vec<f64>> = anchors
196            .iter()
197            .map(|a| source_embeddings[a.source_idx].clone())
198            .collect();
199        let tgt_anchors: Vec<Vec<f64>> = anchors
200            .iter()
201            .map(|a| target_embeddings[a.target_idx].clone())
202            .collect();
203
204        // Compute centroids
205        let source_centroid = if self.config.center {
206            compute_centroid(&src_anchors, dim)
207        } else {
208            vec![0.0; dim]
209        };
210        let target_centroid = if self.config.center {
211            compute_centroid(&tgt_anchors, dim)
212        } else {
213            vec![0.0; dim]
214        };
215
216        // Center the anchor embeddings
217        let src_centered = center_embeddings(&src_anchors, &source_centroid);
218        let tgt_centered = center_embeddings(&tgt_anchors, &target_centroid);
219
220        // Optionally normalize
221        let src_final = if self.config.normalize {
222            normalize_rows(&src_centered)
223        } else {
224            src_centered
225        };
226        let tgt_final = if self.config.normalize {
227            normalize_rows(&tgt_centered)
228        } else {
229            tgt_centered
230        };
231
232        // Compute M = X^T Y (cross-covariance matrix)
233        let m_matrix = cross_covariance(&src_final, &tgt_final, dim);
234
235        // SVD of M: M = U S V^T => W = V U^T
236        let (u, _s, vt) = svd(&m_matrix, dim)?;
237
238        // Rotation W = V^T^T * U^T = V * U^T
239        // Actually: W = V * U^T, where V = Vt^T
240        let v = transpose(&vt, dim);
241        let ut = transpose(&u, dim);
242        let rotation = mat_mul(&v, &ut, dim);
243
244        // Compute MSE and cosine similarity on anchors
245        let mse = compute_mse(&src_final, &tgt_final, &rotation, dim);
246        let mean_cos = compute_mean_cosine(&src_final, &tgt_final, &rotation, dim);
247
248        Ok(ProcrustesResult {
249            rotation_matrix: rotation,
250            source_centroid,
251            target_centroid,
252            mse,
253            mean_cosine_similarity: mean_cos,
254            dim,
255        })
256    }
257
258    /// Evaluate alignment quality using held-out pairs.
259    pub fn evaluate(
260        &self,
261        result: &ProcrustesResult,
262        source_embeddings: &[Vec<f64>],
263        target_embeddings: &[Vec<f64>],
264        eval_pairs: &[AnchorPair],
265    ) -> AlignmentMetrics {
266        if eval_pairs.is_empty() {
267            return AlignmentMetrics {
268                mse: 0.0,
269                mean_cosine_similarity: 0.0,
270                precision_at_1: 0.0,
271                precision_at_5: 0.0,
272                precision_at_10: 0.0,
273                eval_pairs: 0,
274            };
275        }
276
277        let mut total_se = 0.0;
278        let mut total_cos = 0.0;
279        let mut correct_at_1 = 0usize;
280        let mut correct_at_5 = 0usize;
281        let mut correct_at_10 = 0usize;
282
283        for pair in eval_pairs {
284            if pair.source_idx >= source_embeddings.len()
285                || pair.target_idx >= target_embeddings.len()
286            {
287                continue;
288            }
289            let transformed = result.transform(&source_embeddings[pair.source_idx]);
290            let target = &target_embeddings[pair.target_idx];
291
292            // MSE
293            let se: f64 = transformed
294                .iter()
295                .zip(target.iter())
296                .map(|(a, b)| (a - b).powi(2))
297                .sum();
298            total_se += se;
299
300            // Cosine similarity
301            let cos = cosine_sim(&transformed, target);
302            total_cos += cos;
303
304            // Find nearest neighbors in target space
305            let neighbors = find_nearest_neighbors(&transformed, target_embeddings, 10);
306            if neighbors.first().copied() == Some(pair.target_idx) {
307                correct_at_1 += 1;
308            }
309            if neighbors.iter().take(5).any(|&idx| idx == pair.target_idx) {
310                correct_at_5 += 1;
311            }
312            if neighbors.iter().take(10).any(|&idx| idx == pair.target_idx) {
313                correct_at_10 += 1;
314            }
315        }
316
317        let n = eval_pairs.len() as f64;
318        AlignmentMetrics {
319            mse: total_se / n,
320            mean_cosine_similarity: total_cos / n,
321            precision_at_1: correct_at_1 as f64 / n,
322            precision_at_5: correct_at_5 as f64 / n,
323            precision_at_10: correct_at_10 as f64 / n,
324            eval_pairs: eval_pairs.len(),
325        }
326    }
327}
328
329impl Default for ProcrustesAligner {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335/// Error type for Procrustes alignment.
336#[derive(Debug, Clone)]
337pub enum ProcrustesError {
338    NoAnchors,
339    EmptyEmbeddings,
340    InvalidIndex {
341        which: &'static str,
342        idx: usize,
343        len: usize,
344    },
345    SvdFailed(String),
346}
347
348impl std::fmt::Display for ProcrustesError {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        match self {
351            ProcrustesError::NoAnchors => write!(f, "no anchor pairs provided"),
352            ProcrustesError::EmptyEmbeddings => write!(f, "embeddings are empty"),
353            ProcrustesError::InvalidIndex { which, idx, len } => {
354                write!(f, "invalid {which} index {idx} (length {len})")
355            }
356            ProcrustesError::SvdFailed(msg) => write!(f, "SVD failed: {msg}"),
357        }
358    }
359}
360
361impl std::error::Error for ProcrustesError {}
362
363// ─────────────────────────────────────────────
364// Linear algebra helpers
365// ─────────────────────────────────────────────
366
367fn compute_centroid(embeddings: &[Vec<f64>], dim: usize) -> Vec<f64> {
368    let n = embeddings.len() as f64;
369    if n < 1.0 {
370        return vec![0.0; dim];
371    }
372    let mut centroid = vec![0.0; dim];
373    for emb in embeddings {
374        for i in 0..dim.min(emb.len()) {
375            centroid[i] += emb[i];
376        }
377    }
378    for v in &mut centroid {
379        *v /= n;
380    }
381    centroid
382}
383
384fn center_embeddings(embeddings: &[Vec<f64>], centroid: &[f64]) -> Vec<Vec<f64>> {
385    embeddings
386        .iter()
387        .map(|emb| {
388            emb.iter()
389                .enumerate()
390                .map(|(i, &v)| v - centroid.get(i).copied().unwrap_or(0.0))
391                .collect()
392        })
393        .collect()
394}
395
396fn normalize_rows(embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
397    embeddings
398        .iter()
399        .map(|emb| {
400            let norm: f64 = emb.iter().map(|v| v * v).sum::<f64>().sqrt();
401            if norm < 1e-12 {
402                emb.clone()
403            } else {
404                emb.iter().map(|v| v / norm).collect()
405            }
406        })
407        .collect()
408}
409
410fn cross_covariance(src: &[Vec<f64>], tgt: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
411    // M[i][j] = sum_k src[k][i] * tgt[k][j]
412    let mut m = vec![vec![0.0; dim]; dim];
413    for k in 0..src.len().min(tgt.len()) {
414        for (i, m_row) in m.iter_mut().enumerate().take(dim) {
415            let si = src[k].get(i).copied().unwrap_or(0.0);
416            for (j, m_val) in m_row.iter_mut().enumerate().take(dim) {
417                let tj = tgt[k].get(j).copied().unwrap_or(0.0);
418                *m_val += si * tj;
419            }
420        }
421    }
422    m
423}
424
425/// Result type for SVD decomposition: (U, singular_values, V^T).
426type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
427
428/// Simple Jacobi-style SVD for small matrices.
429/// Returns (U, singular_values, V^T) for a dim x dim matrix.
430fn svd(matrix: &[Vec<f64>], dim: usize) -> Result<SvdResult, ProcrustesError> {
431    // Compute A^T A
432    let ata = mat_mul(&transpose(matrix, dim), matrix, dim);
433
434    // Eigendecomposition of A^T A via Jacobi iteration
435    let (eigenvalues, eigenvectors) = jacobi_eigendecomposition(&ata, dim, 200)?;
436
437    // Singular values = sqrt(eigenvalues)
438    let mut singular_values: Vec<f64> = eigenvalues
439        .iter()
440        .map(|&ev| if ev > 0.0 { ev.sqrt() } else { 0.0 })
441        .collect();
442
443    // V = eigenvectors (columns), V^T = transpose
444    let vt = transpose(&eigenvectors, dim);
445
446    // U = A * V * S^{-1}
447    let av = mat_mul(matrix, &eigenvectors, dim);
448    let mut u = vec![vec![0.0; dim]; dim];
449    for i in 0..dim {
450        for j in 0..dim {
451            if singular_values[j].abs() > 1e-12 {
452                u[i][j] = av[i][j] / singular_values[j];
453            }
454        }
455    }
456
457    // Sort by descending singular value
458    let mut indices: Vec<usize> = (0..dim).collect();
459    indices.sort_by(|&a, &b| {
460        singular_values[b]
461            .partial_cmp(&singular_values[a])
462            .unwrap_or(std::cmp::Ordering::Equal)
463    });
464
465    let sorted_s: Vec<f64> = indices.iter().map(|&i| singular_values[i]).collect();
466    let sorted_u: Vec<Vec<f64>> = (0..dim)
467        .map(|row| indices.iter().map(|&col| u[row][col]).collect())
468        .collect();
469    let sorted_vt: Vec<Vec<f64>> = indices.iter().map(|&i| vt[i].clone()).collect();
470
471    singular_values = sorted_s;
472
473    Ok((sorted_u, singular_values, sorted_vt))
474}
475
476fn jacobi_eigendecomposition(
477    matrix: &[Vec<f64>],
478    dim: usize,
479    max_iter: usize,
480) -> Result<(Vec<f64>, Vec<Vec<f64>>), ProcrustesError> {
481    let mut a: Vec<Vec<f64>> = matrix.to_vec();
482    let mut v: Vec<Vec<f64>> = (0..dim)
483        .map(|i| (0..dim).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
484        .collect();
485
486    for _ in 0..max_iter {
487        // Find largest off-diagonal element
488        let mut max_val = 0.0f64;
489        let mut p = 0;
490        let mut q = 1;
491        for (i, a_row) in a.iter().enumerate().take(dim) {
492            for (j, a_val) in a_row.iter().enumerate().take(dim).skip(i + 1) {
493                if a_val.abs() > max_val {
494                    max_val = a_val.abs();
495                    p = i;
496                    q = j;
497                }
498            }
499        }
500
501        if max_val < 1e-12 {
502            break;
503        }
504
505        // Compute rotation angle
506        let theta = if (a[p][p] - a[q][q]).abs() < 1e-15 {
507            std::f64::consts::FRAC_PI_4
508        } else {
509            0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
510        };
511
512        let cos_t = theta.cos();
513        let sin_t = theta.sin();
514
515        // Apply Givens rotation
516        let mut new_a = a.clone();
517        for i in 0..dim {
518            new_a[i][p] = cos_t * a[i][p] + sin_t * a[i][q];
519            new_a[i][q] = -sin_t * a[i][p] + cos_t * a[i][q];
520        }
521        let a_tmp = new_a.clone();
522        for j in 0..dim {
523            new_a[p][j] = cos_t * a_tmp[p][j] + sin_t * a_tmp[q][j];
524            new_a[q][j] = -sin_t * a_tmp[p][j] + cos_t * a_tmp[q][j];
525        }
526        a = new_a;
527
528        // Update eigenvectors
529        let mut new_v = v.clone();
530        for i in 0..dim {
531            new_v[i][p] = cos_t * v[i][p] + sin_t * v[i][q];
532            new_v[i][q] = -sin_t * v[i][p] + cos_t * v[i][q];
533        }
534        v = new_v;
535    }
536
537    let eigenvalues: Vec<f64> = (0..dim).map(|i| a[i][i]).collect();
538    Ok((eigenvalues, v))
539}
540
541fn transpose(matrix: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
542    let mut t = vec![vec![0.0; dim]; dim];
543    for (i, m_row) in matrix.iter().enumerate().take(dim) {
544        for (j, &val) in m_row.iter().enumerate().take(dim) {
545            t[j][i] = val;
546        }
547    }
548    t
549}
550
551fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
552    let mut result = vec![vec![0.0; dim]; dim];
553    for (i, res_row) in result.iter_mut().enumerate().take(dim) {
554        for k in 0..dim {
555            let aik = a.get(i).and_then(|r| r.get(k)).copied().unwrap_or(0.0);
556            if aik.abs() < 1e-15 {
557                continue;
558            }
559            for (j, res_val) in res_row.iter_mut().enumerate().take(dim) {
560                let bkj = b.get(k).and_then(|r| r.get(j)).copied().unwrap_or(0.0);
561                *res_val += aik * bkj;
562            }
563        }
564    }
565    result
566}
567
568fn compute_mse(src: &[Vec<f64>], tgt: &[Vec<f64>], rotation: &[Vec<f64>], dim: usize) -> f64 {
569    let n = src.len().min(tgt.len());
570    if n == 0 {
571        return 0.0;
572    }
573    let mut total = 0.0;
574    for k in 0..n {
575        let mut rotated = vec![0.0; dim];
576        for (i, rot_val) in rotated.iter_mut().enumerate().take(dim) {
577            for (j, &r_ij) in rotation[i].iter().enumerate().take(dim) {
578                *rot_val += r_ij * src[k].get(j).copied().unwrap_or(0.0);
579            }
580        }
581        let se: f64 = rotated
582            .iter()
583            .enumerate()
584            .map(|(i, &v)| (v - tgt[k].get(i).copied().unwrap_or(0.0)).powi(2))
585            .sum();
586        total += se;
587    }
588    total / n as f64
589}
590
591fn compute_mean_cosine(
592    src: &[Vec<f64>],
593    tgt: &[Vec<f64>],
594    rotation: &[Vec<f64>],
595    dim: usize,
596) -> f64 {
597    let n = src.len().min(tgt.len());
598    if n == 0 {
599        return 0.0;
600    }
601    let mut total = 0.0;
602    for k in 0..n {
603        let mut rotated = vec![0.0; dim];
604        for (i, rot_val) in rotated.iter_mut().enumerate().take(dim) {
605            for (j, &r_ij) in rotation[i].iter().enumerate().take(dim) {
606                *rot_val += r_ij * src[k].get(j).copied().unwrap_or(0.0);
607            }
608        }
609        total += cosine_sim(&rotated, &tgt[k]);
610    }
611    total / n as f64
612}
613
614fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
615    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
616    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
617    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
618    if na < 1e-12 || nb < 1e-12 {
619        0.0
620    } else {
621        dot / (na * nb)
622    }
623}
624
625fn find_nearest_neighbors(query: &[f64], candidates: &[Vec<f64>], k: usize) -> Vec<usize> {
626    let mut dists: Vec<(usize, f64)> = candidates
627        .iter()
628        .enumerate()
629        .map(|(i, c)| {
630            let dist: f64 = query
631                .iter()
632                .zip(c.iter())
633                .map(|(a, b)| (a - b).powi(2))
634                .sum();
635            (i, dist)
636        })
637        .collect();
638    dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
639    dists.iter().take(k).map(|(idx, _)| *idx).collect()
640}
641
642// ─────────────────────────────────────────────
643// Tests
644// ─────────────────────────────────────────────
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649
650    /// Generate a set of embeddings as rows.
651    fn make_embeddings(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
652        // Simple deterministic pseudo-random using a linear congruential generator
653        let mut state = seed;
654        (0..n)
655            .map(|_| {
656                (0..dim)
657                    .map(|_| {
658                        state = state
659                            .wrapping_mul(6364136223846793005)
660                            .wrapping_add(1442695040888963407);
661                        ((state >> 33) as f64) / (u32::MAX as f64) - 0.5
662                    })
663                    .collect()
664            })
665            .collect()
666    }
667
668    /// Apply a known rotation to embeddings (90-degree rotation in 2D).
669    fn rotate_90_2d(embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
670        embeddings
671            .iter()
672            .map(|e| {
673                // [x, y] -> [-y, x]
674                vec![-e[1], e[0]]
675            })
676            .collect()
677    }
678
679    // ═══ AnchorPair tests ════════════════════════════════
680
681    #[test]
682    fn test_anchor_pair_creation() {
683        let pair = AnchorPair::new(0, 1);
684        assert_eq!(pair.source_idx, 0);
685        assert_eq!(pair.target_idx, 1);
686        assert!(pair.label.is_none());
687    }
688
689    #[test]
690    fn test_anchor_pair_with_label() {
691        let pair = AnchorPair::new(0, 1).with_label("cat");
692        assert_eq!(pair.label, Some("cat".to_string()));
693    }
694
695    // ═══ ProcrustesConfig tests ══════════════════════════
696
697    #[test]
698    fn test_default_config() {
699        let config = ProcrustesConfig::default();
700        assert!(config.center);
701        assert!(!config.normalize);
702        assert!(config.regularization > 0.0);
703    }
704
705    // ═══ Error tests ═════════════════════════════════════
706
707    #[test]
708    fn test_no_anchors_error() {
709        let aligner = ProcrustesAligner::new();
710        let src = make_embeddings(10, 3, 42);
711        let tgt = make_embeddings(10, 3, 99);
712        let result = aligner.align(&src, &tgt, &[]);
713        assert!(result.is_err());
714    }
715
716    #[test]
717    fn test_invalid_source_index() {
718        let aligner = ProcrustesAligner::new();
719        let src = make_embeddings(5, 3, 42);
720        let tgt = make_embeddings(5, 3, 99);
721        let anchors = vec![AnchorPair::new(10, 0)]; // 10 >= 5
722        let result = aligner.align(&src, &tgt, &anchors);
723        assert!(result.is_err());
724    }
725
726    #[test]
727    fn test_invalid_target_index() {
728        let aligner = ProcrustesAligner::new();
729        let src = make_embeddings(5, 3, 42);
730        let tgt = make_embeddings(5, 3, 99);
731        let anchors = vec![AnchorPair::new(0, 10)]; // 10 >= 5
732        let result = aligner.align(&src, &tgt, &anchors);
733        assert!(result.is_err());
734    }
735
736    #[test]
737    fn test_empty_embeddings() {
738        let aligner = ProcrustesAligner::new();
739        let src: Vec<Vec<f64>> = Vec::new();
740        let tgt: Vec<Vec<f64>> = Vec::new();
741        let anchors = vec![AnchorPair::new(0, 0)];
742        let result = aligner.align(&src, &tgt, &anchors);
743        assert!(result.is_err());
744    }
745
746    #[test]
747    fn test_error_display() {
748        let err = ProcrustesError::NoAnchors;
749        assert!(format!("{err}").contains("anchor"));
750    }
751
752    // ═══ Alignment tests (2D rotation) ═══════════════════
753
754    #[test]
755    fn test_identity_alignment() {
756        let aligner = ProcrustesAligner::new();
757        let src = make_embeddings(20, 3, 42);
758        let tgt = src.clone(); // identical
759        let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
760        let result = aligner.align(&src, &tgt, &anchors);
761        assert!(result.is_ok());
762        let res = result.expect("alignment should succeed");
763        assert!(res.mse < 1e-6);
764    }
765
766    #[test]
767    fn test_2d_rotation_alignment() {
768        let src = make_embeddings(20, 2, 42);
769        let tgt = rotate_90_2d(&src);
770        let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
771
772        let aligner = ProcrustesAligner::new();
773        let result = aligner.align(&src, &tgt, &anchors);
774        assert!(result.is_ok());
775        let res = result.expect("alignment should succeed");
776
777        // MSE should be small
778        assert!(res.mse < 0.5, "MSE too high: {}", res.mse);
779        // Cosine similarity should be high
780        assert!(
781            res.mean_cosine_similarity > 0.5,
782            "Cosine too low: {}",
783            res.mean_cosine_similarity
784        );
785    }
786
787    #[test]
788    fn test_alignment_dim() {
789        let src = make_embeddings(10, 5, 42);
790        let tgt = make_embeddings(10, 5, 99);
791        let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
792
793        let aligner = ProcrustesAligner::new();
794        let result = aligner.align(&src, &tgt, &anchors).expect("should align");
795        assert_eq!(result.dim, 5);
796        assert_eq!(result.rotation_matrix.len(), 5);
797        assert_eq!(result.rotation_matrix[0].len(), 5);
798    }
799
800    // ═══ Transform tests ═════════════════════════════════
801
802    #[test]
803    fn test_transform_preserves_dim() {
804        let src = make_embeddings(10, 4, 42);
805        let tgt = make_embeddings(10, 4, 99);
806        let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
807        let aligner = ProcrustesAligner::new();
808        let result = aligner.align(&src, &tgt, &anchors).expect("should align");
809
810        let transformed = result.transform(&src[0]);
811        assert_eq!(transformed.len(), 4);
812    }
813
814    #[test]
815    fn test_transform_batch() {
816        let src = make_embeddings(10, 3, 42);
817        let tgt = make_embeddings(10, 3, 99);
818        let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
819        let aligner = ProcrustesAligner::new();
820        let result = aligner.align(&src, &tgt, &anchors).expect("should align");
821
822        let batch = result.transform_batch(&src);
823        assert_eq!(batch.len(), 10);
824    }
825
826    // ═══ Evaluation tests ════════════════════════════════
827
828    #[test]
829    fn test_evaluate_identity() {
830        let src = make_embeddings(20, 3, 42);
831        let tgt = src.clone();
832        let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
833        let eval_pairs: Vec<AnchorPair> = (10..20).map(|i| AnchorPair::new(i, i)).collect();
834
835        let aligner = ProcrustesAligner::new();
836        let result = aligner.align(&src, &tgt, &anchors).expect("should align");
837        let metrics = aligner.evaluate(&result, &src, &tgt, &eval_pairs);
838
839        assert_eq!(metrics.eval_pairs, 10);
840        assert!(metrics.mse < 1e-4);
841        assert!(metrics.precision_at_1 > 0.8);
842    }
843
844    #[test]
845    fn test_evaluate_empty() {
846        let src = make_embeddings(10, 3, 42);
847        let tgt = make_embeddings(10, 3, 99);
848        let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
849
850        let aligner = ProcrustesAligner::new();
851        let result = aligner.align(&src, &tgt, &anchors).expect("should align");
852        let metrics = aligner.evaluate(&result, &src, &tgt, &[]);
853        assert_eq!(metrics.eval_pairs, 0);
854    }
855
856    // ═══ Cosine similarity helper tests ══════════════════
857
858    #[test]
859    fn test_cosine_sim_identical() {
860        let a = vec![1.0, 2.0, 3.0];
861        let sim = cosine_sim(&a, &a);
862        assert!((sim - 1.0).abs() < 1e-10);
863    }
864
865    #[test]
866    fn test_cosine_sim_orthogonal() {
867        let a = vec![1.0, 0.0];
868        let b = vec![0.0, 1.0];
869        let sim = cosine_sim(&a, &b);
870        assert!(sim.abs() < 1e-10);
871    }
872
873    #[test]
874    fn test_cosine_sim_opposite() {
875        let a = vec![1.0, 0.0];
876        let b = vec![-1.0, 0.0];
877        let sim = cosine_sim(&a, &b);
878        assert!((sim - (-1.0)).abs() < 1e-10);
879    }
880
881    #[test]
882    fn test_cosine_sim_zero_vector() {
883        let a = vec![0.0, 0.0];
884        let b = vec![1.0, 2.0];
885        let sim = cosine_sim(&a, &b);
886        assert!(sim.abs() < 1e-10);
887    }
888
889    // ═══ Linear algebra helper tests ═════════════════════
890
891    #[test]
892    fn test_centroid_computation() {
893        let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
894        let centroid = compute_centroid(&embeddings, 2);
895        assert!((centroid[0] - 2.0).abs() < 1e-10);
896        assert!((centroid[1] - 3.0).abs() < 1e-10);
897    }
898
899    #[test]
900    fn test_center_embeddings_fn() {
901        let embeddings = vec![vec![2.0, 4.0], vec![4.0, 6.0]];
902        let centroid = vec![3.0, 5.0];
903        let centered = center_embeddings(&embeddings, &centroid);
904        assert!((centered[0][0] - (-1.0)).abs() < 1e-10);
905        assert!((centered[1][1] - 1.0).abs() < 1e-10);
906    }
907
908    #[test]
909    fn test_normalize_rows_fn() {
910        let embeddings = vec![vec![3.0, 4.0]];
911        let normalized = normalize_rows(&embeddings);
912        let norm: f64 = normalized[0].iter().map(|v| v * v).sum::<f64>().sqrt();
913        assert!((norm - 1.0).abs() < 1e-10);
914    }
915
916    #[test]
917    fn test_transpose_identity() {
918        let m = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
919        let t = transpose(&m, 2);
920        assert!((t[0][0] - 1.0).abs() < 1e-10);
921        assert!((t[1][1] - 1.0).abs() < 1e-10);
922    }
923
924    #[test]
925    fn test_mat_mul_identity() {
926        let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
927        let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
928        let result = mat_mul(&a, &identity, 2);
929        assert!((result[0][0] - 1.0).abs() < 1e-10);
930        assert!((result[0][1] - 2.0).abs() < 1e-10);
931        assert!((result[1][0] - 3.0).abs() < 1e-10);
932        assert!((result[1][1] - 4.0).abs() < 1e-10);
933    }
934
935    // ═══ Nearest neighbor tests ══════════════════════════
936
937    #[test]
938    fn test_find_nearest_neighbors() {
939        let query = vec![0.0, 0.0];
940        let candidates = vec![
941            vec![10.0, 10.0],
942            vec![1.0, 0.0],
943            vec![0.0, 1.0],
944            vec![5.0, 5.0],
945        ];
946        let nn = find_nearest_neighbors(&query, &candidates, 2);
947        assert_eq!(nn.len(), 2);
948        // Closest should be [1,0] (idx=1) or [0,1] (idx=2)
949        assert!(nn[0] == 1 || nn[0] == 2);
950    }
951
952    // ═══ Config with normalize ═══════════════════════════
953
954    #[test]
955    fn test_alignment_with_normalization() {
956        let config = ProcrustesConfig {
957            center: true,
958            normalize: true,
959            regularization: 1e-10,
960        };
961        let aligner = ProcrustesAligner::with_config(config);
962        let src = make_embeddings(20, 3, 42);
963        let tgt = src.clone();
964        let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
965        let result = aligner.align(&src, &tgt, &anchors);
966        assert!(result.is_ok());
967    }
968
969    // ═══ Default aligner test ════════════════════════════
970
971    #[test]
972    fn test_default_aligner() {
973        let aligner = ProcrustesAligner::default();
974        let src = make_embeddings(10, 2, 1);
975        let tgt = make_embeddings(10, 2, 2);
976        let anchors = vec![AnchorPair::new(0, 0), AnchorPair::new(1, 1)];
977        let result = aligner.align(&src, &tgt, &anchors);
978        assert!(result.is_ok());
979    }
980}