//! Auto-generated module
//!
//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use super::types::{Distribution, EnsembleConfig, EnsembleType, ExplanationMethod, ExplorationStrategy, ExtractionPerformance, ExtractorType, FairnessMetrics, FeatureExtractor, FeatureNormalization, FeatureSelectionCriteria, FeatureStatistics, MLModel, MLModelType, MLOptimizationEngine, ModelComplexityMetrics, ModelPerformance, OnlineLearningConfig, PerformanceTrend, PotentialFunction, RLAgentType, RLLearningParams, ReinforcementLearningAgent, RewardShaping, SearchSpace, StabilityMetrics, TrainingExample, TrainingStatus, ValidationMetrics, ValidationSplit, VotingStrategy};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_engine_creation() {
let engine = MLOptimizationEngine::new();
assert!(engine.models.is_empty());
assert!(engine.training_data.is_empty());
assert!(engine.feature_extractors.is_empty());
}
#[test]
fn test_model_addition() {
let mut engine = MLOptimizationEngine::new();
let model = MLModel {
id: "test_model".to_string(),
model_type: MLModelType::LinearRegression,
parameters: HashMap::new(),
hyperparameters: HashMap::new(),
training_status: TrainingStatus::Untrained,
accuracy: 0.0,
confidence: 0.0,
last_training: Instant::now(),
prediction_history: Vec::new(),
complexity_metrics: ModelComplexityMetrics {
parameter_count: 100,
flops: 1000,
memory_footprint: 1024,
depth: 1,
branching_factor: 1.0,
effective_complexity: 0.5,
},
interpretability_score: 0.9,
training_duration: Duration::from_secs(60),
model_size: 1024,
validation_metrics: ValidationMetrics {
cross_validation_score: 0.8,
hold_out_score: 0.78,
bootstrap_score: 0.82,
time_series_score: 0.75,
consistency_score: 0.85,
},
};
engine.add_model(model);
assert_eq!(engine.models.len(), 1);
assert!(engine.models.contains_key("test_model"));
}
#[test]
fn test_training_data_addition() {
let mut engine = MLOptimizationEngine::new();
let example = TrainingExample {
features: {
let mut features = HashMap::new();
features.insert("memory_usage".to_string(), 0.7);
features.insert("allocation_frequency".to_string(), 10.0);
features
},
targets: {
let mut targets = HashMap::new();
targets.insert("performance".to_string(), 0.85);
targets
},
weight: 1.0,
timestamp: Instant::now(),
source: "test".to_string(),
quality_score: 0.9,
metadata: HashMap::new(),
feature_correlations: HashMap::new(),
difficulty_score: 0.5,
validation_split: ValidationSplit::Train,
};
engine.add_training_data(example);
assert_eq!(engine.training_data.len(), 1);
}
#[test]
fn test_feature_extractor_types() {
let extractors = vec![
ExtractorType::Statistical, ExtractorType::Temporal,
ExtractorType::Performance, ExtractorType::MemoryUsage,
];
for extractor_type in extractors {
let extractor = FeatureExtractor {
name: format!("{:?}_extractor", extractor_type),
extractor_type,
importance: 0.8,
parameters: HashMap::new(),
feature_stats: FeatureStatistics {
mean: 0.5,
std_dev: 0.2,
min: 0.0,
max: 1.0,
median: 0.5,
quartiles: (0.25, 0.5, 0.75),
skewness: 0.0,
kurtosis: 0.0,
missing_rate: 0.0,
unique_values: 100,
},
extraction_performance: ExtractionPerformance {
extraction_time: Duration::from_millis(10),
memory_usage: 1024,
success_rate: 0.99,
error_rate: 0.01,
throughput: 1000.0,
},
normalization: FeatureNormalization::ZScoreNormalization,
selection_criteria: FeatureSelectionCriteria {
min_importance: 0.1,
max_correlation: 0.9,
information_gain_threshold: 0.1,
variance_threshold: 0.01,
stability_score: 0.8,
},
dependencies: Vec::new(),
validation_rules: Vec::new(),
};
assert_eq!(extractor.extractor_type, extractor_type);
}
}
#[test]
fn test_model_types() {
let model_types = vec![
MLModelType::LinearRegression, MLModelType::DecisionTree,
MLModelType::RandomForest, MLModelType::NeuralNetwork,
MLModelType::ReinforcementLearning,
];
for model_type in model_types {
assert_ne!(model_type, MLModelType::EnsembleModel);
}
}
#[test]
fn test_online_learning_config() {
let config = OnlineLearningConfig::default();
assert!(config.enabled);
assert!(config.learning_rate > 0.0);
assert!(config.batch_size > 0);
assert!(config.min_examples > 0);
}
#[test]
fn test_model_performance_metrics() {
let performance = ModelPerformance {
accuracy: 0.85,
mse: 0.1,
mae: 0.08,
rmse: 0.316,
r_squared: 0.72,
adjusted_r_squared: 0.70,
precision: 0.87,
recall: 0.83,
f1_score: 0.85,
auc_roc: 0.89,
auc_pr: 0.86,
trend: PerformanceTrend::Improving,
cv_scores: vec![0.84, 0.86, 0.85],
training_time: Duration::from_secs(300),
inference_time: Duration::from_millis(10),
training_memory: 512 * 1024 * 1024,
inference_memory: 64 * 1024 * 1024,
stability_metrics: StabilityMetrics {
prediction_variance: 0.05,
feature_importance_stability: 0.92,
cross_validation_stability: 0.88,
temporal_stability: 0.90,
robustness_score: 0.87,
},
fairness_metrics: FairnessMetrics {
demographic_parity: 0.95,
equalized_odds: 0.93,
individual_fairness: 0.91,
group_fairness: 0.94,
bias_score: 0.1,
},
};
assert!(performance.accuracy > 0.8);
assert!(performance.f1_score > 0.8);
assert_eq!(performance.trend, PerformanceTrend::Improving);
}
#[test]
fn test_ensemble_configuration() {
let config = EnsembleConfig::default();
assert_eq!(config.ensemble_type, EnsembleType::Voting);
assert_eq!(config.voting_strategy, VotingStrategy::Weighted);
assert!(config.max_models > 0);
}
#[test]
fn test_explanation_methods() {
let methods = vec![
ExplanationMethod::SHAP, ExplanationMethod::LIME,
ExplanationMethod::PermutationImportance,
];
for method in methods {
assert_ne!(method, ExplanationMethod::Anchors);
}
}
#[test]
fn test_hyperparameter_search_space() {
let continuous_space = SearchSpace::Continuous {
min: 0.001,
max: 1.0,
distribution: Distribution::LogUniform,
};
let integer_space = SearchSpace::Integer {
min: 1,
max: 100,
};
let categorical_space = SearchSpace::Categorical {
choices: vec!["relu".to_string(), "tanh".to_string(), "sigmoid".to_string(),],
};
match continuous_space {
SearchSpace::Continuous { min, max, .. } => {
assert!(min < max);
}
_ => panic!("Expected continuous search space"),
}
match integer_space {
SearchSpace::Integer { min, max } => {
assert!(min < max);
}
_ => panic!("Expected integer search space"),
}
match categorical_space {
SearchSpace::Categorical { choices } => {
assert_eq!(choices.len(), 3);
}
_ => panic!("Expected categorical search space"),
}
}
#[test]
fn test_reinforcement_learning_agent() {
let agent = ReinforcementLearningAgent {
agent_type: RLAgentType::QLearning,
weights: HashMap::new(),
learning_params: RLLearningParams {
learning_rate: 0.1,
discount_factor: 0.99,
epsilon: 0.1,
epsilon_decay: 0.995,
epsilon_min: 0.01,
target_update_frequency: 100,
replay_buffer_size: 10000,
batch_size: 32,
tau: 0.005,
},
replay_buffer: VecDeque::new(),
exploration_strategy: ExplorationStrategy::EpsilonGreedy {
epsilon: 0.1,
decay_rate: 0.995,
},
reward_shaping: RewardShaping {
potential_function: PotentialFunction::Linear {
coefficients: vec![1.0, - 0.5, 0.3],
},
shaping_factor: 0.1,
intrinsic_motivation: true,
curiosity_driven: true,
},
policy_network: None,
value_network: None,
target_networks: None,
coordination: None,
};
assert_eq!(agent.agent_type, RLAgentType::QLearning);
assert!(agent.learning_params.learning_rate > 0.0);
assert!(agent.learning_params.discount_factor < 1.0);
}
}