1use crate::hybrid_classifier::{ClassificationMethod, HybridResult};
31use crate::ml_features::InferredOwnership;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::time::{Duration, Instant};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum TestVariant {
39 Control,
41 Treatment,
43}
44
45impl std::fmt::Display for TestVariant {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 TestVariant::Control => write!(f, "control"),
49 TestVariant::Treatment => write!(f, "treatment"),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TestObservation {
57 pub variant: TestVariant,
59 pub variable: String,
61 pub predicted: InferredOwnership,
63 pub ground_truth: Option<InferredOwnership>,
65 pub confidence: f64,
67 pub method: ClassificationMethod,
69 pub latency: Duration,
71 pub correct: Option<bool>,
73}
74
75impl TestObservation {
76 pub fn from_result(
78 variant: TestVariant,
79 result: &HybridResult,
80 ground_truth: Option<InferredOwnership>,
81 latency: Duration,
82 ) -> Self {
83 let correct = ground_truth.as_ref().map(|gt| *gt == result.ownership);
84
85 Self {
86 variant,
87 variable: result.variable.clone(),
88 predicted: result.ownership,
89 ground_truth,
90 confidence: result.confidence,
91 method: result.method,
92 latency,
93 correct,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
100pub struct VariantMetrics {
101 pub count: u64,
103 pub correct: u64,
105 pub with_ground_truth: u64,
107 pub confidence_sum: f64,
109 pub latency_sum_us: u64,
111 pub by_ownership: HashMap<String, u64>,
113 pub by_method: HashMap<String, u64>,
115}
116
117impl VariantMetrics {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn record(&mut self, obs: &TestObservation) {
125 self.count += 1;
126 self.confidence_sum += obs.confidence;
127 self.latency_sum_us += obs.latency.as_micros() as u64;
128
129 *self
131 .by_ownership
132 .entry(format!("{:?}", obs.predicted))
133 .or_insert(0) += 1;
134
135 *self.by_method.entry(obs.method.to_string()).or_insert(0) += 1;
137
138 if let Some(correct) = obs.correct {
140 self.with_ground_truth += 1;
141 if correct {
142 self.correct += 1;
143 }
144 }
145 }
146
147 pub fn accuracy(&self) -> f64 {
149 if self.with_ground_truth == 0 {
150 0.0
151 } else {
152 self.correct as f64 / self.with_ground_truth as f64
153 }
154 }
155
156 pub fn avg_confidence(&self) -> f64 {
158 if self.count == 0 {
159 0.0
160 } else {
161 self.confidence_sum / self.count as f64
162 }
163 }
164
165 pub fn avg_latency_us(&self) -> f64 {
167 if self.count == 0 {
168 0.0
169 } else {
170 self.latency_sum_us as f64 / self.count as f64
171 }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ABExperiment {
178 pub name: String,
180 pub description: String,
182 pub control: VariantMetrics,
184 pub treatment: VariantMetrics,
186 pub started_at: u64,
188 pub ended_at: u64,
190}
191
192impl ABExperiment {
193 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
195 let now = std::time::SystemTime::now()
196 .duration_since(std::time::UNIX_EPOCH)
197 .unwrap_or_default()
198 .as_millis() as u64;
199
200 Self {
201 name: name.into(),
202 description: description.into(),
203 control: VariantMetrics::new(),
204 treatment: VariantMetrics::new(),
205 started_at: now,
206 ended_at: 0,
207 }
208 }
209
210 pub fn record(&mut self, obs: &TestObservation) {
212 match obs.variant {
213 TestVariant::Control => self.control.record(obs),
214 TestVariant::Treatment => self.treatment.record(obs),
215 }
216 }
217
218 pub fn end(&mut self) {
220 let now = std::time::SystemTime::now()
221 .duration_since(std::time::UNIX_EPOCH)
222 .unwrap_or_default()
223 .as_millis() as u64;
224 self.ended_at = now;
225 }
226
227 pub fn is_active(&self) -> bool {
229 self.ended_at == 0
230 }
231
232 pub fn total_observations(&self) -> u64 {
234 self.control.count + self.treatment.count
235 }
236
237 pub fn accuracy_lift(&self) -> f64 {
239 self.treatment.accuracy() - self.control.accuracy()
240 }
241
242 pub fn confidence_lift(&self) -> f64 {
244 self.treatment.avg_confidence() - self.control.avg_confidence()
245 }
246
247 pub fn latency_diff_us(&self) -> f64 {
249 self.treatment.avg_latency_us() - self.control.avg_latency_us()
250 }
251
252 pub fn is_treatment_better(&self) -> (bool, f64) {
256 let control_correct = self.control.correct as f64;
260 let control_wrong = (self.control.with_ground_truth - self.control.correct) as f64;
261 let treatment_correct = self.treatment.correct as f64;
262 let treatment_wrong = (self.treatment.with_ground_truth - self.treatment.correct) as f64;
263
264 if self.control.with_ground_truth < 30 || self.treatment.with_ground_truth < 30 {
266 return (false, 1.0);
267 }
268
269 let total = control_correct + control_wrong + treatment_correct + treatment_wrong;
271 if total == 0.0 {
272 return (false, 1.0);
273 }
274
275 let row_total_control = control_correct + control_wrong;
276 let row_total_treatment = treatment_correct + treatment_wrong;
277 let col_total_correct = control_correct + treatment_correct;
278 let col_total_wrong = control_wrong + treatment_wrong;
279
280 let e_cc = (row_total_control * col_total_correct) / total;
282 let e_cw = (row_total_control * col_total_wrong) / total;
283 let e_tc = (row_total_treatment * col_total_correct) / total;
284 let e_tw = (row_total_treatment * col_total_wrong) / total;
285
286 let chi_sq = if e_cc > 0.0 && e_cw > 0.0 && e_tc > 0.0 && e_tw > 0.0 {
288 ((control_correct - e_cc).powi(2) / e_cc)
289 + ((control_wrong - e_cw).powi(2) / e_cw)
290 + ((treatment_correct - e_tc).powi(2) / e_tc)
291 + ((treatment_wrong - e_tw).powi(2) / e_tw)
292 } else {
293 0.0
294 };
295
296 let p_value = if chi_sq > 6.63 {
300 0.01
301 } else if chi_sq > 3.84 {
302 0.05
303 } else {
304 0.5
305 };
306
307 let is_significant = p_value < 0.05 && self.treatment.accuracy() > self.control.accuracy();
308
309 (is_significant, p_value)
310 }
311
312 pub fn to_markdown(&self) -> String {
314 let (is_better, p_value) = self.is_treatment_better();
315 let status = if !self.is_active() {
316 "COMPLETED"
317 } else {
318 "ACTIVE"
319 };
320
321 let recommendation = if is_better {
322 "✅ ADOPT TREATMENT - Statistically significant improvement"
323 } else if self.total_observations() < 100 {
324 "⏳ INSUFFICIENT DATA - Need more observations"
325 } else {
326 "❌ KEEP CONTROL - No significant improvement"
327 };
328
329 format!(
330 r#"## A/B Test Report: {}
331
332**Status**: {} | **Description**: {}
333
334### Summary
335
336| Metric | Control | Treatment | Lift |
337|--------|---------|-----------|------|
338| Observations | {} | {} | - |
339| Accuracy | {:.1}% | {:.1}% | {:+.1}% |
340| Avg Confidence | {:.2} | {:.2} | {:+.2} |
341| Avg Latency | {:.0}μs | {:.0}μs | {:+.0}μs |
342
343### Statistical Analysis
344
345- **Chi-squared p-value**: {:.3}
346- **Treatment better?**: {}
347- **Recommendation**: {}
348
349### Control Group Distribution
350
351{}
352
353### Treatment Group Distribution
354
355{}
356"#,
357 self.name,
358 status,
359 self.description,
360 self.control.count,
361 self.treatment.count,
362 self.control.accuracy() * 100.0,
363 self.treatment.accuracy() * 100.0,
364 self.accuracy_lift() * 100.0,
365 self.control.avg_confidence(),
366 self.treatment.avg_confidence(),
367 self.confidence_lift(),
368 self.control.avg_latency_us(),
369 self.treatment.avg_latency_us(),
370 self.latency_diff_us(),
371 p_value,
372 if is_better { "Yes" } else { "No" },
373 recommendation,
374 self.format_distribution(&self.control),
375 self.format_distribution(&self.treatment),
376 )
377 }
378
379 fn format_distribution(&self, metrics: &VariantMetrics) -> String {
380 let mut lines = Vec::new();
381 for (ownership, count) in &metrics.by_ownership {
382 let pct = if metrics.count > 0 {
383 (*count as f64 / metrics.count as f64) * 100.0
384 } else {
385 0.0
386 };
387 lines.push(format!("- {}: {} ({:.1}%)", ownership, count, pct));
388 }
389 if lines.is_empty() {
390 "- No data".to_string()
391 } else {
392 lines.join("\n")
393 }
394 }
395}
396
397#[derive(Debug)]
399pub struct ABTestRunner {
400 experiment: ABExperiment,
402 strategy: AssignmentStrategy,
404 seed: u64,
406 counter: u64,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
412pub enum AssignmentStrategy {
413 RoundRobin,
415 Random,
417 AllControl,
419 AllTreatment,
421}
422
423impl ABTestRunner {
424 pub fn new(
426 name: impl Into<String>,
427 description: impl Into<String>,
428 strategy: AssignmentStrategy,
429 ) -> Self {
430 Self {
431 experiment: ABExperiment::new(name, description),
432 strategy,
433 seed: 42,
434 counter: 0,
435 }
436 }
437
438 pub fn with_seed(mut self, seed: u64) -> Self {
440 self.seed = seed;
441 self
442 }
443
444 pub fn assign(&mut self) -> TestVariant {
446 let variant = match self.strategy {
447 AssignmentStrategy::RoundRobin => {
448 if self.counter % 2 == 0 {
449 TestVariant::Control
450 } else {
451 TestVariant::Treatment
452 }
453 }
454 AssignmentStrategy::Random => {
455 self.seed = self.seed.wrapping_mul(6364136223846793005).wrapping_add(1);
457 if self.seed % 2 == 0 {
458 TestVariant::Control
459 } else {
460 TestVariant::Treatment
461 }
462 }
463 AssignmentStrategy::AllControl => TestVariant::Control,
464 AssignmentStrategy::AllTreatment => TestVariant::Treatment,
465 };
466 self.counter += 1;
467 variant
468 }
469
470 pub fn record(
472 &mut self,
473 variant: TestVariant,
474 result: &HybridResult,
475 ground_truth: Option<InferredOwnership>,
476 latency: Duration,
477 ) {
478 let obs = TestObservation::from_result(variant, result, ground_truth, latency);
479 self.experiment.record(&obs);
480 }
481
482 pub fn timed_record<F>(
484 &mut self,
485 variant: TestVariant,
486 classify_fn: F,
487 ground_truth: Option<InferredOwnership>,
488 ) -> HybridResult
489 where
490 F: FnOnce() -> HybridResult,
491 {
492 let start = Instant::now();
493 let result = classify_fn();
494 let latency = start.elapsed();
495
496 self.record(variant, &result, ground_truth, latency);
497 result
498 }
499
500 pub fn finish(&mut self) -> String {
502 self.experiment.end();
503 self.experiment.to_markdown()
504 }
505
506 pub fn experiment(&self) -> &ABExperiment {
508 &self.experiment
509 }
510
511 pub fn experiment_mut(&mut self) -> &mut ABExperiment {
513 &mut self.experiment
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
526 fn test_variant_display() {
527 assert_eq!(TestVariant::Control.to_string(), "control");
528 assert_eq!(TestVariant::Treatment.to_string(), "treatment");
529 }
530
531 #[test]
536 fn variant_metrics_default() {
537 let metrics = VariantMetrics::new();
538 assert_eq!(metrics.count, 0);
539 assert_eq!(metrics.accuracy(), 0.0);
540 assert_eq!(metrics.avg_confidence(), 0.0);
541 }
542
543 #[test]
544 fn variant_metrics_record() {
545 let mut metrics = VariantMetrics::new();
546
547 let obs = TestObservation {
548 variant: TestVariant::Control,
549 variable: "ptr".to_string(),
550 predicted: InferredOwnership::Owned,
551 ground_truth: Some(InferredOwnership::Owned),
552 confidence: 0.9,
553 method: ClassificationMethod::RuleBased,
554 latency: Duration::from_micros(100),
555 correct: Some(true),
556 };
557
558 metrics.record(&obs);
559
560 assert_eq!(metrics.count, 1);
561 assert_eq!(metrics.with_ground_truth, 1);
562 assert_eq!(metrics.correct, 1);
563 assert!((metrics.accuracy() - 1.0).abs() < 0.001);
564 assert!((metrics.avg_confidence() - 0.9).abs() < 0.001);
565 }
566
567 #[test]
568 fn variant_metrics_accuracy() {
569 let mut metrics = VariantMetrics::new();
570
571 for (correct, gt) in [(true, true), (true, true), (true, true), (false, false)] {
573 let obs = TestObservation {
574 variant: TestVariant::Control,
575 variable: "x".to_string(),
576 predicted: InferredOwnership::Owned,
577 ground_truth: Some(if gt {
578 InferredOwnership::Owned
579 } else {
580 InferredOwnership::Borrowed
581 }),
582 confidence: 0.8,
583 method: ClassificationMethod::RuleBased,
584 latency: Duration::from_micros(50),
585 correct: Some(correct),
586 };
587 metrics.record(&obs);
588 }
589
590 assert_eq!(metrics.with_ground_truth, 4);
591 assert_eq!(metrics.correct, 3);
592 assert!((metrics.accuracy() - 0.75).abs() < 0.001);
593 }
594
595 #[test]
600 fn ab_experiment_new() {
601 let exp = ABExperiment::new("test-001", "Testing hybrid vs rules");
602 assert!(exp.is_active());
603 assert_eq!(exp.name, "test-001");
604 assert_eq!(exp.total_observations(), 0);
605 }
606
607 #[test]
608 fn ab_experiment_end() {
609 let mut exp = ABExperiment::new("test", "desc");
610 assert!(exp.is_active());
611 exp.end();
612 assert!(!exp.is_active());
613 }
614
615 #[test]
616 fn ab_experiment_record() {
617 let mut exp = ABExperiment::new("test", "desc");
618
619 let obs_control = TestObservation {
620 variant: TestVariant::Control,
621 variable: "a".to_string(),
622 predicted: InferredOwnership::Owned,
623 ground_truth: Some(InferredOwnership::Owned),
624 confidence: 0.8,
625 method: ClassificationMethod::RuleBased,
626 latency: Duration::from_micros(100),
627 correct: Some(true),
628 };
629
630 let obs_treatment = TestObservation {
631 variant: TestVariant::Treatment,
632 variable: "b".to_string(),
633 predicted: InferredOwnership::Borrowed,
634 ground_truth: Some(InferredOwnership::Borrowed),
635 confidence: 0.9,
636 method: ClassificationMethod::MachineLearning,
637 latency: Duration::from_micros(150),
638 correct: Some(true),
639 };
640
641 exp.record(&obs_control);
642 exp.record(&obs_treatment);
643
644 assert_eq!(exp.control.count, 1);
645 assert_eq!(exp.treatment.count, 1);
646 assert_eq!(exp.total_observations(), 2);
647 }
648
649 #[test]
650 fn ab_experiment_lift_calculation() {
651 let mut exp = ABExperiment::new("test", "desc");
652
653 for i in 0..10 {
655 let correct = i < 7;
656 let obs = TestObservation {
657 variant: TestVariant::Control,
658 variable: format!("c{}", i),
659 predicted: InferredOwnership::Owned,
660 ground_truth: Some(if correct {
661 InferredOwnership::Owned
662 } else {
663 InferredOwnership::Borrowed
664 }),
665 confidence: 0.7,
666 method: ClassificationMethod::RuleBased,
667 latency: Duration::from_micros(100),
668 correct: Some(correct),
669 };
670 exp.record(&obs);
671 }
672
673 for i in 0..10 {
675 let correct = i < 9;
676 let obs = TestObservation {
677 variant: TestVariant::Treatment,
678 variable: format!("t{}", i),
679 predicted: InferredOwnership::Owned,
680 ground_truth: Some(if correct {
681 InferredOwnership::Owned
682 } else {
683 InferredOwnership::Borrowed
684 }),
685 confidence: 0.9,
686 method: ClassificationMethod::MachineLearning,
687 latency: Duration::from_micros(150),
688 correct: Some(correct),
689 };
690 exp.record(&obs);
691 }
692
693 assert!((exp.accuracy_lift() - 0.2).abs() < 0.001);
695 assert!((exp.confidence_lift() - 0.2).abs() < 0.001);
696 }
697
698 #[test]
699 fn ab_experiment_to_markdown() {
700 let exp = ABExperiment::new("test-001", "Hybrid vs Rules");
701 let md = exp.to_markdown();
702
703 assert!(md.contains("A/B Test Report: test-001"));
704 assert!(md.contains("ACTIVE"));
705 assert!(md.contains("Hybrid vs Rules"));
706 }
707
708 #[test]
713 fn ab_runner_round_robin() {
714 let mut runner = ABTestRunner::new("test", "desc", AssignmentStrategy::RoundRobin);
715
716 assert_eq!(runner.assign(), TestVariant::Control);
717 assert_eq!(runner.assign(), TestVariant::Treatment);
718 assert_eq!(runner.assign(), TestVariant::Control);
719 assert_eq!(runner.assign(), TestVariant::Treatment);
720 }
721
722 #[test]
723 fn ab_runner_all_control() {
724 let mut runner = ABTestRunner::new("test", "desc", AssignmentStrategy::AllControl);
725
726 for _ in 0..10 {
727 assert_eq!(runner.assign(), TestVariant::Control);
728 }
729 }
730
731 #[test]
732 fn ab_runner_all_treatment() {
733 let mut runner = ABTestRunner::new("test", "desc", AssignmentStrategy::AllTreatment);
734
735 for _ in 0..10 {
736 assert_eq!(runner.assign(), TestVariant::Treatment);
737 }
738 }
739
740 #[test]
741 fn ab_runner_random_deterministic() {
742 let mut runner1 =
743 ABTestRunner::new("test", "desc", AssignmentStrategy::Random).with_seed(42);
744 let mut runner2 =
745 ABTestRunner::new("test", "desc", AssignmentStrategy::Random).with_seed(42);
746
747 for _ in 0..10 {
749 assert_eq!(runner1.assign(), runner2.assign());
750 }
751 }
752
753 #[test]
754 fn ab_runner_finish_generates_report() {
755 let mut runner = ABTestRunner::new("exp-001", "Testing", AssignmentStrategy::RoundRobin);
756
757 let result = HybridResult {
758 variable: "ptr".to_string(),
759 ownership: InferredOwnership::Owned,
760 confidence: 0.9,
761 method: ClassificationMethod::MachineLearning,
762 rule_result: Some(InferredOwnership::Owned),
763 ml_result: None,
764 reasoning: "test".to_string(),
765 };
766
767 runner.record(
768 TestVariant::Control,
769 &result,
770 Some(InferredOwnership::Owned),
771 Duration::from_micros(100),
772 );
773
774 let report = runner.finish();
775 assert!(report.contains("exp-001"));
776 assert!(report.contains("COMPLETED"));
777 }
778
779 #[test]
784 fn ab_experiment_insufficient_data() {
785 let mut exp = ABExperiment::new("test", "desc");
786
787 for i in 0..5 {
789 let obs = TestObservation {
790 variant: TestVariant::Control,
791 variable: format!("c{}", i),
792 predicted: InferredOwnership::Owned,
793 ground_truth: Some(InferredOwnership::Owned),
794 confidence: 0.8,
795 method: ClassificationMethod::RuleBased,
796 latency: Duration::from_micros(100),
797 correct: Some(true),
798 };
799 exp.record(&obs);
800 }
801
802 let (is_better, p_value) = exp.is_treatment_better();
803 assert!(!is_better);
804 assert!((p_value - 1.0).abs() < 0.001); }
806
807 #[test]
808 fn ab_experiment_significant_improvement() {
809 let mut exp = ABExperiment::new("test", "desc");
810
811 for i in 0..30 {
813 let correct = i < 15;
814 let obs = TestObservation {
815 variant: TestVariant::Control,
816 variable: format!("c{}", i),
817 predicted: InferredOwnership::Owned,
818 ground_truth: Some(if correct {
819 InferredOwnership::Owned
820 } else {
821 InferredOwnership::Borrowed
822 }),
823 confidence: 0.5,
824 method: ClassificationMethod::RuleBased,
825 latency: Duration::from_micros(100),
826 correct: Some(correct),
827 };
828 exp.record(&obs);
829 }
830
831 for i in 0..30 {
833 let correct = i < 27;
834 let obs = TestObservation {
835 variant: TestVariant::Treatment,
836 variable: format!("t{}", i),
837 predicted: InferredOwnership::Owned,
838 ground_truth: Some(if correct {
839 InferredOwnership::Owned
840 } else {
841 InferredOwnership::Borrowed
842 }),
843 confidence: 0.9,
844 method: ClassificationMethod::MachineLearning,
845 latency: Duration::from_micros(150),
846 correct: Some(correct),
847 };
848 exp.record(&obs);
849 }
850
851 let (is_better, p_value) = exp.is_treatment_better();
852 assert!(is_better);
853 assert!(p_value < 0.05);
854 }
855
856 #[test]
857 fn ab_test_zero_observations_returns_not_significant() {
858 let exp = ABExperiment::new("empty_test", "No data");
860 let (is_better, p_value) = exp.is_treatment_better();
861 assert!(!is_better);
862 assert!((p_value - 1.0).abs() < f64::EPSILON);
863 }
864
865 #[test]
866 fn ab_test_variant_metrics_default_trait() {
867 let metrics = VariantMetrics::default();
868 assert_eq!(metrics.count, 0);
869 assert_eq!(metrics.accuracy(), 0.0);
870 assert_eq!(metrics.avg_confidence(), 0.0);
871 }
872
873 fn make_obs(variant: TestVariant, correct: Option<bool>) -> TestObservation {
874 TestObservation {
875 variant,
876 variable: "ptr".to_string(),
877 predicted: InferredOwnership::Owned,
878 ground_truth: Some(InferredOwnership::Owned),
879 confidence: 0.8,
880 method: ClassificationMethod::RuleBased,
881 latency: Duration::from_micros(100),
882 correct,
883 }
884 }
885
886 #[test]
887 fn ab_test_sufficient_data_both_correct() {
888 let mut exp = ABExperiment::new("equal", "Both equally good");
890 for _ in 0..40 {
891 exp.record(&make_obs(TestVariant::Control, Some(true)));
892 exp.record(&make_obs(TestVariant::Treatment, Some(true)));
893 }
894 let (is_better, _p_value) = exp.is_treatment_better();
895 assert!(!is_better, "Equal groups should not show treatment as better");
896 }
897
898 #[test]
899 fn ab_test_total_zero_early_return() {
900 let mut exp = ABExperiment::new("zero_total", "Zero data case");
902 for _ in 0..30 {
905 let mut obs = make_obs(TestVariant::Control, None);
906 obs.ground_truth = None;
907 exp.record(&obs);
908 }
909 for _ in 0..30 {
910 let mut obs = make_obs(TestVariant::Treatment, None);
911 obs.ground_truth = None;
912 exp.record(&obs);
913 }
914 let (is_better, p_value) = exp.is_treatment_better();
916 assert!(!is_better);
917 assert!((p_value - 1.0).abs() < f64::EPSILON);
918 }
919}