1#[derive(Debug, Clone)]
22pub struct AnchorPair {
23 pub source_idx: usize,
25 pub target_idx: usize,
27 pub label: Option<String>,
29}
30
31impl AnchorPair {
32 pub fn new(source_idx: usize, target_idx: usize) -> Self {
33 Self {
34 source_idx,
35 target_idx,
36 label: None,
37 }
38 }
39
40 pub fn with_label(mut self, label: impl Into<String>) -> Self {
41 self.label = Some(label.into());
42 self
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct ProcrustesConfig {
49 pub center: bool,
51 pub normalize: bool,
53 pub regularization: f64,
55}
56
57impl Default for ProcrustesConfig {
58 fn default() -> Self {
59 Self {
60 center: true,
61 normalize: false,
62 regularization: 1e-10,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct ProcrustesResult {
70 pub rotation_matrix: Vec<Vec<f64>>,
72 pub source_centroid: Vec<f64>,
74 pub target_centroid: Vec<f64>,
76 pub mse: f64,
78 pub mean_cosine_similarity: f64,
80 pub dim: usize,
82}
83
84impl ProcrustesResult {
85 pub fn transform(&self, embedding: &[f64]) -> Vec<f64> {
87 let dim = self.dim;
88 let centered: Vec<f64> = (0..dim)
90 .map(|i| embedding.get(i).copied().unwrap_or(0.0) - self.source_centroid[i])
91 .collect();
92
93 let mut rotated = vec![0.0; dim];
95 for (i, rot_val) in rotated.iter_mut().enumerate().take(dim) {
96 for (j, &c_val) in centered.iter().enumerate().take(dim) {
97 *rot_val += self.rotation_matrix[i][j] * c_val;
98 }
99 }
100
101 for (i, val) in rotated.iter_mut().enumerate().take(dim) {
103 *val += self.target_centroid[i];
104 }
105
106 rotated
107 }
108
109 pub fn transform_batch(&self, embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
111 embeddings.iter().map(|e| self.transform(e)).collect()
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct AlignmentMetrics {
118 pub mse: f64,
120 pub mean_cosine_similarity: f64,
122 pub precision_at_1: f64,
125 pub precision_at_5: f64,
127 pub precision_at_10: f64,
129 pub eval_pairs: usize,
131}
132
133pub struct ProcrustesAligner {
139 config: ProcrustesConfig,
140}
141
142impl ProcrustesAligner {
143 pub fn new() -> Self {
145 Self {
146 config: ProcrustesConfig::default(),
147 }
148 }
149
150 pub fn with_config(config: ProcrustesConfig) -> Self {
152 Self { config }
153 }
154
155 pub fn align(
161 &self,
162 source_embeddings: &[Vec<f64>],
163 target_embeddings: &[Vec<f64>],
164 anchors: &[AnchorPair],
165 ) -> Result<ProcrustesResult, ProcrustesError> {
166 if anchors.is_empty() {
167 return Err(ProcrustesError::NoAnchors);
168 }
169
170 for anchor in anchors {
172 if anchor.source_idx >= source_embeddings.len() {
173 return Err(ProcrustesError::InvalidIndex {
174 which: "source",
175 idx: anchor.source_idx,
176 len: source_embeddings.len(),
177 });
178 }
179 if anchor.target_idx >= target_embeddings.len() {
180 return Err(ProcrustesError::InvalidIndex {
181 which: "target",
182 idx: anchor.target_idx,
183 len: target_embeddings.len(),
184 });
185 }
186 }
187
188 let dim = source_embeddings.first().map(|v| v.len()).unwrap_or(0);
190 if dim == 0 {
191 return Err(ProcrustesError::EmptyEmbeddings);
192 }
193
194 let src_anchors: Vec<Vec<f64>> = anchors
196 .iter()
197 .map(|a| source_embeddings[a.source_idx].clone())
198 .collect();
199 let tgt_anchors: Vec<Vec<f64>> = anchors
200 .iter()
201 .map(|a| target_embeddings[a.target_idx].clone())
202 .collect();
203
204 let source_centroid = if self.config.center {
206 compute_centroid(&src_anchors, dim)
207 } else {
208 vec![0.0; dim]
209 };
210 let target_centroid = if self.config.center {
211 compute_centroid(&tgt_anchors, dim)
212 } else {
213 vec![0.0; dim]
214 };
215
216 let src_centered = center_embeddings(&src_anchors, &source_centroid);
218 let tgt_centered = center_embeddings(&tgt_anchors, &target_centroid);
219
220 let src_final = if self.config.normalize {
222 normalize_rows(&src_centered)
223 } else {
224 src_centered
225 };
226 let tgt_final = if self.config.normalize {
227 normalize_rows(&tgt_centered)
228 } else {
229 tgt_centered
230 };
231
232 let m_matrix = cross_covariance(&src_final, &tgt_final, dim);
234
235 let (u, _s, vt) = svd(&m_matrix, dim)?;
237
238 let v = transpose(&vt, dim);
241 let ut = transpose(&u, dim);
242 let rotation = mat_mul(&v, &ut, dim);
243
244 let mse = compute_mse(&src_final, &tgt_final, &rotation, dim);
246 let mean_cos = compute_mean_cosine(&src_final, &tgt_final, &rotation, dim);
247
248 Ok(ProcrustesResult {
249 rotation_matrix: rotation,
250 source_centroid,
251 target_centroid,
252 mse,
253 mean_cosine_similarity: mean_cos,
254 dim,
255 })
256 }
257
258 pub fn evaluate(
260 &self,
261 result: &ProcrustesResult,
262 source_embeddings: &[Vec<f64>],
263 target_embeddings: &[Vec<f64>],
264 eval_pairs: &[AnchorPair],
265 ) -> AlignmentMetrics {
266 if eval_pairs.is_empty() {
267 return AlignmentMetrics {
268 mse: 0.0,
269 mean_cosine_similarity: 0.0,
270 precision_at_1: 0.0,
271 precision_at_5: 0.0,
272 precision_at_10: 0.0,
273 eval_pairs: 0,
274 };
275 }
276
277 let mut total_se = 0.0;
278 let mut total_cos = 0.0;
279 let mut correct_at_1 = 0usize;
280 let mut correct_at_5 = 0usize;
281 let mut correct_at_10 = 0usize;
282
283 for pair in eval_pairs {
284 if pair.source_idx >= source_embeddings.len()
285 || pair.target_idx >= target_embeddings.len()
286 {
287 continue;
288 }
289 let transformed = result.transform(&source_embeddings[pair.source_idx]);
290 let target = &target_embeddings[pair.target_idx];
291
292 let se: f64 = transformed
294 .iter()
295 .zip(target.iter())
296 .map(|(a, b)| (a - b).powi(2))
297 .sum();
298 total_se += se;
299
300 let cos = cosine_sim(&transformed, target);
302 total_cos += cos;
303
304 let neighbors = find_nearest_neighbors(&transformed, target_embeddings, 10);
306 if neighbors.first().copied() == Some(pair.target_idx) {
307 correct_at_1 += 1;
308 }
309 if neighbors.iter().take(5).any(|&idx| idx == pair.target_idx) {
310 correct_at_5 += 1;
311 }
312 if neighbors.iter().take(10).any(|&idx| idx == pair.target_idx) {
313 correct_at_10 += 1;
314 }
315 }
316
317 let n = eval_pairs.len() as f64;
318 AlignmentMetrics {
319 mse: total_se / n,
320 mean_cosine_similarity: total_cos / n,
321 precision_at_1: correct_at_1 as f64 / n,
322 precision_at_5: correct_at_5 as f64 / n,
323 precision_at_10: correct_at_10 as f64 / n,
324 eval_pairs: eval_pairs.len(),
325 }
326 }
327}
328
329impl Default for ProcrustesAligner {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335#[derive(Debug, Clone)]
337pub enum ProcrustesError {
338 NoAnchors,
339 EmptyEmbeddings,
340 InvalidIndex {
341 which: &'static str,
342 idx: usize,
343 len: usize,
344 },
345 SvdFailed(String),
346}
347
348impl std::fmt::Display for ProcrustesError {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 match self {
351 ProcrustesError::NoAnchors => write!(f, "no anchor pairs provided"),
352 ProcrustesError::EmptyEmbeddings => write!(f, "embeddings are empty"),
353 ProcrustesError::InvalidIndex { which, idx, len } => {
354 write!(f, "invalid {which} index {idx} (length {len})")
355 }
356 ProcrustesError::SvdFailed(msg) => write!(f, "SVD failed: {msg}"),
357 }
358 }
359}
360
361impl std::error::Error for ProcrustesError {}
362
363fn compute_centroid(embeddings: &[Vec<f64>], dim: usize) -> Vec<f64> {
368 let n = embeddings.len() as f64;
369 if n < 1.0 {
370 return vec![0.0; dim];
371 }
372 let mut centroid = vec![0.0; dim];
373 for emb in embeddings {
374 for i in 0..dim.min(emb.len()) {
375 centroid[i] += emb[i];
376 }
377 }
378 for v in &mut centroid {
379 *v /= n;
380 }
381 centroid
382}
383
384fn center_embeddings(embeddings: &[Vec<f64>], centroid: &[f64]) -> Vec<Vec<f64>> {
385 embeddings
386 .iter()
387 .map(|emb| {
388 emb.iter()
389 .enumerate()
390 .map(|(i, &v)| v - centroid.get(i).copied().unwrap_or(0.0))
391 .collect()
392 })
393 .collect()
394}
395
396fn normalize_rows(embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
397 embeddings
398 .iter()
399 .map(|emb| {
400 let norm: f64 = emb.iter().map(|v| v * v).sum::<f64>().sqrt();
401 if norm < 1e-12 {
402 emb.clone()
403 } else {
404 emb.iter().map(|v| v / norm).collect()
405 }
406 })
407 .collect()
408}
409
410fn cross_covariance(src: &[Vec<f64>], tgt: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
411 let mut m = vec![vec![0.0; dim]; dim];
413 for k in 0..src.len().min(tgt.len()) {
414 for (i, m_row) in m.iter_mut().enumerate().take(dim) {
415 let si = src[k].get(i).copied().unwrap_or(0.0);
416 for (j, m_val) in m_row.iter_mut().enumerate().take(dim) {
417 let tj = tgt[k].get(j).copied().unwrap_or(0.0);
418 *m_val += si * tj;
419 }
420 }
421 }
422 m
423}
424
425type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
427
428fn svd(matrix: &[Vec<f64>], dim: usize) -> Result<SvdResult, ProcrustesError> {
431 let ata = mat_mul(&transpose(matrix, dim), matrix, dim);
433
434 let (eigenvalues, eigenvectors) = jacobi_eigendecomposition(&ata, dim, 200)?;
436
437 let mut singular_values: Vec<f64> = eigenvalues
439 .iter()
440 .map(|&ev| if ev > 0.0 { ev.sqrt() } else { 0.0 })
441 .collect();
442
443 let vt = transpose(&eigenvectors, dim);
445
446 let av = mat_mul(matrix, &eigenvectors, dim);
448 let mut u = vec![vec![0.0; dim]; dim];
449 for i in 0..dim {
450 for j in 0..dim {
451 if singular_values[j].abs() > 1e-12 {
452 u[i][j] = av[i][j] / singular_values[j];
453 }
454 }
455 }
456
457 let mut indices: Vec<usize> = (0..dim).collect();
459 indices.sort_by(|&a, &b| {
460 singular_values[b]
461 .partial_cmp(&singular_values[a])
462 .unwrap_or(std::cmp::Ordering::Equal)
463 });
464
465 let sorted_s: Vec<f64> = indices.iter().map(|&i| singular_values[i]).collect();
466 let sorted_u: Vec<Vec<f64>> = (0..dim)
467 .map(|row| indices.iter().map(|&col| u[row][col]).collect())
468 .collect();
469 let sorted_vt: Vec<Vec<f64>> = indices.iter().map(|&i| vt[i].clone()).collect();
470
471 singular_values = sorted_s;
472
473 Ok((sorted_u, singular_values, sorted_vt))
474}
475
476fn jacobi_eigendecomposition(
477 matrix: &[Vec<f64>],
478 dim: usize,
479 max_iter: usize,
480) -> Result<(Vec<f64>, Vec<Vec<f64>>), ProcrustesError> {
481 let mut a: Vec<Vec<f64>> = matrix.to_vec();
482 let mut v: Vec<Vec<f64>> = (0..dim)
483 .map(|i| (0..dim).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
484 .collect();
485
486 for _ in 0..max_iter {
487 let mut max_val = 0.0f64;
489 let mut p = 0;
490 let mut q = 1;
491 for (i, a_row) in a.iter().enumerate().take(dim) {
492 for (j, a_val) in a_row.iter().enumerate().take(dim).skip(i + 1) {
493 if a_val.abs() > max_val {
494 max_val = a_val.abs();
495 p = i;
496 q = j;
497 }
498 }
499 }
500
501 if max_val < 1e-12 {
502 break;
503 }
504
505 let theta = if (a[p][p] - a[q][q]).abs() < 1e-15 {
507 std::f64::consts::FRAC_PI_4
508 } else {
509 0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
510 };
511
512 let cos_t = theta.cos();
513 let sin_t = theta.sin();
514
515 let mut new_a = a.clone();
517 for i in 0..dim {
518 new_a[i][p] = cos_t * a[i][p] + sin_t * a[i][q];
519 new_a[i][q] = -sin_t * a[i][p] + cos_t * a[i][q];
520 }
521 let a_tmp = new_a.clone();
522 for j in 0..dim {
523 new_a[p][j] = cos_t * a_tmp[p][j] + sin_t * a_tmp[q][j];
524 new_a[q][j] = -sin_t * a_tmp[p][j] + cos_t * a_tmp[q][j];
525 }
526 a = new_a;
527
528 let mut new_v = v.clone();
530 for i in 0..dim {
531 new_v[i][p] = cos_t * v[i][p] + sin_t * v[i][q];
532 new_v[i][q] = -sin_t * v[i][p] + cos_t * v[i][q];
533 }
534 v = new_v;
535 }
536
537 let eigenvalues: Vec<f64> = (0..dim).map(|i| a[i][i]).collect();
538 Ok((eigenvalues, v))
539}
540
541fn transpose(matrix: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
542 let mut t = vec![vec![0.0; dim]; dim];
543 for (i, m_row) in matrix.iter().enumerate().take(dim) {
544 for (j, &val) in m_row.iter().enumerate().take(dim) {
545 t[j][i] = val;
546 }
547 }
548 t
549}
550
551fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
552 let mut result = vec![vec![0.0; dim]; dim];
553 for (i, res_row) in result.iter_mut().enumerate().take(dim) {
554 for k in 0..dim {
555 let aik = a.get(i).and_then(|r| r.get(k)).copied().unwrap_or(0.0);
556 if aik.abs() < 1e-15 {
557 continue;
558 }
559 for (j, res_val) in res_row.iter_mut().enumerate().take(dim) {
560 let bkj = b.get(k).and_then(|r| r.get(j)).copied().unwrap_or(0.0);
561 *res_val += aik * bkj;
562 }
563 }
564 }
565 result
566}
567
568fn compute_mse(src: &[Vec<f64>], tgt: &[Vec<f64>], rotation: &[Vec<f64>], dim: usize) -> f64 {
569 let n = src.len().min(tgt.len());
570 if n == 0 {
571 return 0.0;
572 }
573 let mut total = 0.0;
574 for k in 0..n {
575 let mut rotated = vec![0.0; dim];
576 for (i, rot_val) in rotated.iter_mut().enumerate().take(dim) {
577 for (j, &r_ij) in rotation[i].iter().enumerate().take(dim) {
578 *rot_val += r_ij * src[k].get(j).copied().unwrap_or(0.0);
579 }
580 }
581 let se: f64 = rotated
582 .iter()
583 .enumerate()
584 .map(|(i, &v)| (v - tgt[k].get(i).copied().unwrap_or(0.0)).powi(2))
585 .sum();
586 total += se;
587 }
588 total / n as f64
589}
590
591fn compute_mean_cosine(
592 src: &[Vec<f64>],
593 tgt: &[Vec<f64>],
594 rotation: &[Vec<f64>],
595 dim: usize,
596) -> f64 {
597 let n = src.len().min(tgt.len());
598 if n == 0 {
599 return 0.0;
600 }
601 let mut total = 0.0;
602 for k in 0..n {
603 let mut rotated = vec![0.0; dim];
604 for (i, rot_val) in rotated.iter_mut().enumerate().take(dim) {
605 for (j, &r_ij) in rotation[i].iter().enumerate().take(dim) {
606 *rot_val += r_ij * src[k].get(j).copied().unwrap_or(0.0);
607 }
608 }
609 total += cosine_sim(&rotated, &tgt[k]);
610 }
611 total / n as f64
612}
613
614fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
615 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
616 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
617 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
618 if na < 1e-12 || nb < 1e-12 {
619 0.0
620 } else {
621 dot / (na * nb)
622 }
623}
624
625fn find_nearest_neighbors(query: &[f64], candidates: &[Vec<f64>], k: usize) -> Vec<usize> {
626 let mut dists: Vec<(usize, f64)> = candidates
627 .iter()
628 .enumerate()
629 .map(|(i, c)| {
630 let dist: f64 = query
631 .iter()
632 .zip(c.iter())
633 .map(|(a, b)| (a - b).powi(2))
634 .sum();
635 (i, dist)
636 })
637 .collect();
638 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
639 dists.iter().take(k).map(|(idx, _)| *idx).collect()
640}
641
642#[cfg(test)]
647mod tests {
648 use super::*;
649
650 fn make_embeddings(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
652 let mut state = seed;
654 (0..n)
655 .map(|_| {
656 (0..dim)
657 .map(|_| {
658 state = state
659 .wrapping_mul(6364136223846793005)
660 .wrapping_add(1442695040888963407);
661 ((state >> 33) as f64) / (u32::MAX as f64) - 0.5
662 })
663 .collect()
664 })
665 .collect()
666 }
667
668 fn rotate_90_2d(embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
670 embeddings
671 .iter()
672 .map(|e| {
673 vec![-e[1], e[0]]
675 })
676 .collect()
677 }
678
679 #[test]
682 fn test_anchor_pair_creation() {
683 let pair = AnchorPair::new(0, 1);
684 assert_eq!(pair.source_idx, 0);
685 assert_eq!(pair.target_idx, 1);
686 assert!(pair.label.is_none());
687 }
688
689 #[test]
690 fn test_anchor_pair_with_label() {
691 let pair = AnchorPair::new(0, 1).with_label("cat");
692 assert_eq!(pair.label, Some("cat".to_string()));
693 }
694
695 #[test]
698 fn test_default_config() {
699 let config = ProcrustesConfig::default();
700 assert!(config.center);
701 assert!(!config.normalize);
702 assert!(config.regularization > 0.0);
703 }
704
705 #[test]
708 fn test_no_anchors_error() {
709 let aligner = ProcrustesAligner::new();
710 let src = make_embeddings(10, 3, 42);
711 let tgt = make_embeddings(10, 3, 99);
712 let result = aligner.align(&src, &tgt, &[]);
713 assert!(result.is_err());
714 }
715
716 #[test]
717 fn test_invalid_source_index() {
718 let aligner = ProcrustesAligner::new();
719 let src = make_embeddings(5, 3, 42);
720 let tgt = make_embeddings(5, 3, 99);
721 let anchors = vec![AnchorPair::new(10, 0)]; let result = aligner.align(&src, &tgt, &anchors);
723 assert!(result.is_err());
724 }
725
726 #[test]
727 fn test_invalid_target_index() {
728 let aligner = ProcrustesAligner::new();
729 let src = make_embeddings(5, 3, 42);
730 let tgt = make_embeddings(5, 3, 99);
731 let anchors = vec![AnchorPair::new(0, 10)]; let result = aligner.align(&src, &tgt, &anchors);
733 assert!(result.is_err());
734 }
735
736 #[test]
737 fn test_empty_embeddings() {
738 let aligner = ProcrustesAligner::new();
739 let src: Vec<Vec<f64>> = Vec::new();
740 let tgt: Vec<Vec<f64>> = Vec::new();
741 let anchors = vec![AnchorPair::new(0, 0)];
742 let result = aligner.align(&src, &tgt, &anchors);
743 assert!(result.is_err());
744 }
745
746 #[test]
747 fn test_error_display() {
748 let err = ProcrustesError::NoAnchors;
749 assert!(format!("{err}").contains("anchor"));
750 }
751
752 #[test]
755 fn test_identity_alignment() {
756 let aligner = ProcrustesAligner::new();
757 let src = make_embeddings(20, 3, 42);
758 let tgt = src.clone(); let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
760 let result = aligner.align(&src, &tgt, &anchors);
761 assert!(result.is_ok());
762 let res = result.expect("alignment should succeed");
763 assert!(res.mse < 1e-6);
764 }
765
766 #[test]
767 fn test_2d_rotation_alignment() {
768 let src = make_embeddings(20, 2, 42);
769 let tgt = rotate_90_2d(&src);
770 let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
771
772 let aligner = ProcrustesAligner::new();
773 let result = aligner.align(&src, &tgt, &anchors);
774 assert!(result.is_ok());
775 let res = result.expect("alignment should succeed");
776
777 assert!(res.mse < 0.5, "MSE too high: {}", res.mse);
779 assert!(
781 res.mean_cosine_similarity > 0.5,
782 "Cosine too low: {}",
783 res.mean_cosine_similarity
784 );
785 }
786
787 #[test]
788 fn test_alignment_dim() {
789 let src = make_embeddings(10, 5, 42);
790 let tgt = make_embeddings(10, 5, 99);
791 let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
792
793 let aligner = ProcrustesAligner::new();
794 let result = aligner.align(&src, &tgt, &anchors).expect("should align");
795 assert_eq!(result.dim, 5);
796 assert_eq!(result.rotation_matrix.len(), 5);
797 assert_eq!(result.rotation_matrix[0].len(), 5);
798 }
799
800 #[test]
803 fn test_transform_preserves_dim() {
804 let src = make_embeddings(10, 4, 42);
805 let tgt = make_embeddings(10, 4, 99);
806 let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
807 let aligner = ProcrustesAligner::new();
808 let result = aligner.align(&src, &tgt, &anchors).expect("should align");
809
810 let transformed = result.transform(&src[0]);
811 assert_eq!(transformed.len(), 4);
812 }
813
814 #[test]
815 fn test_transform_batch() {
816 let src = make_embeddings(10, 3, 42);
817 let tgt = make_embeddings(10, 3, 99);
818 let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
819 let aligner = ProcrustesAligner::new();
820 let result = aligner.align(&src, &tgt, &anchors).expect("should align");
821
822 let batch = result.transform_batch(&src);
823 assert_eq!(batch.len(), 10);
824 }
825
826 #[test]
829 fn test_evaluate_identity() {
830 let src = make_embeddings(20, 3, 42);
831 let tgt = src.clone();
832 let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
833 let eval_pairs: Vec<AnchorPair> = (10..20).map(|i| AnchorPair::new(i, i)).collect();
834
835 let aligner = ProcrustesAligner::new();
836 let result = aligner.align(&src, &tgt, &anchors).expect("should align");
837 let metrics = aligner.evaluate(&result, &src, &tgt, &eval_pairs);
838
839 assert_eq!(metrics.eval_pairs, 10);
840 assert!(metrics.mse < 1e-4);
841 assert!(metrics.precision_at_1 > 0.8);
842 }
843
844 #[test]
845 fn test_evaluate_empty() {
846 let src = make_embeddings(10, 3, 42);
847 let tgt = make_embeddings(10, 3, 99);
848 let anchors: Vec<AnchorPair> = (0..5).map(|i| AnchorPair::new(i, i)).collect();
849
850 let aligner = ProcrustesAligner::new();
851 let result = aligner.align(&src, &tgt, &anchors).expect("should align");
852 let metrics = aligner.evaluate(&result, &src, &tgt, &[]);
853 assert_eq!(metrics.eval_pairs, 0);
854 }
855
856 #[test]
859 fn test_cosine_sim_identical() {
860 let a = vec![1.0, 2.0, 3.0];
861 let sim = cosine_sim(&a, &a);
862 assert!((sim - 1.0).abs() < 1e-10);
863 }
864
865 #[test]
866 fn test_cosine_sim_orthogonal() {
867 let a = vec![1.0, 0.0];
868 let b = vec![0.0, 1.0];
869 let sim = cosine_sim(&a, &b);
870 assert!(sim.abs() < 1e-10);
871 }
872
873 #[test]
874 fn test_cosine_sim_opposite() {
875 let a = vec![1.0, 0.0];
876 let b = vec![-1.0, 0.0];
877 let sim = cosine_sim(&a, &b);
878 assert!((sim - (-1.0)).abs() < 1e-10);
879 }
880
881 #[test]
882 fn test_cosine_sim_zero_vector() {
883 let a = vec![0.0, 0.0];
884 let b = vec![1.0, 2.0];
885 let sim = cosine_sim(&a, &b);
886 assert!(sim.abs() < 1e-10);
887 }
888
889 #[test]
892 fn test_centroid_computation() {
893 let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
894 let centroid = compute_centroid(&embeddings, 2);
895 assert!((centroid[0] - 2.0).abs() < 1e-10);
896 assert!((centroid[1] - 3.0).abs() < 1e-10);
897 }
898
899 #[test]
900 fn test_center_embeddings_fn() {
901 let embeddings = vec![vec![2.0, 4.0], vec![4.0, 6.0]];
902 let centroid = vec![3.0, 5.0];
903 let centered = center_embeddings(&embeddings, ¢roid);
904 assert!((centered[0][0] - (-1.0)).abs() < 1e-10);
905 assert!((centered[1][1] - 1.0).abs() < 1e-10);
906 }
907
908 #[test]
909 fn test_normalize_rows_fn() {
910 let embeddings = vec![vec![3.0, 4.0]];
911 let normalized = normalize_rows(&embeddings);
912 let norm: f64 = normalized[0].iter().map(|v| v * v).sum::<f64>().sqrt();
913 assert!((norm - 1.0).abs() < 1e-10);
914 }
915
916 #[test]
917 fn test_transpose_identity() {
918 let m = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
919 let t = transpose(&m, 2);
920 assert!((t[0][0] - 1.0).abs() < 1e-10);
921 assert!((t[1][1] - 1.0).abs() < 1e-10);
922 }
923
924 #[test]
925 fn test_mat_mul_identity() {
926 let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
927 let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
928 let result = mat_mul(&a, &identity, 2);
929 assert!((result[0][0] - 1.0).abs() < 1e-10);
930 assert!((result[0][1] - 2.0).abs() < 1e-10);
931 assert!((result[1][0] - 3.0).abs() < 1e-10);
932 assert!((result[1][1] - 4.0).abs() < 1e-10);
933 }
934
935 #[test]
938 fn test_find_nearest_neighbors() {
939 let query = vec![0.0, 0.0];
940 let candidates = vec![
941 vec![10.0, 10.0],
942 vec![1.0, 0.0],
943 vec![0.0, 1.0],
944 vec![5.0, 5.0],
945 ];
946 let nn = find_nearest_neighbors(&query, &candidates, 2);
947 assert_eq!(nn.len(), 2);
948 assert!(nn[0] == 1 || nn[0] == 2);
950 }
951
952 #[test]
955 fn test_alignment_with_normalization() {
956 let config = ProcrustesConfig {
957 center: true,
958 normalize: true,
959 regularization: 1e-10,
960 };
961 let aligner = ProcrustesAligner::with_config(config);
962 let src = make_embeddings(20, 3, 42);
963 let tgt = src.clone();
964 let anchors: Vec<AnchorPair> = (0..10).map(|i| AnchorPair::new(i, i)).collect();
965 let result = aligner.align(&src, &tgt, &anchors);
966 assert!(result.is_ok());
967 }
968
969 #[test]
972 fn test_default_aligner() {
973 let aligner = ProcrustesAligner::default();
974 let src = make_embeddings(10, 2, 1);
975 let tgt = make_embeddings(10, 2, 2);
976 let anchors = vec![AnchorPair::new(0, 0), AnchorPair::new(1, 1)];
977 let result = aligner.align(&src, &tgt, &anchors);
978 assert!(result.is_ok());
979 }
980}