use crate::Regressor;
use crate::weak_learner::common::{
type_and_struct::*,
split_rule::*,
};
use crate::Sample;
use super::train_node::*;
use serde::{Serialize, Deserialize};
use std::rc::Rc;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Node {
Branch(BranchNode),
Leaf(LeafNode),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BranchNode {
pub(super) rule: Splitter,
pub(super) left: Box<Node>,
pub(super) right: Box<Node>,
}
impl BranchNode {
#[inline]
pub(super) fn from_raw(
rule: Splitter,
left: Box<Node>,
right: Box<Node>
) -> Self
{
Self {
rule,
left,
right,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LeafNode {
pub(super) prediction: Prediction<f64>,
}
impl LeafNode {
#[inline]
pub(crate) fn from_raw(prediction: Prediction<f64>) -> Self {
Self { prediction }
}
}
impl From<TrainBranchNode> for BranchNode {
#[inline]
fn from(branch: TrainBranchNode) -> Self {
let left = match Rc::try_unwrap(branch.left) {
Ok(l) => l.into_inner().into(),
Err(_) => panic!("Strong count is greater than 1")
};
let right = match Rc::try_unwrap(branch.right) {
Ok(r) => r.into_inner().into(),
Err(_) => panic!("Strong count is greater than 1")
};
Self::from_raw(
branch.rule,
Box::new(left),
Box::new(right),
)
}
}
impl From<TrainLeafNode> for LeafNode {
#[inline]
fn from(leaf: TrainLeafNode) -> Self {
Self::from_raw(leaf.prediction)
}
}
impl From<TrainNode> for Node {
#[inline]
fn from(train_node: TrainNode) -> Self {
match train_node {
TrainNode::Branch(node) => {
Node::Branch(node.into())
},
TrainNode::Leaf(node) => {
Node::Leaf(node.into())
}
}
}
}
impl Regressor for LeafNode {
#[inline]
fn predict(&self, _sample: &Sample, _row: usize) -> f64 {
self.prediction.0
}
}
impl Regressor for BranchNode {
#[inline]
fn predict(&self, sample: &Sample, row: usize) -> f64 {
match self.rule.split(sample, row) {
LR::Left => self.left.predict(sample, row),
LR::Right => self.right.predict(sample, row)
}
}
}
impl Regressor for Node {
#[inline]
fn predict(&self, sample: &Sample, row: usize) -> f64 {
match self {
Node::Branch(ref node) => node.predict(sample, row),
Node::Leaf(ref node) => node.predict(sample, row)
}
}
}
impl Node {
pub(super) fn to_dot_info(&self, id: usize) -> (Vec<String>, usize) {
match self {
Node::Branch(b) => {
let b_info = format!(
"\tnode_{id} [ label = \"{feat} < {thr:.2} ?\" ];\n",
feat = b.rule.feature,
thr = b.rule.threshold.0
);
let (l_info, next_id) = b.left.to_dot_info(id + 1);
let (mut r_info, ret_id) = b.right.to_dot_info(next_id);
let mut info = l_info;
info.push(b_info);
info.append(&mut r_info);
let l_edge = format!(
"\tnode_{id} -- node_{l_id} [ label = \"Yes\" ];\n",
l_id = id + 1
);
let r_edge = format!(
"\tnode_{id} -- node_{r_id} [ label = \"No\" ];\n",
r_id = next_id
);
info.push(l_edge);
info.push(r_edge);
(info, ret_id)
},
Node::Leaf(l) => {
let info = format!(
"\tnode_{id} [ \
label = \"{p:.2}\", \
shape = box, \
];\n",
p = l.prediction.0
);
(vec![info], id + 1)
}
}
}
}