use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLOptimizationConfig {
pub enable_optimization: bool,
pub model_config: MLModelConfig,
pub feature_extraction: FeatureExtractionConfig,
pub hardware_prediction: HardwarePredictionConfig,
pub online_learning: OnlineLearningConfig,
pub transfer_learning: TransferLearningConfig,
pub ensemble_config: EnsembleConfig,
pub optimization_strategy: OptimizationStrategyConfig,
pub validation_config: MLValidationConfig,
pub monitoring_config: MLMonitoringConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLModelConfig {
pub primary_algorithms: Vec<MLAlgorithm>,
pub fallback_algorithms: Vec<MLAlgorithm>,
pub hyperparameters: HashMap<String, MLHyperparameter>,
pub training_config: TrainingConfig,
pub model_selection: ModelSelectionStrategy,
pub regularization: RegularizationConfig,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MLAlgorithm {
DeepNeuralNetwork,
RandomForest,
GradientBoosting,
SupportVectorMachine,
GaussianProcess,
EnsembleMethods,
ReinforcementLearning,
QuantumNeuralNetwork,
GraphNeuralNetwork,
TransformerNetwork,
BayesianNetwork,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLHyperparameter {
pub parameter_type: HyperparameterType,
pub value: HyperparameterValue,
pub search_space: Option<HyperparameterSearchSpace>,
pub importance: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum HyperparameterType {
Integer,
Float,
Categorical,
Boolean,
Array,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HyperparameterValue {
Integer(i64),
Float(f64),
Categorical(String),
Boolean(bool),
Array(Array1<f64>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HyperparameterSearchSpace {
IntegerRange(i64, i64),
FloatRange(f64, f64),
CategoricalOptions(Vec<String>),
BooleanOptions,
ArrayBounds(Vec<(f64, f64)>),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelSelectionStrategy {
CrossValidation,
HoldoutValidation,
BootstrapValidation,
TimeSeriesValidation,
BayesianModelSelection,
EnsembleSelection,
}
use super::{
ensemble::EnsembleConfig,
features::FeatureExtractionConfig,
hardware::HardwarePredictionConfig,
monitoring::MLMonitoringConfig,
online_learning::OnlineLearningConfig,
optimization::OptimizationStrategyConfig,
training::{RegularizationConfig, TrainingConfig},
transfer_learning::TransferLearningConfig,
validation::MLValidationConfig,
};