#[cfg(all(feature = "mesalock_sgx", not(target_env = "sgx")))]
use std::prelude::v1::*;
use crate::decision_tree::ValueType;
use serde_derive::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum Loss {
SquaredError,
LogLikelyhood,
LAD,
RegLinear,
RegLogistic,
BinaryLogistic,
BinaryLogitraw,
MultiSoftprob,
MultiSoftmax,
RankPairwise,
}
impl Default for Loss {
fn default() -> Self {
Loss::SquaredError
}
}
pub fn string2loss(s: &str) -> Loss {
match s {
"LogLikelyhood" => Loss::LogLikelyhood,
"SquaredError" => Loss::SquaredError,
"LAD" => Loss::LAD,
"reg:linear" => Loss::RegLinear,
"binary:logistic" => Loss::BinaryLogistic,
"reg:logistic" => Loss::RegLogistic,
"binary:logitraw" => Loss::BinaryLogitraw,
"multi:softprob" => Loss::MultiSoftprob,
"multi:softmax" => Loss::MultiSoftmax,
"rank:pairwise" => Loss::RankPairwise,
_ => {
println!("unsupported loss, set to default(SquaredError)");
Loss::SquaredError
}
}
}
pub fn loss2string(l: &Loss) -> String {
match l {
Loss::LogLikelyhood => String::from("LogLikelyhood"),
Loss::SquaredError => String::from("SquaredError"),
Loss::LAD => String::from("LAD"),
Loss::RegLinear => String::from("reg:linear"),
Loss::BinaryLogistic => String::from("binary:logistic"),
Loss::RegLogistic => String::from("reg:logistic"),
Loss::BinaryLogitraw => String::from("binary:logitraw"),
Loss::MultiSoftprob => String::from("multi:softprob"),
Loss::MultiSoftmax => String::from("multi:softmax"),
Loss::RankPairwise => String::from("rank:pairwise"),
}
}
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Config {
pub feature_size: usize,
pub max_depth: u32,
pub iterations: usize,
pub shrinkage: ValueType,
pub feature_sample_ratio: f64,
pub data_sample_ratio: f64,
pub min_leaf_size: usize,
pub loss: Loss,
pub debug: bool,
pub initial_guess_enabled: bool,
pub training_optimization_level: u8,
}
impl Config {
pub fn new() -> Config {
Config {
feature_size: 1,
max_depth: 2,
iterations: 2,
shrinkage: 1.0,
feature_sample_ratio: 1.0,
data_sample_ratio: 1.0,
min_leaf_size: 1,
loss: Loss::SquaredError,
debug: false,
initial_guess_enabled: false,
training_optimization_level: 2,
}
}
pub fn set_feature_size(&mut self, n: usize) {
self.feature_size = n;
}
pub fn set_shrinkage(&mut self, eta: ValueType) {
self.shrinkage = eta;
}
pub fn set_training_optimization_level(&mut self, level: u8) {
let optimization_level = if level >= 3 { 2 } else { level };
self.training_optimization_level = optimization_level;
}
pub fn set_max_depth(&mut self, n: u32) {
self.max_depth = n;
}
pub fn set_iterations(&mut self, n: usize) {
self.iterations = n;
}
pub fn set_feature_sample_ratio(&mut self, n: f64) {
self.feature_sample_ratio = n;
}
pub fn set_data_sample_ratio(&mut self, n: f64) {
self.data_sample_ratio = n;
}
pub fn set_min_leaf_size(&mut self, n: usize) {
self.min_leaf_size = n;
}
pub fn set_loss(&mut self, l: &str) {
self.loss = string2loss(&l);
}
pub fn set_debug(&mut self, option: bool) {
self.debug = option;
}
pub fn enabled_initial_guess(&mut self, option: bool) {
self.initial_guess_enabled = option;
}
pub fn to_string(&self) -> String {
let mut s = String::from("");
s.push_str(&format!("number of features = {}\n", self.feature_size));
s.push_str(&format!("min leaf size = {}\n", self.min_leaf_size));
s.push_str(&format!("maximum depth = {}\n", self.max_depth));
s.push_str(&format!("iterations = {}\n", self.iterations));
s.push_str(&format!("shrinkage = {}\n", self.shrinkage));
s.push_str(&format!(
"feature sample ratio = {}\n",
self.feature_sample_ratio
));
s.push_str(&format!("data sample ratio = {}\n", self.data_sample_ratio));
s.push_str(&format!("debug enabled = {}\n", self.debug));
s.push_str(&format!("loss type = {}\n", loss2string(&self.loss)));
s.push_str(&format!(
"initial guess enabled = {}\n",
self.initial_guess_enabled
));
s
}
}