use crate::{Sample, WeakLearner};
use super::bin::*;
use crate::weak_learner::common::{
split_rule::*,
};
use super::{
node::*,
train_node::*,
loss::LossType,
regression_tree_regressor::RegressionTreeRegressor,
};
use std::fmt;
use std::rc::Rc;
use std::cell::RefCell;
use std::collections::HashMap;
pub struct RegressionTree<'a> {
bins: HashMap<&'a str, Bins>,
max_depth: usize,
n_sample: usize,
lambda_l2: f64,
loss_type: LossType,
}
impl<'a> RegressionTree<'a> {
#[inline]
pub(super) fn from_components(
bins: HashMap<&'a str, Bins>,
n_sample: usize,
max_depth: usize,
lambda_l2: f64,
loss_type: LossType,
) -> Self
{
Self { bins, n_sample, max_depth, lambda_l2, loss_type, }
}
#[inline]
fn full_tree(
&self,
sample: &Sample,
gh: &[GradientHessian],
indices: Vec<usize>,
max_depth: usize,
) -> Rc<RefCell<TrainNode>>
{
let (pred, loss) = self.loss_type.prediction_and_loss(
&indices, gh, self.lambda_l2,
);
if loss == 0.0 || max_depth <= 1 {
return TrainNode::leaf(pred, loss);
}
let (feature, threshold) = self.loss_type.best_split(
&self.bins, &sample, gh, &indices[..], self.lambda_l2,
);
let rule = Splitter::new(feature, threshold);
let mut lindices = Vec::new();
let mut rindices = Vec::new();
for i in indices.into_iter() {
match rule.split(sample, i) {
LR::Left => { lindices.push(i); },
LR::Right => { rindices.push(i); },
}
}
if lindices.is_empty() || rindices.is_empty() {
return TrainNode::leaf(pred, loss);
}
let ltree = self.full_tree(sample, gh, lindices, max_depth-1);
let rtree = self.full_tree(sample, gh, rindices, max_depth-1);
TrainNode::branch(rule, ltree, rtree, pred, loss)
}
}
impl<'a> WeakLearner for RegressionTree<'a> {
type Hypothesis = RegressionTreeRegressor;
fn name(&self) -> &str {
"Regression Tree"
}
fn info(&self) -> Option<Vec<(&str, String)>> {
let n_bins = self.bins.values()
.map(|bin| bin.len())
.reduce(usize::max)
.unwrap_or(0);
let info = Vec::from([
("# of bins (max)", format!("{n_bins}")),
("Max depth", format!("{}", self.max_depth)),
("Split criterion", format!("{}", self.loss_type)),
("Regularization param.", format!("{}", self.lambda_l2)),
]);
Some(info)
}
fn produce(&self, sample: &Sample, predictions: &[f64])
-> Self::Hypothesis
{
let gh = self.loss_type.gradient_and_hessian(
sample.target(),
predictions,
);
let indices = (0..self.n_sample).filter(|&i| {
gh[i].grad != 0.0 || gh[i].hess != 0.0
})
.collect::<Vec<usize>>();
let tree = self.full_tree(sample, &gh, indices, self.max_depth);
let root = Node::from(
Rc::try_unwrap(tree)
.expect("Root node has reference counter >= 1")
.into_inner()
);
RegressionTreeRegressor::from(root)
}
}
impl fmt::Display for RegressionTree<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"\
----------\n\
# Decision Tree Weak Learner\n\n\
- Max depth: {}\n\
- Loss function: {}\n\
- Bins:\
",
self.max_depth,
self.loss_type,
)?;
let width = self.bins.keys()
.map(|key| key.len())
.max()
.expect("Tried to print bins, but no features are found");
let max_bin_width = self.bins.values()
.map(|bin| bin.len().ilog10() as usize)
.max()
.expect("Tried to print bins, but no features are found")
+ 1;
for (feat_name, feat_bins) in self.bins.iter() {
let n_bins = feat_bins.len();
writeln!(
f,
"\
\t* [{feat_name: <width$} | \
{n_bins: >max_bin_width$} bins] \
{feat_bins}\
"
)?;
}
write!(f, "----------")
}
}