use crate::Regressor;
use crate::weak_learner::common::{
type_and_struct::*,
split_rule::*,
};
use crate::Sample;
use std::rc::Rc;
use std::cell::RefCell;
use std::fmt;
pub enum TrainNode {
Branch(TrainBranchNode),
Leaf(TrainLeafNode),
}
pub struct TrainBranchNode {
pub(super) rule: Splitter,
pub(super) left: Rc<RefCell<TrainNode>>,
pub(super) right: Rc<RefCell<TrainNode>>,
pub(super) prediction: Prediction<f64>,
pub(self) loss_as_leaf: LossValue,
pub(self) leaves: usize,
}
pub struct TrainLeafNode {
pub(super) prediction: Prediction<f64>,
pub(self) loss_as_leaf: LossValue,
}
impl From<TrainBranchNode> for TrainLeafNode {
#[inline]
fn from(branch: TrainBranchNode) -> Self {
Self {
prediction: branch.prediction,
loss_as_leaf: branch.loss_as_leaf,
}
}
}
impl TrainNode {
#[inline]
pub(super) fn leaf(
prediction: Prediction<f64>,
loss_as_leaf: LossValue,
) -> Rc<RefCell<Self>>
{
let leaf = TrainLeafNode {
prediction,
loss_as_leaf,
};
Rc::new(RefCell::new(TrainNode::Leaf(leaf)))
}
#[inline]
pub(super) fn branch(
rule: Splitter,
left: Rc<RefCell<TrainNode>>,
right: Rc<RefCell<TrainNode>>,
prediction: Prediction<f64>,
loss_as_leaf: LossValue,
) -> Rc<RefCell<Self>>
{
let leaves = left.borrow().leaves() + right.borrow().leaves();
let node = TrainBranchNode {
rule,
left,
right,
prediction,
loss_as_leaf,
leaves,
};
Rc::new(RefCell::new(TrainNode::Branch(node)))
}
#[inline]
pub(super) fn leaves(&self) -> usize {
match self {
TrainNode::Branch(ref node) => node.leaves,
TrainNode::Leaf(_) => 1_usize
}
}
}
impl Regressor for TrainLeafNode {
#[inline]
fn predict(&self, _sample: &Sample, _row: usize) -> f64 {
self.prediction.0
}
}
impl Regressor for TrainBranchNode {
#[inline]
fn predict(&self, sample: &Sample, row: usize) -> f64 {
match self.rule.split(sample, row) {
LR::Left => self.left.borrow().predict(sample, row),
LR::Right => self.right.borrow().predict(sample, row)
}
}
}
impl Regressor for TrainNode {
#[inline]
fn predict(&self, sample: &Sample, row: usize) -> f64 {
match self {
TrainNode::Branch(ref node) => node.predict(sample, row),
TrainNode::Leaf(ref node) => node.predict(sample, row)
}
}
}
impl fmt::Debug for TrainBranchNode {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TrainBranchNode")
.field("threshold", &self.rule)
.field("leaves", &self.leaves)
.field("r(t)", &self.loss_as_leaf.0)
.field("left", &self.left)
.field("right", &self.right)
.finish()
}
}
impl fmt::Debug for TrainLeafNode {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TrainLeafNode")
.field("prediction", &self.prediction.0)
.field("r(t)", &self.loss_as_leaf.0)
.finish()
}
}
impl fmt::Debug for TrainNode {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrainNode::Branch(branch) => {
write!(f, "{:?}", branch)
},
TrainNode::Leaf(leaf) => {
write!(f, "{:?}", leaf)
},
}
}
}