use scirs2_core::ndarray::Array1;
#[derive(Debug, Clone)]
pub struct BifurcationEnsembleClassifier {
pub base_classifiers: Vec<BaseClassifier>,
pub meta_learner: Option<MetaLearner>,
pub training_strategy: EnsembleTrainingStrategy,
pub cross_validation: CrossValidationConfig,
pub feature_selection: FeatureSelectionConfig,
}
#[derive(Debug, Clone)]
pub enum BaseClassifier {
NeuralNetwork(Box<super::neural_network::BifurcationPredictionNetwork>),
RandomForest {
n_trees: usize,
max_depth: Option<usize>,
min_samples_split: usize,
min_samples_leaf: usize,
},
SVM {
kernel: SVMKernel,
c_parameter: f64,
gamma: Option<f64>,
},
GradientBoosting {
n_estimators: usize,
learning_rate: f64,
max_depth: usize,
subsample: f64,
},
KNN {
n_neighbors: usize,
weights: KNNWeights,
distance_metric: DistanceMetric,
},
}
#[derive(Debug, Clone, Copy)]
pub enum SVMKernel {
Linear,
RBF,
Polynomial(usize), Sigmoid,
}
#[derive(Debug, Clone, Copy)]
pub enum KNNWeights {
Uniform,
Distance,
}
#[derive(Debug, Clone, Copy)]
pub enum DistanceMetric {
Euclidean,
Manhattan,
Minkowski(f64), Cosine,
Hamming,
}
#[derive(Debug, Clone)]
pub enum MetaLearner {
LinearCombination { weights: Array1<f64> },
LogisticRegression { regularization: f64 },
NeuralNetwork { hidden_layers: Vec<usize> },
DecisionTree { max_depth: Option<usize> },
}
#[derive(Debug, Clone)]
pub enum EnsembleTrainingStrategy {
FullDataset,
Bagging { n_samples: usize, replacement: bool },
CrossValidation { n_folds: usize, stratified: bool },
Stacking { holdout_ratio: f64 },
}
#[derive(Debug, Clone)]
pub struct CrossValidationConfig {
pub n_folds: usize,
pub stratified: bool,
pub random_seed: Option<u64>,
pub shuffle: bool,
}
#[derive(Debug, Clone)]
pub struct FeatureSelectionConfig {
pub methods: Vec<FeatureSelectionMethod>,
pub n_features: Option<usize>,
pub threshold: Option<f64>,
pub cross_validate: bool,
}
#[derive(Debug, Clone)]
pub enum FeatureSelectionMethod {
UnivariateSelection { score_func: ScoreFunction },
RecursiveElimination {
estimator: String, },
L1BasedSelection { alpha: f64 },
TreeBasedSelection { importance_threshold: f64 },
MutualInformation,
PCA { n_components: usize },
}
#[derive(Debug, Clone, Copy)]
pub enum ScoreFunction {
FClassif,
Chi2,
MutualInfoClassif,
FRegression,
MutualInfoRegression,
}