use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum CorrectionMethod {
#[default]
Bonferroni,
HolmBonferroni,
BenjaminiHochberg,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum WarningLevel {
None = 0,
Info = 1,
Caution = 2,
Warning = 3,
Critical = 4,
}
impl WarningLevel {
pub fn description(&self) -> &'static str {
match self {
WarningLevel::None => "No multiple testing concerns",
WarningLevel::Info => "Low risk - consider walk-forward validation",
WarningLevel::Caution => "Moderate risk - walk-forward analysis recommended",
WarningLevel::Warning => "High risk - walk-forward analysis strongly recommended",
WarningLevel::Critical => {
"Critical overfitting risk - results may be meaningless without validation"
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultipleTestingStats {
pub n_tests: usize,
pub alpha: f64,
pub adjusted_alpha: f64,
pub method: CorrectionMethod,
pub warning_level: WarningLevel,
pub warning_message: Option<String>,
pub risk_accepted: bool,
}
impl MultipleTestingStats {
pub fn to_value(&self) -> shape_value::KindedSlot {
use shape_value::KindedSlot;
let warning_msg = self
.warning_message
.clone()
.map(|s| KindedSlot::from_string_arc(Arc::new(s)))
.unwrap_or(KindedSlot::none());
crate::type_schema::typed_object_from_pairs(&[
("n_tests", KindedSlot::from_number(self.n_tests as f64)),
("alpha", KindedSlot::from_number(self.alpha)),
("adjusted_alpha", KindedSlot::from_number(self.adjusted_alpha)),
(
"method",
KindedSlot::from_string_arc(Arc::new(format!("{:?}", self.method))),
),
(
"warning_level",
KindedSlot::from_string_arc(Arc::new(format!("{:?}", self.warning_level))),
),
("warning_message", warning_msg),
("risk_accepted", KindedSlot::from_bool(self.risk_accepted)),
])
}
}
#[derive(Debug, Clone)]
pub struct MultipleTestingGuard {
combinations_tested: usize,
alpha: f64,
method: CorrectionMethod,
accept_overfitting_risk: bool,
_caution_threshold: usize,
_warning_threshold: usize,
_critical_threshold: usize,
}
impl Default for MultipleTestingGuard {
fn default() -> Self {
Self::new(0.05)
}
}
impl MultipleTestingGuard {
pub fn new(alpha: f64) -> Self {
Self {
combinations_tested: 0,
alpha,
method: CorrectionMethod::Bonferroni,
accept_overfitting_risk: false,
_caution_threshold: 50,
_warning_threshold: 200,
_critical_threshold: 1000,
}
}
pub fn with_method(mut self, method: CorrectionMethod) -> Self {
self.method = method;
self
}
pub fn record_tests(&mut self, n: usize) {
self.combinations_tested += n;
}
pub fn combinations_tested(&self) -> usize {
self.combinations_tested
}
pub fn accept_risk(&mut self) {
self.accept_overfitting_risk = true;
}
pub fn is_risk_accepted(&self) -> bool {
self.accept_overfitting_risk
}
pub fn adjusted_alpha(&self) -> f64 {
if self.combinations_tested == 0 {
return self.alpha;
}
match self.method {
CorrectionMethod::Bonferroni => self.alpha / self.combinations_tested as f64,
CorrectionMethod::HolmBonferroni => {
self.alpha / self.combinations_tested as f64
}
CorrectionMethod::BenjaminiHochberg => {
self.alpha * 0.5 / self.combinations_tested as f64
}
CorrectionMethod::None => self.alpha,
}
}
pub fn warning_level(&self) -> WarningLevel {
match self.combinations_tested {
0..=49 => WarningLevel::None,
50..=199 => WarningLevel::Info,
200..=999 => WarningLevel::Caution,
_ => WarningLevel::Critical,
}
}
pub fn get_stats(&self) -> MultipleTestingStats {
let warning_level = self.warning_level();
let adjusted_alpha = self.adjusted_alpha();
let warning_message = if self.accept_overfitting_risk {
None
} else {
self.generate_warning_message(warning_level, adjusted_alpha)
};
MultipleTestingStats {
n_tests: self.combinations_tested,
alpha: self.alpha,
adjusted_alpha,
method: self.method,
warning_level,
warning_message,
risk_accepted: self.accept_overfitting_risk,
}
}
fn generate_warning_message(&self, level: WarningLevel, adjusted_alpha: f64) -> Option<String> {
match level {
WarningLevel::None => None,
WarningLevel::Info => Some(format!(
"INFO: {} parameter combinations tested. Consider walk-forward validation.",
self.combinations_tested
)),
WarningLevel::Caution => Some(format!(
"CAUTION: {} parameter combinations tested. \
Bonferroni-adjusted alpha: {:.6}. \
Walk-forward analysis recommended.",
self.combinations_tested, adjusted_alpha
)),
WarningLevel::Warning | WarningLevel::Critical => Some(format!(
"WARNING: {} parameter combinations tested without walk-forward analysis.\n\
Bonferroni-adjusted alpha: {:.2e}\n\
This many tests dramatically increases false discovery risk.\n\n\
To address this:\n\
1. Use walk-forward analysis: `walk_forward: {{ ... }}`\n\
2. Or explicitly accept risk: `accept_overfitting_risk: true`",
self.combinations_tested, adjusted_alpha
)),
}
}
pub fn emit_warning_if_needed(&self) {
if self.accept_overfitting_risk {
return;
}
let stats = self.get_stats();
if let Some(msg) = &stats.warning_message {
if stats.warning_level >= WarningLevel::Caution {
eprintln!("\n{}\n", msg);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_warning_levels() {
let mut guard = MultipleTestingGuard::new(0.05);
assert_eq!(guard.warning_level(), WarningLevel::None);
guard.record_tests(50);
assert_eq!(guard.warning_level(), WarningLevel::Info);
guard.record_tests(150);
assert_eq!(guard.warning_level(), WarningLevel::Caution);
guard.record_tests(800);
assert_eq!(guard.warning_level(), WarningLevel::Critical);
}
#[test]
fn test_bonferroni_correction() {
let mut guard = MultipleTestingGuard::new(0.05);
guard.record_tests(100);
let adjusted = guard.adjusted_alpha();
assert!((adjusted - 0.0005).abs() < 1e-10);
}
#[test]
fn test_accept_risk_suppresses_warning() {
let mut guard = MultipleTestingGuard::new(0.05);
guard.record_tests(500);
guard.accept_risk();
let stats = guard.get_stats();
assert!(stats.warning_message.is_none());
assert!(stats.risk_accepted);
}
}