1use anyhow::{anyhow, Result};
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45
46pub struct ModelVariant {
55 pub name: String,
57 pub description: String,
59 pub version: String,
61 #[allow(clippy::type_complexity)]
63 infer: Box<dyn Fn(&str) -> Vec<f64> + Send + Sync>,
64}
65
66impl std::fmt::Debug for ModelVariant {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("ModelVariant")
69 .field("name", &self.name)
70 .field("description", &self.description)
71 .field("version", &self.version)
72 .finish()
73 }
74}
75
76impl ModelVariant {
77 pub fn new<F>(name: impl Into<String>, infer: F) -> Self
79 where
80 F: Fn(&str) -> Vec<f64> + Send + Sync + 'static,
81 {
82 Self {
83 name: name.into(),
84 description: String::new(),
85 version: String::from("0.1.0"),
86 infer: Box::new(infer),
87 }
88 }
89
90 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
92 self.description = desc.into();
93 self
94 }
95
96 pub fn with_version(mut self, version: impl Into<String>) -> Self {
98 self.version = version.into();
99 self
100 }
101
102 pub fn infer(&self, key: &str) -> Vec<f64> {
104 (self.infer)(key)
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
114pub enum OptimizeMetric {
115 Maximize,
117 Minimize,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ABTestConfig {
124 pub traffic_split: f64,
127 pub min_samples: usize,
129 pub significance_level: f64,
131 pub optimize: OptimizeMetric,
133 pub seed: u64,
135 pub max_requests: Option<usize>,
137 pub min_effect_size: Option<f64>,
139}
140
141impl Default for ABTestConfig {
142 fn default() -> Self {
143 Self {
144 traffic_split: 0.5,
145 min_samples: 50,
146 significance_level: 0.05,
147 optimize: OptimizeMetric::Maximize,
148 seed: 42,
149 max_requests: None,
150 min_effect_size: None,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct Observation {
162 pub variant_name: String,
164 pub key: String,
166 pub metric: f64,
168 pub latency_us: u64,
170}
171
172pub struct ABTestRunner {
178 config: ABTestConfig,
179 control: ModelVariant,
180 treatment: ModelVariant,
181 observations: Vec<Observation>,
182 total_requests: usize,
183 lcg_state: u64,
185 last_latency_us: u64,
187}
188
189impl std::fmt::Debug for ABTestRunner {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.debug_struct("ABTestRunner")
192 .field("control", &self.control.name)
193 .field("treatment", &self.treatment.name)
194 .field("total_requests", &self.total_requests)
195 .field("observations", &self.observations.len())
196 .finish()
197 }
198}
199
200impl ABTestRunner {
201 pub fn new(
203 config: ABTestConfig,
204 control: ModelVariant,
205 treatment: ModelVariant,
206 ) -> Result<Self> {
207 if !(0.0..=1.0).contains(&config.traffic_split) {
208 return Err(anyhow!(
209 "traffic_split must be in [0, 1], got {}",
210 config.traffic_split
211 ));
212 }
213 if config.significance_level <= 0.0 || config.significance_level >= 1.0 {
214 return Err(anyhow!(
215 "significance_level must be in (0, 1), got {}",
216 config.significance_level
217 ));
218 }
219 let lcg_state = config.seed.wrapping_add(1);
220 Ok(Self {
221 config,
222 control,
223 treatment,
224 observations: Vec::new(),
225 total_requests: 0,
226 lcg_state,
227 last_latency_us: 0,
228 })
229 }
230
231 pub fn route(&mut self, key: &str) -> Result<(Vec<f64>, String)> {
236 if let Some(max) = self.config.max_requests {
237 if self.total_requests >= max {
238 return Err(anyhow!("A/B test has reached max_requests {}", max));
239 }
240 }
241 let key_hash = fnv1a_hash(key);
243 self.lcg_state = self
244 .lcg_state
245 .wrapping_mul(6364136223846793005)
246 .wrapping_add(1442695040888963407)
247 .wrapping_add(key_hash);
248 let r = (self.lcg_state >> 11) as f64 / (1u64 << 53) as f64;
249
250 let use_treatment = r < self.config.traffic_split;
251 let variant = if use_treatment {
252 &self.treatment
253 } else {
254 &self.control
255 };
256 let start = std::time::Instant::now();
257 let embedding = variant.infer(key);
258 let latency_us = start.elapsed().as_micros() as u64;
259
260 self.last_latency_us = latency_us;
261 self.total_requests += 1;
262 Ok((embedding, variant.name.clone()))
263 }
264
265 pub fn record_metric(&mut self, variant_name: &str, metric: f64) -> Result<()> {
269 if !metric.is_finite() {
270 return Err(anyhow!("metric must be finite, got {}", metric));
271 }
272 let key = format!("req_{}", self.total_requests);
274 self.observations.push(Observation {
275 variant_name: variant_name.to_string(),
276 key,
277 metric,
278 latency_us: self.last_latency_us,
279 });
280 Ok(())
281 }
282
283 pub fn record_observation(&mut self, obs: Observation) -> Result<()> {
285 if !obs.metric.is_finite() {
286 return Err(anyhow!("Observation metric must be finite"));
287 }
288 self.observations.push(obs);
289 Ok(())
290 }
291
292 pub fn analyze(&self) -> Result<ABTestReport> {
294 let ctrl_metrics: Vec<f64> = self
295 .observations
296 .iter()
297 .filter(|o| o.variant_name == self.control.name)
298 .map(|o| o.metric)
299 .collect();
300 let trt_metrics: Vec<f64> = self
301 .observations
302 .iter()
303 .filter(|o| o.variant_name == self.treatment.name)
304 .map(|o| o.metric)
305 .collect();
306
307 if ctrl_metrics.len() < self.config.min_samples {
308 return Err(anyhow!(
309 "Not enough control observations: {} < {}",
310 ctrl_metrics.len(),
311 self.config.min_samples
312 ));
313 }
314 if trt_metrics.len() < self.config.min_samples {
315 return Err(anyhow!(
316 "Not enough treatment observations: {} < {}",
317 trt_metrics.len(),
318 self.config.min_samples
319 ));
320 }
321
322 let analyzer = ABTestAnalyzer::new(&self.config);
323 analyzer.analyze(
324 &self.control.name,
325 &ctrl_metrics,
326 &self.treatment.name,
327 &trt_metrics,
328 )
329 }
330
331 pub fn total_requests(&self) -> usize {
333 self.total_requests
334 }
335
336 pub fn observations(&self) -> &[Observation] {
338 &self.observations
339 }
340
341 pub fn variant_stats(&self) -> HashMap<String, VariantStats> {
343 let mut map: HashMap<String, Vec<f64>> = HashMap::new();
344 for obs in &self.observations {
345 map.entry(obs.variant_name.clone())
346 .or_default()
347 .push(obs.metric);
348 }
349 map.into_iter()
350 .map(|(name, metrics)| {
351 let stats = VariantStats::from_slice(&metrics);
352 (name, stats)
353 })
354 .collect()
355 }
356}
357
358pub struct ABTestAnalyzer<'a> {
364 config: &'a ABTestConfig,
365}
366
367impl<'a> ABTestAnalyzer<'a> {
368 pub fn new(config: &'a ABTestConfig) -> Self {
370 Self { config }
371 }
372
373 pub fn analyze(
375 &self,
376 control_name: &str,
377 control_metrics: &[f64],
378 treatment_name: &str,
379 treatment_metrics: &[f64],
380 ) -> Result<ABTestReport> {
381 if control_metrics.is_empty() || treatment_metrics.is_empty() {
382 return Err(anyhow!("Both metric slices must be non-empty"));
383 }
384
385 let ctrl_stats = VariantStats::from_slice(control_metrics);
386 let trt_stats = VariantStats::from_slice(treatment_metrics);
387
388 let ttest_result = self.welchs_ttest(control_metrics, treatment_metrics)?;
389 let mwu_result = self.mann_whitney_u(control_metrics, treatment_metrics)?;
390 let cohens_d = self.cohens_d(control_metrics, treatment_metrics);
391
392 let significant = ttest_result.p_value < self.config.significance_level
393 && mwu_result.p_value < self.config.significance_level;
394
395 let winner = if !significant {
397 Winner::NoSignificantDifference
398 } else {
399 let ctrl_better = ctrl_stats.mean > trt_stats.mean;
400 match self.config.optimize {
401 OptimizeMetric::Maximize => {
402 if ctrl_better {
403 Winner::Control(control_name.to_string())
404 } else {
405 Winner::Treatment(treatment_name.to_string())
406 }
407 }
408 OptimizeMetric::Minimize => {
409 if ctrl_better {
410 Winner::Treatment(treatment_name.to_string())
411 } else {
412 Winner::Control(control_name.to_string())
413 }
414 }
415 }
416 };
417
418 Ok(ABTestReport {
419 control_name: control_name.to_string(),
420 treatment_name: treatment_name.to_string(),
421 control_stats: ctrl_stats,
422 treatment_stats: trt_stats,
423 ttest: ttest_result,
424 mann_whitney: mwu_result,
425 cohens_d,
426 significant,
427 winner,
428 significance_level: self.config.significance_level,
429 })
430 }
431
432 pub fn welchs_ttest(&self, a: &[f64], b: &[f64]) -> Result<TTestResult> {
437 if a.len() < 2 || b.len() < 2 {
438 return Err(anyhow!(
439 "Welch's t-test requires >= 2 observations per group"
440 ));
441 }
442 let na = a.len() as f64;
443 let nb = b.len() as f64;
444 let mean_a = mean(a);
445 let mean_b = mean(b);
446 let var_a = variance(a, mean_a);
447 let var_b = variance(b, mean_b);
448
449 if var_a < 1e-15 && var_b < 1e-15 {
450 let t = if (mean_a - mean_b).abs() < 1e-12 {
452 0.0
453 } else {
454 f64::INFINITY
455 };
456 return Ok(TTestResult {
457 t_statistic: t,
458 degrees_of_freedom: 0.0,
459 p_value: if t == 0.0 { 1.0 } else { 0.0 },
460 mean_diff: mean_a - mean_b,
461 });
462 }
463
464 let se = (var_a / na + var_b / nb).sqrt();
465 if se < 1e-15 {
466 return Err(anyhow!("Standard error too small for t-test"));
467 }
468 let t = (mean_a - mean_b) / se;
469
470 let df_num = (var_a / na + var_b / nb).powi(2);
472 let df_den = (var_a / na).powi(2) / (na - 1.0) + (var_b / nb).powi(2) / (nb - 1.0);
473 let df = if df_den < 1e-15 { 1.0 } else { df_num / df_den };
474
475 let p_value = t_distribution_two_sided_p(t.abs(), df);
477
478 Ok(TTestResult {
479 t_statistic: t,
480 degrees_of_freedom: df,
481 p_value,
482 mean_diff: mean_a - mean_b,
483 })
484 }
485
486 pub fn mann_whitney_u(&self, a: &[f64], b: &[f64]) -> Result<MannWhitneyResult> {
491 if a.is_empty() || b.is_empty() {
492 return Err(anyhow!("Mann-Whitney U requires non-empty groups"));
493 }
494 let na = a.len() as f64;
495 let nb = b.len() as f64;
496
497 let mut combined: Vec<(f64, u8)> = a
499 .iter()
500 .map(|&v| (v, 0u8))
501 .chain(b.iter().map(|&v| (v, 1u8)))
502 .collect();
503 combined.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap_or(std::cmp::Ordering::Equal));
504
505 let n_total = combined.len();
507 let mut ranks = vec![0.0f64; n_total];
508 let mut i = 0;
509 while i < n_total {
510 let mut j = i;
511 while j < n_total && (combined[j].0 - combined[i].0).abs() < 1e-12 {
512 j += 1;
513 }
514 let avg_rank = (i + j + 1) as f64 / 2.0; for rank in ranks[i..j].iter_mut() {
517 *rank = avg_rank;
518 }
519 i = j;
520 }
521
522 let rank_sum_a: f64 = combined
524 .iter()
525 .zip(ranks.iter())
526 .filter(|(obs, _)| obs.1 == 0)
527 .map(|(_, &r)| r)
528 .sum();
529
530 let u_a = rank_sum_a - na * (na + 1.0) / 2.0;
531 let u_b = na * nb - u_a;
532 let u = u_a.min(u_b);
533
534 let mu_u = na * nb / 2.0;
536 let sigma_u = ((na * nb * (na + nb + 1.0)) / 12.0).sqrt();
537 let z = if sigma_u < 1e-12 {
538 0.0
539 } else {
540 (u - mu_u) / sigma_u
541 };
542 let p_value = 2.0 * standard_normal_sf(z.abs());
543
544 Ok(MannWhitneyResult {
545 u_statistic: u,
546 z_score: z,
547 p_value: p_value.clamp(0.0, 1.0),
548 rank_sum_a,
549 })
550 }
551
552 pub fn cohens_d(&self, a: &[f64], b: &[f64]) -> f64 {
554 if a.len() < 2 || b.len() < 2 {
555 return 0.0;
556 }
557 let mean_a = mean(a);
558 let mean_b = mean(b);
559 let var_a = variance(a, mean_a);
560 let var_b = variance(b, mean_b);
561 let na = a.len() as f64;
562 let nb = b.len() as f64;
563 let pooled_std = (((na - 1.0) * var_a + (nb - 1.0) * var_b) / (na + nb - 2.0)).sqrt();
565 if pooled_std < 1e-15 {
566 return 0.0;
567 }
568 (mean_a - mean_b) / pooled_std
569 }
570}
571
572#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct TTestResult {
579 pub t_statistic: f64,
580 pub degrees_of_freedom: f64,
581 pub p_value: f64,
583 pub mean_diff: f64,
585}
586
587#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct MannWhitneyResult {
590 pub u_statistic: f64,
591 pub z_score: f64,
592 pub p_value: f64,
594 pub rank_sum_a: f64,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct VariantStats {
604 pub n: usize,
605 pub mean: f64,
606 pub std_dev: f64,
607 pub min: f64,
608 pub max: f64,
609 pub p25: f64,
611 pub p50: f64,
613 pub p75: f64,
615 pub p95: f64,
617}
618
619impl VariantStats {
620 pub fn from_slice(data: &[f64]) -> Self {
622 if data.is_empty() {
623 return Self {
624 n: 0,
625 mean: f64::NAN,
626 std_dev: f64::NAN,
627 min: f64::NAN,
628 max: f64::NAN,
629 p25: f64::NAN,
630 p50: f64::NAN,
631 p75: f64::NAN,
632 p95: f64::NAN,
633 };
634 }
635 let n = data.len();
636 let m = mean(data);
637 let var = variance(data, m);
638 let mut sorted = data.to_vec();
639 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
640
641 Self {
642 n,
643 mean: m,
644 std_dev: var.sqrt(),
645 min: sorted[0],
646 max: sorted[n - 1],
647 p25: percentile_sorted(&sorted, 25.0),
648 p50: percentile_sorted(&sorted, 50.0),
649 p75: percentile_sorted(&sorted, 75.0),
650 p95: percentile_sorted(&sorted, 95.0),
651 }
652 }
653}
654
655#[derive(Debug, Clone, Serialize, Deserialize)]
661pub enum Winner {
662 Control(String),
663 Treatment(String),
664 NoSignificantDifference,
665}
666
667#[derive(Debug, Clone, Serialize, Deserialize)]
669pub struct ABTestReport {
670 pub control_name: String,
671 pub treatment_name: String,
672 pub control_stats: VariantStats,
673 pub treatment_stats: VariantStats,
674 pub ttest: TTestResult,
675 pub mann_whitney: MannWhitneyResult,
676 pub cohens_d: f64,
678 pub significant: bool,
680 pub winner: Winner,
681 pub significance_level: f64,
682}
683
684impl ABTestReport {
685 pub fn summary(&self) -> String {
687 let ctrl = &self.control_stats;
688 let trt = &self.treatment_stats;
689 let mut lines = Vec::new();
690 lines.push("=== A/B Test Report ===".to_string());
691 lines.push(format!(
692 "Control ({:>20}): n={:4} mean={:.4} std={:.4} p50={:.4}",
693 self.control_name, ctrl.n, ctrl.mean, ctrl.std_dev, ctrl.p50
694 ));
695 lines.push(format!(
696 "Treatment ({:>20}): n={:4} mean={:.4} std={:.4} p50={:.4}",
697 self.treatment_name, trt.n, trt.mean, trt.std_dev, trt.p50
698 ));
699 lines.push(format!(
700 "Welch's t-test: t={:.4} df={:.1} p={:.4}",
701 self.ttest.t_statistic, self.ttest.degrees_of_freedom, self.ttest.p_value
702 ));
703 lines.push(format!(
704 "Mann-Whitney U: U={:.1} z={:.4} p={:.4}",
705 self.mann_whitney.u_statistic, self.mann_whitney.z_score, self.mann_whitney.p_value
706 ));
707 lines.push(format!("Cohen's d: {:.4}", self.cohens_d));
708 lines.push(format!(
709 "Significant (α={}): {}",
710 self.significance_level, self.significant
711 ));
712 lines.push(match &self.winner {
713 Winner::Control(n) => format!("Winner: CONTROL ({n})"),
714 Winner::Treatment(n) => format!("Winner: TREATMENT ({n})"),
715 Winner::NoSignificantDifference => "Winner: No significant difference".to_string(),
716 });
717 lines.join("\n")
718 }
719
720 pub fn treatment_wins(&self) -> bool {
722 matches!(&self.winner, Winner::Treatment(_))
723 }
724
725 pub fn control_wins(&self) -> bool {
727 matches!(&self.winner, Winner::Control(_))
728 }
729
730 pub fn relative_improvement(&self) -> f64 {
732 let ctrl_mean = self.control_stats.mean;
733 let trt_mean = self.treatment_stats.mean;
734 if ctrl_mean.abs() < 1e-12 {
735 return 0.0;
736 }
737 (trt_mean - ctrl_mean) / ctrl_mean.abs()
738 }
739}
740
741fn mean(data: &[f64]) -> f64 {
746 if data.is_empty() {
747 return 0.0;
748 }
749 data.iter().sum::<f64>() / data.len() as f64
750}
751
752fn variance(data: &[f64], m: f64) -> f64 {
753 if data.len() < 2 {
754 return 0.0;
755 }
756 let sum_sq: f64 = data.iter().map(|&x| (x - m).powi(2)).sum();
757 sum_sq / (data.len() - 1) as f64
758}
759
760fn percentile_sorted(sorted: &[f64], p: f64) -> f64 {
761 if sorted.is_empty() {
762 return f64::NAN;
763 }
764 let n = sorted.len();
765 if n == 1 {
766 return sorted[0];
767 }
768 let rank = p / 100.0 * (n - 1) as f64;
769 let lo = rank.floor() as usize;
770 let hi = (lo + 1).min(n - 1);
771 let frac = rank - lo as f64;
772 sorted[lo] + frac * (sorted[hi] - sorted[lo])
773}
774
775fn t_distribution_two_sided_p(t_abs: f64, df: f64) -> f64 {
779 if df <= 0.0 {
780 return 1.0;
781 }
782 if df > 200.0 {
784 return 2.0 * standard_normal_sf(t_abs);
785 }
786 let x = df / (df + t_abs * t_abs);
788 let a = df / 2.0;
789 let b = 0.5f64;
790 let ibeta = regularized_incomplete_beta(x, a, b);
791 ibeta.clamp(0.0, 1.0)
792}
793
794fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
796 if x <= 0.0 {
797 return 0.0;
798 }
799 if x >= 1.0 {
800 return 1.0;
801 }
802 let switch = (a + 1.0) / (a + b + 2.0);
804 if x > switch {
805 return 1.0 - regularized_incomplete_beta(1.0 - x, b, a);
806 }
807 let ln_front = a * x.ln() + b * (1.0 - x).ln() - ln_beta(a, b);
809 let front = ln_front.exp();
810 let cf = beta_continued_fraction(x, a, b);
812 (front * cf / a).clamp(0.0, 1.0)
813}
814
815fn beta_continued_fraction(x: f64, a: f64, b: f64) -> f64 {
816 let max_iter = 200;
818 let eps = 1e-14;
819 let tiny = 1e-300;
820 let mut f = tiny;
821 let mut c = f;
822 let mut d = 1.0 - (a + b) * x / (a + 1.0);
823 if d.abs() < tiny {
824 d = tiny;
825 }
826 d = 1.0 / d;
827 f = d;
828 for m in 1..=max_iter {
829 let m_f = m as f64;
830 let aa = m_f * (b - m_f) * x / ((a + 2.0 * m_f - 1.0) * (a + 2.0 * m_f));
832 d = 1.0 + aa * d;
833 if d.abs() < tiny {
834 d = tiny;
835 }
836 c = 1.0 + aa / c;
837 if c.abs() < tiny {
838 c = tiny;
839 }
840 d = 1.0 / d;
841 f *= d * c;
842 let aa = -(a + m_f) * (a + b + m_f) * x / ((a + 2.0 * m_f) * (a + 2.0 * m_f + 1.0));
844 d = 1.0 + aa * d;
845 if d.abs() < tiny {
846 d = tiny;
847 }
848 c = 1.0 + aa / c;
849 if c.abs() < tiny {
850 c = tiny;
851 }
852 d = 1.0 / d;
853 let del = d * c;
854 f *= del;
855 if (del - 1.0).abs() < eps {
856 break;
857 }
858 }
859 f
860}
861
862fn ln_beta(a: f64, b: f64) -> f64 {
864 lgamma(a) + lgamma(b) - lgamma(a + b)
865}
866
867fn lgamma(a: f64) -> f64 {
869 const G: f64 = 7.0;
871 const C: [f64; 9] = [
872 0.999_999_999_999_809_9,
873 676.520_368_121_885_1,
874 -1259.139216722403,
875 771.323_428_777_653,
876 -176.615_029_162_141_9,
877 12.507343278686905,
878 -0.138571095265720,
879 9.984369578019572e-6,
880 1.505632735149312e-7,
881 ];
882 if a < 0.5 {
883 std::f64::consts::PI.ln() - (std::f64::consts::PI * a).sin().abs().ln() - lgamma(1.0 - a)
884 } else {
885 let x = a - 1.0;
886 let t = x + G + 0.5;
887 let ser: f64 = C[0]
888 + C[1..]
889 .iter()
890 .enumerate()
891 .map(|(i, &c)| c / (x + i as f64 + 1.0))
892 .sum::<f64>();
893 (2.0 * std::f64::consts::PI).sqrt().ln() + ser.abs().ln() + (x + 0.5) * t.ln() - t
894 }
895}
896
897fn standard_normal_sf(z: f64) -> f64 {
899 0.5 * erfc(z / std::f64::consts::SQRT_2)
900}
901
902fn erfc(x: f64) -> f64 {
904 if x < 0.0 {
906 return 2.0 - erfc(-x);
907 }
908 let t = 1.0 / (1.0 + 0.3275911 * x);
909 let poly = t
910 * (0.254829592
911 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
912 poly * (-x * x).exp()
913}
914
915fn fnv1a_hash(s: &str) -> u64 {
917 let mut h: u64 = 14695981039346656037;
918 for b in s.bytes() {
919 h ^= b as u64;
920 h = h.wrapping_mul(1099511628211);
921 }
922 h
923}
924
925#[cfg(test)]
930mod tests {
931 use super::*;
932
933 fn make_runner(split: f64) -> ABTestRunner {
934 let control = ModelVariant::new("control", |_| vec![0.0f64; 4]);
935 let treatment = ModelVariant::new("treatment", |_| vec![1.0f64; 4]);
936 let config = ABTestConfig {
937 traffic_split: split,
938 min_samples: 5,
939 ..Default::default()
940 };
941 ABTestRunner::new(config, control, treatment).expect("runner should construct")
942 }
943
944 #[test]
945 fn test_model_variant_infer() {
946 let v = ModelVariant::new("test", |key| vec![key.len() as f64]);
947 let result = v.infer("hello");
948 assert_eq!(result, vec![5.0]);
949 }
950
951 #[test]
952 fn test_model_variant_metadata() {
953 let v = ModelVariant::new("sage-v1", |_| vec![])
954 .with_description("GraphSAGE v1")
955 .with_version("1.2.3");
956 assert_eq!(v.name, "sage-v1");
957 assert_eq!(v.description, "GraphSAGE v1");
958 assert_eq!(v.version, "1.2.3");
959 }
960
961 #[test]
962 fn test_abtest_config_default() {
963 let cfg = ABTestConfig::default();
964 assert!((cfg.traffic_split - 0.5).abs() < 1e-10);
965 assert_eq!(cfg.min_samples, 50);
966 assert!((cfg.significance_level - 0.05).abs() < 1e-10);
967 }
968
969 #[test]
970 fn test_runner_construction_invalid_split() {
971 let ctrl = ModelVariant::new("c", |_| vec![]);
972 let trt = ModelVariant::new("t", |_| vec![]);
973 let cfg = ABTestConfig {
974 traffic_split: 1.5,
975 ..Default::default()
976 };
977 assert!(ABTestRunner::new(cfg, ctrl, trt).is_err());
978 }
979
980 #[test]
981 fn test_runner_route() {
982 let mut runner = make_runner(0.5);
983 for i in 0..20 {
984 let key = format!("entity:{i}");
985 let (emb, variant) = runner.route(&key).expect("route should succeed");
986 assert!(!emb.is_empty());
987 assert!(variant == "control" || variant == "treatment");
988 }
989 assert_eq!(runner.total_requests(), 20);
990 }
991
992 #[test]
993 fn test_runner_traffic_split_deterministic() {
994 let mut r1 = make_runner(0.3);
996 let mut r2 = make_runner(0.3);
997 for i in 0..50 {
998 let key = format!("k{i}");
999 let (_, v1) = r1.route(&key).expect("route 1 ok");
1000 let (_, v2) = r2.route(&key).expect("route 2 ok");
1001 assert_eq!(v1, v2, "routing should be deterministic");
1002 }
1003 }
1004
1005 #[test]
1006 fn test_runner_record_metric_invalid() {
1007 let mut runner = make_runner(0.5);
1008 assert!(runner.record_metric("control", f64::NAN).is_err());
1009 assert!(runner.record_metric("control", f64::INFINITY).is_err());
1010 }
1011
1012 #[test]
1013 fn test_runner_record_and_stats() {
1014 let mut runner = make_runner(0.5);
1015 for i in 0..20 {
1016 let key = format!("e:{i}");
1017 let (_, variant) = runner.route(&key).expect("route ok");
1018 runner
1019 .record_metric(&variant, (i as f64) * 0.1)
1020 .expect("record ok");
1021 }
1022 let stats = runner.variant_stats();
1023 assert!(!stats.is_empty());
1024 for s in stats.values() {
1025 assert!(s.n > 0);
1026 assert!(s.mean.is_finite());
1027 }
1028 }
1029
1030 #[test]
1031 fn test_welchs_ttest_identical_groups() {
1032 let cfg = ABTestConfig::default();
1033 let analyzer = ABTestAnalyzer::new(&cfg);
1034 let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
1035 let result = analyzer
1036 .welchs_ttest(&data, &data)
1037 .expect("t-test should succeed");
1038 assert!(
1039 (result.t_statistic).abs() < 1e-10,
1040 "t should be 0 for identical groups"
1041 );
1042 assert!(
1043 (result.p_value - 1.0).abs() < 0.01,
1044 "p should be ~1 for identical groups, got {}",
1045 result.p_value
1046 );
1047 assert!((result.mean_diff).abs() < 1e-10);
1048 }
1049
1050 #[test]
1051 fn test_welchs_ttest_clearly_different() {
1052 let cfg = ABTestConfig::default();
1053 let analyzer = ABTestAnalyzer::new(&cfg);
1054 let a: Vec<f64> = (0..50).map(|_| 0.0).collect();
1055 let b: Vec<f64> = (0..50).map(|_| 100.0).collect();
1056 let result = analyzer
1057 .welchs_ttest(&a, &b)
1058 .expect("t-test should succeed");
1059 assert!(
1061 result.p_value < 0.001,
1062 "p-value should be very small, got {}",
1063 result.p_value
1064 );
1065 }
1066
1067 #[test]
1068 fn test_mann_whitney_identical() {
1069 let cfg = ABTestConfig::default();
1070 let analyzer = ABTestAnalyzer::new(&cfg);
1071 let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
1072 let result = analyzer
1073 .mann_whitney_u(&data, &data)
1074 .expect("MWU should succeed");
1075 assert!(
1077 result.p_value > 0.3,
1078 "p-value for identical groups should be high, got {}",
1079 result.p_value
1080 );
1081 }
1082
1083 #[test]
1084 fn test_mann_whitney_clearly_different() {
1085 let cfg = ABTestConfig::default();
1086 let analyzer = ABTestAnalyzer::new(&cfg);
1087 let a: Vec<f64> = (0..40).map(|_| 0.0).collect();
1088 let b: Vec<f64> = (0..40).map(|_| 10.0).collect();
1089 let result = analyzer.mann_whitney_u(&a, &b).expect("MWU should succeed");
1090 assert!(
1091 result.p_value < 0.001,
1092 "p-value for clearly different groups should be small, got {}",
1093 result.p_value
1094 );
1095 }
1096
1097 #[test]
1098 fn test_cohens_d_no_difference() {
1099 let cfg = ABTestConfig::default();
1100 let analyzer = ABTestAnalyzer::new(&cfg);
1101 let data: Vec<f64> = (0..20).map(|i| i as f64).collect();
1102 let d = analyzer.cohens_d(&data, &data);
1103 assert!(
1104 (d).abs() < 1e-10,
1105 "Cohen's d should be 0 for identical groups"
1106 );
1107 }
1108
1109 #[test]
1110 fn test_cohens_d_large_effect() {
1111 let cfg = ABTestConfig::default();
1112 let analyzer = ABTestAnalyzer::new(&cfg);
1113 let a: Vec<f64> = vec![0.0f64; 30];
1114 let b: Vec<f64> = vec![10.0f64; 30];
1115 let d = analyzer.cohens_d(&a, &b);
1116 assert!(d.is_finite() || d == 0.0);
1118 }
1119
1120 #[test]
1121 fn test_full_ab_test_workflow() {
1122 let control = ModelVariant::new("baseline", |_| vec![0.0f64; 4]);
1123 let treatment = ModelVariant::new("improved", |_| vec![1.0f64; 4]);
1124 let config = ABTestConfig {
1125 traffic_split: 0.5,
1126 min_samples: 10,
1127 significance_level: 0.05,
1128 optimize: OptimizeMetric::Maximize,
1129 seed: 99,
1130 max_requests: None,
1131 min_effect_size: None,
1132 };
1133 let mut runner =
1134 ABTestRunner::new(config, control, treatment).expect("runner should construct");
1135
1136 let mut rng_state: u64 = 12345;
1138 for i in 0..100 {
1139 let key = format!("entity:{i}");
1140 let (_, variant) = runner.route(&key).expect("route ok");
1141 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
1143 let base = (rng_state >> 32) as f64 / u32::MAX as f64;
1144 let metric = if variant == "improved" {
1145 base * 0.3 + 0.7
1146 } else {
1147 base * 0.3 + 0.2
1148 };
1149 runner.record_metric(&variant, metric).expect("record ok");
1150 }
1151
1152 let report = runner.analyze().expect("analysis should succeed");
1153 assert!(report.control_stats.n >= 10);
1154 assert!(report.treatment_stats.n >= 10);
1155 let wins = report.treatment_wins() || !report.significant;
1157 assert!(wins, "treatment should win or no significant difference");
1158
1159 let summary = report.summary();
1160 assert!(summary.contains("A/B Test Report"));
1161 }
1162
1163 #[test]
1164 fn test_ab_test_max_requests() {
1165 let ctrl = ModelVariant::new("c", |_| vec![1.0]);
1166 let trt = ModelVariant::new("t", |_| vec![2.0]);
1167 let cfg = ABTestConfig {
1168 max_requests: Some(5),
1169 ..Default::default()
1170 };
1171 let mut runner = ABTestRunner::new(cfg, ctrl, trt).expect("runner ok");
1172 for i in 0..5 {
1173 runner.route(&format!("k{i}")).expect("route ok");
1174 }
1175 let err = runner.route("k5");
1176 assert!(err.is_err(), "should error after max_requests");
1177 }
1178
1179 #[test]
1180 fn test_variant_stats_empty() {
1181 let stats = VariantStats::from_slice(&[]);
1182 assert_eq!(stats.n, 0);
1183 assert!(stats.mean.is_nan());
1184 }
1185
1186 #[test]
1187 fn test_variant_stats_single() {
1188 let stats = VariantStats::from_slice(&[42.0]);
1189 assert_eq!(stats.n, 1);
1190 assert_eq!(stats.mean, 42.0);
1191 assert_eq!(stats.min, 42.0);
1192 assert_eq!(stats.max, 42.0);
1193 }
1194
1195 #[test]
1196 fn test_variant_stats_known_values() {
1197 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1198 let stats = VariantStats::from_slice(&data);
1199 assert_eq!(stats.n, 5);
1200 assert!((stats.mean - 3.0).abs() < 1e-10);
1201 assert_eq!(stats.min, 1.0);
1202 assert_eq!(stats.max, 5.0);
1203 assert!((stats.p50 - 3.0).abs() < 1e-10);
1204 }
1205
1206 #[test]
1207 fn test_report_relative_improvement() {
1208 let ctrl = ModelVariant::new("c", |_| vec![0.5f64]);
1209 let trt = ModelVariant::new("t", |_| vec![0.6f64]);
1210 let cfg = ABTestConfig {
1211 min_samples: 5,
1212 ..Default::default()
1213 };
1214 let mut runner = ABTestRunner::new(cfg, ctrl, trt).expect("runner ok");
1215 for i in 0..30 {
1216 let (_, v) = runner.route(&format!("k{i}")).expect("route ok");
1217 let metric = if v == "c" { 0.5 } else { 0.6 };
1218 runner.record_metric(&v, metric).expect("record ok");
1219 }
1220 let report = runner.analyze().expect("analyze ok");
1221 let ri = report.relative_improvement();
1222 assert!(ri.is_finite());
1223 }
1224}