miniboosts/weak_learner/regression_tree/
node.rs

1//! Defines the inner representation 
2//! of the Decision Tree class.
3use crate::Regressor;
4
5
6use crate::weak_learner::common::{
7    type_and_struct::*,
8    split_rule::*,
9};
10use crate::Sample;
11
12
13use super::train_node::*;
14
15
16use serde::{Serialize, Deserialize};
17
18use std::rc::Rc;
19
20
21/// Enumeration of `BranchNode` and `LeafNode`.
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub enum Node {
24    /// A node that have two childrens.
25    Branch(BranchNode),
26
27
28    /// A node that have no child.
29    Leaf(LeafNode),
30}
31
32
33/// Represents the branch nodes of decision tree.
34/// Each `BranchNode` must have two childrens
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub struct BranchNode {
37    pub(super) rule: Splitter,
38    pub(super) left: Box<Node>,
39    pub(super) right: Box<Node>,
40}
41
42
43impl BranchNode {
44    /// Returns the `BranchNode` from the given components.
45    /// Note that this function does not assign the impurity.
46    #[inline]
47    pub(super) fn from_raw(
48        rule: Splitter,
49        left: Box<Node>,
50        right: Box<Node>
51    ) -> Self
52    {
53        Self {
54            rule,
55            left,
56            right,
57        }
58    }
59}
60
61
62/// Represents the leaf nodes of decision tree.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct LeafNode {
65    pub(super) prediction: Prediction<f64>,
66}
67
68
69impl LeafNode {
70    /// Returns a `LeafNode` that predicts the label
71    /// given to this function.
72    /// Note that this function does not assign the impurity.
73    #[inline]
74    pub(crate) fn from_raw(prediction: Prediction<f64>) -> Self {
75        Self { prediction }
76    }
77}
78
79
80impl From<TrainBranchNode> for BranchNode {
81    #[inline]
82    fn from(branch: TrainBranchNode) -> Self {
83
84        let left = match Rc::try_unwrap(branch.left) {
85            Ok(l) => l.into_inner().into(),
86            Err(_) => panic!("Strong count is greater than 1")
87        };
88        let right = match Rc::try_unwrap(branch.right) {
89            Ok(r) => r.into_inner().into(),
90            Err(_) => panic!("Strong count is greater than 1")
91        };
92
93        Self::from_raw(
94            branch.rule,
95            Box::new(left),
96            Box::new(right),
97        )
98    }
99}
100
101
102impl From<TrainLeafNode> for LeafNode {
103    #[inline]
104    fn from(leaf: TrainLeafNode) -> Self {
105        Self::from_raw(leaf.prediction)
106    }
107}
108
109
110impl From<TrainNode> for Node {
111    #[inline]
112    fn from(train_node: TrainNode) -> Self {
113        match train_node {
114            TrainNode::Branch(node) => {
115                Node::Branch(node.into())
116            },
117            TrainNode::Leaf(node) => {
118                Node::Leaf(node.into())
119            }
120        }
121    }
122}
123
124
125impl Regressor for LeafNode {
126    #[inline]
127    fn predict(&self, _sample: &Sample, _row: usize) -> f64 {
128        self.prediction.0
129    }
130}
131
132
133impl Regressor for BranchNode {
134    #[inline]
135    fn predict(&self, sample: &Sample, row: usize) -> f64 {
136        match self.rule.split(sample, row) {
137            LR::Left => self.left.predict(sample, row),
138            LR::Right => self.right.predict(sample, row)
139        }
140    }
141}
142
143
144impl Regressor for Node {
145    #[inline]
146    fn predict(&self, sample: &Sample, row: usize) -> f64 {
147        match self {
148            Node::Branch(ref node) => node.predict(sample, row),
149            Node::Leaf(ref node) => node.predict(sample, row)
150        }
151    }
152}
153
154
155impl Node {
156    pub(super) fn to_dot_info(&self, id: usize) -> (Vec<String>, usize) {
157        match self {
158            Node::Branch(b) => {
159                let b_info = format!(
160                    "\tnode_{id} [ label = \"{feat} < {thr:.2} ?\" ];\n",
161                    feat = b.rule.feature,
162                    thr = b.rule.threshold.0
163                );
164
165                let (l_info, next_id) = b.left.to_dot_info(id + 1);
166                let (mut r_info, ret_id) = b.right.to_dot_info(next_id);
167
168                let mut info = l_info;
169                info.push(b_info);
170                info.append(&mut r_info);
171
172                let l_edge = format!(
173                    "\tnode_{id} -- node_{l_id} [ label = \"Yes\" ];\n",
174                    l_id = id + 1
175                );
176                let r_edge = format!(
177                    "\tnode_{id} -- node_{r_id} [ label = \"No\" ];\n",
178                    r_id = next_id
179                );
180
181                info.push(l_edge);
182                info.push(r_edge);
183
184                (info, ret_id)
185            },
186            Node::Leaf(l) => {
187                let info = format!(
188                    "\tnode_{id} [ \
189                     label = \"{p:.2}\", \
190                     shape = box, \
191                     ];\n",
192                    p = l.prediction.0
193                );
194
195                (vec![info], id + 1)
196            }
197        }
198    }
199}
200