#[derive(Debug, Clone)]
pub struct EvaluationConfig {
pub cv_strategy: CrossValidationStrategy,
pub metrics: Vec<EvaluationMetric>,
pub test_size: f64,
pub validation_size: f64,
pub random_seed: Option<u64>,
}
#[derive(Debug, Clone)]
pub enum CrossValidationStrategy {
KFold { k: usize },
StratifiedKFold { k: usize },
TimeSeriesSplit { n_splits: usize },
LeaveOneOut,
Bootstrap { n_bootstrap: usize },
HoldOut { test_size: f64 },
}
#[derive(Debug, Clone)]
pub enum EvaluationMetric {
Accuracy,
Precision,
Recall,
F1Score,
AUC,
MeanSquaredError,
MeanAbsoluteError,
R2Score,
QuantumAdvantage,
ResourceEfficiency,
InferenceTime,
TrainingTime,
ModelComplexity,
Robustness,
}
impl Default for EvaluationConfig {
fn default() -> Self {
Self {
cv_strategy: CrossValidationStrategy::KFold { k: 5 },
metrics: vec![EvaluationMetric::Accuracy, EvaluationMetric::F1Score],
test_size: 0.2,
validation_size: 0.2,
random_seed: Some(42),
}
}
}
impl EvaluationConfig {
pub fn quick() -> Self {
Self {
cv_strategy: CrossValidationStrategy::HoldOut { test_size: 0.3 },
metrics: vec![EvaluationMetric::Accuracy],
test_size: 0.3,
validation_size: 0.2,
random_seed: Some(42),
}
}
pub fn rigorous() -> Self {
Self {
cv_strategy: CrossValidationStrategy::StratifiedKFold { k: 10 },
metrics: vec![
EvaluationMetric::Accuracy,
EvaluationMetric::Precision,
EvaluationMetric::Recall,
EvaluationMetric::F1Score,
EvaluationMetric::AUC,
EvaluationMetric::QuantumAdvantage,
EvaluationMetric::ResourceEfficiency,
],
test_size: 0.15,
validation_size: 0.15,
random_seed: Some(42),
}
}
pub fn production() -> Self {
Self {
cv_strategy: CrossValidationStrategy::StratifiedKFold { k: 5 },
metrics: vec![
EvaluationMetric::Accuracy,
EvaluationMetric::F1Score,
EvaluationMetric::InferenceTime,
EvaluationMetric::Robustness,
],
test_size: 0.2,
validation_size: 0.2,
random_seed: Some(42),
}
}
pub fn time_series() -> Self {
Self {
cv_strategy: CrossValidationStrategy::TimeSeriesSplit { n_splits: 5 },
metrics: vec![
EvaluationMetric::MeanAbsoluteError,
EvaluationMetric::MeanSquaredError,
EvaluationMetric::R2Score,
],
test_size: 0.2,
validation_size: 0.2,
random_seed: Some(42),
}
}
pub fn regression() -> Self {
Self {
cv_strategy: CrossValidationStrategy::KFold { k: 5 },
metrics: vec![
EvaluationMetric::MeanSquaredError,
EvaluationMetric::MeanAbsoluteError,
EvaluationMetric::R2Score,
],
test_size: 0.2,
validation_size: 0.2,
random_seed: Some(42),
}
}
pub fn quantum_focused() -> Self {
Self {
cv_strategy: CrossValidationStrategy::KFold { k: 5 },
metrics: vec![
EvaluationMetric::Accuracy,
EvaluationMetric::QuantumAdvantage,
EvaluationMetric::ResourceEfficiency,
EvaluationMetric::ModelComplexity,
],
test_size: 0.2,
validation_size: 0.2,
random_seed: Some(42),
}
}
}
impl EvaluationMetric {
pub fn is_maximization(&self) -> bool {
matches!(
self,
EvaluationMetric::Accuracy
| EvaluationMetric::Precision
| EvaluationMetric::Recall
| EvaluationMetric::F1Score
| EvaluationMetric::AUC
| EvaluationMetric::R2Score
| EvaluationMetric::QuantumAdvantage
| EvaluationMetric::ResourceEfficiency
| EvaluationMetric::Robustness
)
}
pub fn name(&self) -> &'static str {
match self {
EvaluationMetric::Accuracy => "accuracy",
EvaluationMetric::Precision => "precision",
EvaluationMetric::Recall => "recall",
EvaluationMetric::F1Score => "f1_score",
EvaluationMetric::AUC => "auc",
EvaluationMetric::MeanSquaredError => "mse",
EvaluationMetric::MeanAbsoluteError => "mae",
EvaluationMetric::R2Score => "r2_score",
EvaluationMetric::QuantumAdvantage => "quantum_advantage",
EvaluationMetric::ResourceEfficiency => "resource_efficiency",
EvaluationMetric::InferenceTime => "inference_time",
EvaluationMetric::TrainingTime => "training_time",
EvaluationMetric::ModelComplexity => "model_complexity",
EvaluationMetric::Robustness => "robustness",
}
}
}
impl CrossValidationStrategy {
pub fn n_splits(&self) -> usize {
match self {
CrossValidationStrategy::KFold { k } => *k,
CrossValidationStrategy::StratifiedKFold { k } => *k,
CrossValidationStrategy::TimeSeriesSplit { n_splits } => *n_splits,
CrossValidationStrategy::LeaveOneOut => 1, CrossValidationStrategy::Bootstrap { n_bootstrap } => *n_bootstrap,
CrossValidationStrategy::HoldOut { .. } => 1,
}
}
pub fn is_stratified(&self) -> bool {
matches!(self, CrossValidationStrategy::StratifiedKFold { .. })
}
pub fn is_time_series_safe(&self) -> bool {
matches!(
self,
CrossValidationStrategy::TimeSeriesSplit { .. }
| CrossValidationStrategy::HoldOut { .. }
)
}
}