miniboosts/weak_learner/decision_tree/
decision_tree_algorithm.rs1use rayon::prelude::*;
2
3
4use crate::{Sample, WeakLearner};
5use super::bin::*;
6
7
8use crate::weak_learner::common::{
9 type_and_struct::*,
10 split_rule::*,
11};
12use super::{
13 node::*,
14 criterion::*,
15 train_node::*,
16 decision_tree_classifier::DecisionTreeClassifier,
17};
18
19
20use std::fmt;
21use std::rc::Rc;
22use std::collections::HashMap;
23
24
25pub struct DecisionTree<'a> {
76 bins: HashMap<&'a str, Bins>,
77 criterion: Criterion,
78 max_depth: Depth,
79}
80
81
82impl<'a> DecisionTree<'a> {
83 #[inline]
86 pub(super) fn from_components(
87 bins: HashMap<&'a str, Bins>,
88 criterion: Criterion,
89 max_depth: Depth,
90 ) -> Self
91 {
92 Self { bins, criterion, max_depth, }
93 }
94
95
96 #[inline]
98 fn full_tree(
99 &self,
100 sample: &'a Sample,
101 dist: &[f64],
102 indices: Vec<usize>,
103 criterion: Criterion,
104 depth: Depth,
105 ) -> TrainNodePtr
106 {
107 let total_weight = indices.par_iter()
108 .copied()
109 .map(|i| dist[i])
110 .sum::<f64>();
111
112
113 let (conf, loss) = confidence_and_loss(sample, dist, &indices[..]);
116
117
118 if loss == 0f64 || depth < 1 {
120 return TrainNode::leaf(conf, total_weight, loss);
121 }
122
123
124 let (feature, threshold) = criterion.best_split(
127 &self.bins, sample, dist, &indices[..]
128 );
129
130
131 let rule = Splitter::new(feature, Threshold::from(threshold));
134
135
136 let mut lindices = Vec::new();
138 let mut rindices = Vec::new();
139 for i in indices {
140 match rule.split(sample, i) {
141 LR::Left => { lindices.push(i); },
142 LR::Right => { rindices.push(i); },
143 }
144 }
145
146
147 if lindices.is_empty() || rindices.is_empty() {
149 return TrainNode::leaf(conf, total_weight, loss);
150 }
151
152 let depth = depth - 1;
155 let ltree = self.full_tree(sample, dist, lindices, criterion, depth);
156 let rtree = self.full_tree(sample, dist, rindices, criterion, depth);
157
158
159 TrainNode::branch(rule, ltree, rtree, conf, total_weight, loss)
160 }
161}
162
163
164impl<'a> WeakLearner for DecisionTree<'a> {
165 type Hypothesis = DecisionTreeClassifier;
166
167
168 fn name(&self) -> &str {
169 "Decision Tree"
170 }
171
172
173 fn info(&self) -> Option<Vec<(&str, String)>> {
174 let n_bins = self.bins.values()
175 .map(|bin| bin.len())
176 .reduce(usize::max)
177 .unwrap_or(0);
178 let info = Vec::from([
179 ("# of bins (max)", format!("{n_bins}")),
180 ("Max depth", format!("{}", self.max_depth)),
181 ("Split criterion", format!("{}", self.criterion)),
182 ]);
183 Some(info)
184 }
185
186
187 #[inline]
192 fn produce(&self, sample: &Sample, dist: &[f64])
193 -> Self::Hypothesis
194 {
195 let n_sample = sample.shape().0;
196
197 let indices = (0..n_sample).filter(|&i| dist[i] > 0f64)
198 .collect::<Vec<usize>>();
199 assert_ne!(indices.len(), 0);
200
201 let criterion = self.criterion;
202
203 let tree = self.full_tree(
205 sample, dist, indices, criterion, self.max_depth
206 );
207
208
209 tree.borrow_mut().remove_redundant_nodes();
210
211
212 let root = Node::from(
213 Rc::try_unwrap(tree)
214 .expect("Root node has reference counter >= 1")
215 .into_inner()
216 );
217
218
219 DecisionTreeClassifier::from(root)
220 }
221}
222
223
224#[inline]
231fn confidence_and_loss(sample: &Sample, dist: &[f64], indices: &[usize])
232 -> (Confidence<f64>, LossValue)
233{
234
235 assert_ne!(indices.len(), 0);
236 let target = sample.target();
237 let mut counter: HashMap<i64, f64> = HashMap::new();
238
239 for &i in indices {
240 let l = target[i] as i64;
241 let cnt = counter.entry(l).or_insert(0f64);
242 *cnt += dist[i];
243 }
244
245
246 let total = counter.values().sum::<f64>();
247
248 let (label, p) = counter.into_par_iter()
250 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
251 .unwrap();
252
253
254 let loss = if total > 0f64 { total * (1f64 - (p / total)) } else { 0f64 };
257
258 let confidence = if total > 0f64 {
260 (label as f64 * (2f64 * (p / total) - 1f64)).clamp(-1f64, 1f64)
261 } else {
262 (label as f64).clamp(-1f64, 1f64)
263 };
264
265 let confidence = Confidence::from(confidence);
266 let loss = LossValue::from(loss);
267 (confidence, loss)
268}
269
270
271impl fmt::Display for DecisionTree<'_> {
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 writeln!(
274 f,
275 "\
276 ----------\n\
277 # Decision Tree Weak Learner\n\n\
278 - Max depth: {}\n\
279 - Splitting criterion: {}\n\
280 - Bins:\
281 ",
282 self.max_depth,
283 self.criterion,
284 )?;
285
286
287 let width = self.bins.keys()
288 .map(|key| key.len())
289 .max()
290 .expect("Tried to print bins, but no features are found");
291 let max_bin_width = self.bins.values()
292 .map(|bin| bin.len().ilog10() as usize)
293 .max()
294 .expect("Tried to print bins, but no features are found")
295 + 1;
296 for (feat_name, feat_bins) in self.bins.iter() {
297 let n_bins = feat_bins.len();
298 writeln!(
299 f,
300 "\
301 \t* [{feat_name: <width$} | \
302 {n_bins: >max_bin_width$} bins] \
303 {feat_bins}\
304 "
305 )?;
306 }
307
308 write!(f, "----------")
309 }
310}