miniboosts 0.2.1

MiniBoosts: A collection of boosting algorithms written in Rust 🦀
Documentation
//! Defines the inner representation 
//! of the Decision Tree class.
use crate::Regressor;


use crate::weak_learner::common::{
    type_and_struct::*,
    split_rule::*,
};
use crate::Sample;


use super::train_node::*;


use serde::{Serialize, Deserialize};

use std::rc::Rc;


/// Enumeration of `BranchNode` and `LeafNode`.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Node {
    /// A node that have two childrens.
    Branch(BranchNode),


    /// A node that have no child.
    Leaf(LeafNode),
}


/// Represents the branch nodes of decision tree.
/// Each `BranchNode` must have two childrens
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BranchNode {
    pub(super) rule: Splitter,
    pub(super) left: Box<Node>,
    pub(super) right: Box<Node>,
}


impl BranchNode {
    /// Returns the `BranchNode` from the given components.
    /// Note that this function does not assign the impurity.
    #[inline]
    pub(super) fn from_raw(
        rule: Splitter,
        left: Box<Node>,
        right: Box<Node>
    ) -> Self
    {
        Self {
            rule,
            left,
            right,
        }
    }
}


/// Represents the leaf nodes of decision tree.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LeafNode {
    pub(super) prediction: Prediction<f64>,
}


impl LeafNode {
    /// Returns a `LeafNode` that predicts the label
    /// given to this function.
    /// Note that this function does not assign the impurity.
    #[inline]
    pub(crate) fn from_raw(prediction: Prediction<f64>) -> Self {
        Self { prediction }
    }
}


impl From<TrainBranchNode> for BranchNode {
    #[inline]
    fn from(branch: TrainBranchNode) -> Self {

        let left = match Rc::try_unwrap(branch.left) {
            Ok(l) => l.into_inner().into(),
            Err(_) => panic!("Strong count is greater than 1")
        };
        let right = match Rc::try_unwrap(branch.right) {
            Ok(r) => r.into_inner().into(),
            Err(_) => panic!("Strong count is greater than 1")
        };

        Self::from_raw(
            branch.rule,
            Box::new(left),
            Box::new(right),
        )
    }
}


impl From<TrainLeafNode> for LeafNode {
    #[inline]
    fn from(leaf: TrainLeafNode) -> Self {
        Self::from_raw(leaf.prediction)
    }
}


impl From<TrainNode> for Node {
    #[inline]
    fn from(train_node: TrainNode) -> Self {
        match train_node {
            TrainNode::Branch(node) => {
                Node::Branch(node.into())
            },
            TrainNode::Leaf(node) => {
                Node::Leaf(node.into())
            }
        }
    }
}


impl Regressor for LeafNode {
    #[inline]
    fn predict(&self, _sample: &Sample, _row: usize) -> f64 {
        self.prediction.0
    }
}


impl Regressor for BranchNode {
    #[inline]
    fn predict(&self, sample: &Sample, row: usize) -> f64 {
        match self.rule.split(sample, row) {
            LR::Left => self.left.predict(sample, row),
            LR::Right => self.right.predict(sample, row)
        }
    }
}


impl Regressor for Node {
    #[inline]
    fn predict(&self, sample: &Sample, row: usize) -> f64 {
        match self {
            Node::Branch(ref node) => node.predict(sample, row),
            Node::Leaf(ref node) => node.predict(sample, row)
        }
    }
}


impl Node {
    pub(super) fn to_dot_info(&self, id: usize) -> (Vec<String>, usize) {
        match self {
            Node::Branch(b) => {
                let b_info = format!(
                    "\tnode_{id} [ label = \"{feat} < {thr:.2} ?\" ];\n",
                    feat = b.rule.feature,
                    thr = b.rule.threshold.0
                );

                let (l_info, next_id) = b.left.to_dot_info(id + 1);
                let (mut r_info, ret_id) = b.right.to_dot_info(next_id);

                let mut info = l_info;
                info.push(b_info);
                info.append(&mut r_info);

                let l_edge = format!(
                    "\tnode_{id} -- node_{l_id} [ label = \"Yes\" ];\n",
                    l_id = id + 1
                );
                let r_edge = format!(
                    "\tnode_{id} -- node_{r_id} [ label = \"No\" ];\n",
                    r_id = next_id
                );

                info.push(l_edge);
                info.push(r_edge);

                (info, ret_id)
            },
            Node::Leaf(l) => {
                let info = format!(
                    "\tnode_{id} [ \
                     label = \"{p:.2}\", \
                     shape = box, \
                     ];\n",
                    p = l.prediction.0
                );

                (vec![info], id + 1)
            }
        }
    }
}