miniboosts/weak_learner/regression_tree/
node.rs1use crate::Regressor;
4
5
6use crate::weak_learner::common::{
7 type_and_struct::*,
8 split_rule::*,
9};
10use crate::Sample;
11
12
13use super::train_node::*;
14
15
16use serde::{Serialize, Deserialize};
17
18use std::rc::Rc;
19
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub enum Node {
24 Branch(BranchNode),
26
27
28 Leaf(LeafNode),
30}
31
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub struct BranchNode {
37 pub(super) rule: Splitter,
38 pub(super) left: Box<Node>,
39 pub(super) right: Box<Node>,
40}
41
42
43impl BranchNode {
44 #[inline]
47 pub(super) fn from_raw(
48 rule: Splitter,
49 left: Box<Node>,
50 right: Box<Node>
51 ) -> Self
52 {
53 Self {
54 rule,
55 left,
56 right,
57 }
58 }
59}
60
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct LeafNode {
65 pub(super) prediction: Prediction<f64>,
66}
67
68
69impl LeafNode {
70 #[inline]
74 pub(crate) fn from_raw(prediction: Prediction<f64>) -> Self {
75 Self { prediction }
76 }
77}
78
79
80impl From<TrainBranchNode> for BranchNode {
81 #[inline]
82 fn from(branch: TrainBranchNode) -> Self {
83
84 let left = match Rc::try_unwrap(branch.left) {
85 Ok(l) => l.into_inner().into(),
86 Err(_) => panic!("Strong count is greater than 1")
87 };
88 let right = match Rc::try_unwrap(branch.right) {
89 Ok(r) => r.into_inner().into(),
90 Err(_) => panic!("Strong count is greater than 1")
91 };
92
93 Self::from_raw(
94 branch.rule,
95 Box::new(left),
96 Box::new(right),
97 )
98 }
99}
100
101
102impl From<TrainLeafNode> for LeafNode {
103 #[inline]
104 fn from(leaf: TrainLeafNode) -> Self {
105 Self::from_raw(leaf.prediction)
106 }
107}
108
109
110impl From<TrainNode> for Node {
111 #[inline]
112 fn from(train_node: TrainNode) -> Self {
113 match train_node {
114 TrainNode::Branch(node) => {
115 Node::Branch(node.into())
116 },
117 TrainNode::Leaf(node) => {
118 Node::Leaf(node.into())
119 }
120 }
121 }
122}
123
124
125impl Regressor for LeafNode {
126 #[inline]
127 fn predict(&self, _sample: &Sample, _row: usize) -> f64 {
128 self.prediction.0
129 }
130}
131
132
133impl Regressor for BranchNode {
134 #[inline]
135 fn predict(&self, sample: &Sample, row: usize) -> f64 {
136 match self.rule.split(sample, row) {
137 LR::Left => self.left.predict(sample, row),
138 LR::Right => self.right.predict(sample, row)
139 }
140 }
141}
142
143
144impl Regressor for Node {
145 #[inline]
146 fn predict(&self, sample: &Sample, row: usize) -> f64 {
147 match self {
148 Node::Branch(ref node) => node.predict(sample, row),
149 Node::Leaf(ref node) => node.predict(sample, row)
150 }
151 }
152}
153
154
155impl Node {
156 pub(super) fn to_dot_info(&self, id: usize) -> (Vec<String>, usize) {
157 match self {
158 Node::Branch(b) => {
159 let b_info = format!(
160 "\tnode_{id} [ label = \"{feat} < {thr:.2} ?\" ];\n",
161 feat = b.rule.feature,
162 thr = b.rule.threshold.0
163 );
164
165 let (l_info, next_id) = b.left.to_dot_info(id + 1);
166 let (mut r_info, ret_id) = b.right.to_dot_info(next_id);
167
168 let mut info = l_info;
169 info.push(b_info);
170 info.append(&mut r_info);
171
172 let l_edge = format!(
173 "\tnode_{id} -- node_{l_id} [ label = \"Yes\" ];\n",
174 l_id = id + 1
175 );
176 let r_edge = format!(
177 "\tnode_{id} -- node_{r_id} [ label = \"No\" ];\n",
178 r_id = next_id
179 );
180
181 info.push(l_edge);
182 info.push(r_edge);
183
184 (info, ret_id)
185 },
186 Node::Leaf(l) => {
187 let info = format!(
188 "\tnode_{id} [ \
189 label = \"{p:.2}\", \
190 shape = box, \
191 ];\n",
192 p = l.prediction.0
193 );
194
195 (vec![info], id + 1)
196 }
197 }
198 }
199}
200