use std;
use std::default::Default;
use super::Interval;
pub enum Objective {
RegLinear,
RegLogistic,
BinaryLogistic,
BinaryLogisticRaw,
GpuRegLinear,
GpuRegLogistic,
GpuBinaryLogistic,
GpuBinaryLogisticRaw,
CountPoisson,
SurvivalCox,
MultiSoftmax(u32),
MultiSoftprob(u32),
RankPairwise,
RegGamma,
RegTweedie(Option<f32>),
}
impl Copy for Objective {}
impl Clone for Objective {
fn clone(&self) -> Self { *self }
}
impl ToString for Objective {
fn to_string(&self) -> String {
match *self {
Objective::RegLinear => "reg:linear".to_owned(),
Objective::RegLogistic => "reg:logistic".to_owned(),
Objective::BinaryLogistic => "binary:logistic".to_owned(),
Objective::BinaryLogisticRaw => "binary:logitraw".to_owned(),
Objective::GpuRegLinear => "gpu:reg:linear".to_owned(),
Objective::GpuRegLogistic => "gpu:reg:logistic".to_owned(),
Objective::GpuBinaryLogistic => "gpu:binary:logistic".to_owned(),
Objective::GpuBinaryLogisticRaw => "gpu:binary:logitraw".to_owned(),
Objective::CountPoisson => "count:poisson".to_owned(),
Objective::SurvivalCox => "survival:cox".to_owned(),
Objective::MultiSoftmax(_) => "multi:softmax".to_owned(), Objective::MultiSoftprob(_) => "multi:softprob".to_owned(), Objective::RankPairwise => "rank:pairwise".to_owned(),
Objective::RegGamma => "reg:gamma".to_owned(),
Objective::RegTweedie(_) => "reg:tweedie".to_owned(),
}
}
}
impl Default for Objective {
fn default() -> Self { Objective::RegLinear }
}
#[derive(Clone)]
pub enum Metrics {
Auto,
Custom(Vec<EvaluationMetric>),
}
#[derive(Clone)]
pub enum EvaluationMetric {
RMSE,
MAE,
LogLoss,
BinaryErrorRate(f32),
MultiClassErrorRate,
MultiClassLogLoss,
AUC,
NDCG,
NDCGCut(u32),
NDCGNegative,
NDCGCutNegative(u32),
MAP,
MAPCut(u32),
MAPNegative,
MAPCutNegative(u32),
PoissonLogLoss,
GammaLogLoss,
CoxLogLoss,
GammaDeviance,
TweedieLogLoss,
}
impl ToString for EvaluationMetric {
fn to_string(&self) -> String {
match *self {
EvaluationMetric::RMSE => "rmse".to_owned(),
EvaluationMetric::MAE => "mae".to_owned(),
EvaluationMetric::LogLoss => "logloss".to_owned(),
EvaluationMetric::BinaryErrorRate(t) => {
if (t - 0.5).abs() < std::f32::EPSILON {
"error".to_owned()
} else {
format!("error@{}", t)
}
},
EvaluationMetric::MultiClassErrorRate => "merror".to_owned(),
EvaluationMetric::MultiClassLogLoss => "mlogloss".to_owned(),
EvaluationMetric::AUC => "auc".to_owned(),
EvaluationMetric::NDCG => "ndcg".to_owned(),
EvaluationMetric::NDCGCut(n) => format!("ndcg@{}", n),
EvaluationMetric::NDCGNegative => "ndcg-".to_owned(),
EvaluationMetric::NDCGCutNegative(n) => format!("ndcg@{}-", n),
EvaluationMetric::MAP => "map".to_owned(),
EvaluationMetric::MAPCut(n) => format!("map@{}", n),
EvaluationMetric::MAPNegative => "map-".to_owned(),
EvaluationMetric::MAPCutNegative(n) => format!("map@{}-", n),
EvaluationMetric::PoissonLogLoss => "poisson-nloglik".to_owned(),
EvaluationMetric::GammaLogLoss => "gamma-nloglik".to_owned(),
EvaluationMetric::CoxLogLoss => "cox-nloglik".to_owned(),
EvaluationMetric::GammaDeviance => "gamma-deviance".to_owned(),
EvaluationMetric::TweedieLogLoss => "tweedie-nloglik".to_owned(),
}
}
}
#[derive(Builder, Clone)]
#[builder(build_fn(validate = "Self::validate"))]
#[builder(default)]
pub struct LearningTaskParameters {
pub(crate) objective: Objective,
base_score: f32,
pub(crate) eval_metrics: Metrics,
seed: u64,
}
impl Default for LearningTaskParameters {
fn default() -> Self {
LearningTaskParameters {
objective: Objective::default(),
base_score: 0.5,
eval_metrics: Metrics::Auto,
seed: 0,
}
}
}
impl LearningTaskParameters {
pub fn objective(&self) -> &Objective {
&self.objective
}
pub fn set_objective<T: Into<Objective>>(&mut self, objective: T) {
self.objective = objective.into();
}
pub fn base_score(&self) -> f32 {
self.base_score
}
pub fn set_base_score(&mut self, base_score: f32) {
self.base_score = base_score;
}
pub fn eval_metrics(&self) -> &Metrics {
&self.eval_metrics
}
pub fn set_eval_metrics<T: Into<Metrics>>(&mut self, eval_metrics: T) {
self.eval_metrics = eval_metrics.into();
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn set_seed(&mut self, seed: u64) {
self.seed = seed;
}
pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
let mut v = Vec::new();
if let Objective::MultiSoftmax(n) = self.objective {
v.push(("num_class".to_owned(), n.to_string()));
} else if let Objective::MultiSoftprob(n) = self.objective {
v.push(("num_class".to_owned(), n.to_string()));
} else if let Objective::RegTweedie(Some(n)) = self.objective {
v.push(("tweedie_variance_power".to_owned(), n.to_string()));
}
v.push(("objective".to_owned(), self.objective.to_string()));
v.push(("base_score".to_owned(), self.base_score.to_string()));
v.push(("seed".to_owned(), self.seed.to_string()));
if let Metrics::Custom(eval_metrics) = &self.eval_metrics {
for metric in eval_metrics {
v.push(("eval_metric".to_owned(), metric.to_string()));
}
}
v
}
}
impl LearningTaskParametersBuilder {
fn validate(&self) -> Result<(), String> {
if let Some(Objective::RegTweedie(variance_power)) = self.objective {
Interval::new_closed_closed(1.0, 2.0).validate(&variance_power, "tweedie_variance_power")?;
}
Ok(())
}
}