use rayon::prelude::*;
use crate::{Sample, WeakLearner};
use crate::weak_learner::common::{
split_rule::*,
};
use super::{
node::*,
train_node::*,
loss::LossType,
rtree_regressor::RTreeRegressor,
};
use std::rc::Rc;
use std::cell::RefCell;
pub struct RTree {
max_depth: usize,
n_sample: usize,
loss_type: LossType,
}
impl RTree {
#[inline]
pub fn init(sample: &Sample)
-> Self
{
let n_sample = sample.shape().0;
let max_depth = (n_sample as f64).log2().ceil() as usize;
Self {
max_depth,
n_sample,
loss_type: LossType::L2,
}
}
#[inline]
pub fn max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
#[inline]
pub fn loss_type(mut self, loss: LossType) -> Self {
self.loss_type = loss;
self
}
}
impl WeakLearner for RTree {
type Hypothesis = RTreeRegressor;
fn produce(&self, sample: &Sample, dist: &[f64])
-> Self::Hypothesis
{
let indices = (0..self.n_sample).into_iter()
.filter(|&i| dist[i] > 0.0)
.collect::<Vec<usize>>();
let depth = self.max_depth;
let tree = full_tree(
sample, dist, indices, depth, self.loss_type
);
tree.borrow_mut().remove_redundant_nodes();
let root = Node::from(
Rc::try_unwrap(tree)
.expect("Root node has reference counter >= 1")
.into_inner()
);
RTreeRegressor::from(root)
}
}
#[inline]
fn full_tree(
sample: &Sample,
dist: &[f64],
indices: Vec<usize>,
max_depth: usize,
loss_type: LossType
) -> Rc<RefCell<TrainNode>>
{
let total_weight = indices.par_iter()
.copied()
.map(|i| dist[i])
.sum::<f64>();
let target = sample.target();
let target = &target[..];
let (pred, loss) = loss_type.prediction_and_loss(
target, &indices, dist
);
if loss == 0.0 {
return TrainNode::leaf(pred, total_weight, loss);
}
let (feature, threshold) = loss_type.best_split(
sample, dist, &indices[..],
);
let rule = Splitter::new(feature, threshold);
let mut lindices = Vec::new();
let mut rindices = Vec::new();
for i in indices.into_iter() {
match rule.split(sample, i) {
LR::Left => { lindices.push(i); },
LR::Right => { rindices.push(i); },
}
}
if lindices.is_empty() || rindices.is_empty() {
return TrainNode::leaf(pred, total_weight, loss);
}
let ltree; let rtree; if max_depth <= 1 {
ltree = construct_leaf(target, dist, lindices, loss_type);
rtree = construct_leaf(target, dist, rindices, loss_type);
} else {
let d = max_depth - 1;
ltree = full_tree(sample, dist, lindices, d, loss_type);
rtree = full_tree(sample, dist, rindices, d, loss_type);
}
TrainNode::branch(rule, ltree, rtree, pred, total_weight, loss)
}
#[inline]
fn construct_leaf(
target: &[f64],
dist: &[f64],
indices: Vec<usize>,
loss_type: LossType,
) -> Rc<RefCell<TrainNode>>
{
let (p, l) = loss_type.prediction_and_loss(target, &indices, dist);
let total_weight = indices.iter()
.copied()
.map(|i| dist[i])
.sum::<f64>();
TrainNode::leaf(p, total_weight, l)
}