Skip to main content

oxirs_embed/ab_testing/
mod.rs

1//! A/B Testing Framework for Embedding Models (v0.3.0)
2//!
3//! Production-ready framework for comparing embedding model variants with:
4//! - [`ModelVariant`]: encapsulates a model with metadata and metrics collection
5//! - [`ABTestConfig`]: configures traffic splits, metric targets, and test duration
6//! - [`ABTestRunner`]: routes inference requests between variants and records outcomes
7//! - [`ABTestAnalyzer`]: statistical significance testing (Welch's t-test, Mann-Whitney U)
8//! - [`ABTestReport`]: generates detailed comparison reports
9//!
10//! ## Design
11//!
12//! The framework is model-agnostic: any function from an input key to a
13//! `Vec<f64>` embedding qualifies as a "model variant".  This keeps the A/B
14//! framework decoupled from specific GNN or KGE implementations.
15//!
16//! ## Example
17//!
18//! ```rust,no_run
19//! use oxirs_embed::ab_testing::{ABTestConfig, ABTestRunner, ModelVariant};
20//!
21//! # fn main() -> anyhow::Result<()> {
22//! let control = ModelVariant::new("transe-v1", |_key: &str| vec![0.0f64; 64]);
23//! let treatment = ModelVariant::new("transe-v2", |_key: &str| vec![0.1f64; 64]);
24//!
25//! let config = ABTestConfig::default();
26//! let mut runner = ABTestRunner::new(config, control, treatment)?;
27//!
28//! // Simulate requests
29//! for i in 0..200 {
30//!     let key = format!("entity:{i}");
31//!     let (embedding, variant_name) = runner.route(&key)?;
32//!     // Record a business metric (e.g., link prediction hit@10)
33//!     runner.record_metric(&variant_name, 0.85)?;
34//! }
35//!
36//! let report = runner.analyze()?;
37//! println!("{}", report.summary());
38//! # Ok(())
39//! # }
40//! ```
41
42use anyhow::{anyhow, Result};
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45
46// ---------------------------------------------------------------------------
47// ModelVariant
48// ---------------------------------------------------------------------------
49
50/// A named embedding model variant for A/B testing.
51///
52/// Wraps an inference function `fn(&str) -> Vec<f64>` together with
53/// metadata about the variant (name, description, version).
54pub struct ModelVariant {
55    /// Human-readable name, e.g. `"transe-v1"`.
56    pub name: String,
57    /// Optional description of the variant.
58    pub description: String,
59    /// Optional version string.
60    pub version: String,
61    /// The inference function: maps an entity key to an embedding.
62    #[allow(clippy::type_complexity)]
63    infer: Box<dyn Fn(&str) -> Vec<f64> + Send + Sync>,
64}
65
66impl std::fmt::Debug for ModelVariant {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("ModelVariant")
69            .field("name", &self.name)
70            .field("description", &self.description)
71            .field("version", &self.version)
72            .finish()
73    }
74}
75
76impl ModelVariant {
77    /// Create a new variant with a name and inference function.
78    pub fn new<F>(name: impl Into<String>, infer: F) -> Self
79    where
80        F: Fn(&str) -> Vec<f64> + Send + Sync + 'static,
81    {
82        Self {
83            name: name.into(),
84            description: String::new(),
85            version: String::from("0.1.0"),
86            infer: Box::new(infer),
87        }
88    }
89
90    /// Attach a description.
91    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
92        self.description = desc.into();
93        self
94    }
95
96    /// Attach a version string.
97    pub fn with_version(mut self, version: impl Into<String>) -> Self {
98        self.version = version.into();
99        self
100    }
101
102    /// Run inference for the given entity key.
103    pub fn infer(&self, key: &str) -> Vec<f64> {
104        (self.infer)(key)
105    }
106}
107
108// ---------------------------------------------------------------------------
109// ABTestConfig
110// ---------------------------------------------------------------------------
111
112/// Metric to optimize / compare between variants.
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
114pub enum OptimizeMetric {
115    /// Higher is better (e.g., recall, NDCG).
116    Maximize,
117    /// Lower is better (e.g., MRR rank, latency).
118    Minimize,
119}
120
121/// Configuration for an A/B test.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ABTestConfig {
124    /// Fraction of requests routed to the treatment variant (0.0 – 1.0).
125    /// The remainder goes to the control variant.
126    pub traffic_split: f64,
127    /// Minimum number of observations per variant before analysis.
128    pub min_samples: usize,
129    /// Significance level α for hypothesis tests (default 0.05).
130    pub significance_level: f64,
131    /// Direction of optimization for the primary metric.
132    pub optimize: OptimizeMetric,
133    /// Random seed for deterministic traffic splitting.
134    pub seed: u64,
135    /// Optional maximum number of requests before the test is declared complete.
136    pub max_requests: Option<usize>,
137    /// Optional minimum detectable effect size (Cohen's d).
138    pub min_effect_size: Option<f64>,
139}
140
141impl Default for ABTestConfig {
142    fn default() -> Self {
143        Self {
144            traffic_split: 0.5,
145            min_samples: 50,
146            significance_level: 0.05,
147            optimize: OptimizeMetric::Maximize,
148            seed: 42,
149            max_requests: None,
150            min_effect_size: None,
151        }
152    }
153}
154
155// ---------------------------------------------------------------------------
156// Observation
157// ---------------------------------------------------------------------------
158
159/// A single metric observation from a model variant.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct Observation {
162    /// The variant that produced this observation.
163    pub variant_name: String,
164    /// Request key (entity IRI, query string, etc.).
165    pub key: String,
166    /// Observed metric value.
167    pub metric: f64,
168    /// Wall-clock time taken for inference (microseconds).
169    pub latency_us: u64,
170}
171
172// ---------------------------------------------------------------------------
173// ABTestRunner
174// ---------------------------------------------------------------------------
175
176/// Routes inference requests between two variants and records metrics.
177pub struct ABTestRunner {
178    config: ABTestConfig,
179    control: ModelVariant,
180    treatment: ModelVariant,
181    observations: Vec<Observation>,
182    total_requests: usize,
183    /// Simple LCG for deterministic traffic assignment.
184    lcg_state: u64,
185    /// Latency of the most recent route() call (microseconds).
186    last_latency_us: u64,
187}
188
189impl std::fmt::Debug for ABTestRunner {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        f.debug_struct("ABTestRunner")
192            .field("control", &self.control.name)
193            .field("treatment", &self.treatment.name)
194            .field("total_requests", &self.total_requests)
195            .field("observations", &self.observations.len())
196            .finish()
197    }
198}
199
200impl ABTestRunner {
201    /// Create a new A/B test runner.
202    pub fn new(
203        config: ABTestConfig,
204        control: ModelVariant,
205        treatment: ModelVariant,
206    ) -> Result<Self> {
207        if !(0.0..=1.0).contains(&config.traffic_split) {
208            return Err(anyhow!(
209                "traffic_split must be in [0, 1], got {}",
210                config.traffic_split
211            ));
212        }
213        if config.significance_level <= 0.0 || config.significance_level >= 1.0 {
214            return Err(anyhow!(
215                "significance_level must be in (0, 1), got {}",
216                config.significance_level
217            ));
218        }
219        let lcg_state = config.seed.wrapping_add(1);
220        Ok(Self {
221            config,
222            control,
223            treatment,
224            observations: Vec::new(),
225            total_requests: 0,
226            lcg_state,
227            last_latency_us: 0,
228        })
229    }
230
231    /// Route a single request to a variant, returning `(embedding, variant_name)`.
232    ///
233    /// The variant is chosen deterministically based on `key` and the LCG state,
234    /// ensuring the configured `traffic_split` is respected on average.
235    pub fn route(&mut self, key: &str) -> Result<(Vec<f64>, String)> {
236        if let Some(max) = self.config.max_requests {
237            if self.total_requests >= max {
238                return Err(anyhow!("A/B test has reached max_requests {}", max));
239            }
240        }
241        // Mix key hash with LCG for per-request randomness
242        let key_hash = fnv1a_hash(key);
243        self.lcg_state = self
244            .lcg_state
245            .wrapping_mul(6364136223846793005)
246            .wrapping_add(1442695040888963407)
247            .wrapping_add(key_hash);
248        let r = (self.lcg_state >> 11) as f64 / (1u64 << 53) as f64;
249
250        let use_treatment = r < self.config.traffic_split;
251        let variant = if use_treatment {
252            &self.treatment
253        } else {
254            &self.control
255        };
256        let start = std::time::Instant::now();
257        let embedding = variant.infer(key);
258        let latency_us = start.elapsed().as_micros() as u64;
259
260        self.last_latency_us = latency_us;
261        self.total_requests += 1;
262        Ok((embedding, variant.name.clone()))
263    }
264
265    /// Record a metric observation for a variant.
266    ///
267    /// Call this after routing to register the quality signal for the request.
268    pub fn record_metric(&mut self, variant_name: &str, metric: f64) -> Result<()> {
269        if !metric.is_finite() {
270            return Err(anyhow!("metric must be finite, got {}", metric));
271        }
272        // Attribute to the last routed request
273        let key = format!("req_{}", self.total_requests);
274        self.observations.push(Observation {
275            variant_name: variant_name.to_string(),
276            key,
277            metric,
278            latency_us: self.last_latency_us,
279        });
280        Ok(())
281    }
282
283    /// Record a full observation (metric + latency + key).
284    pub fn record_observation(&mut self, obs: Observation) -> Result<()> {
285        if !obs.metric.is_finite() {
286            return Err(anyhow!("Observation metric must be finite"));
287        }
288        self.observations.push(obs);
289        Ok(())
290    }
291
292    /// Analyze the collected observations and return a report.
293    pub fn analyze(&self) -> Result<ABTestReport> {
294        let ctrl_metrics: Vec<f64> = self
295            .observations
296            .iter()
297            .filter(|o| o.variant_name == self.control.name)
298            .map(|o| o.metric)
299            .collect();
300        let trt_metrics: Vec<f64> = self
301            .observations
302            .iter()
303            .filter(|o| o.variant_name == self.treatment.name)
304            .map(|o| o.metric)
305            .collect();
306
307        if ctrl_metrics.len() < self.config.min_samples {
308            return Err(anyhow!(
309                "Not enough control observations: {} < {}",
310                ctrl_metrics.len(),
311                self.config.min_samples
312            ));
313        }
314        if trt_metrics.len() < self.config.min_samples {
315            return Err(anyhow!(
316                "Not enough treatment observations: {} < {}",
317                trt_metrics.len(),
318                self.config.min_samples
319            ));
320        }
321
322        let analyzer = ABTestAnalyzer::new(&self.config);
323        analyzer.analyze(
324            &self.control.name,
325            &ctrl_metrics,
326            &self.treatment.name,
327            &trt_metrics,
328        )
329    }
330
331    /// Number of requests routed so far.
332    pub fn total_requests(&self) -> usize {
333        self.total_requests
334    }
335
336    /// All recorded observations.
337    pub fn observations(&self) -> &[Observation] {
338        &self.observations
339    }
340
341    /// Get per-variant summary statistics without a full report.
342    pub fn variant_stats(&self) -> HashMap<String, VariantStats> {
343        let mut map: HashMap<String, Vec<f64>> = HashMap::new();
344        for obs in &self.observations {
345            map.entry(obs.variant_name.clone())
346                .or_default()
347                .push(obs.metric);
348        }
349        map.into_iter()
350            .map(|(name, metrics)| {
351                let stats = VariantStats::from_slice(&metrics);
352                (name, stats)
353            })
354            .collect()
355    }
356}
357
358// ---------------------------------------------------------------------------
359// ABTestAnalyzer
360// ---------------------------------------------------------------------------
361
362/// Statistical significance testing for A/B experiment results.
363pub struct ABTestAnalyzer<'a> {
364    config: &'a ABTestConfig,
365}
366
367impl<'a> ABTestAnalyzer<'a> {
368    /// Create with the test configuration.
369    pub fn new(config: &'a ABTestConfig) -> Self {
370        Self { config }
371    }
372
373    /// Run Welch's t-test and Mann-Whitney U test, then build a report.
374    pub fn analyze(
375        &self,
376        control_name: &str,
377        control_metrics: &[f64],
378        treatment_name: &str,
379        treatment_metrics: &[f64],
380    ) -> Result<ABTestReport> {
381        if control_metrics.is_empty() || treatment_metrics.is_empty() {
382            return Err(anyhow!("Both metric slices must be non-empty"));
383        }
384
385        let ctrl_stats = VariantStats::from_slice(control_metrics);
386        let trt_stats = VariantStats::from_slice(treatment_metrics);
387
388        let ttest_result = self.welchs_ttest(control_metrics, treatment_metrics)?;
389        let mwu_result = self.mann_whitney_u(control_metrics, treatment_metrics)?;
390        let cohens_d = self.cohens_d(control_metrics, treatment_metrics);
391
392        let significant = ttest_result.p_value < self.config.significance_level
393            && mwu_result.p_value < self.config.significance_level;
394
395        // Determine winner
396        let winner = if !significant {
397            Winner::NoSignificantDifference
398        } else {
399            let ctrl_better = ctrl_stats.mean > trt_stats.mean;
400            match self.config.optimize {
401                OptimizeMetric::Maximize => {
402                    if ctrl_better {
403                        Winner::Control(control_name.to_string())
404                    } else {
405                        Winner::Treatment(treatment_name.to_string())
406                    }
407                }
408                OptimizeMetric::Minimize => {
409                    if ctrl_better {
410                        Winner::Treatment(treatment_name.to_string())
411                    } else {
412                        Winner::Control(control_name.to_string())
413                    }
414                }
415            }
416        };
417
418        Ok(ABTestReport {
419            control_name: control_name.to_string(),
420            treatment_name: treatment_name.to_string(),
421            control_stats: ctrl_stats,
422            treatment_stats: trt_stats,
423            ttest: ttest_result,
424            mann_whitney: mwu_result,
425            cohens_d,
426            significant,
427            winner,
428            significance_level: self.config.significance_level,
429        })
430    }
431
432    /// Welch's t-test (unequal variance, two-sided).
433    ///
434    /// Returns t-statistic, degrees of freedom (Welch-Satterthwaite), and
435    /// a p-value approximated via Student's t CDF.
436    pub fn welchs_ttest(&self, a: &[f64], b: &[f64]) -> Result<TTestResult> {
437        if a.len() < 2 || b.len() < 2 {
438            return Err(anyhow!(
439                "Welch's t-test requires >= 2 observations per group"
440            ));
441        }
442        let na = a.len() as f64;
443        let nb = b.len() as f64;
444        let mean_a = mean(a);
445        let mean_b = mean(b);
446        let var_a = variance(a, mean_a);
447        let var_b = variance(b, mean_b);
448
449        if var_a < 1e-15 && var_b < 1e-15 {
450            // Both groups are constant
451            let t = if (mean_a - mean_b).abs() < 1e-12 {
452                0.0
453            } else {
454                f64::INFINITY
455            };
456            return Ok(TTestResult {
457                t_statistic: t,
458                degrees_of_freedom: 0.0,
459                p_value: if t == 0.0 { 1.0 } else { 0.0 },
460                mean_diff: mean_a - mean_b,
461            });
462        }
463
464        let se = (var_a / na + var_b / nb).sqrt();
465        if se < 1e-15 {
466            return Err(anyhow!("Standard error too small for t-test"));
467        }
468        let t = (mean_a - mean_b) / se;
469
470        // Welch-Satterthwaite degrees of freedom
471        let df_num = (var_a / na + var_b / nb).powi(2);
472        let df_den = (var_a / na).powi(2) / (na - 1.0) + (var_b / nb).powi(2) / (nb - 1.0);
473        let df = if df_den < 1e-15 { 1.0 } else { df_num / df_den };
474
475        // Two-sided p-value via t distribution CDF approximation
476        let p_value = t_distribution_two_sided_p(t.abs(), df);
477
478        Ok(TTestResult {
479            t_statistic: t,
480            degrees_of_freedom: df,
481            p_value,
482            mean_diff: mean_a - mean_b,
483        })
484    }
485
486    /// Mann-Whitney U test (Wilcoxon rank-sum, two-sided).
487    ///
488    /// Suitable for non-normally distributed metrics.
489    /// Uses the normal approximation for p-value when n > 20.
490    pub fn mann_whitney_u(&self, a: &[f64], b: &[f64]) -> Result<MannWhitneyResult> {
491        if a.is_empty() || b.is_empty() {
492            return Err(anyhow!("Mann-Whitney U requires non-empty groups"));
493        }
494        let na = a.len() as f64;
495        let nb = b.len() as f64;
496
497        // Rank all observations pooled
498        let mut combined: Vec<(f64, u8)> = a
499            .iter()
500            .map(|&v| (v, 0u8))
501            .chain(b.iter().map(|&v| (v, 1u8)))
502            .collect();
503        combined.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap_or(std::cmp::Ordering::Equal));
504
505        // Assign average ranks (handle ties)
506        let n_total = combined.len();
507        let mut ranks = vec![0.0f64; n_total];
508        let mut i = 0;
509        while i < n_total {
510            let mut j = i;
511            while j < n_total && (combined[j].0 - combined[i].0).abs() < 1e-12 {
512                j += 1;
513            }
514            // Average rank for tied group: 1-based
515            let avg_rank = (i + j + 1) as f64 / 2.0; // 1-based average
516            for rank in ranks[i..j].iter_mut() {
517                *rank = avg_rank;
518            }
519            i = j;
520        }
521
522        // Sum of ranks for group a
523        let rank_sum_a: f64 = combined
524            .iter()
525            .zip(ranks.iter())
526            .filter(|(obs, _)| obs.1 == 0)
527            .map(|(_, &r)| r)
528            .sum();
529
530        let u_a = rank_sum_a - na * (na + 1.0) / 2.0;
531        let u_b = na * nb - u_a;
532        let u = u_a.min(u_b);
533
534        // Normal approximation
535        let mu_u = na * nb / 2.0;
536        let sigma_u = ((na * nb * (na + nb + 1.0)) / 12.0).sqrt();
537        let z = if sigma_u < 1e-12 {
538            0.0
539        } else {
540            (u - mu_u) / sigma_u
541        };
542        let p_value = 2.0 * standard_normal_sf(z.abs());
543
544        Ok(MannWhitneyResult {
545            u_statistic: u,
546            z_score: z,
547            p_value: p_value.clamp(0.0, 1.0),
548            rank_sum_a,
549        })
550    }
551
552    /// Cohen's d effect size: (mean_a - mean_b) / pooled_std.
553    pub fn cohens_d(&self, a: &[f64], b: &[f64]) -> f64 {
554        if a.len() < 2 || b.len() < 2 {
555            return 0.0;
556        }
557        let mean_a = mean(a);
558        let mean_b = mean(b);
559        let var_a = variance(a, mean_a);
560        let var_b = variance(b, mean_b);
561        let na = a.len() as f64;
562        let nb = b.len() as f64;
563        // Pooled standard deviation (Hedges' g denominator)
564        let pooled_std = (((na - 1.0) * var_a + (nb - 1.0) * var_b) / (na + nb - 2.0)).sqrt();
565        if pooled_std < 1e-15 {
566            return 0.0;
567        }
568        (mean_a - mean_b) / pooled_std
569    }
570}
571
572// ---------------------------------------------------------------------------
573// Statistical test results
574// ---------------------------------------------------------------------------
575
576/// Result of Welch's two-sample t-test.
577#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct TTestResult {
579    pub t_statistic: f64,
580    pub degrees_of_freedom: f64,
581    /// Two-sided p-value.
582    pub p_value: f64,
583    /// `mean(a) - mean(b)`.
584    pub mean_diff: f64,
585}
586
587/// Result of Mann-Whitney U test.
588#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct MannWhitneyResult {
590    pub u_statistic: f64,
591    pub z_score: f64,
592    /// Two-sided p-value (normal approximation).
593    pub p_value: f64,
594    pub rank_sum_a: f64,
595}
596
597// ---------------------------------------------------------------------------
598// VariantStats
599// ---------------------------------------------------------------------------
600
601/// Summary statistics for one variant's metric observations.
602#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct VariantStats {
604    pub n: usize,
605    pub mean: f64,
606    pub std_dev: f64,
607    pub min: f64,
608    pub max: f64,
609    /// 25th percentile (linear interpolation).
610    pub p25: f64,
611    /// 50th percentile (median).
612    pub p50: f64,
613    /// 75th percentile.
614    pub p75: f64,
615    /// 95th percentile.
616    pub p95: f64,
617}
618
619impl VariantStats {
620    /// Compute from a slice of observations.
621    pub fn from_slice(data: &[f64]) -> Self {
622        if data.is_empty() {
623            return Self {
624                n: 0,
625                mean: f64::NAN,
626                std_dev: f64::NAN,
627                min: f64::NAN,
628                max: f64::NAN,
629                p25: f64::NAN,
630                p50: f64::NAN,
631                p75: f64::NAN,
632                p95: f64::NAN,
633            };
634        }
635        let n = data.len();
636        let m = mean(data);
637        let var = variance(data, m);
638        let mut sorted = data.to_vec();
639        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
640
641        Self {
642            n,
643            mean: m,
644            std_dev: var.sqrt(),
645            min: sorted[0],
646            max: sorted[n - 1],
647            p25: percentile_sorted(&sorted, 25.0),
648            p50: percentile_sorted(&sorted, 50.0),
649            p75: percentile_sorted(&sorted, 75.0),
650            p95: percentile_sorted(&sorted, 95.0),
651        }
652    }
653}
654
655// ---------------------------------------------------------------------------
656// ABTestReport
657// ---------------------------------------------------------------------------
658
659/// Which variant won the A/B test.
660#[derive(Debug, Clone, Serialize, Deserialize)]
661pub enum Winner {
662    Control(String),
663    Treatment(String),
664    NoSignificantDifference,
665}
666
667/// Full A/B test analysis report.
668#[derive(Debug, Clone, Serialize, Deserialize)]
669pub struct ABTestReport {
670    pub control_name: String,
671    pub treatment_name: String,
672    pub control_stats: VariantStats,
673    pub treatment_stats: VariantStats,
674    pub ttest: TTestResult,
675    pub mann_whitney: MannWhitneyResult,
676    /// Cohen's d effect size.
677    pub cohens_d: f64,
678    /// True iff both tests are below `significance_level`.
679    pub significant: bool,
680    pub winner: Winner,
681    pub significance_level: f64,
682}
683
684impl ABTestReport {
685    /// Render a human-readable summary.
686    pub fn summary(&self) -> String {
687        let ctrl = &self.control_stats;
688        let trt = &self.treatment_stats;
689        let mut lines = Vec::new();
690        lines.push("=== A/B Test Report ===".to_string());
691        lines.push(format!(
692            "Control   ({:>20}): n={:4} mean={:.4} std={:.4} p50={:.4}",
693            self.control_name, ctrl.n, ctrl.mean, ctrl.std_dev, ctrl.p50
694        ));
695        lines.push(format!(
696            "Treatment ({:>20}): n={:4} mean={:.4} std={:.4} p50={:.4}",
697            self.treatment_name, trt.n, trt.mean, trt.std_dev, trt.p50
698        ));
699        lines.push(format!(
700            "Welch's t-test:  t={:.4}  df={:.1}  p={:.4}",
701            self.ttest.t_statistic, self.ttest.degrees_of_freedom, self.ttest.p_value
702        ));
703        lines.push(format!(
704            "Mann-Whitney U:  U={:.1}  z={:.4}  p={:.4}",
705            self.mann_whitney.u_statistic, self.mann_whitney.z_score, self.mann_whitney.p_value
706        ));
707        lines.push(format!("Cohen's d: {:.4}", self.cohens_d));
708        lines.push(format!(
709            "Significant (α={}): {}",
710            self.significance_level, self.significant
711        ));
712        lines.push(match &self.winner {
713            Winner::Control(n) => format!("Winner: CONTROL ({n})"),
714            Winner::Treatment(n) => format!("Winner: TREATMENT ({n})"),
715            Winner::NoSignificantDifference => "Winner: No significant difference".to_string(),
716        });
717        lines.join("\n")
718    }
719
720    /// Return true if the treatment is the statistically significant winner.
721    pub fn treatment_wins(&self) -> bool {
722        matches!(&self.winner, Winner::Treatment(_))
723    }
724
725    /// Return true if the control is the statistically significant winner.
726    pub fn control_wins(&self) -> bool {
727        matches!(&self.winner, Winner::Control(_))
728    }
729
730    /// Relative improvement of treatment over control: `(trt - ctrl) / |ctrl|`.
731    pub fn relative_improvement(&self) -> f64 {
732        let ctrl_mean = self.control_stats.mean;
733        let trt_mean = self.treatment_stats.mean;
734        if ctrl_mean.abs() < 1e-12 {
735            return 0.0;
736        }
737        (trt_mean - ctrl_mean) / ctrl_mean.abs()
738    }
739}
740
741// ---------------------------------------------------------------------------
742// Statistical utilities
743// ---------------------------------------------------------------------------
744
745fn mean(data: &[f64]) -> f64 {
746    if data.is_empty() {
747        return 0.0;
748    }
749    data.iter().sum::<f64>() / data.len() as f64
750}
751
752fn variance(data: &[f64], m: f64) -> f64 {
753    if data.len() < 2 {
754        return 0.0;
755    }
756    let sum_sq: f64 = data.iter().map(|&x| (x - m).powi(2)).sum();
757    sum_sq / (data.len() - 1) as f64
758}
759
760fn percentile_sorted(sorted: &[f64], p: f64) -> f64 {
761    if sorted.is_empty() {
762        return f64::NAN;
763    }
764    let n = sorted.len();
765    if n == 1 {
766        return sorted[0];
767    }
768    let rank = p / 100.0 * (n - 1) as f64;
769    let lo = rank.floor() as usize;
770    let hi = (lo + 1).min(n - 1);
771    let frac = rank - lo as f64;
772    sorted[lo] + frac * (sorted[hi] - sorted[lo])
773}
774
775/// Two-sided p-value for t distribution (approximation via regularized incomplete beta).
776///
777/// Uses Abramowitz & Stegun approximation for the t-distribution CDF.
778fn t_distribution_two_sided_p(t_abs: f64, df: f64) -> f64 {
779    if df <= 0.0 {
780        return 1.0;
781    }
782    // Use normal approximation for large df
783    if df > 200.0 {
784        return 2.0 * standard_normal_sf(t_abs);
785    }
786    // Regularized incomplete beta function I_x(a, b) where x = df/(df+t^2)
787    let x = df / (df + t_abs * t_abs);
788    let a = df / 2.0;
789    let b = 0.5f64;
790    let ibeta = regularized_incomplete_beta(x, a, b);
791    ibeta.clamp(0.0, 1.0)
792}
793
794/// Regularized incomplete beta function via continued fraction (Lentz's method).
795fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
796    if x <= 0.0 {
797        return 0.0;
798    }
799    if x >= 1.0 {
800        return 1.0;
801    }
802    // Use symmetry if x > (a+1)/(a+b+2)
803    let switch = (a + 1.0) / (a + b + 2.0);
804    if x > switch {
805        return 1.0 - regularized_incomplete_beta(1.0 - x, b, a);
806    }
807    // Front factor
808    let ln_front = a * x.ln() + b * (1.0 - x).ln() - ln_beta(a, b);
809    let front = ln_front.exp();
810    // Continued fraction
811    let cf = beta_continued_fraction(x, a, b);
812    (front * cf / a).clamp(0.0, 1.0)
813}
814
815fn beta_continued_fraction(x: f64, a: f64, b: f64) -> f64 {
816    // Lentz's algorithm
817    let max_iter = 200;
818    let eps = 1e-14;
819    let tiny = 1e-300;
820    let mut f = tiny;
821    let mut c = f;
822    let mut d = 1.0 - (a + b) * x / (a + 1.0);
823    if d.abs() < tiny {
824        d = tiny;
825    }
826    d = 1.0 / d;
827    f = d;
828    for m in 1..=max_iter {
829        let m_f = m as f64;
830        // Even step
831        let aa = m_f * (b - m_f) * x / ((a + 2.0 * m_f - 1.0) * (a + 2.0 * m_f));
832        d = 1.0 + aa * d;
833        if d.abs() < tiny {
834            d = tiny;
835        }
836        c = 1.0 + aa / c;
837        if c.abs() < tiny {
838            c = tiny;
839        }
840        d = 1.0 / d;
841        f *= d * c;
842        // Odd step
843        let aa = -(a + m_f) * (a + b + m_f) * x / ((a + 2.0 * m_f) * (a + 2.0 * m_f + 1.0));
844        d = 1.0 + aa * d;
845        if d.abs() < tiny {
846            d = tiny;
847        }
848        c = 1.0 + aa / c;
849        if c.abs() < tiny {
850            c = tiny;
851        }
852        d = 1.0 / d;
853        let del = d * c;
854        f *= del;
855        if (del - 1.0).abs() < eps {
856            break;
857        }
858    }
859    f
860}
861
862/// Natural log of the Beta function via lgamma.
863fn ln_beta(a: f64, b: f64) -> f64 {
864    lgamma(a) + lgamma(b) - lgamma(a + b)
865}
866
867/// Stirling approximation for log-gamma (accurate for a > 0.5).
868fn lgamma(a: f64) -> f64 {
869    // Lanczos approximation coefficients (g=7, n=9)
870    const G: f64 = 7.0;
871    const C: [f64; 9] = [
872        0.999_999_999_999_809_9,
873        676.520_368_121_885_1,
874        -1259.139216722403,
875        771.323_428_777_653,
876        -176.615_029_162_141_9,
877        12.507343278686905,
878        -0.138571095265720,
879        9.984369578019572e-6,
880        1.505632735149312e-7,
881    ];
882    if a < 0.5 {
883        std::f64::consts::PI.ln() - (std::f64::consts::PI * a).sin().abs().ln() - lgamma(1.0 - a)
884    } else {
885        let x = a - 1.0;
886        let t = x + G + 0.5;
887        let ser: f64 = C[0]
888            + C[1..]
889                .iter()
890                .enumerate()
891                .map(|(i, &c)| c / (x + i as f64 + 1.0))
892                .sum::<f64>();
893        (2.0 * std::f64::consts::PI).sqrt().ln() + ser.abs().ln() + (x + 0.5) * t.ln() - t
894    }
895}
896
897/// Survival function of standard normal: P(Z > z).
898fn standard_normal_sf(z: f64) -> f64 {
899    0.5 * erfc(z / std::f64::consts::SQRT_2)
900}
901
902/// Complementary error function (erfc) via Horner's method approximation.
903fn erfc(x: f64) -> f64 {
904    // Abramowitz & Stegun 7.1.26 rational approximation
905    if x < 0.0 {
906        return 2.0 - erfc(-x);
907    }
908    let t = 1.0 / (1.0 + 0.3275911 * x);
909    let poly = t
910        * (0.254829592
911            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
912    poly * (-x * x).exp()
913}
914
915/// FNV-1a hash for deterministic traffic assignment.
916fn fnv1a_hash(s: &str) -> u64 {
917    let mut h: u64 = 14695981039346656037;
918    for b in s.bytes() {
919        h ^= b as u64;
920        h = h.wrapping_mul(1099511628211);
921    }
922    h
923}
924
925// ---------------------------------------------------------------------------
926// Tests
927// ---------------------------------------------------------------------------
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932
933    fn make_runner(split: f64) -> ABTestRunner {
934        let control = ModelVariant::new("control", |_| vec![0.0f64; 4]);
935        let treatment = ModelVariant::new("treatment", |_| vec![1.0f64; 4]);
936        let config = ABTestConfig {
937            traffic_split: split,
938            min_samples: 5,
939            ..Default::default()
940        };
941        ABTestRunner::new(config, control, treatment).expect("runner should construct")
942    }
943
944    #[test]
945    fn test_model_variant_infer() {
946        let v = ModelVariant::new("test", |key| vec![key.len() as f64]);
947        let result = v.infer("hello");
948        assert_eq!(result, vec![5.0]);
949    }
950
951    #[test]
952    fn test_model_variant_metadata() {
953        let v = ModelVariant::new("sage-v1", |_| vec![])
954            .with_description("GraphSAGE v1")
955            .with_version("1.2.3");
956        assert_eq!(v.name, "sage-v1");
957        assert_eq!(v.description, "GraphSAGE v1");
958        assert_eq!(v.version, "1.2.3");
959    }
960
961    #[test]
962    fn test_abtest_config_default() {
963        let cfg = ABTestConfig::default();
964        assert!((cfg.traffic_split - 0.5).abs() < 1e-10);
965        assert_eq!(cfg.min_samples, 50);
966        assert!((cfg.significance_level - 0.05).abs() < 1e-10);
967    }
968
969    #[test]
970    fn test_runner_construction_invalid_split() {
971        let ctrl = ModelVariant::new("c", |_| vec![]);
972        let trt = ModelVariant::new("t", |_| vec![]);
973        let cfg = ABTestConfig {
974            traffic_split: 1.5,
975            ..Default::default()
976        };
977        assert!(ABTestRunner::new(cfg, ctrl, trt).is_err());
978    }
979
980    #[test]
981    fn test_runner_route() {
982        let mut runner = make_runner(0.5);
983        for i in 0..20 {
984            let key = format!("entity:{i}");
985            let (emb, variant) = runner.route(&key).expect("route should succeed");
986            assert!(!emb.is_empty());
987            assert!(variant == "control" || variant == "treatment");
988        }
989        assert_eq!(runner.total_requests(), 20);
990    }
991
992    #[test]
993    fn test_runner_traffic_split_deterministic() {
994        // Same seed => same routing sequence
995        let mut r1 = make_runner(0.3);
996        let mut r2 = make_runner(0.3);
997        for i in 0..50 {
998            let key = format!("k{i}");
999            let (_, v1) = r1.route(&key).expect("route 1 ok");
1000            let (_, v2) = r2.route(&key).expect("route 2 ok");
1001            assert_eq!(v1, v2, "routing should be deterministic");
1002        }
1003    }
1004
1005    #[test]
1006    fn test_runner_record_metric_invalid() {
1007        let mut runner = make_runner(0.5);
1008        assert!(runner.record_metric("control", f64::NAN).is_err());
1009        assert!(runner.record_metric("control", f64::INFINITY).is_err());
1010    }
1011
1012    #[test]
1013    fn test_runner_record_and_stats() {
1014        let mut runner = make_runner(0.5);
1015        for i in 0..20 {
1016            let key = format!("e:{i}");
1017            let (_, variant) = runner.route(&key).expect("route ok");
1018            runner
1019                .record_metric(&variant, (i as f64) * 0.1)
1020                .expect("record ok");
1021        }
1022        let stats = runner.variant_stats();
1023        assert!(!stats.is_empty());
1024        for s in stats.values() {
1025            assert!(s.n > 0);
1026            assert!(s.mean.is_finite());
1027        }
1028    }
1029
1030    #[test]
1031    fn test_welchs_ttest_identical_groups() {
1032        let cfg = ABTestConfig::default();
1033        let analyzer = ABTestAnalyzer::new(&cfg);
1034        let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
1035        let result = analyzer
1036            .welchs_ttest(&data, &data)
1037            .expect("t-test should succeed");
1038        assert!(
1039            (result.t_statistic).abs() < 1e-10,
1040            "t should be 0 for identical groups"
1041        );
1042        assert!(
1043            (result.p_value - 1.0).abs() < 0.01,
1044            "p should be ~1 for identical groups, got {}",
1045            result.p_value
1046        );
1047        assert!((result.mean_diff).abs() < 1e-10);
1048    }
1049
1050    #[test]
1051    fn test_welchs_ttest_clearly_different() {
1052        let cfg = ABTestConfig::default();
1053        let analyzer = ABTestAnalyzer::new(&cfg);
1054        let a: Vec<f64> = (0..50).map(|_| 0.0).collect();
1055        let b: Vec<f64> = (0..50).map(|_| 100.0).collect();
1056        let result = analyzer
1057            .welchs_ttest(&a, &b)
1058            .expect("t-test should succeed");
1059        // Very different means => very significant
1060        assert!(
1061            result.p_value < 0.001,
1062            "p-value should be very small, got {}",
1063            result.p_value
1064        );
1065    }
1066
1067    #[test]
1068    fn test_mann_whitney_identical() {
1069        let cfg = ABTestConfig::default();
1070        let analyzer = ABTestAnalyzer::new(&cfg);
1071        let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
1072        let result = analyzer
1073            .mann_whitney_u(&data, &data)
1074            .expect("MWU should succeed");
1075        // Identical groups: p should be high
1076        assert!(
1077            result.p_value > 0.3,
1078            "p-value for identical groups should be high, got {}",
1079            result.p_value
1080        );
1081    }
1082
1083    #[test]
1084    fn test_mann_whitney_clearly_different() {
1085        let cfg = ABTestConfig::default();
1086        let analyzer = ABTestAnalyzer::new(&cfg);
1087        let a: Vec<f64> = (0..40).map(|_| 0.0).collect();
1088        let b: Vec<f64> = (0..40).map(|_| 10.0).collect();
1089        let result = analyzer.mann_whitney_u(&a, &b).expect("MWU should succeed");
1090        assert!(
1091            result.p_value < 0.001,
1092            "p-value for clearly different groups should be small, got {}",
1093            result.p_value
1094        );
1095    }
1096
1097    #[test]
1098    fn test_cohens_d_no_difference() {
1099        let cfg = ABTestConfig::default();
1100        let analyzer = ABTestAnalyzer::new(&cfg);
1101        let data: Vec<f64> = (0..20).map(|i| i as f64).collect();
1102        let d = analyzer.cohens_d(&data, &data);
1103        assert!(
1104            (d).abs() < 1e-10,
1105            "Cohen's d should be 0 for identical groups"
1106        );
1107    }
1108
1109    #[test]
1110    fn test_cohens_d_large_effect() {
1111        let cfg = ABTestConfig::default();
1112        let analyzer = ABTestAnalyzer::new(&cfg);
1113        let a: Vec<f64> = vec![0.0f64; 30];
1114        let b: Vec<f64> = vec![10.0f64; 30];
1115        let d = analyzer.cohens_d(&a, &b);
1116        // Both have std=0, should handle gracefully
1117        assert!(d.is_finite() || d == 0.0);
1118    }
1119
1120    #[test]
1121    fn test_full_ab_test_workflow() {
1122        let control = ModelVariant::new("baseline", |_| vec![0.0f64; 4]);
1123        let treatment = ModelVariant::new("improved", |_| vec![1.0f64; 4]);
1124        let config = ABTestConfig {
1125            traffic_split: 0.5,
1126            min_samples: 10,
1127            significance_level: 0.05,
1128            optimize: OptimizeMetric::Maximize,
1129            seed: 99,
1130            max_requests: None,
1131            min_effect_size: None,
1132        };
1133        let mut runner =
1134            ABTestRunner::new(config, control, treatment).expect("runner should construct");
1135
1136        // Route requests and record metrics (treatment gets higher scores)
1137        let mut rng_state: u64 = 12345;
1138        for i in 0..100 {
1139            let key = format!("entity:{i}");
1140            let (_, variant) = runner.route(&key).expect("route ok");
1141            // Simple pseudo-random metric: treatment gets +0.5
1142            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
1143            let base = (rng_state >> 32) as f64 / u32::MAX as f64;
1144            let metric = if variant == "improved" {
1145                base * 0.3 + 0.7
1146            } else {
1147                base * 0.3 + 0.2
1148            };
1149            runner.record_metric(&variant, metric).expect("record ok");
1150        }
1151
1152        let report = runner.analyze().expect("analysis should succeed");
1153        assert!(report.control_stats.n >= 10);
1154        assert!(report.treatment_stats.n >= 10);
1155        // Treatment has higher scores, so it should win (or no significant diff)
1156        let wins = report.treatment_wins() || !report.significant;
1157        assert!(wins, "treatment should win or no significant difference");
1158
1159        let summary = report.summary();
1160        assert!(summary.contains("A/B Test Report"));
1161    }
1162
1163    #[test]
1164    fn test_ab_test_max_requests() {
1165        let ctrl = ModelVariant::new("c", |_| vec![1.0]);
1166        let trt = ModelVariant::new("t", |_| vec![2.0]);
1167        let cfg = ABTestConfig {
1168            max_requests: Some(5),
1169            ..Default::default()
1170        };
1171        let mut runner = ABTestRunner::new(cfg, ctrl, trt).expect("runner ok");
1172        for i in 0..5 {
1173            runner.route(&format!("k{i}")).expect("route ok");
1174        }
1175        let err = runner.route("k5");
1176        assert!(err.is_err(), "should error after max_requests");
1177    }
1178
1179    #[test]
1180    fn test_variant_stats_empty() {
1181        let stats = VariantStats::from_slice(&[]);
1182        assert_eq!(stats.n, 0);
1183        assert!(stats.mean.is_nan());
1184    }
1185
1186    #[test]
1187    fn test_variant_stats_single() {
1188        let stats = VariantStats::from_slice(&[42.0]);
1189        assert_eq!(stats.n, 1);
1190        assert_eq!(stats.mean, 42.0);
1191        assert_eq!(stats.min, 42.0);
1192        assert_eq!(stats.max, 42.0);
1193    }
1194
1195    #[test]
1196    fn test_variant_stats_known_values() {
1197        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1198        let stats = VariantStats::from_slice(&data);
1199        assert_eq!(stats.n, 5);
1200        assert!((stats.mean - 3.0).abs() < 1e-10);
1201        assert_eq!(stats.min, 1.0);
1202        assert_eq!(stats.max, 5.0);
1203        assert!((stats.p50 - 3.0).abs() < 1e-10);
1204    }
1205
1206    #[test]
1207    fn test_report_relative_improvement() {
1208        let ctrl = ModelVariant::new("c", |_| vec![0.5f64]);
1209        let trt = ModelVariant::new("t", |_| vec![0.6f64]);
1210        let cfg = ABTestConfig {
1211            min_samples: 5,
1212            ..Default::default()
1213        };
1214        let mut runner = ABTestRunner::new(cfg, ctrl, trt).expect("runner ok");
1215        for i in 0..30 {
1216            let (_, v) = runner.route(&format!("k{i}")).expect("route ok");
1217            let metric = if v == "c" { 0.5 } else { 0.6 };
1218            runner.record_metric(&v, metric).expect("record ok");
1219        }
1220        let report = runner.analyze().expect("analyze ok");
1221        let ri = report.relative_improvement();
1222        assert!(ri.is_finite());
1223    }
1224}