use std::fmt;
use std::hash::Hash;
pub trait ParamKey: Copy + Clone + Eq + Hash + fmt::Debug {
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RandomForestParam {
NEstimators,
MaxDepth,
MinSamplesSplit,
MinSamplesLeaf,
MaxFeatures,
Bootstrap,
}
impl ParamKey for RandomForestParam {
fn name(&self) -> &'static str {
match self {
Self::NEstimators => "n_estimators",
Self::MaxDepth => "max_depth",
Self::MinSamplesSplit => "min_samples_split",
Self::MinSamplesLeaf => "min_samples_leaf",
Self::MaxFeatures => "max_features",
Self::Bootstrap => "bootstrap",
}
}
}
impl fmt::Display for RandomForestParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GradientBoostingParam {
NEstimators,
LearningRate,
MaxDepth,
Subsample,
MinSamplesLeaf,
}
impl ParamKey for GradientBoostingParam {
fn name(&self) -> &'static str {
match self {
Self::NEstimators => "n_estimators",
Self::LearningRate => "learning_rate",
Self::MaxDepth => "max_depth",
Self::Subsample => "subsample",
Self::MinSamplesLeaf => "min_samples_leaf",
}
}
}
impl fmt::Display for GradientBoostingParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KNNParam {
NNeighbors,
Weights,
Metric,
LeafSize,
}
impl ParamKey for KNNParam {
fn name(&self) -> &'static str {
match self {
Self::NNeighbors => "n_neighbors",
Self::Weights => "weights",
Self::Metric => "metric",
Self::LeafSize => "leaf_size",
}
}
}
impl fmt::Display for KNNParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LinearParam {
Alpha,
L1Ratio,
FitIntercept,
MaxIter,
Tol,
}
impl ParamKey for LinearParam {
fn name(&self) -> &'static str {
match self {
Self::Alpha => "alpha",
Self::L1Ratio => "l1_ratio",
Self::FitIntercept => "fit_intercept",
Self::MaxIter => "max_iter",
Self::Tol => "tol",
}
}
}
impl fmt::Display for LinearParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DecisionTreeParam {
MaxDepth,
MinSamplesSplit,
MinSamplesLeaf,
Criterion,
Splitter,
}
impl ParamKey for DecisionTreeParam {
fn name(&self) -> &'static str {
match self {
Self::MaxDepth => "max_depth",
Self::MinSamplesSplit => "min_samples_split",
Self::MinSamplesLeaf => "min_samples_leaf",
Self::Criterion => "criterion",
Self::Splitter => "splitter",
}
}
}
impl fmt::Display for DecisionTreeParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KMeansParam {
NClusters,
MaxIter,
NInit,
Init,
Tol,
}
impl ParamKey for KMeansParam {
fn name(&self) -> &'static str {
match self {
Self::NClusters => "n_clusters",
Self::MaxIter => "max_iter",
Self::NInit => "n_init",
Self::Init => "init",
Self::Tol => "tol",
}
}
}
impl fmt::Display for KMeansParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[cfg(test)]
#[path = "params_tests.rs"]
mod tests;