Skip to main content

oxirs_embed/evaluation/
ab_test.rs

1//! A/B Testing Framework for Embedding Model Comparison
2//!
3//! Provides statistical significance testing to determine which embedding
4//! model performs better on a given evaluation metric.
5//!
6//! Supported statistical tests:
7//! - Student's paired t-test (parametric, assumes normality)
8//! - Bootstrap permutation test (non-parametric)
9//! - Wilcoxon signed-rank test (non-parametric, distribution-free)
10//!
11//! Effect size is measured using Cohen's d for interpretability.
12
13use anyhow::{anyhow, Result};
14use serde::{Deserialize, Serialize};
15
16/// Metric used to compare embedding models
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub enum EmbedMetric {
19    /// Mean Reciprocal Rank on link prediction
20    MeanReciprocalRank,
21    /// Hits@K on knowledge graph completion
22    HitsAtK(usize),
23    /// Average cosine similarity of known-similar pairs
24    SimilarityScore,
25    /// Clustering quality via silhouette coefficient
26    SilhouetteScore,
27    /// Linear classification accuracy on node labels
28    ClassificationAccuracy,
29    /// Custom named metric
30    Custom(String),
31}
32
33impl std::fmt::Display for EmbedMetric {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            EmbedMetric::MeanReciprocalRank => write!(f, "MRR"),
37            EmbedMetric::HitsAtK(k) => write!(f, "Hits@{}", k),
38            EmbedMetric::SimilarityScore => write!(f, "SimilarityScore"),
39            EmbedMetric::SilhouetteScore => write!(f, "SilhouetteScore"),
40            EmbedMetric::ClassificationAccuracy => write!(f, "ClassificationAccuracy"),
41            EmbedMetric::Custom(name) => write!(f, "{}", name),
42        }
43    }
44}
45
46/// Evaluation result for a single model on a given metric
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelEvalResult {
49    /// Identifier for the model being evaluated
50    pub model_id: String,
51    /// Metric being measured
52    pub metric: EmbedMetric,
53    /// Per-sample scores (e.g., per-query MRR, per-entity accuracy)
54    pub scores: Vec<f64>,
55    /// Arithmetic mean of scores
56    pub mean: f64,
57    /// Standard deviation of scores
58    pub std_dev: f64,
59    /// Number of samples
60    pub sample_count: usize,
61}
62
63impl ModelEvalResult {
64    /// Create a new evaluation result, computing statistics automatically
65    pub fn new(model_id: String, metric: EmbedMetric, scores: Vec<f64>) -> Result<Self> {
66        if scores.is_empty() {
67            return Err(anyhow!("scores must not be empty for model '{}'", model_id));
68        }
69
70        let mean = stats_mean(&scores);
71        let std_dev = stats_std_dev(&scores, mean);
72        let sample_count = scores.len();
73
74        Ok(Self {
75            model_id,
76            metric,
77            scores,
78            mean,
79            std_dev,
80            sample_count,
81        })
82    }
83
84    /// Compute a two-sided confidence interval for the mean.
85    ///
86    /// Uses the t-distribution approximation.
87    /// `alpha` is the significance level (e.g., 0.05 for 95% CI).
88    ///
89    /// Returns (lower_bound, upper_bound).
90    pub fn confidence_interval(&self, alpha: f64) -> (f64, f64) {
91        let n = self.sample_count as f64;
92        if n < 2.0 {
93            return (self.mean, self.mean);
94        }
95        let se = self.std_dev / n.sqrt();
96        // t critical value approximation using normal distribution for large n,
97        // or a simple lookup for common alpha values
98        let t_crit = t_critical_value(n as usize - 1, alpha / 2.0);
99        let margin = t_crit * se;
100        (self.mean - margin, self.mean + margin)
101    }
102}
103
104/// Statistical test to use for comparing two models
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum StatTest {
107    /// Student's paired t-test (assumes paired samples from same test set)
108    TTest,
109    /// Wilcoxon signed-rank test (non-parametric alternative to paired t-test)
110    WilcoxonSignedRank,
111    /// Bootstrap permutation test (model-free, most general)
112    Bootstrap {
113        /// Number of permutation iterations
114        n_permutations: usize,
115        /// Random seed for reproducibility
116        seed: u64,
117    },
118}
119
120/// Result of an A/B test comparing two embedding models
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct AbTestResult {
123    /// Identifier of model A
124    pub model_a: String,
125    /// Identifier of model B
126    pub model_b: String,
127    /// Metric being compared
128    pub metric: EmbedMetric,
129    /// p-value from the statistical test
130    pub p_value: f64,
131    /// Effect size (Cohen's d): magnitude of the difference
132    pub effect_size: f64,
133    /// Whether the result is statistically significant (p_value < alpha)
134    pub is_significant: bool,
135    /// Significance level used
136    pub alpha: f64,
137    /// The winning model (None if result is not significant)
138    pub winner: Option<String>,
139    /// Mean score for model A
140    pub mean_a: f64,
141    /// Mean score for model B
142    pub mean_b: f64,
143    /// 95% confidence interval for model A mean
144    pub ci_a: (f64, f64),
145    /// 95% confidence interval for model B mean
146    pub ci_b: (f64, f64),
147    /// Which statistical test was used
148    pub test_used: StatTest,
149}
150
151impl AbTestResult {
152    /// Returns a human-readable summary of the test result
153    pub fn summary(&self) -> String {
154        let sig_str = if self.is_significant {
155            "significant"
156        } else {
157            "not significant"
158        };
159        let winner_str = match &self.winner {
160            Some(w) => format!("Winner: {}", w),
161            None => "No clear winner".to_string(),
162        };
163        format!(
164            "A/B Test [{}] ({}) -- {} vs {}: p={:.4}, effect={:.3} ({}). {}",
165            self.metric,
166            self.test_used.name(),
167            self.model_a,
168            self.model_b,
169            self.p_value,
170            self.effect_size,
171            sig_str,
172            winner_str
173        )
174    }
175}
176
177impl StatTest {
178    fn name(&self) -> &'static str {
179        match self {
180            StatTest::TTest => "paired t-test",
181            StatTest::WilcoxonSignedRank => "Wilcoxon signed-rank",
182            StatTest::Bootstrap { .. } => "bootstrap permutation",
183        }
184    }
185}
186
187/// Runner for A/B testing experiments
188pub struct AbTestRunner {
189    /// Significance level (default: 0.05)
190    alpha: f64,
191    /// Statistical test to use
192    test: StatTest,
193}
194
195impl Default for AbTestRunner {
196    fn default() -> Self {
197        Self {
198            alpha: 0.05,
199            test: StatTest::TTest,
200        }
201    }
202}
203
204impl AbTestRunner {
205    /// Create a new A/B test runner with default settings (t-test, alpha=0.05)
206    pub fn new() -> Self {
207        Self::default()
208    }
209
210    /// Set the significance level (e.g., 0.05 for 95% confidence)
211    pub fn with_alpha(mut self, alpha: f64) -> Self {
212        self.alpha = alpha;
213        self
214    }
215
216    /// Set the statistical test to use
217    pub fn with_test(mut self, test: StatTest) -> Self {
218        self.test = test;
219        self
220    }
221
222    /// Compare two models on the same evaluation metric.
223    ///
224    /// The scores must be paired (same test samples evaluated by both models).
225    /// If scores have different lengths, only the common prefix is used.
226    pub fn compare(
227        &self,
228        result_a: &ModelEvalResult,
229        result_b: &ModelEvalResult,
230    ) -> Result<AbTestResult> {
231        if result_a.metric != result_b.metric {
232            return Err(anyhow!(
233                "Cannot compare models on different metrics: {:?} vs {:?}",
234                result_a.metric,
235                result_b.metric
236            ));
237        }
238        if result_a.scores.is_empty() || result_b.scores.is_empty() {
239            return Err(anyhow!("Both models must have non-empty scores"));
240        }
241
242        // Use the minimum length for paired comparison
243        let n = result_a.scores.len().min(result_b.scores.len());
244        let scores_a = &result_a.scores[..n];
245        let scores_b = &result_b.scores[..n];
246
247        let p_value = match &self.test {
248            StatTest::TTest => self.t_test_paired(scores_a, scores_b)?,
249            StatTest::WilcoxonSignedRank => self.wilcoxon_signed_rank(scores_a, scores_b)?,
250            StatTest::Bootstrap {
251                n_permutations,
252                seed,
253            } => self.bootstrap_test(scores_a, scores_b, *n_permutations, *seed)?,
254        };
255
256        let effect_size = cohens_d(scores_a, scores_b);
257        let is_significant = p_value < self.alpha;
258
259        let mean_a = stats_mean(scores_a);
260        let mean_b = stats_mean(scores_b);
261
262        // Determine winner (higher mean is better for most metrics)
263        let winner = if is_significant {
264            if mean_a >= mean_b {
265                Some(result_a.model_id.clone())
266            } else {
267                Some(result_b.model_id.clone())
268            }
269        } else {
270            None
271        };
272
273        let ci_a = result_a.confidence_interval(self.alpha);
274        let ci_b = result_b.confidence_interval(self.alpha);
275
276        Ok(AbTestResult {
277            model_a: result_a.model_id.clone(),
278            model_b: result_b.model_id.clone(),
279            metric: result_a.metric.clone(),
280            p_value,
281            effect_size,
282            is_significant,
283            alpha: self.alpha,
284            winner,
285            mean_a,
286            mean_b,
287            ci_a,
288            ci_b,
289            test_used: self.test.clone(),
290        })
291    }
292
293    /// Compare multiple models and produce a ranking by mean score.
294    ///
295    /// Returns a list of (model_id, mean_score, is_significantly_best) tuples,
296    /// sorted by mean score descending.
297    pub fn rank_models(&self, results: &[ModelEvalResult]) -> Result<Vec<(String, f64, bool)>> {
298        if results.is_empty() {
299            return Err(anyhow!("Need at least one model to rank"));
300        }
301
302        // Verify all metrics are the same
303        let metric = &results[0].metric;
304        for r in results.iter().skip(1) {
305            if &r.metric != metric {
306                return Err(anyhow!(
307                    "All models must use the same metric for ranking, found {:?} and {:?}",
308                    metric,
309                    r.metric
310                ));
311            }
312        }
313
314        // Sort by mean score descending
315        let mut ranked: Vec<(String, f64)> = results
316            .iter()
317            .map(|r| (r.model_id.clone(), r.mean))
318            .collect();
319        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
320
321        // Determine if best model is significantly better than runner-up
322        let mut is_best_significant = false;
323        if ranked.len() >= 2 {
324            let best = results
325                .iter()
326                .find(|r| r.model_id == ranked[0].0)
327                .expect("best model must be in results");
328            let second = results
329                .iter()
330                .find(|r| r.model_id == ranked[1].0)
331                .expect("second model must be in results");
332            if let Ok(cmp) = self.compare(best, second) {
333                is_best_significant = cmp.is_significant;
334            }
335        } else {
336            is_best_significant = true; // Only one model, trivially best
337        }
338
339        let result: Vec<(String, f64, bool)> = ranked
340            .into_iter()
341            .enumerate()
342            .map(|(i, (id, mean))| (id, mean, i == 0 && is_best_significant))
343            .collect();
344
345        Ok(result)
346    }
347
348    /// Paired t-test: tests whether mean difference is significantly != 0.
349    ///
350    /// H0: mean(a - b) = 0
351    /// Returns two-sided p-value.
352    fn t_test_paired(&self, a: &[f64], b: &[f64]) -> Result<f64> {
353        let n = a.len();
354        if n < 2 {
355            return Err(anyhow!("t-test requires at least 2 paired samples"));
356        }
357
358        // Compute differences
359        let diffs: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
360        let mean_diff = stats_mean(&diffs);
361        let std_diff = stats_std_dev(&diffs, mean_diff);
362
363        if std_diff < 1e-15 {
364            // No variation in differences - perfectly tied or identical
365            return Ok(if mean_diff.abs() < 1e-15 { 1.0 } else { 0.0 });
366        }
367
368        let se = std_diff / (n as f64).sqrt();
369        let t_stat = mean_diff / se;
370        let df = (n - 1) as f64;
371
372        // Two-sided p-value: 2 * P(T > |t_stat|)
373        let p = 2.0 * (1.0 - t_distribution_cdf(t_stat.abs(), df));
374        Ok(p.clamp(0.0, 1.0))
375    }
376
377    /// Wilcoxon signed-rank test (non-parametric paired comparison).
378    ///
379    /// Ranks the absolute differences and tests whether positive and
380    /// negative ranks are symmetrically distributed.
381    fn wilcoxon_signed_rank(&self, a: &[f64], b: &[f64]) -> Result<f64> {
382        let n = a.len();
383        if n < 2 {
384            return Err(anyhow!(
385                "Wilcoxon signed-rank test requires at least 2 paired samples"
386            ));
387        }
388
389        // Compute differences, ignoring zeros
390        let diffs: Vec<f64> = a
391            .iter()
392            .zip(b.iter())
393            .map(|(x, y)| x - y)
394            .filter(|d| d.abs() > 1e-15)
395            .collect();
396
397        let m = diffs.len();
398        if m == 0 {
399            return Ok(1.0); // All tied
400        }
401
402        // Rank the absolute differences (with ties averaged)
403        let abs_diffs: Vec<f64> = diffs.iter().map(|d| d.abs()).collect();
404        let ranks = rank_with_ties(&abs_diffs);
405
406        // Sum of positive and negative ranks
407        let w_plus: f64 = diffs
408            .iter()
409            .zip(ranks.iter())
410            .filter(|(d, _)| **d > 0.0)
411            .map(|(_, r)| r)
412            .sum();
413
414        let w_minus: f64 = diffs
415            .iter()
416            .zip(ranks.iter())
417            .filter(|(d, _)| **d < 0.0)
418            .map(|(_, r)| r)
419            .sum();
420
421        let w_stat = w_plus.min(w_minus);
422
423        // Normal approximation for p-value (valid for m >= 10)
424        let m_f = m as f64;
425        let expected = m_f * (m_f + 1.0) / 4.0;
426        let variance = m_f * (m_f + 1.0) * (2.0 * m_f + 1.0) / 24.0;
427
428        if variance < 1e-15 {
429            return Ok(1.0);
430        }
431
432        let z = (w_stat - expected) / variance.sqrt();
433        // Two-sided p-value using normal approximation
434        let p = 2.0 * standard_normal_cdf(-z.abs());
435        Ok(p.clamp(0.0, 1.0))
436    }
437
438    /// Bootstrap permutation test: estimates p-value by random label swapping.
439    ///
440    /// Under H0 (no difference), swapping labels between A and B should
441    /// produce test statistics as extreme as the observed difference with
442    /// probability equal to the p-value.
443    fn bootstrap_test(
444        &self,
445        a: &[f64],
446        b: &[f64],
447        n_permutations: usize,
448        seed: u64,
449    ) -> Result<f64> {
450        if n_permutations == 0 {
451            return Err(anyhow!("n_permutations must be > 0"));
452        }
453        let n = a.len();
454        if n < 2 {
455            return Err(anyhow!("Bootstrap test requires at least 2 paired samples"));
456        }
457
458        let observed_diff = (stats_mean(a) - stats_mean(b)).abs();
459
460        // LCG for reproducible random permutations
461        let mut rng = BootstrapRng::new(seed);
462        let mut extreme_count = 0usize;
463
464        for _ in 0..n_permutations {
465            // For each pair (a_i, b_i), randomly swap with probability 0.5
466            let (perm_a, perm_b): (Vec<f64>, Vec<f64>) = a
467                .iter()
468                .zip(b.iter())
469                .map(
470                    |(&ai, &bi)| {
471                        if rng.next_bool() {
472                            (ai, bi)
473                        } else {
474                            (bi, ai)
475                        }
476                    },
477                )
478                .unzip();
479
480            let perm_diff = (stats_mean(&perm_a) - stats_mean(&perm_b)).abs();
481            if perm_diff >= observed_diff {
482                extreme_count += 1;
483            }
484        }
485
486        let p = (extreme_count as f64 + 1.0) / (n_permutations as f64 + 1.0);
487        Ok(p.clamp(0.0, 1.0))
488    }
489}
490
491/// Evaluate link prediction quality on a set of embeddings.
492///
493/// For each query node, computes the rank of the positive target
494/// among all candidates, then returns mean reciprocal rank scores.
495///
496/// # Arguments
497/// * `embeddings` - Node embedding matrix (indexed by node ID)
498/// * `positive_pairs` - Known positive (source, target) pairs
499/// * `negative_pairs` - Known negative (source, non-target) pairs
500pub fn evaluate_link_prediction(
501    model_id: String,
502    embeddings: &[Vec<f64>],
503    positive_pairs: &[(usize, usize)],
504    negative_pairs: &[(usize, usize)],
505) -> Result<ModelEvalResult> {
506    if embeddings.is_empty() {
507        return Err(anyhow!("embeddings must not be empty"));
508    }
509    if positive_pairs.is_empty() {
510        return Err(anyhow!("positive_pairs must not be empty"));
511    }
512
513    // For each positive pair (u, v), compute its rank among all negatives
514    let mut mrr_scores: Vec<f64> = Vec::with_capacity(positive_pairs.len());
515
516    for &(u, v) in positive_pairs {
517        let emb_u = match embeddings.get(u) {
518            Some(e) => e,
519            None => continue,
520        };
521        let emb_v = match embeddings.get(v) {
522            Some(e) => e,
523            None => continue,
524        };
525
526        let pos_score = cosine_similarity_slice(emb_u, emb_v);
527
528        // Count how many negatives score higher than this positive
529        let mut higher_count = 0usize;
530        for &(nu, nv) in negative_pairs {
531            let emb_nu = match embeddings.get(nu) {
532                Some(e) => e,
533                None => continue,
534            };
535            let emb_nv = match embeddings.get(nv) {
536                Some(e) => e,
537                None => continue,
538            };
539            let neg_score = cosine_similarity_slice(emb_nu, emb_nv);
540            if neg_score >= pos_score {
541                higher_count += 1;
542            }
543        }
544
545        let rank = (higher_count + 1) as f64;
546        mrr_scores.push(1.0 / rank);
547    }
548
549    if mrr_scores.is_empty() {
550        return Err(anyhow!("No valid positive pairs found in embeddings"));
551    }
552
553    ModelEvalResult::new(model_id, EmbedMetric::MeanReciprocalRank, mrr_scores)
554}
555
556/// Evaluate Hits@K: fraction of positive pairs ranked in top-K
557pub fn evaluate_hits_at_k(
558    model_id: String,
559    embeddings: &[Vec<f64>],
560    positive_pairs: &[(usize, usize)],
561    negative_pairs: &[(usize, usize)],
562    k: usize,
563) -> Result<ModelEvalResult> {
564    if embeddings.is_empty() {
565        return Err(anyhow!("embeddings must not be empty"));
566    }
567    if positive_pairs.is_empty() {
568        return Err(anyhow!("positive_pairs must not be empty"));
569    }
570    if k == 0 {
571        return Err(anyhow!("k must be > 0"));
572    }
573
574    let mut hit_scores: Vec<f64> = Vec::with_capacity(positive_pairs.len());
575
576    for &(u, v) in positive_pairs {
577        let emb_u = match embeddings.get(u) {
578            Some(e) => e,
579            None => continue,
580        };
581        let emb_v = match embeddings.get(v) {
582            Some(e) => e,
583            None => continue,
584        };
585
586        let pos_score = cosine_similarity_slice(emb_u, emb_v);
587
588        let mut higher_count = 0usize;
589        for &(nu, nv) in negative_pairs {
590            let emb_nu = match embeddings.get(nu) {
591                Some(e) => e,
592                None => continue,
593            };
594            let emb_nv = match embeddings.get(nv) {
595                Some(e) => e,
596                None => continue,
597            };
598            let neg_score = cosine_similarity_slice(emb_nu, emb_nv);
599            if neg_score >= pos_score {
600                higher_count += 1;
601            }
602        }
603
604        // Hit if rank <= k
605        let rank = higher_count + 1;
606        hit_scores.push(if rank <= k { 1.0 } else { 0.0 });
607    }
608
609    if hit_scores.is_empty() {
610        return Err(anyhow!("No valid positive pairs found in embeddings"));
611    }
612
613    ModelEvalResult::new(model_id, EmbedMetric::HitsAtK(k), hit_scores)
614}
615
616/// Evaluate silhouette score for node clustering quality.
617///
618/// The silhouette coefficient measures how similar each node is to its own cluster
619/// compared to other clusters. Range: [-1, 1], higher is better.
620pub fn evaluate_silhouette(
621    model_id: String,
622    embeddings: &[Vec<f64>],
623    cluster_labels: &[usize],
624) -> Result<ModelEvalResult> {
625    let n = embeddings.len();
626    if n < 2 {
627        return Err(anyhow!("Need at least 2 nodes for silhouette score"));
628    }
629    if cluster_labels.len() != n {
630        return Err(anyhow!(
631            "cluster_labels length {} != embeddings length {}",
632            cluster_labels.len(),
633            n
634        ));
635    }
636
637    // Collect unique clusters
638    let unique_clusters: std::collections::HashSet<usize> =
639        cluster_labels.iter().copied().collect();
640    if unique_clusters.len() < 2 {
641        return Err(anyhow!(
642            "Need at least 2 distinct clusters for silhouette score"
643        ));
644    }
645
646    let mut silhouette_scores: Vec<f64> = Vec::with_capacity(n);
647
648    for i in 0..n {
649        let my_cluster = cluster_labels[i];
650
651        // Intra-cluster: mean distance to all other nodes in same cluster
652        let my_cluster_nodes: Vec<usize> = (0..n)
653            .filter(|&j| j != i && cluster_labels[j] == my_cluster)
654            .collect();
655
656        let a = if my_cluster_nodes.is_empty() {
657            0.0
658        } else {
659            my_cluster_nodes
660                .iter()
661                .map(|&j| euclidean_distance(&embeddings[i], &embeddings[j]))
662                .sum::<f64>()
663                / my_cluster_nodes.len() as f64
664        };
665
666        // Inter-cluster: minimum mean distance to any other cluster
667        let b = unique_clusters
668            .iter()
669            .filter(|&&c| c != my_cluster)
670            .map(|&c| {
671                let other_nodes: Vec<usize> = (0..n).filter(|&j| cluster_labels[j] == c).collect();
672                if other_nodes.is_empty() {
673                    f64::INFINITY
674                } else {
675                    other_nodes
676                        .iter()
677                        .map(|&j| euclidean_distance(&embeddings[i], &embeddings[j]))
678                        .sum::<f64>()
679                        / other_nodes.len() as f64
680                }
681            })
682            .fold(f64::INFINITY, f64::min);
683
684        // Silhouette coefficient
685        let s = if a < b {
686            1.0 - a / b
687        } else if a > b {
688            b / a - 1.0
689        } else {
690            0.0
691        };
692
693        silhouette_scores.push(s);
694    }
695
696    ModelEvalResult::new(model_id, EmbedMetric::SilhouetteScore, silhouette_scores)
697}
698
699// ============================================================================
700// Statistical utility functions
701// ============================================================================
702
703/// Compute arithmetic mean
704fn stats_mean(v: &[f64]) -> f64 {
705    if v.is_empty() {
706        return 0.0;
707    }
708    v.iter().sum::<f64>() / v.len() as f64
709}
710
711/// Compute population standard deviation
712fn stats_std_dev(v: &[f64], mean: f64) -> f64 {
713    if v.len() < 2 {
714        return 0.0;
715    }
716    let variance: f64 =
717        v.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / (v.len() - 1) as f64;
718    variance.sqrt()
719}
720
721/// Cohen's d: standardized effect size between two samples
722fn cohens_d(a: &[f64], b: &[f64]) -> f64 {
723    let mean_a = stats_mean(a);
724    let mean_b = stats_mean(b);
725    let std_a = stats_std_dev(a, mean_a);
726    let std_b = stats_std_dev(b, mean_b);
727
728    // Pooled standard deviation
729    let n_a = a.len() as f64;
730    let n_b = b.len() as f64;
731    let pooled_var =
732        ((n_a - 1.0) * std_a * std_a + (n_b - 1.0) * std_b * std_b) / (n_a + n_b - 2.0);
733
734    if pooled_var < 1e-15 {
735        return 0.0;
736    }
737
738    (mean_a - mean_b) / pooled_var.sqrt()
739}
740
741/// CDF of the Student's t-distribution.
742///
743/// Uses a simple and robust normal approximation for large |t|,
744/// and the exact continued fraction for small |t|.
745/// Two-sided p-value computation: p = 2 * (1 - CDF(|t|)).
746fn t_distribution_cdf(t: f64, df: f64) -> f64 {
747    if df <= 0.0 {
748        return 0.5;
749    }
750    // For very large t, the CDF is essentially 1.0
751    if t.abs() > 1e6 {
752        return if t >= 0.0 { 1.0 } else { 0.0 };
753    }
754    // Regularized incomplete beta: P(T <= t | df) = 1 - 0.5 * I_x(df/2, 1/2)
755    // where x = df / (df + t^2)
756    let x = df / (df + t * t);
757    // I_x(df/2, 1/2) computed via the beta regularized function
758    let beta_inc = betai(df / 2.0, 0.5, x);
759    // CDF(t) = 1 - 0.5 * betai for t >= 0
760    let cdf = 1.0 - 0.5 * beta_inc;
761    cdf.clamp(0.0, 1.0)
762}
763
764/// Regularized incomplete beta function I_x(a, b).
765///
766/// Uses the continued fraction method from Numerical Recipes.
767/// Returns a value in [0, 1].
768fn betai(a: f64, b: f64, x: f64) -> f64 {
769    if !(0.0..=1.0).contains(&x) {
770        return 0.0;
771    }
772    if x == 0.0 {
773        return 0.0;
774    }
775    if x == 1.0 {
776        return 1.0;
777    }
778    // Use the symmetry relation for numerical stability
779    let bt =
780        (log_gamma(a + b) - log_gamma(a) - log_gamma(b) + a * x.ln() + b * (1.0 - x).ln()).exp();
781
782    // Evaluate on the smaller argument side for better CF convergence
783    if x < (a + 1.0) / (a + b + 2.0) {
784        bt * betacf(a, b, x) / a
785    } else {
786        1.0 - bt * betacf(b, a, 1.0 - x) / b
787    }
788}
789
790/// Continued fraction for the incomplete beta function (Numerical Recipes).
791/// Evaluates the continued fraction via Lentz's method.
792fn betacf(a: f64, b: f64, x: f64) -> f64 {
793    const MAX_ITER: usize = 200;
794    const EPS: f64 = 3.0e-10;
795    const FPMIN: f64 = 1.0e-300;
796
797    let qab = a + b;
798    let qap = a + 1.0;
799    let qam = a - 1.0;
800    let mut c = 1.0f64;
801    let mut d = 1.0 - qab * x / qap;
802    if d.abs() < FPMIN {
803        d = FPMIN;
804    }
805    d = 1.0 / d;
806    let mut h = d;
807
808    for m in 1..=MAX_ITER {
809        let m2 = 2 * m;
810        // Even step
811        let aa = m as f64 * (b - m as f64) * x / ((qam + m2 as f64) * (a + m2 as f64));
812        d = 1.0 + aa * d;
813        if d.abs() < FPMIN {
814            d = FPMIN;
815        }
816        c = 1.0 + aa / c;
817        if c.abs() < FPMIN {
818            c = FPMIN;
819        }
820        d = 1.0 / d;
821        h *= d * c;
822        // Odd step
823        let aa = -(a + m as f64) * (qab + m as f64) * x / ((a + m2 as f64) * (qap + m2 as f64));
824        d = 1.0 + aa * d;
825        if d.abs() < FPMIN {
826            d = FPMIN;
827        }
828        c = 1.0 + aa / c;
829        if c.abs() < FPMIN {
830            c = FPMIN;
831        }
832        d = 1.0 / d;
833        let delta = d * c;
834        h *= delta;
835        if (delta - 1.0).abs() < EPS {
836            break;
837        }
838    }
839    h
840}
841
842/// Log of the Gamma function using Lanczos approximation (Numerical Recipes).
843/// Accurate to ~15 decimal places for z > 0.
844fn log_gamma(z: f64) -> f64 {
845    if z <= 0.0 {
846        return f64::INFINITY;
847    }
848    // Lanczos coefficients for g=7
849    const G: f64 = 7.0;
850    const C: [f64; 9] = [
851        0.999_999_999_999_809_9,
852        676.5203681218851,
853        -1259.1392167224028,
854        771.323_428_777_653_1,
855        -176.615_029_162_140_6,
856        12.507343278686905,
857        -0.13857109526572012,
858        9.984_369_578_019_572e-6,
859        1.5056327351493116e-7,
860    ];
861    if z < 0.5 {
862        // Reflection formula: Gamma(z)*Gamma(1-z) = pi/sin(pi*z)
863        std::f64::consts::PI.ln() - (std::f64::consts::PI * z).sin().abs().ln() - log_gamma(1.0 - z)
864    } else {
865        let z = z - 1.0;
866        let mut x = C[0];
867        for (i, &c) in C[1..].iter().enumerate() {
868            x += c / (z + i as f64 + 1.0);
869        }
870        let t = z + G + 0.5;
871        (std::f64::consts::TAU.sqrt()).ln() + x.ln() + (z + 0.5) * t.ln() - t
872    }
873}
874
875/// t critical value approximation using inverse t-distribution.
876///
877/// For common significance levels, returns the t-critical value
878/// for the given degrees of freedom and tail probability.
879fn t_critical_value(df: usize, tail_prob: f64) -> f64 {
880    // Use normal approximation for large df (>= 30)
881    if df >= 30 {
882        return normal_quantile(1.0 - tail_prob);
883    }
884    // Simple lookup for small df (common in practice)
885    // These are approximate values for two-sided alpha=0.05 (tail=0.025)
886    match df {
887        1 => 12.706,
888        2 => 4.303,
889        3 => 3.182,
890        4 => 2.776,
891        5 => 2.571,
892        6 => 2.447,
893        7 => 2.365,
894        8 => 2.306,
895        9 => 2.262,
896        10 => 2.228,
897        11 => 2.201,
898        12 => 2.179,
899        13 => 2.160,
900        14 => 2.145,
901        15 => 2.131,
902        16 => 2.120,
903        17 => 2.110,
904        18 => 2.101,
905        19 => 2.093,
906        20 => 2.086,
907        21..=25 => 2.064,
908        26..=29 => 2.048,
909        _ => normal_quantile(1.0 - tail_prob),
910    }
911}
912
913/// Standard normal CDF using rational approximation (Abramowitz & Stegun 26.2.17)
914fn standard_normal_cdf(z: f64) -> f64 {
915    let t = 1.0 / (1.0 + 0.2316419 * z.abs());
916    let d = 0.3989422820 * (-z * z / 2.0).exp();
917    let poly = t
918        * (0.3193815306
919            + t * (-0.3565637813 + t * (1.7814779372 + t * (-1.8212559978 + t * 1.3302744929))));
920    if z >= 0.0 {
921        1.0 - d * poly
922    } else {
923        d * poly
924    }
925}
926
927/// Inverse normal (quantile function) - approximation (Peter Acklam)
928fn normal_quantile(p: f64) -> f64 {
929    if p <= 0.0 {
930        return f64::NEG_INFINITY;
931    }
932    if p >= 1.0 {
933        return f64::INFINITY;
934    }
935    // Rational approximation for central region
936    let q = p - 0.5;
937    if q.abs() <= 0.425 {
938        let r = 0.180625 - q * q;
939        q * (((2_509.080_928_730_122_7 * r + 33_430.575_583_588_13) * r + 67_265.770_927_008_7) * r
940            + 45_921.953_931_549_87)
941            / (((5_226.495_278_852_854 * r + 28_729.085_735_721_943) * r + 39_307.895_800_092_71)
942                * r
943                + 10_765.120_437_959_045
944                + 1.0)
945    } else {
946        let r = if q < 0.0 { p } else { 1.0 - p };
947        let r = (-r.ln()).sqrt();
948        let x = (((2.990_113_295_264_179 * r + 4.740_220_281_696_907_5) * r
949            + 3.343_057_558_358_813)
950            * r
951            + 0.675_865_739_902_174_9)
952            / ((5.104_063_170_295_205_5 * r + 3.874_403_263_689_304_7) * r
953                + 0.732_672_856_137_836_4
954                + 1.0);
955        if q < 0.0 {
956            -x
957        } else {
958            x
959        }
960    }
961}
962
963/// Rank a slice with ties averaged
964fn rank_with_ties(values: &[f64]) -> Vec<f64> {
965    let n = values.len();
966    let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
967    indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
968
969    let mut ranks = vec![0.0f64; n];
970    let mut i = 0;
971    while i < n {
972        let mut j = i + 1;
973        while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-15 {
974            j += 1;
975        }
976        // Assign average rank to ties
977        let avg_rank = (i + j + 1) as f64 / 2.0; // average of ranks i+1..j
978        for k in i..j {
979            ranks[indexed[k].0] = avg_rank;
980        }
981        i = j;
982    }
983    ranks
984}
985
986/// Cosine similarity between two slices
987fn cosine_similarity_slice(a: &[f64], b: &[f64]) -> f64 {
988    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
989    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
990    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
991    if norm_a < 1e-12 || norm_b < 1e-12 {
992        return 0.0;
993    }
994    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
995}
996
997/// Euclidean distance between two vectors
998fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
999    a.iter()
1000        .zip(b.iter())
1001        .map(|(x, y)| (x - y) * (x - y))
1002        .sum::<f64>()
1003        .sqrt()
1004}
1005
1006/// Simple PRNG for bootstrap tests
1007struct BootstrapRng {
1008    state: u64,
1009}
1010
1011impl BootstrapRng {
1012    fn new(seed: u64) -> Self {
1013        Self {
1014            state: seed.wrapping_add(1),
1015        }
1016    }
1017
1018    fn next_u64(&mut self) -> u64 {
1019        self.state = self
1020            .state
1021            .wrapping_mul(6364136223846793005)
1022            .wrapping_add(1442695040888963407);
1023        self.state
1024    }
1025
1026    fn next_bool(&mut self) -> bool {
1027        self.next_u64() & 1 == 0
1028    }
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033    use super::*;
1034
1035    fn make_scores(values: &[f64], model_id: &str) -> ModelEvalResult {
1036        ModelEvalResult::new(
1037            model_id.to_string(),
1038            EmbedMetric::MeanReciprocalRank,
1039            values.to_vec(),
1040        )
1041        .expect("scores should be valid")
1042    }
1043
1044    #[test]
1045    fn test_model_eval_result_basic() {
1046        let scores = vec![0.8, 0.7, 0.9, 0.6, 0.85];
1047        let result = ModelEvalResult::new(
1048            "ModelA".to_string(),
1049            EmbedMetric::MeanReciprocalRank,
1050            scores.clone(),
1051        )
1052        .expect("should construct");
1053
1054        assert_eq!(result.sample_count, 5);
1055        assert!((result.mean - 0.77).abs() < 0.01);
1056        assert!(result.std_dev > 0.0);
1057    }
1058
1059    #[test]
1060    fn test_model_eval_result_empty_scores() {
1061        let result =
1062            ModelEvalResult::new("Model".to_string(), EmbedMetric::SilhouetteScore, vec![]);
1063        assert!(result.is_err());
1064    }
1065
1066    #[test]
1067    fn test_confidence_interval() {
1068        let scores: Vec<f64> = (0..50).map(|i| i as f64 * 0.02).collect();
1069        let result =
1070            ModelEvalResult::new("M".to_string(), EmbedMetric::ClassificationAccuracy, scores)
1071                .expect("should construct");
1072        let (lo, hi) = result.confidence_interval(0.05);
1073        assert!(lo < result.mean);
1074        assert!(hi > result.mean);
1075        assert!(lo >= 0.0);
1076        assert!(hi <= 1.0 + 1e-6);
1077    }
1078
1079    #[test]
1080    fn test_ab_test_clearly_different() {
1081        // Model A is clearly better (scores ~0.9 vs ~0.1)
1082        let a_scores: Vec<f64> = vec![0.85, 0.90, 0.88, 0.92, 0.87, 0.91, 0.89, 0.93, 0.86, 0.94];
1083        let b_scores: Vec<f64> = vec![0.10, 0.12, 0.09, 0.11, 0.13, 0.08, 0.10, 0.12, 0.11, 0.09];
1084
1085        let result_a = make_scores(&a_scores, "ModelA");
1086        let result_b = make_scores(&b_scores, "ModelB");
1087
1088        let runner = AbTestRunner::new();
1089        let ab_result = runner
1090            .compare(&result_a, &result_b)
1091            .expect("comparison should succeed");
1092
1093        assert!(ab_result.is_significant, "should be significant");
1094        assert_eq!(ab_result.winner, Some("ModelA".to_string()));
1095        assert!(ab_result.p_value < 0.001);
1096        assert!(ab_result.effect_size.abs() > 1.0); // Large effect size
1097    }
1098
1099    #[test]
1100    fn test_ab_test_similar_models() {
1101        // Models A and B with very similar scores (just noise)
1102        let a_scores: Vec<f64> = vec![0.500, 0.501, 0.499, 0.502, 0.498, 0.501, 0.500, 0.499];
1103        let b_scores: Vec<f64> = vec![0.499, 0.500, 0.501, 0.500, 0.501, 0.499, 0.500, 0.501];
1104
1105        let result_a = make_scores(&a_scores, "ModelA");
1106        let result_b = make_scores(&b_scores, "ModelB");
1107
1108        let runner = AbTestRunner::new();
1109        let ab_result = runner
1110            .compare(&result_a, &result_b)
1111            .expect("comparison should succeed");
1112
1113        assert!(
1114            !ab_result.is_significant,
1115            "should not be significant for similar scores"
1116        );
1117        assert!(ab_result.winner.is_none());
1118        assert!(ab_result.p_value > 0.05);
1119    }
1120
1121    #[test]
1122    fn test_ab_test_different_metrics_error() {
1123        let result_a = ModelEvalResult::new(
1124            "A".to_string(),
1125            EmbedMetric::MeanReciprocalRank,
1126            vec![0.5; 10],
1127        )
1128        .expect("ok");
1129        let result_b =
1130            ModelEvalResult::new("B".to_string(), EmbedMetric::SilhouetteScore, vec![0.5; 10])
1131                .expect("ok");
1132
1133        let runner = AbTestRunner::new();
1134        assert!(runner.compare(&result_a, &result_b).is_err());
1135    }
1136
1137    #[test]
1138    fn test_rank_models() {
1139        let scores_a: Vec<f64> = vec![0.8; 20];
1140        let scores_b: Vec<f64> = vec![0.6; 20];
1141        let scores_c: Vec<f64> = vec![0.7; 20];
1142
1143        let results = vec![
1144            ModelEvalResult::new("B".to_string(), EmbedMetric::HitsAtK(10), scores_b).expect("ok"),
1145            ModelEvalResult::new("A".to_string(), EmbedMetric::HitsAtK(10), scores_a).expect("ok"),
1146            ModelEvalResult::new("C".to_string(), EmbedMetric::HitsAtK(10), scores_c).expect("ok"),
1147        ];
1148
1149        let runner = AbTestRunner::new();
1150        let ranking = runner
1151            .rank_models(&results)
1152            .expect("ranking should succeed");
1153
1154        assert_eq!(ranking.len(), 3);
1155        assert_eq!(ranking[0].0, "A"); // Highest mean
1156        assert_eq!(ranking[1].0, "C");
1157        assert_eq!(ranking[2].0, "B");
1158    }
1159
1160    #[test]
1161    fn test_bootstrap_test() {
1162        let a_scores: Vec<f64> = vec![0.85, 0.90, 0.88, 0.92, 0.87, 0.91, 0.89, 0.93, 0.86, 0.94];
1163        let b_scores: Vec<f64> = vec![0.10, 0.12, 0.09, 0.11, 0.13, 0.08, 0.10, 0.12, 0.11, 0.09];
1164
1165        let result_a = make_scores(&a_scores, "A");
1166        let result_b = make_scores(&b_scores, "B");
1167
1168        let runner = AbTestRunner::new().with_test(StatTest::Bootstrap {
1169            n_permutations: 999,
1170            seed: 42,
1171        });
1172        let ab_result = runner
1173            .compare(&result_a, &result_b)
1174            .expect("bootstrap should succeed");
1175
1176        assert!(ab_result.is_significant);
1177        assert!(ab_result.p_value < 0.05);
1178    }
1179
1180    #[test]
1181    fn test_wilcoxon_test() {
1182        let a_scores: Vec<f64> = vec![0.9, 0.85, 0.92, 0.88, 0.91, 0.87, 0.93, 0.86, 0.94, 0.89];
1183        let b_scores: Vec<f64> = vec![0.1, 0.12, 0.09, 0.11, 0.13, 0.08, 0.1, 0.12, 0.11, 0.09];
1184
1185        let result_a = make_scores(&a_scores, "A");
1186        let result_b = make_scores(&b_scores, "B");
1187
1188        let runner = AbTestRunner::new().with_test(StatTest::WilcoxonSignedRank);
1189        let ab_result = runner
1190            .compare(&result_a, &result_b)
1191            .expect("wilcoxon should succeed");
1192
1193        assert!(ab_result.is_significant);
1194    }
1195
1196    #[test]
1197    fn test_evaluate_link_prediction() {
1198        // Create simple embeddings: nodes 0,1 are similar; 2,3 are different from 0
1199        let embeddings = vec![
1200            vec![1.0, 0.0, 0.0, 0.0],
1201            vec![0.9, 0.1, 0.0, 0.0],
1202            vec![0.0, 0.0, 1.0, 0.0],
1203            vec![0.0, 1.0, 0.0, 0.0],
1204        ];
1205        let positive_pairs = vec![(0, 1)];
1206        let negative_pairs = vec![(0, 2), (0, 3)];
1207
1208        let result = evaluate_link_prediction(
1209            "test_model".to_string(),
1210            &embeddings,
1211            &positive_pairs,
1212            &negative_pairs,
1213        )
1214        .expect("link prediction eval should succeed");
1215
1216        assert_eq!(result.model_id, "test_model");
1217        assert_eq!(result.metric, EmbedMetric::MeanReciprocalRank);
1218        assert!(!result.scores.is_empty());
1219        // Rank 1 => MRR = 1.0, node 0-1 should score higher than 0-2
1220        assert!((result.scores[0] - 1.0).abs() < 1e-10, "MRR should be 1.0");
1221    }
1222
1223    #[test]
1224    fn test_evaluate_hits_at_k() {
1225        let embeddings = vec![
1226            vec![1.0, 0.0, 0.0, 0.0],
1227            vec![0.9, 0.1, 0.0, 0.0],
1228            vec![0.0, 0.0, 1.0, 0.0],
1229        ];
1230        let positive_pairs = vec![(0, 1)];
1231        let negative_pairs = vec![(0, 2)];
1232
1233        let result = evaluate_hits_at_k(
1234            "model".to_string(),
1235            &embeddings,
1236            &positive_pairs,
1237            &negative_pairs,
1238            1,
1239        )
1240        .expect("hits@k eval should succeed");
1241
1242        assert_eq!(result.metric, EmbedMetric::HitsAtK(1));
1243        assert!(!result.scores.is_empty());
1244        // Positive pair has higher cosine similarity => rank 1 => hit@1 = 1.0
1245        assert!((result.scores[0] - 1.0).abs() < 1e-10);
1246    }
1247
1248    #[test]
1249    fn test_evaluate_silhouette() {
1250        // Two well-separated clusters
1251        let embeddings = vec![
1252            vec![1.0, 0.0],
1253            vec![0.9, 0.1],
1254            vec![0.95, 0.05],
1255            vec![0.0, 1.0],
1256            vec![0.1, 0.9],
1257            vec![0.05, 0.95],
1258        ];
1259        let labels = vec![0, 0, 0, 1, 1, 1];
1260
1261        let result = evaluate_silhouette("model".to_string(), &embeddings, &labels)
1262            .expect("silhouette eval should succeed");
1263
1264        assert_eq!(result.metric, EmbedMetric::SilhouetteScore);
1265        assert_eq!(result.sample_count, 6);
1266        // Well-separated clusters should have positive mean silhouette
1267        assert!(
1268            result.mean > 0.0,
1269            "mean silhouette should be positive for well-separated clusters"
1270        );
1271    }
1272
1273    #[test]
1274    fn test_cohens_d_interpretation() {
1275        // d > 0.8 is "large" effect; identical means => d = 0
1276        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1277        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1278        assert!((cohens_d(&a, &b)).abs() < 1e-10);
1279
1280        let c = vec![10.0, 11.0, 12.0, 13.0, 14.0];
1281        let d = cohens_d(&c, &a);
1282        assert!(
1283            d.abs() > 1.0,
1284            "Large difference should give large Cohen's d"
1285        );
1286    }
1287
1288    #[test]
1289    fn test_rank_with_ties() {
1290        let values = vec![3.0, 1.0, 1.0, 2.0];
1291        let ranks = rank_with_ties(&values);
1292        assert_eq!(ranks.len(), 4);
1293        // value 1.0 appears at positions 1,2 => avg rank = (1+2)/2 = 1.5
1294        assert!((ranks[1] - 1.5).abs() < 1e-10);
1295        assert!((ranks[2] - 1.5).abs() < 1e-10);
1296        // value 2.0 at position 3 => rank 3
1297        assert!((ranks[3] - 3.0).abs() < 1e-10);
1298        // value 3.0 at position 0 => rank 4
1299        assert!((ranks[0] - 4.0).abs() < 1e-10);
1300    }
1301
1302    #[test]
1303    fn test_hits_at_k_zero_error() {
1304        let result = evaluate_hits_at_k(
1305            "m".to_string(),
1306            &[vec![1.0f64]],
1307            &[(0, 0)],
1308            &[],
1309            0, // k=0 is invalid
1310        );
1311        assert!(result.is_err());
1312    }
1313
1314    #[test]
1315    fn test_embed_metric_display() {
1316        assert_eq!(EmbedMetric::MeanReciprocalRank.to_string(), "MRR");
1317        assert_eq!(EmbedMetric::HitsAtK(10).to_string(), "Hits@10");
1318        assert_eq!(
1319            EmbedMetric::Custom("MyMetric".to_string()).to_string(),
1320            "MyMetric"
1321        );
1322    }
1323
1324    #[test]
1325    fn test_ab_test_summary() {
1326        let a_scores = vec![0.8f64; 10];
1327        let b_scores = vec![0.2f64; 10];
1328        let result_a = make_scores(&a_scores, "ModelA");
1329        let result_b = make_scores(&b_scores, "ModelB");
1330
1331        let runner = AbTestRunner::new();
1332        let ab_result = runner.compare(&result_a, &result_b).expect("ok");
1333        let summary = ab_result.summary();
1334        assert!(summary.contains("ModelA"));
1335        assert!(summary.contains("ModelB"));
1336        assert!(summary.contains("significant"));
1337    }
1338}