1use anyhow::{anyhow, Result};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub enum EmbedMetric {
19 MeanReciprocalRank,
21 HitsAtK(usize),
23 SimilarityScore,
25 SilhouetteScore,
27 ClassificationAccuracy,
29 Custom(String),
31}
32
33impl std::fmt::Display for EmbedMetric {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 EmbedMetric::MeanReciprocalRank => write!(f, "MRR"),
37 EmbedMetric::HitsAtK(k) => write!(f, "Hits@{}", k),
38 EmbedMetric::SimilarityScore => write!(f, "SimilarityScore"),
39 EmbedMetric::SilhouetteScore => write!(f, "SilhouetteScore"),
40 EmbedMetric::ClassificationAccuracy => write!(f, "ClassificationAccuracy"),
41 EmbedMetric::Custom(name) => write!(f, "{}", name),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelEvalResult {
49 pub model_id: String,
51 pub metric: EmbedMetric,
53 pub scores: Vec<f64>,
55 pub mean: f64,
57 pub std_dev: f64,
59 pub sample_count: usize,
61}
62
63impl ModelEvalResult {
64 pub fn new(model_id: String, metric: EmbedMetric, scores: Vec<f64>) -> Result<Self> {
66 if scores.is_empty() {
67 return Err(anyhow!("scores must not be empty for model '{}'", model_id));
68 }
69
70 let mean = stats_mean(&scores);
71 let std_dev = stats_std_dev(&scores, mean);
72 let sample_count = scores.len();
73
74 Ok(Self {
75 model_id,
76 metric,
77 scores,
78 mean,
79 std_dev,
80 sample_count,
81 })
82 }
83
84 pub fn confidence_interval(&self, alpha: f64) -> (f64, f64) {
91 let n = self.sample_count as f64;
92 if n < 2.0 {
93 return (self.mean, self.mean);
94 }
95 let se = self.std_dev / n.sqrt();
96 let t_crit = t_critical_value(n as usize - 1, alpha / 2.0);
99 let margin = t_crit * se;
100 (self.mean - margin, self.mean + margin)
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum StatTest {
107 TTest,
109 WilcoxonSignedRank,
111 Bootstrap {
113 n_permutations: usize,
115 seed: u64,
117 },
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct AbTestResult {
123 pub model_a: String,
125 pub model_b: String,
127 pub metric: EmbedMetric,
129 pub p_value: f64,
131 pub effect_size: f64,
133 pub is_significant: bool,
135 pub alpha: f64,
137 pub winner: Option<String>,
139 pub mean_a: f64,
141 pub mean_b: f64,
143 pub ci_a: (f64, f64),
145 pub ci_b: (f64, f64),
147 pub test_used: StatTest,
149}
150
151impl AbTestResult {
152 pub fn summary(&self) -> String {
154 let sig_str = if self.is_significant {
155 "significant"
156 } else {
157 "not significant"
158 };
159 let winner_str = match &self.winner {
160 Some(w) => format!("Winner: {}", w),
161 None => "No clear winner".to_string(),
162 };
163 format!(
164 "A/B Test [{}] ({}) -- {} vs {}: p={:.4}, effect={:.3} ({}). {}",
165 self.metric,
166 self.test_used.name(),
167 self.model_a,
168 self.model_b,
169 self.p_value,
170 self.effect_size,
171 sig_str,
172 winner_str
173 )
174 }
175}
176
177impl StatTest {
178 fn name(&self) -> &'static str {
179 match self {
180 StatTest::TTest => "paired t-test",
181 StatTest::WilcoxonSignedRank => "Wilcoxon signed-rank",
182 StatTest::Bootstrap { .. } => "bootstrap permutation",
183 }
184 }
185}
186
187pub struct AbTestRunner {
189 alpha: f64,
191 test: StatTest,
193}
194
195impl Default for AbTestRunner {
196 fn default() -> Self {
197 Self {
198 alpha: 0.05,
199 test: StatTest::TTest,
200 }
201 }
202}
203
204impl AbTestRunner {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn with_alpha(mut self, alpha: f64) -> Self {
212 self.alpha = alpha;
213 self
214 }
215
216 pub fn with_test(mut self, test: StatTest) -> Self {
218 self.test = test;
219 self
220 }
221
222 pub fn compare(
227 &self,
228 result_a: &ModelEvalResult,
229 result_b: &ModelEvalResult,
230 ) -> Result<AbTestResult> {
231 if result_a.metric != result_b.metric {
232 return Err(anyhow!(
233 "Cannot compare models on different metrics: {:?} vs {:?}",
234 result_a.metric,
235 result_b.metric
236 ));
237 }
238 if result_a.scores.is_empty() || result_b.scores.is_empty() {
239 return Err(anyhow!("Both models must have non-empty scores"));
240 }
241
242 let n = result_a.scores.len().min(result_b.scores.len());
244 let scores_a = &result_a.scores[..n];
245 let scores_b = &result_b.scores[..n];
246
247 let p_value = match &self.test {
248 StatTest::TTest => self.t_test_paired(scores_a, scores_b)?,
249 StatTest::WilcoxonSignedRank => self.wilcoxon_signed_rank(scores_a, scores_b)?,
250 StatTest::Bootstrap {
251 n_permutations,
252 seed,
253 } => self.bootstrap_test(scores_a, scores_b, *n_permutations, *seed)?,
254 };
255
256 let effect_size = cohens_d(scores_a, scores_b);
257 let is_significant = p_value < self.alpha;
258
259 let mean_a = stats_mean(scores_a);
260 let mean_b = stats_mean(scores_b);
261
262 let winner = if is_significant {
264 if mean_a >= mean_b {
265 Some(result_a.model_id.clone())
266 } else {
267 Some(result_b.model_id.clone())
268 }
269 } else {
270 None
271 };
272
273 let ci_a = result_a.confidence_interval(self.alpha);
274 let ci_b = result_b.confidence_interval(self.alpha);
275
276 Ok(AbTestResult {
277 model_a: result_a.model_id.clone(),
278 model_b: result_b.model_id.clone(),
279 metric: result_a.metric.clone(),
280 p_value,
281 effect_size,
282 is_significant,
283 alpha: self.alpha,
284 winner,
285 mean_a,
286 mean_b,
287 ci_a,
288 ci_b,
289 test_used: self.test.clone(),
290 })
291 }
292
293 pub fn rank_models(&self, results: &[ModelEvalResult]) -> Result<Vec<(String, f64, bool)>> {
298 if results.is_empty() {
299 return Err(anyhow!("Need at least one model to rank"));
300 }
301
302 let metric = &results[0].metric;
304 for r in results.iter().skip(1) {
305 if &r.metric != metric {
306 return Err(anyhow!(
307 "All models must use the same metric for ranking, found {:?} and {:?}",
308 metric,
309 r.metric
310 ));
311 }
312 }
313
314 let mut ranked: Vec<(String, f64)> = results
316 .iter()
317 .map(|r| (r.model_id.clone(), r.mean))
318 .collect();
319 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
320
321 let mut is_best_significant = false;
323 if ranked.len() >= 2 {
324 let best = results
325 .iter()
326 .find(|r| r.model_id == ranked[0].0)
327 .expect("best model must be in results");
328 let second = results
329 .iter()
330 .find(|r| r.model_id == ranked[1].0)
331 .expect("second model must be in results");
332 if let Ok(cmp) = self.compare(best, second) {
333 is_best_significant = cmp.is_significant;
334 }
335 } else {
336 is_best_significant = true; }
338
339 let result: Vec<(String, f64, bool)> = ranked
340 .into_iter()
341 .enumerate()
342 .map(|(i, (id, mean))| (id, mean, i == 0 && is_best_significant))
343 .collect();
344
345 Ok(result)
346 }
347
348 fn t_test_paired(&self, a: &[f64], b: &[f64]) -> Result<f64> {
353 let n = a.len();
354 if n < 2 {
355 return Err(anyhow!("t-test requires at least 2 paired samples"));
356 }
357
358 let diffs: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
360 let mean_diff = stats_mean(&diffs);
361 let std_diff = stats_std_dev(&diffs, mean_diff);
362
363 if std_diff < 1e-15 {
364 return Ok(if mean_diff.abs() < 1e-15 { 1.0 } else { 0.0 });
366 }
367
368 let se = std_diff / (n as f64).sqrt();
369 let t_stat = mean_diff / se;
370 let df = (n - 1) as f64;
371
372 let p = 2.0 * (1.0 - t_distribution_cdf(t_stat.abs(), df));
374 Ok(p.clamp(0.0, 1.0))
375 }
376
377 fn wilcoxon_signed_rank(&self, a: &[f64], b: &[f64]) -> Result<f64> {
382 let n = a.len();
383 if n < 2 {
384 return Err(anyhow!(
385 "Wilcoxon signed-rank test requires at least 2 paired samples"
386 ));
387 }
388
389 let diffs: Vec<f64> = a
391 .iter()
392 .zip(b.iter())
393 .map(|(x, y)| x - y)
394 .filter(|d| d.abs() > 1e-15)
395 .collect();
396
397 let m = diffs.len();
398 if m == 0 {
399 return Ok(1.0); }
401
402 let abs_diffs: Vec<f64> = diffs.iter().map(|d| d.abs()).collect();
404 let ranks = rank_with_ties(&abs_diffs);
405
406 let w_plus: f64 = diffs
408 .iter()
409 .zip(ranks.iter())
410 .filter(|(d, _)| **d > 0.0)
411 .map(|(_, r)| r)
412 .sum();
413
414 let w_minus: f64 = diffs
415 .iter()
416 .zip(ranks.iter())
417 .filter(|(d, _)| **d < 0.0)
418 .map(|(_, r)| r)
419 .sum();
420
421 let w_stat = w_plus.min(w_minus);
422
423 let m_f = m as f64;
425 let expected = m_f * (m_f + 1.0) / 4.0;
426 let variance = m_f * (m_f + 1.0) * (2.0 * m_f + 1.0) / 24.0;
427
428 if variance < 1e-15 {
429 return Ok(1.0);
430 }
431
432 let z = (w_stat - expected) / variance.sqrt();
433 let p = 2.0 * standard_normal_cdf(-z.abs());
435 Ok(p.clamp(0.0, 1.0))
436 }
437
438 fn bootstrap_test(
444 &self,
445 a: &[f64],
446 b: &[f64],
447 n_permutations: usize,
448 seed: u64,
449 ) -> Result<f64> {
450 if n_permutations == 0 {
451 return Err(anyhow!("n_permutations must be > 0"));
452 }
453 let n = a.len();
454 if n < 2 {
455 return Err(anyhow!("Bootstrap test requires at least 2 paired samples"));
456 }
457
458 let observed_diff = (stats_mean(a) - stats_mean(b)).abs();
459
460 let mut rng = BootstrapRng::new(seed);
462 let mut extreme_count = 0usize;
463
464 for _ in 0..n_permutations {
465 let (perm_a, perm_b): (Vec<f64>, Vec<f64>) = a
467 .iter()
468 .zip(b.iter())
469 .map(
470 |(&ai, &bi)| {
471 if rng.next_bool() {
472 (ai, bi)
473 } else {
474 (bi, ai)
475 }
476 },
477 )
478 .unzip();
479
480 let perm_diff = (stats_mean(&perm_a) - stats_mean(&perm_b)).abs();
481 if perm_diff >= observed_diff {
482 extreme_count += 1;
483 }
484 }
485
486 let p = (extreme_count as f64 + 1.0) / (n_permutations as f64 + 1.0);
487 Ok(p.clamp(0.0, 1.0))
488 }
489}
490
491pub fn evaluate_link_prediction(
501 model_id: String,
502 embeddings: &[Vec<f64>],
503 positive_pairs: &[(usize, usize)],
504 negative_pairs: &[(usize, usize)],
505) -> Result<ModelEvalResult> {
506 if embeddings.is_empty() {
507 return Err(anyhow!("embeddings must not be empty"));
508 }
509 if positive_pairs.is_empty() {
510 return Err(anyhow!("positive_pairs must not be empty"));
511 }
512
513 let mut mrr_scores: Vec<f64> = Vec::with_capacity(positive_pairs.len());
515
516 for &(u, v) in positive_pairs {
517 let emb_u = match embeddings.get(u) {
518 Some(e) => e,
519 None => continue,
520 };
521 let emb_v = match embeddings.get(v) {
522 Some(e) => e,
523 None => continue,
524 };
525
526 let pos_score = cosine_similarity_slice(emb_u, emb_v);
527
528 let mut higher_count = 0usize;
530 for &(nu, nv) in negative_pairs {
531 let emb_nu = match embeddings.get(nu) {
532 Some(e) => e,
533 None => continue,
534 };
535 let emb_nv = match embeddings.get(nv) {
536 Some(e) => e,
537 None => continue,
538 };
539 let neg_score = cosine_similarity_slice(emb_nu, emb_nv);
540 if neg_score >= pos_score {
541 higher_count += 1;
542 }
543 }
544
545 let rank = (higher_count + 1) as f64;
546 mrr_scores.push(1.0 / rank);
547 }
548
549 if mrr_scores.is_empty() {
550 return Err(anyhow!("No valid positive pairs found in embeddings"));
551 }
552
553 ModelEvalResult::new(model_id, EmbedMetric::MeanReciprocalRank, mrr_scores)
554}
555
556pub fn evaluate_hits_at_k(
558 model_id: String,
559 embeddings: &[Vec<f64>],
560 positive_pairs: &[(usize, usize)],
561 negative_pairs: &[(usize, usize)],
562 k: usize,
563) -> Result<ModelEvalResult> {
564 if embeddings.is_empty() {
565 return Err(anyhow!("embeddings must not be empty"));
566 }
567 if positive_pairs.is_empty() {
568 return Err(anyhow!("positive_pairs must not be empty"));
569 }
570 if k == 0 {
571 return Err(anyhow!("k must be > 0"));
572 }
573
574 let mut hit_scores: Vec<f64> = Vec::with_capacity(positive_pairs.len());
575
576 for &(u, v) in positive_pairs {
577 let emb_u = match embeddings.get(u) {
578 Some(e) => e,
579 None => continue,
580 };
581 let emb_v = match embeddings.get(v) {
582 Some(e) => e,
583 None => continue,
584 };
585
586 let pos_score = cosine_similarity_slice(emb_u, emb_v);
587
588 let mut higher_count = 0usize;
589 for &(nu, nv) in negative_pairs {
590 let emb_nu = match embeddings.get(nu) {
591 Some(e) => e,
592 None => continue,
593 };
594 let emb_nv = match embeddings.get(nv) {
595 Some(e) => e,
596 None => continue,
597 };
598 let neg_score = cosine_similarity_slice(emb_nu, emb_nv);
599 if neg_score >= pos_score {
600 higher_count += 1;
601 }
602 }
603
604 let rank = higher_count + 1;
606 hit_scores.push(if rank <= k { 1.0 } else { 0.0 });
607 }
608
609 if hit_scores.is_empty() {
610 return Err(anyhow!("No valid positive pairs found in embeddings"));
611 }
612
613 ModelEvalResult::new(model_id, EmbedMetric::HitsAtK(k), hit_scores)
614}
615
616pub fn evaluate_silhouette(
621 model_id: String,
622 embeddings: &[Vec<f64>],
623 cluster_labels: &[usize],
624) -> Result<ModelEvalResult> {
625 let n = embeddings.len();
626 if n < 2 {
627 return Err(anyhow!("Need at least 2 nodes for silhouette score"));
628 }
629 if cluster_labels.len() != n {
630 return Err(anyhow!(
631 "cluster_labels length {} != embeddings length {}",
632 cluster_labels.len(),
633 n
634 ));
635 }
636
637 let unique_clusters: std::collections::HashSet<usize> =
639 cluster_labels.iter().copied().collect();
640 if unique_clusters.len() < 2 {
641 return Err(anyhow!(
642 "Need at least 2 distinct clusters for silhouette score"
643 ));
644 }
645
646 let mut silhouette_scores: Vec<f64> = Vec::with_capacity(n);
647
648 for i in 0..n {
649 let my_cluster = cluster_labels[i];
650
651 let my_cluster_nodes: Vec<usize> = (0..n)
653 .filter(|&j| j != i && cluster_labels[j] == my_cluster)
654 .collect();
655
656 let a = if my_cluster_nodes.is_empty() {
657 0.0
658 } else {
659 my_cluster_nodes
660 .iter()
661 .map(|&j| euclidean_distance(&embeddings[i], &embeddings[j]))
662 .sum::<f64>()
663 / my_cluster_nodes.len() as f64
664 };
665
666 let b = unique_clusters
668 .iter()
669 .filter(|&&c| c != my_cluster)
670 .map(|&c| {
671 let other_nodes: Vec<usize> = (0..n).filter(|&j| cluster_labels[j] == c).collect();
672 if other_nodes.is_empty() {
673 f64::INFINITY
674 } else {
675 other_nodes
676 .iter()
677 .map(|&j| euclidean_distance(&embeddings[i], &embeddings[j]))
678 .sum::<f64>()
679 / other_nodes.len() as f64
680 }
681 })
682 .fold(f64::INFINITY, f64::min);
683
684 let s = if a < b {
686 1.0 - a / b
687 } else if a > b {
688 b / a - 1.0
689 } else {
690 0.0
691 };
692
693 silhouette_scores.push(s);
694 }
695
696 ModelEvalResult::new(model_id, EmbedMetric::SilhouetteScore, silhouette_scores)
697}
698
699fn stats_mean(v: &[f64]) -> f64 {
705 if v.is_empty() {
706 return 0.0;
707 }
708 v.iter().sum::<f64>() / v.len() as f64
709}
710
711fn stats_std_dev(v: &[f64], mean: f64) -> f64 {
713 if v.len() < 2 {
714 return 0.0;
715 }
716 let variance: f64 =
717 v.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / (v.len() - 1) as f64;
718 variance.sqrt()
719}
720
721fn cohens_d(a: &[f64], b: &[f64]) -> f64 {
723 let mean_a = stats_mean(a);
724 let mean_b = stats_mean(b);
725 let std_a = stats_std_dev(a, mean_a);
726 let std_b = stats_std_dev(b, mean_b);
727
728 let n_a = a.len() as f64;
730 let n_b = b.len() as f64;
731 let pooled_var =
732 ((n_a - 1.0) * std_a * std_a + (n_b - 1.0) * std_b * std_b) / (n_a + n_b - 2.0);
733
734 if pooled_var < 1e-15 {
735 return 0.0;
736 }
737
738 (mean_a - mean_b) / pooled_var.sqrt()
739}
740
741fn t_distribution_cdf(t: f64, df: f64) -> f64 {
747 if df <= 0.0 {
748 return 0.5;
749 }
750 if t.abs() > 1e6 {
752 return if t >= 0.0 { 1.0 } else { 0.0 };
753 }
754 let x = df / (df + t * t);
757 let beta_inc = betai(df / 2.0, 0.5, x);
759 let cdf = 1.0 - 0.5 * beta_inc;
761 cdf.clamp(0.0, 1.0)
762}
763
764fn betai(a: f64, b: f64, x: f64) -> f64 {
769 if !(0.0..=1.0).contains(&x) {
770 return 0.0;
771 }
772 if x == 0.0 {
773 return 0.0;
774 }
775 if x == 1.0 {
776 return 1.0;
777 }
778 let bt =
780 (log_gamma(a + b) - log_gamma(a) - log_gamma(b) + a * x.ln() + b * (1.0 - x).ln()).exp();
781
782 if x < (a + 1.0) / (a + b + 2.0) {
784 bt * betacf(a, b, x) / a
785 } else {
786 1.0 - bt * betacf(b, a, 1.0 - x) / b
787 }
788}
789
790fn betacf(a: f64, b: f64, x: f64) -> f64 {
793 const MAX_ITER: usize = 200;
794 const EPS: f64 = 3.0e-10;
795 const FPMIN: f64 = 1.0e-300;
796
797 let qab = a + b;
798 let qap = a + 1.0;
799 let qam = a - 1.0;
800 let mut c = 1.0f64;
801 let mut d = 1.0 - qab * x / qap;
802 if d.abs() < FPMIN {
803 d = FPMIN;
804 }
805 d = 1.0 / d;
806 let mut h = d;
807
808 for m in 1..=MAX_ITER {
809 let m2 = 2 * m;
810 let aa = m as f64 * (b - m as f64) * x / ((qam + m2 as f64) * (a + m2 as f64));
812 d = 1.0 + aa * d;
813 if d.abs() < FPMIN {
814 d = FPMIN;
815 }
816 c = 1.0 + aa / c;
817 if c.abs() < FPMIN {
818 c = FPMIN;
819 }
820 d = 1.0 / d;
821 h *= d * c;
822 let aa = -(a + m as f64) * (qab + m as f64) * x / ((a + m2 as f64) * (qap + m2 as f64));
824 d = 1.0 + aa * d;
825 if d.abs() < FPMIN {
826 d = FPMIN;
827 }
828 c = 1.0 + aa / c;
829 if c.abs() < FPMIN {
830 c = FPMIN;
831 }
832 d = 1.0 / d;
833 let delta = d * c;
834 h *= delta;
835 if (delta - 1.0).abs() < EPS {
836 break;
837 }
838 }
839 h
840}
841
842fn log_gamma(z: f64) -> f64 {
845 if z <= 0.0 {
846 return f64::INFINITY;
847 }
848 const G: f64 = 7.0;
850 const C: [f64; 9] = [
851 0.999_999_999_999_809_9,
852 676.5203681218851,
853 -1259.1392167224028,
854 771.323_428_777_653_1,
855 -176.615_029_162_140_6,
856 12.507343278686905,
857 -0.13857109526572012,
858 9.984_369_578_019_572e-6,
859 1.5056327351493116e-7,
860 ];
861 if z < 0.5 {
862 std::f64::consts::PI.ln() - (std::f64::consts::PI * z).sin().abs().ln() - log_gamma(1.0 - z)
864 } else {
865 let z = z - 1.0;
866 let mut x = C[0];
867 for (i, &c) in C[1..].iter().enumerate() {
868 x += c / (z + i as f64 + 1.0);
869 }
870 let t = z + G + 0.5;
871 (std::f64::consts::TAU.sqrt()).ln() + x.ln() + (z + 0.5) * t.ln() - t
872 }
873}
874
875fn t_critical_value(df: usize, tail_prob: f64) -> f64 {
880 if df >= 30 {
882 return normal_quantile(1.0 - tail_prob);
883 }
884 match df {
887 1 => 12.706,
888 2 => 4.303,
889 3 => 3.182,
890 4 => 2.776,
891 5 => 2.571,
892 6 => 2.447,
893 7 => 2.365,
894 8 => 2.306,
895 9 => 2.262,
896 10 => 2.228,
897 11 => 2.201,
898 12 => 2.179,
899 13 => 2.160,
900 14 => 2.145,
901 15 => 2.131,
902 16 => 2.120,
903 17 => 2.110,
904 18 => 2.101,
905 19 => 2.093,
906 20 => 2.086,
907 21..=25 => 2.064,
908 26..=29 => 2.048,
909 _ => normal_quantile(1.0 - tail_prob),
910 }
911}
912
913fn standard_normal_cdf(z: f64) -> f64 {
915 let t = 1.0 / (1.0 + 0.2316419 * z.abs());
916 let d = 0.3989422820 * (-z * z / 2.0).exp();
917 let poly = t
918 * (0.3193815306
919 + t * (-0.3565637813 + t * (1.7814779372 + t * (-1.8212559978 + t * 1.3302744929))));
920 if z >= 0.0 {
921 1.0 - d * poly
922 } else {
923 d * poly
924 }
925}
926
927fn normal_quantile(p: f64) -> f64 {
929 if p <= 0.0 {
930 return f64::NEG_INFINITY;
931 }
932 if p >= 1.0 {
933 return f64::INFINITY;
934 }
935 let q = p - 0.5;
937 if q.abs() <= 0.425 {
938 let r = 0.180625 - q * q;
939 q * (((2_509.080_928_730_122_7 * r + 33_430.575_583_588_13) * r + 67_265.770_927_008_7) * r
940 + 45_921.953_931_549_87)
941 / (((5_226.495_278_852_854 * r + 28_729.085_735_721_943) * r + 39_307.895_800_092_71)
942 * r
943 + 10_765.120_437_959_045
944 + 1.0)
945 } else {
946 let r = if q < 0.0 { p } else { 1.0 - p };
947 let r = (-r.ln()).sqrt();
948 let x = (((2.990_113_295_264_179 * r + 4.740_220_281_696_907_5) * r
949 + 3.343_057_558_358_813)
950 * r
951 + 0.675_865_739_902_174_9)
952 / ((5.104_063_170_295_205_5 * r + 3.874_403_263_689_304_7) * r
953 + 0.732_672_856_137_836_4
954 + 1.0);
955 if q < 0.0 {
956 -x
957 } else {
958 x
959 }
960 }
961}
962
963fn rank_with_ties(values: &[f64]) -> Vec<f64> {
965 let n = values.len();
966 let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
967 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
968
969 let mut ranks = vec![0.0f64; n];
970 let mut i = 0;
971 while i < n {
972 let mut j = i + 1;
973 while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-15 {
974 j += 1;
975 }
976 let avg_rank = (i + j + 1) as f64 / 2.0; for k in i..j {
979 ranks[indexed[k].0] = avg_rank;
980 }
981 i = j;
982 }
983 ranks
984}
985
986fn cosine_similarity_slice(a: &[f64], b: &[f64]) -> f64 {
988 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
989 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
990 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
991 if norm_a < 1e-12 || norm_b < 1e-12 {
992 return 0.0;
993 }
994 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
995}
996
997fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
999 a.iter()
1000 .zip(b.iter())
1001 .map(|(x, y)| (x - y) * (x - y))
1002 .sum::<f64>()
1003 .sqrt()
1004}
1005
1006struct BootstrapRng {
1008 state: u64,
1009}
1010
1011impl BootstrapRng {
1012 fn new(seed: u64) -> Self {
1013 Self {
1014 state: seed.wrapping_add(1),
1015 }
1016 }
1017
1018 fn next_u64(&mut self) -> u64 {
1019 self.state = self
1020 .state
1021 .wrapping_mul(6364136223846793005)
1022 .wrapping_add(1442695040888963407);
1023 self.state
1024 }
1025
1026 fn next_bool(&mut self) -> bool {
1027 self.next_u64() & 1 == 0
1028 }
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033 use super::*;
1034
1035 fn make_scores(values: &[f64], model_id: &str) -> ModelEvalResult {
1036 ModelEvalResult::new(
1037 model_id.to_string(),
1038 EmbedMetric::MeanReciprocalRank,
1039 values.to_vec(),
1040 )
1041 .expect("scores should be valid")
1042 }
1043
1044 #[test]
1045 fn test_model_eval_result_basic() {
1046 let scores = vec![0.8, 0.7, 0.9, 0.6, 0.85];
1047 let result = ModelEvalResult::new(
1048 "ModelA".to_string(),
1049 EmbedMetric::MeanReciprocalRank,
1050 scores.clone(),
1051 )
1052 .expect("should construct");
1053
1054 assert_eq!(result.sample_count, 5);
1055 assert!((result.mean - 0.77).abs() < 0.01);
1056 assert!(result.std_dev > 0.0);
1057 }
1058
1059 #[test]
1060 fn test_model_eval_result_empty_scores() {
1061 let result =
1062 ModelEvalResult::new("Model".to_string(), EmbedMetric::SilhouetteScore, vec![]);
1063 assert!(result.is_err());
1064 }
1065
1066 #[test]
1067 fn test_confidence_interval() {
1068 let scores: Vec<f64> = (0..50).map(|i| i as f64 * 0.02).collect();
1069 let result =
1070 ModelEvalResult::new("M".to_string(), EmbedMetric::ClassificationAccuracy, scores)
1071 .expect("should construct");
1072 let (lo, hi) = result.confidence_interval(0.05);
1073 assert!(lo < result.mean);
1074 assert!(hi > result.mean);
1075 assert!(lo >= 0.0);
1076 assert!(hi <= 1.0 + 1e-6);
1077 }
1078
1079 #[test]
1080 fn test_ab_test_clearly_different() {
1081 let a_scores: Vec<f64> = vec![0.85, 0.90, 0.88, 0.92, 0.87, 0.91, 0.89, 0.93, 0.86, 0.94];
1083 let b_scores: Vec<f64> = vec![0.10, 0.12, 0.09, 0.11, 0.13, 0.08, 0.10, 0.12, 0.11, 0.09];
1084
1085 let result_a = make_scores(&a_scores, "ModelA");
1086 let result_b = make_scores(&b_scores, "ModelB");
1087
1088 let runner = AbTestRunner::new();
1089 let ab_result = runner
1090 .compare(&result_a, &result_b)
1091 .expect("comparison should succeed");
1092
1093 assert!(ab_result.is_significant, "should be significant");
1094 assert_eq!(ab_result.winner, Some("ModelA".to_string()));
1095 assert!(ab_result.p_value < 0.001);
1096 assert!(ab_result.effect_size.abs() > 1.0); }
1098
1099 #[test]
1100 fn test_ab_test_similar_models() {
1101 let a_scores: Vec<f64> = vec![0.500, 0.501, 0.499, 0.502, 0.498, 0.501, 0.500, 0.499];
1103 let b_scores: Vec<f64> = vec![0.499, 0.500, 0.501, 0.500, 0.501, 0.499, 0.500, 0.501];
1104
1105 let result_a = make_scores(&a_scores, "ModelA");
1106 let result_b = make_scores(&b_scores, "ModelB");
1107
1108 let runner = AbTestRunner::new();
1109 let ab_result = runner
1110 .compare(&result_a, &result_b)
1111 .expect("comparison should succeed");
1112
1113 assert!(
1114 !ab_result.is_significant,
1115 "should not be significant for similar scores"
1116 );
1117 assert!(ab_result.winner.is_none());
1118 assert!(ab_result.p_value > 0.05);
1119 }
1120
1121 #[test]
1122 fn test_ab_test_different_metrics_error() {
1123 let result_a = ModelEvalResult::new(
1124 "A".to_string(),
1125 EmbedMetric::MeanReciprocalRank,
1126 vec![0.5; 10],
1127 )
1128 .expect("ok");
1129 let result_b =
1130 ModelEvalResult::new("B".to_string(), EmbedMetric::SilhouetteScore, vec![0.5; 10])
1131 .expect("ok");
1132
1133 let runner = AbTestRunner::new();
1134 assert!(runner.compare(&result_a, &result_b).is_err());
1135 }
1136
1137 #[test]
1138 fn test_rank_models() {
1139 let scores_a: Vec<f64> = vec![0.8; 20];
1140 let scores_b: Vec<f64> = vec![0.6; 20];
1141 let scores_c: Vec<f64> = vec![0.7; 20];
1142
1143 let results = vec![
1144 ModelEvalResult::new("B".to_string(), EmbedMetric::HitsAtK(10), scores_b).expect("ok"),
1145 ModelEvalResult::new("A".to_string(), EmbedMetric::HitsAtK(10), scores_a).expect("ok"),
1146 ModelEvalResult::new("C".to_string(), EmbedMetric::HitsAtK(10), scores_c).expect("ok"),
1147 ];
1148
1149 let runner = AbTestRunner::new();
1150 let ranking = runner
1151 .rank_models(&results)
1152 .expect("ranking should succeed");
1153
1154 assert_eq!(ranking.len(), 3);
1155 assert_eq!(ranking[0].0, "A"); assert_eq!(ranking[1].0, "C");
1157 assert_eq!(ranking[2].0, "B");
1158 }
1159
1160 #[test]
1161 fn test_bootstrap_test() {
1162 let a_scores: Vec<f64> = vec![0.85, 0.90, 0.88, 0.92, 0.87, 0.91, 0.89, 0.93, 0.86, 0.94];
1163 let b_scores: Vec<f64> = vec![0.10, 0.12, 0.09, 0.11, 0.13, 0.08, 0.10, 0.12, 0.11, 0.09];
1164
1165 let result_a = make_scores(&a_scores, "A");
1166 let result_b = make_scores(&b_scores, "B");
1167
1168 let runner = AbTestRunner::new().with_test(StatTest::Bootstrap {
1169 n_permutations: 999,
1170 seed: 42,
1171 });
1172 let ab_result = runner
1173 .compare(&result_a, &result_b)
1174 .expect("bootstrap should succeed");
1175
1176 assert!(ab_result.is_significant);
1177 assert!(ab_result.p_value < 0.05);
1178 }
1179
1180 #[test]
1181 fn test_wilcoxon_test() {
1182 let a_scores: Vec<f64> = vec![0.9, 0.85, 0.92, 0.88, 0.91, 0.87, 0.93, 0.86, 0.94, 0.89];
1183 let b_scores: Vec<f64> = vec![0.1, 0.12, 0.09, 0.11, 0.13, 0.08, 0.1, 0.12, 0.11, 0.09];
1184
1185 let result_a = make_scores(&a_scores, "A");
1186 let result_b = make_scores(&b_scores, "B");
1187
1188 let runner = AbTestRunner::new().with_test(StatTest::WilcoxonSignedRank);
1189 let ab_result = runner
1190 .compare(&result_a, &result_b)
1191 .expect("wilcoxon should succeed");
1192
1193 assert!(ab_result.is_significant);
1194 }
1195
1196 #[test]
1197 fn test_evaluate_link_prediction() {
1198 let embeddings = vec![
1200 vec![1.0, 0.0, 0.0, 0.0],
1201 vec![0.9, 0.1, 0.0, 0.0],
1202 vec![0.0, 0.0, 1.0, 0.0],
1203 vec![0.0, 1.0, 0.0, 0.0],
1204 ];
1205 let positive_pairs = vec![(0, 1)];
1206 let negative_pairs = vec![(0, 2), (0, 3)];
1207
1208 let result = evaluate_link_prediction(
1209 "test_model".to_string(),
1210 &embeddings,
1211 &positive_pairs,
1212 &negative_pairs,
1213 )
1214 .expect("link prediction eval should succeed");
1215
1216 assert_eq!(result.model_id, "test_model");
1217 assert_eq!(result.metric, EmbedMetric::MeanReciprocalRank);
1218 assert!(!result.scores.is_empty());
1219 assert!((result.scores[0] - 1.0).abs() < 1e-10, "MRR should be 1.0");
1221 }
1222
1223 #[test]
1224 fn test_evaluate_hits_at_k() {
1225 let embeddings = vec![
1226 vec![1.0, 0.0, 0.0, 0.0],
1227 vec![0.9, 0.1, 0.0, 0.0],
1228 vec![0.0, 0.0, 1.0, 0.0],
1229 ];
1230 let positive_pairs = vec![(0, 1)];
1231 let negative_pairs = vec![(0, 2)];
1232
1233 let result = evaluate_hits_at_k(
1234 "model".to_string(),
1235 &embeddings,
1236 &positive_pairs,
1237 &negative_pairs,
1238 1,
1239 )
1240 .expect("hits@k eval should succeed");
1241
1242 assert_eq!(result.metric, EmbedMetric::HitsAtK(1));
1243 assert!(!result.scores.is_empty());
1244 assert!((result.scores[0] - 1.0).abs() < 1e-10);
1246 }
1247
1248 #[test]
1249 fn test_evaluate_silhouette() {
1250 let embeddings = vec![
1252 vec![1.0, 0.0],
1253 vec![0.9, 0.1],
1254 vec![0.95, 0.05],
1255 vec![0.0, 1.0],
1256 vec![0.1, 0.9],
1257 vec![0.05, 0.95],
1258 ];
1259 let labels = vec![0, 0, 0, 1, 1, 1];
1260
1261 let result = evaluate_silhouette("model".to_string(), &embeddings, &labels)
1262 .expect("silhouette eval should succeed");
1263
1264 assert_eq!(result.metric, EmbedMetric::SilhouetteScore);
1265 assert_eq!(result.sample_count, 6);
1266 assert!(
1268 result.mean > 0.0,
1269 "mean silhouette should be positive for well-separated clusters"
1270 );
1271 }
1272
1273 #[test]
1274 fn test_cohens_d_interpretation() {
1275 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1277 let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1278 assert!((cohens_d(&a, &b)).abs() < 1e-10);
1279
1280 let c = vec![10.0, 11.0, 12.0, 13.0, 14.0];
1281 let d = cohens_d(&c, &a);
1282 assert!(
1283 d.abs() > 1.0,
1284 "Large difference should give large Cohen's d"
1285 );
1286 }
1287
1288 #[test]
1289 fn test_rank_with_ties() {
1290 let values = vec![3.0, 1.0, 1.0, 2.0];
1291 let ranks = rank_with_ties(&values);
1292 assert_eq!(ranks.len(), 4);
1293 assert!((ranks[1] - 1.5).abs() < 1e-10);
1295 assert!((ranks[2] - 1.5).abs() < 1e-10);
1296 assert!((ranks[3] - 3.0).abs() < 1e-10);
1298 assert!((ranks[0] - 4.0).abs() < 1e-10);
1300 }
1301
1302 #[test]
1303 fn test_hits_at_k_zero_error() {
1304 let result = evaluate_hits_at_k(
1305 "m".to_string(),
1306 &[vec![1.0f64]],
1307 &[(0, 0)],
1308 &[],
1309 0, );
1311 assert!(result.is_err());
1312 }
1313
1314 #[test]
1315 fn test_embed_metric_display() {
1316 assert_eq!(EmbedMetric::MeanReciprocalRank.to_string(), "MRR");
1317 assert_eq!(EmbedMetric::HitsAtK(10).to_string(), "Hits@10");
1318 assert_eq!(
1319 EmbedMetric::Custom("MyMetric".to_string()).to_string(),
1320 "MyMetric"
1321 );
1322 }
1323
1324 #[test]
1325 fn test_ab_test_summary() {
1326 let a_scores = vec![0.8f64; 10];
1327 let b_scores = vec![0.2f64; 10];
1328 let result_a = make_scores(&a_scores, "ModelA");
1329 let result_b = make_scores(&b_scores, "ModelB");
1330
1331 let runner = AbTestRunner::new();
1332 let ab_result = runner.compare(&result_a, &result_b).expect("ok");
1333 let summary = ab_result.summary();
1334 assert!(summary.contains("ModelA"));
1335 assert!(summary.contains("ModelB"));
1336 assert!(summary.contains("significant"));
1337 }
1338}