#[cfg(all(feature = "mesalock_sgx", not(target_env = "sgx")))]
use std::prelude::v1::*;
use crate::binary_tree::BinaryTree;
use crate::binary_tree::BinaryTreeNode;
use crate::binary_tree::TreeIndex;
use crate::config::Loss;
#[cfg(feature = "enable_training")]
use crate::fitness::almost_equal;
use std::error::Error;
#[cfg(feature = "enable_training")]
use rand::prelude::SliceRandom;
#[cfg(feature = "enable_training")]
use rand::thread_rng;
use serde_derive::{Deserialize, Serialize};
macro_rules! def_value_type {
($t: tt) => {
pub type ValueType = $t;
pub const VALUE_TYPE_MAX: ValueType = std::$t::MAX;
pub const VALUE_TYPE_MIN: ValueType = std::$t::MIN;
pub const VALUE_TYPE_UNKNOWN: ValueType = VALUE_TYPE_MIN;
};
}
def_value_type!(f32);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Data {
pub feature: Vec<ValueType>,
pub target: ValueType,
pub weight: ValueType,
pub label: ValueType,
pub residual: ValueType,
pub initial_guess: ValueType,
}
impl Data {
pub fn new_training_data(
feature: Vec<ValueType>,
weight: ValueType,
label: ValueType,
initial_guess: Option<ValueType>,
) -> Self {
Data {
feature,
target: label,
weight,
label,
residual: label,
initial_guess: initial_guess.unwrap_or(0.0),
}
}
pub fn new_test_data(feature: Vec<ValueType>, label: Option<ValueType>) -> Self {
Data {
feature,
target: 0.0,
weight: 1.0,
label: label.unwrap_or(0.0),
residual: 0.0,
initial_guess: 0.0,
}
}
}
pub type DataVec = Vec<Data>;
pub type PredVec = Vec<ValueType>;
#[cfg(feature = "enable_training")]
struct ImpurityCache {
sum_s: f64,
sum_ss: f64,
sum_c: f64,
cached: bool,
bool_vec: Vec<bool>,
}
#[cfg(feature = "enable_training")]
impl ImpurityCache {
fn new(sample_size: usize, train_data: &[usize]) -> Self {
let mut bool_vec: Vec<bool> = vec![false; sample_size];
for index in train_data.iter() {
bool_vec[*index] = true;
}
ImpurityCache {
sum_s: 0.0,
sum_ss: 0.0,
sum_c: 0.0,
cached: false, bool_vec,
}
}
}
#[cfg(feature = "enable_training")]
struct CacheValue {
s: f64,
ss: f64,
c: f64,
}
#[cfg(feature = "enable_training")]
pub struct TrainingCache {
ordered_features: Vec<Vec<(usize, ValueType)>>,
ordered_residual: Vec<(usize, ValueType)>,
cache_value: Vec<CacheValue>, cache_target: Vec<ValueType>,
logit_c: Vec<ValueType>,
sample_size: usize,
feature_size: usize,
preds: Vec<ValueType>,
cache_level: u8,
}
#[cfg(feature = "enable_training")]
impl TrainingCache {
pub fn get_cache(feature_size: usize, data: &DataVec, cache_level: u8) -> Self {
let level = if cache_level >= 3 { 2 } else { cache_level };
let sample_size = data.len();
let logit_c = vec![0.0; data.len()];
let preds = vec![VALUE_TYPE_UNKNOWN; sample_size];
let mut cache_value = Vec::with_capacity(data.len());
for elem in data {
let item = CacheValue {
s: 0.0,
ss: 0.0,
c: f64::from(elem.weight),
};
cache_value.push(item);
}
let ordered_features: Vec<Vec<(usize, ValueType)>> = if (level == 0) || (level == 2) {
TrainingCache::cache_features(data, feature_size)
} else {
Vec::new()
};
let ordered_residual: Vec<(usize, ValueType)> = Vec::new();
let cache_target: Vec<ValueType> = vec![0.0; data.len()];
TrainingCache {
ordered_features,
ordered_residual,
cache_value,
cache_target,
logit_c,
sample_size,
feature_size,
preds,
cache_level: level,
}
}
pub fn get_preds(&self) -> Vec<ValueType> {
self.preds.to_vec()
}
fn init_one_iteration(&mut self, whole_data: &[Data], loss: &Loss) {
for (index, data) in whole_data.iter().enumerate() {
let target = data.target;
self.cache_target[index] = target;
let weight = f64::from(data.weight);
let target = f64::from(target);
let s = target * weight;
self.cache_value[index].s = s;
self.cache_value[index].ss = target * s;
if let Loss::LogLikelyhood = loss {
let y = target.abs();
let c = y * (2.0 - y) * weight;
self.logit_c[index] = c as ValueType;
}
}
if let Loss::LAD = loss {
self.ordered_residual = TrainingCache::cache_residual(whole_data);
}
}
fn cache_features(whole_data: &[Data], feature_size: usize) -> Vec<Vec<(usize, ValueType)>> {
let mut ordered_features = Vec::with_capacity(feature_size);
for _index in 0..feature_size {
let nv: Vec<(usize, ValueType)> = Vec::with_capacity(whole_data.len());
ordered_features.push(nv);
}
for (i, item) in whole_data.iter().enumerate() {
for (index, ordered_item) in ordered_features.iter_mut().enumerate().take(feature_size)
{
ordered_item.push((i, item.feature[index]));
}
}
for item in ordered_features.iter_mut().take(feature_size) {
item.sort_unstable_by(|a, b| {
let v1 = a.1;
let v2 = b.1;
v1.partial_cmp(&v2).unwrap()
});
}
ordered_features
}
fn cache_residual(whole_data: &[Data]) -> Vec<(usize, ValueType)> {
let mut ordered_residual = Vec::with_capacity(whole_data.len());
for (index, elem) in whole_data.iter().enumerate() {
ordered_residual.push((index, elem.residual));
}
ordered_residual.sort_unstable_by(|a, b| {
let v1: ValueType = a.1;
let v2: ValueType = b.1;
v1.partial_cmp(&v2).unwrap()
});
ordered_residual
}
fn sort_with_bool_vec(
&self,
feature_index: usize,
is_residual: bool,
to_sort: &[bool],
to_sort_size: usize,
sub_cache: &SubCache,
) -> Vec<(usize, ValueType)> {
let whole_data_sorted_index = if is_residual {
if (self.cache_level == 0) || sub_cache.lazy {
&self.ordered_residual
} else {
&sub_cache.ordered_residual
}
} else if (self.cache_level == 0) || sub_cache.lazy {
&self.ordered_features[feature_index]
} else {
&sub_cache.ordered_features[feature_index]
};
if whole_data_sorted_index.len() == to_sort_size {
return whole_data_sorted_index.to_vec();
}
let mut ret = Vec::with_capacity(to_sort_size);
for item in whole_data_sorted_index.iter() {
let (index, value) = *item;
if to_sort[index] {
ret.push((index, value));
}
}
ret
}
fn sort_with_cache(
&self,
feature_index: usize,
is_residual: bool,
to_sort: &[usize],
sub_cache: &SubCache,
) -> Vec<(usize, ValueType)> {
let whole_data_sorted_index = if is_residual {
&self.ordered_residual
} else {
&self.ordered_features[feature_index]
};
let mut index_exists: Vec<bool> = vec![false; whole_data_sorted_index.len()];
for index in to_sort.iter() {
index_exists[*index] = true;
}
self.sort_with_bool_vec(
feature_index,
is_residual,
&index_exists,
to_sort.len(),
sub_cache,
)
}
}
#[cfg(feature = "enable_training")]
struct SubCache {
ordered_features: Vec<Vec<(usize, ValueType)>>,
ordered_residual: Vec<(usize, ValueType)>,
lazy: bool,
}
#[cfg(feature = "enable_training")]
impl SubCache {
fn get_cache_from_training_cache(cache: &TrainingCache, data: &[Data], loss: &Loss) -> Self {
let level = cache.cache_level;
if level == 2 {
return SubCache {
ordered_features: Vec::new(),
ordered_residual: Vec::new(),
lazy: true,
};
}
let ordered_features = if level == 0 {
Vec::new()
} else if level == 1 {
TrainingCache::cache_features(data, cache.feature_size)
} else {
let mut ordered_features: Vec<Vec<(usize, ValueType)>> =
Vec::with_capacity(cache.feature_size);
for index in 0..cache.feature_size {
ordered_features.push(cache.ordered_features[index].to_vec());
}
ordered_features
};
let ordered_residual = if level == 0 {
Vec::new()
} else if level == 1 {
if let Loss::LAD = loss {
TrainingCache::cache_residual(data)
} else {
Vec::new()
}
} else {
if let Loss::LAD = loss {
cache.ordered_residual.to_vec()
} else {
Vec::new()
}
};
SubCache {
ordered_features,
ordered_residual,
lazy: false,
}
}
fn get_empty() -> Self {
SubCache {
ordered_features: Vec::new(),
ordered_residual: Vec::new(),
lazy: false,
}
}
fn split_cache(
mut self,
left_set: &[usize],
right_set: &[usize],
cache: &TrainingCache,
) -> (Self, Self) {
if cache.cache_level == 0 {
return (SubCache::get_empty(), SubCache::get_empty());
}
let mut left_ordered_features: Vec<Vec<(usize, ValueType)>> =
Vec::with_capacity(cache.feature_size);
let mut right_ordered_features: Vec<Vec<(usize, ValueType)>> =
Vec::with_capacity(cache.feature_size);
let mut left_ordered_residual = Vec::with_capacity(left_set.len());
let mut right_ordered_residual = Vec::with_capacity(right_set.len());
for _ in 0..cache.feature_size {
left_ordered_features.push(Vec::with_capacity(left_set.len()));
right_ordered_features.push(Vec::with_capacity(right_set.len()));
}
let mut left_bool = vec![false; cache.sample_size];
let mut right_bool = vec![false; cache.sample_size];
for index in left_set.iter() {
left_bool[*index] = true;
}
for index in right_set.iter() {
right_bool[*index] = true;
}
if self.lazy {
for (feature_index, feature_vec) in cache.ordered_features.iter().enumerate() {
for pair in feature_vec.iter() {
let (index, value) = *pair;
if left_bool[index] {
left_ordered_features[feature_index].push((index, value));
continue;
}
if right_bool[index] {
right_ordered_features[feature_index].push((index, value));
}
}
}
} else {
for feature_index in 0..self.ordered_features.len() {
let feature_vec = &mut self.ordered_features[feature_index];
for pair in feature_vec.iter() {
let (index, value) = *pair;
if left_bool[index] {
left_ordered_features[feature_index].push((index, value));
continue;
}
if right_bool[index] {
right_ordered_features[feature_index].push((index, value));
}
}
feature_vec.clear();
feature_vec.shrink_to_fit();
}
self.ordered_features.clear();
self.ordered_features.shrink_to_fit();
}
if self.lazy {
for pair in cache.ordered_residual.iter() {
let (index, value) = *pair;
if left_bool[index] {
left_ordered_residual.push((index, value));
continue;
}
if right_bool[index] {
right_ordered_residual.push((index, value));
}
}
} else {
for pair in self.ordered_residual.into_iter() {
let (index, value) = pair;
if left_bool[index] {
left_ordered_residual.push((index, value));
continue;
}
if right_bool[index] {
right_ordered_residual.push((index, value));
}
}
}
(
SubCache {
ordered_features: left_ordered_features,
ordered_residual: left_ordered_residual,
lazy: false,
},
SubCache {
ordered_features: right_ordered_features,
ordered_residual: right_ordered_residual,
lazy: false,
},
)
}
}
#[cfg(feature = "enable_training")]
fn calculate_pred(
data: &[usize],
loss: &Loss,
cache: &TrainingCache,
sub_cache: &SubCache,
) -> ValueType {
match loss {
Loss::SquaredError => average(data, cache),
Loss::LogLikelyhood => logit_optimal_value(data, cache),
Loss::LAD => lad_optimal_value(data, cache, sub_cache),
_ => average(data, cache),
}
}
#[cfg(feature = "enable_training")]
fn average(data: &[usize], cache: &TrainingCache) -> ValueType {
let mut sum: f64 = 0.0;
let mut weight: f64 = 0.0;
for index in data.iter() {
let cv: &CacheValue = &cache.cache_value[*index];
sum += cv.s;
weight += cv.c;
}
if weight.abs() < 1e-10 {
0.0
} else {
(sum / weight) as ValueType
}
}
#[cfg(feature = "enable_training")]
fn logit_optimal_value(data: &[usize], cache: &TrainingCache) -> ValueType {
let mut s: f64 = 0.0;
let mut c: f64 = 0.0;
for index in data.iter() {
s += cache.cache_value[*index].s;
c += f64::from(cache.logit_c[*index]);
}
if c.abs() < 1e-10 {
0.0
} else {
(s / c) as ValueType
}
}
#[cfg(feature = "enable_training")]
fn lad_optimal_value(data: &[usize], cache: &TrainingCache, sub_cache: &SubCache) -> ValueType {
let sorted_data = cache.sort_with_cache(0, true, data, sub_cache);
let all_weight = sorted_data
.iter()
.fold(0.0f64, |acc, x| acc + cache.cache_value[x.0].c);
let mut weighted_median: f64 = 0.0;
let mut weight: f64 = 0.0;
for (i, pair) in sorted_data.iter().enumerate() {
weight += cache.cache_value[pair.0].c;
if (weight * 2.0) > all_weight {
if i >= 1 {
weighted_median = f64::from((pair.1 + sorted_data[i - 1].1) / 2.0);
} else {
weighted_median = f64::from(pair.1);
}
break;
}
}
weighted_median as ValueType
}
#[allow(unused)]
#[cfg(feature = "enable_training")]
fn same(iv: &[usize], cache: &TrainingCache) -> bool {
if iv.is_empty() {
return false;
}
let t: ValueType = cache.cache_target[iv[0]];
for i in iv.iter().skip(1) {
if !(almost_equal(t, cache.cache_target[*i])) {
return false;
}
}
true
}
#[derive(Debug, Serialize, Deserialize)]
struct DTNode {
feature_index: usize,
feature_value: ValueType,
pred: ValueType,
missing: i8,
is_leaf: bool,
}
impl DTNode {
pub fn new() -> Self {
DTNode {
feature_index: 0,
feature_value: 0.0,
pred: 0.0,
missing: 0,
is_leaf: false,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DecisionTree {
tree: BinaryTree<DTNode>,
feature_size: usize,
max_depth: u32,
min_leaf_size: usize,
loss: Loss,
feature_sample_ratio: f64,
}
impl Default for DecisionTree {
fn default() -> Self {
Self::new()
}
}
impl DecisionTree {
pub fn new() -> Self {
DecisionTree {
tree: BinaryTree::new(),
feature_size: 1,
max_depth: 2,
min_leaf_size: 1,
loss: Loss::SquaredError,
feature_sample_ratio: 1.0,
}
}
pub fn set_feature_size(&mut self, size: usize) {
self.feature_size = size;
}
pub fn set_max_depth(&mut self, max_depth: u32) {
self.max_depth = max_depth;
}
pub fn set_min_leaf_size(&mut self, min_leaf_size: usize) {
self.min_leaf_size = min_leaf_size;
}
pub fn set_loss(&mut self, loss: Loss) {
self.loss = loss;
}
pub fn set_feature_sample_ratio(&mut self, feature_sample_ratio: f64) {
self.feature_sample_ratio = feature_sample_ratio;
}
#[cfg(feature = "enable_training")]
pub fn fit_n(&mut self, train_data: &DataVec, subset: &[usize], cache: &mut TrainingCache) {
assert!(
self.feature_size == cache.feature_size,
"Decision_tree and TrainingCache should have same feature size"
);
cache.init_one_iteration(train_data, &self.loss);
let root_index = self.tree.add_root(BinaryTreeNode::new(DTNode::new()));
let sub_cache = SubCache::get_cache_from_training_cache(cache, train_data, &self.loss);
self.fit_node(root_index, 0, subset, cache, sub_cache);
}
#[cfg(feature = "enable_training")]
pub fn fit(&mut self, train_data: &DataVec, cache: &mut TrainingCache) {
assert!(
self.feature_size == cache.feature_size,
"Decision_tree and TrainingCache should have same feature size"
);
let data_collection: Vec<usize> = (0..train_data.len()).collect();
cache.init_one_iteration(train_data, &self.loss);
let root_index = self.tree.add_root(BinaryTreeNode::new(DTNode::new()));
let sub_cache = SubCache::get_cache_from_training_cache(cache, train_data, &self.loss);
self.fit_node(root_index, 0, &data_collection, cache, sub_cache);
}
#[cfg(feature = "enable_training")]
fn fit_node(
&mut self,
node: TreeIndex,
depth: u32,
train_data: &[usize],
cache: &mut TrainingCache,
sub_cache: SubCache,
) {
{
let node_ref = self
.tree
.get_node_mut(node)
.expect("node should not be empty!");
node_ref.value.pred = calculate_pred(train_data, &self.loss, cache, &sub_cache);
if (depth >= self.max_depth)
|| same(train_data, cache)
|| (train_data.len() <= self.min_leaf_size)
{
node_ref.value.is_leaf = true;
for index in train_data.iter() {
cache.preds[*index] = node_ref.value.pred;
}
return;
}
}
let (splited_data, feature_index, feature_value) = DecisionTree::split(
train_data,
self.feature_size,
self.feature_sample_ratio,
cache,
&sub_cache,
);
{
let node_ref = self
.tree
.get_node_mut(node)
.expect("node should not be empty");
if splited_data.is_none() {
node_ref.value.is_leaf = true;
node_ref.value.pred = calculate_pred(train_data, &self.loss, cache, &sub_cache);
for index in train_data.iter() {
cache.preds[*index] = node_ref.value.pred;
}
return;
} else {
node_ref.value.feature_index = feature_index;
node_ref.value.feature_value = feature_value;
}
}
if let Some((left_data, right_data, _unknown_data)) = splited_data {
let (left_sub_cache, right_sub_cache) =
sub_cache.split_cache(&left_data, &right_data, cache);
let left_index = self
.tree
.add_left_node(node, BinaryTreeNode::new(DTNode::new()));
self.fit_node(left_index, depth + 1, &left_data, cache, left_sub_cache);
let right_index = self
.tree
.add_right_node(node, BinaryTreeNode::new(DTNode::new()));
self.fit_node(right_index, depth + 1, &right_data, cache, right_sub_cache);
}
}
pub fn predict_n(&self, test_data: &DataVec, subset: &[usize]) -> PredVec {
let root = self
.tree
.get_node(self.tree.get_root_index())
.expect("Decision tree should have root node");
let mut ret = vec![0.0; test_data.len()];
for index in subset {
ret[*index] = self.predict_one(root, &test_data[*index]);
}
ret
}
pub fn predict(&self, test_data: &DataVec) -> PredVec {
let root = self
.tree
.get_node(self.tree.get_root_index())
.expect("Decision tree should have root node");
test_data
.iter()
.map(|x| self.predict_one(root, x))
.collect()
}
fn predict_one(&self, node: &BinaryTreeNode<DTNode>, sample: &Data) -> ValueType {
let mut is_node_value = false;
let mut is_left_child = false;
let mut _is_right_child = false;
if node.value.is_leaf {
is_node_value = true;
} else {
assert!(
sample.feature.len() > node.value.feature_index,
"sample doesn't have the feature"
);
if sample.feature[node.value.feature_index] == VALUE_TYPE_UNKNOWN {
if node.value.missing == -1 {
is_left_child = true;
} else if node.value.missing == 0 {
is_node_value = true;
} else {
_is_right_child = true;
}
} else if sample.feature[node.value.feature_index] < node.value.feature_value {
is_left_child = true;
} else {
_is_right_child = true;
}
}
if is_node_value {
node.value.pred
} else if is_left_child {
let left = self
.tree
.get_left_child(node)
.expect("Left child should not be None");
self.predict_one(left, sample)
} else {
let right = self
.tree
.get_right_child(node)
.expect("Right child should not be None");
self.predict_one(right, sample)
}
}
#[cfg(feature = "enable_training")]
fn split(
train_data: &[usize],
feature_size: usize,
feature_sample_ratio: f64,
cache: &TrainingCache,
sub_cache: &SubCache,
) -> (
Option<(Vec<usize>, Vec<usize>, Vec<usize>)>,
usize,
ValueType,
) {
let mut fs = feature_size;
let mut fv: Vec<usize> = (0..).take(fs).collect();
let mut rng = thread_rng();
if feature_sample_ratio < 1.0 {
fs = (feature_sample_ratio * (feature_size as f64)) as usize;
fv.shuffle(&mut rng);
}
let mut v: ValueType = 0.0;
let mut impurity: f64 = 0.0;
let mut best_fitness: f64 = std::f64::MAX;
let mut index: usize = 0;
let mut value: ValueType = 0.0;
let mut impurity_cache = ImpurityCache::new(cache.sample_size, train_data);
let mut find: bool = false;
let mut data_to_split: Vec<(usize, ValueType)> = Vec::new();
for i in fv.iter().take(fs) {
let sorted_data = DecisionTree::get_impurity(
train_data,
*i,
&mut v,
&mut impurity,
cache,
&mut impurity_cache,
&sub_cache,
);
if best_fitness > impurity {
find = true;
best_fitness = impurity;
index = *i;
value = v;
data_to_split = sorted_data;
}
}
if find {
let mut left: Vec<usize> = Vec::new();
let mut right: Vec<usize> = Vec::new();
let mut unknown: Vec<usize> = Vec::new();
for pair in data_to_split.iter() {
let (item_index, feature_value) = *pair;
if feature_value == VALUE_TYPE_UNKNOWN {
unknown.push(item_index);
} else if feature_value < value {
left.push(item_index);
} else {
right.push(item_index);
}
}
let mut count: u8 = 0;
if left.is_empty() {
count += 1;
}
if right.is_empty() {
count += 1;
}
if unknown.is_empty() {
count += 1;
}
if count >= 2 {
(None, 0, 0.0)
} else {
(Some((left, right, unknown)), index, value)
}
} else {
(None, 0, 0.0)
}
}
#[cfg(feature = "enable_training")]
fn get_impurity(
train_data: &[usize],
feature_index: usize,
value: &mut ValueType,
impurity: &mut f64,
cache: &TrainingCache,
impurity_cache: &mut ImpurityCache,
sub_cache: &SubCache,
) -> Vec<(usize, ValueType)> {
*impurity = std::f64::MAX;
*value = VALUE_TYPE_UNKNOWN;
let sorted_data = cache.sort_with_bool_vec(
feature_index,
false,
&impurity_cache.bool_vec,
train_data.len(),
sub_cache,
);
let mut unknown: usize = 0;
let mut s: f64 = 0.0;
let mut ss: f64 = 0.0;
let mut c: f64 = 0.0;
for pair in sorted_data.iter() {
let (index, feature_value) = *pair;
if feature_value == VALUE_TYPE_UNKNOWN {
let cv: &CacheValue = &cache.cache_value[index];
s += cv.s;
ss += cv.ss;
c += cv.c;
unknown += 1;
} else {
break;
}
}
if unknown == sorted_data.len() {
return sorted_data;
}
let mut fitness0 = if c > 1.0 { ss - s * s / c } else { 0.0 };
if fitness0 < 0.0 {
fitness0 = 0.0;
}
if !impurity_cache.cached {
impurity_cache.sum_s = 0.0;
impurity_cache.sum_ss = 0.0;
impurity_cache.sum_c = 0.0;
for index in train_data.iter() {
let cv: &CacheValue = &cache.cache_value[*index];
impurity_cache.sum_s += cv.s;
impurity_cache.sum_ss += cv.ss;
impurity_cache.sum_c += cv.c;
}
}
s = impurity_cache.sum_s - s;
ss = impurity_cache.sum_ss - ss;
c = impurity_cache.sum_c - c;
let _fitness00: f64 = if c > 1.0 { ss - s * s / c } else { 0.0 };
let mut ls: f64 = 0.0;
let mut lss: f64 = 0.0;
let mut lc: f64 = 0.0;
let mut rs: f64 = s;
let mut rss: f64 = ss;
let mut rc: f64 = c;
for i in unknown..(sorted_data.len() - 1) {
let (index, feature_value) = sorted_data[i];
let (_next_index, next_value) = sorted_data[i + 1];
let cv: &CacheValue = &cache.cache_value[index];
s = cv.s;
ss = cv.ss;
c = cv.c;
ls += s;
lss += ss;
lc += c;
rs -= s;
rss -= ss;
rc -= c;
let f1: ValueType = feature_value;
let f2: ValueType = next_value;
if almost_equal(f1, f2) {
continue;
}
let mut fitness1: f64 = if lc > 1.0 { lss - ls * ls / lc } else { 0.0 };
if fitness1 < 0.0 {
fitness1 = 0.0;
}
let mut fitness2: f64 = if rc > 1.0 { rss - rs * rs / rc } else { 0.0 };
if fitness2 < 0.0 {
fitness2 = 0.0;
}
let fitness: f64 = fitness0 + fitness1 + fitness2;
if *impurity > fitness {
*impurity = fitness;
*value = (f1 + f2) / 2.0;
}
}
sorted_data
}
pub fn print(&self) {
self.tree.print();
}
pub fn get_from_xgboost(node: &serde_json::Value) -> Result<Self, Box<Error>> {
let mut tree = DecisionTree::new();
let index = tree.tree.add_root(BinaryTreeNode::new(DTNode::new()));
tree.add_node_from_json(index, node)?;
Ok(tree)
}
fn add_node_from_json(
&mut self,
index: TreeIndex,
node: &serde_json::Value,
) -> Result<(), Box<Error>> {
{
let node_ref = self
.tree
.get_node_mut(index)
.expect("node should not be empty!");
if let serde_json::Value::Number(pred) = &node["leaf"] {
let leaf_value = pred.as_f64().ok_or("parse 'leaf' error")?;
node_ref.value.pred = leaf_value as ValueType;
node_ref.value.is_leaf = true;
return Ok(());
} else {
let feature_value = node["split_condition"]
.as_f64()
.ok_or("parse 'split condition' error")?;
node_ref.value.feature_value = feature_value as ValueType;
let feature_index = match node["split"].as_i64() {
Some(v) => v,
None => {
let feature_name = node["split"].as_str().ok_or("parse 'split' error")?;
let feature_str: String = feature_name.chars().skip(3).collect();
feature_str.parse::<i64>()?
}
};
node_ref.value.feature_index = feature_index as usize;
let missing = node["missing"].as_i64().ok_or("parse 'missing' error")?;
let left_child = node["yes"].as_i64().ok_or("parse 'yes' error")?;
let right_child = node["no"].as_i64().ok_or("parse 'no' error")?;
if missing == left_child {
node_ref.value.missing = -1;
} else if missing == right_child {
node_ref.value.missing = 1;
} else {
let err: Box<Error> = From::from("not support extra missing node".to_string());
return Err(err);
}
}
}
let left_child = node["yes"].as_i64().ok_or("parse 'yes' error")?;
let right_child = node["no"].as_i64().ok_or("parse 'no' error")?;
let children = node["children"]
.as_array()
.ok_or("parse 'children' error")?;
let mut find_left = false;
let mut find_right = false;
for child in children.iter() {
let node_id = child["nodeid"].as_i64().ok_or("parse 'nodeid' error")?;
if node_id == left_child {
find_left = true;
let left_index = self
.tree
.add_left_node(index, BinaryTreeNode::new(DTNode::new()));
self.add_node_from_json(left_index, child)?;
}
if node_id == right_child {
find_right = true;
let right_index = self
.tree
.add_right_node(index, BinaryTreeNode::new(DTNode::new()));
self.add_node_from_json(right_index, child)?;
}
}
if (!find_left) || (!find_right) {
let err: Box<Error> = From::from("children not found".to_string());
return Err(err);
}
Ok(())
}
pub fn len(&self) -> usize {
self.tree.len()
}
pub fn is_empty(&self) -> bool {
self.tree.is_empty()
}
}