1use super::node::TreeNode;
2use super::params::TreeClassifierParams;
3use crate::data::dataset::{Dataset, Number, WholeNumber};
4use crate::metrics::confusion::ClassificationMetrics;
5use nalgebra::{DMatrix, DVector};
6use rayon::iter::{IntoParallelIterator, ParallelIterator};
7use std::collections::{HashMap, HashSet};
8use std::error::Error;
9use std::f64;
10use std::marker::PhantomData;
11
12struct SplitData<XT: Number, YT: WholeNumber> {
13 pub feature_index: usize,
14 pub threshold: XT,
15 pub left: Dataset<XT, YT>,
16 pub right: Dataset<XT, YT>,
17 information_gain: f64,
18}
19#[derive(Clone, Debug)]
59pub struct DecisionTreeClassifier<XT: Number, YT: WholeNumber> {
60 root: Option<Box<TreeNode<XT, YT>>>,
61 tree_params: TreeClassifierParams,
62
63 _marker: PhantomData<XT>,
64}
65
66impl<XT: Number, YT: WholeNumber> ClassificationMetrics<YT> for DecisionTreeClassifier<XT, YT> {}
67
68impl<XT: Number, YT: WholeNumber> Default for DecisionTreeClassifier<XT, YT> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl<XT: Number, YT: WholeNumber> DecisionTreeClassifier<XT, YT> {
75 pub fn new() -> Self {
76 Self {
77 root: None,
78 tree_params: TreeClassifierParams::new(),
79
80 _marker: PhantomData,
81 }
82 }
83
84 pub fn with_params(
100 criterion: Option<String>,
101 min_samples_split: Option<u16>,
102 max_depth: Option<u16>,
103 ) -> Result<Self, Box<dyn Error>> {
104 let mut tree = Self::new();
105 tree.set_criterion(criterion.unwrap_or("gini".to_string()))?;
106 tree.set_min_samples_split(min_samples_split.unwrap_or(2))?;
107 tree.set_max_depth(max_depth)?;
108 Ok(tree)
109 }
110
111 pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
121 self.tree_params.set_min_samples_split(min_samples_split)
122 }
123
124 pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
134 self.tree_params.set_max_depth(max_depth)
135 }
136
137 pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
147 self.tree_params.set_criterion(criterion)
148 }
149
150 pub fn max_depth(&self) -> Option<u16> {
152 self.tree_params.max_depth()
153 }
154
155 pub fn min_samples_split(&self) -> u16 {
157 self.tree_params.min_samples_split()
158 }
159
160 pub fn criterion(&self) -> &str {
162 self.tree_params.criterion()
163 }
164
165 pub fn fit(&mut self, dataset: &Dataset<XT, YT>) -> Result<String, Box<dyn Error>> {
179 self.root = Some(Box::new(
180 self.build_tree(dataset, self.max_depth().map(|_| 0))?,
181 ));
182 Ok("Finished building the tree.".into())
183 }
184
185 pub fn predict(&self, features: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
199 if self.root.is_none() {
200 return Err("Tree wasn't built yet.".into());
201 }
202
203 let predictions: Vec<_> = features
204 .row_iter()
205 .map(|row| Self::make_prediction(row.transpose(), self.root.as_ref().unwrap()))
206 .collect();
207
208 Ok(DVector::from_vec(predictions))
209 }
210
211 fn make_prediction(features: DVector<XT>, node: &TreeNode<XT, YT>) -> YT {
212 if let Some(value) = &node.value {
213 return *value;
214 }
215 match &features[node.feature_index.unwrap()] {
216 x if x <= node.threshold.as_ref().unwrap() => {
217 return Self::make_prediction(features, node.left.as_ref().unwrap())
218 }
219 _ => return Self::make_prediction(features, node.right.as_ref().unwrap()),
220 }
221 }
222
223 fn build_tree(
224 &mut self,
225 dataset: &Dataset<XT, YT>,
226 current_depth: Option<u16>,
227 ) -> Result<TreeNode<XT, YT>, Box<dyn Error>> {
228 let (x, y) = &dataset.into_parts();
229 let (num_samples, num_features) = x.shape();
230 let is_data_homogenous = y.iter().all(|&val| val == y[0]);
231
232 if num_samples >= self.min_samples_split().into()
233 && current_depth <= self.max_depth()
234 && !is_data_homogenous
235 {
236 let splits = (0..num_features)
237 .into_par_iter()
238 .map(|feature_idx| {
239 self.get_split(dataset, feature_idx)
240 .map_err(|err| err.to_string())
241 })
242 .collect::<Vec<_>>();
243
244 let valid_splits = splits
245 .into_iter()
246 .filter_map(Result::ok)
247 .collect::<Vec<_>>();
248
249 if valid_splits.is_empty() {
250 return Ok(TreeNode::new(self.leaf_value(y.clone_owned())));
251 }
252
253 let best_split = match valid_splits.into_iter().max_by(|split1, split2| {
254 split1
255 .information_gain
256 .partial_cmp(&split2.information_gain)
257 .unwrap_or(std::cmp::Ordering::Equal)
258 }) {
259 Some(split) => split,
260 _ => {
261 return Err("No best split found.".into());
262 }
263 };
264
265 let left_child = best_split.left;
266 let right_child = best_split.right;
267 if best_split.information_gain > 0.0 {
268 let new_depth = current_depth.map(|depth| depth + 1);
269 let left_node = self.build_tree(&left_child, new_depth)?;
270 let right_node = self.build_tree(&right_child, new_depth)?;
271 return Ok(TreeNode {
272 feature_index: Some(best_split.feature_index),
273 threshold: Some(best_split.threshold),
274 left: Some(Box::new(left_node)),
275 right: Some(Box::new(right_node)),
276 value: None,
277 });
278 }
279 }
280
281 let leaf_value = self.leaf_value(y.clone_owned());
282 Ok(TreeNode::new(leaf_value))
283 }
284
285 fn leaf_value(&self, y: DVector<YT>) -> Option<YT> {
286 let mut class_counts = HashMap::new();
287 for item in y.iter() {
288 *class_counts.entry(item).or_insert(0) += 1;
289 }
290 class_counts
291 .into_iter()
292 .max_by_key(|&(_, count)| count)
293 .map(|(val, _)| *val)
294 }
295
296 fn get_split(
297 &self,
298 dataset: &Dataset<XT, YT>,
299 feature_index: usize,
300 ) -> Result<SplitData<XT, YT>, String> {
301 let mut best_split: Option<SplitData<XT, YT>> = None;
302 let mut best_information_gain = f64::NEG_INFINITY;
303
304 let mut unique_values: Vec<_> = dataset.x.column(feature_index).iter().cloned().collect();
305 unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
306 unique_values.dedup();
307
308 for value in &unique_values {
309 let (left_child, right_child) = dataset.split_on_threshold(feature_index, *value);
310
311 if left_child.is_not_empty() && right_child.is_not_empty() {
312 let current_information_gain =
313 self.calculate_information_gain(&dataset.y, &left_child.y, &right_child.y);
314
315 if current_information_gain > best_information_gain {
316 best_split = Some(SplitData {
317 feature_index,
318 threshold: *value,
319 left: left_child,
320 right: right_child,
321 information_gain: current_information_gain,
322 });
323 best_information_gain = current_information_gain;
324 }
325 }
326 }
327
328 best_split.ok_or(String::from("No split found."))
329 }
330
331 fn calculate_information_gain(
332 &self,
333 parent_y: &DVector<YT>,
334 left_y: &DVector<YT>,
335 right_y: &DVector<YT>,
336 ) -> f64 {
337 let weight_left = left_y.len() as f64 / parent_y.len() as f64;
338 let weight_right = right_y.len() as f64 / parent_y.len() as f64;
339
340 match self.criterion() {
341 "gini" => {
342 Self::gini_impurity(parent_y)
343 - weight_left * Self::gini_impurity(left_y)
344 - weight_right * Self::gini_impurity(right_y)
345 }
346 _ => {
347 Self::entropy(parent_y)
348 - weight_left * Self::entropy(left_y)
349 - weight_right * Self::entropy(right_y)
350 }
351 }
352 }
353
354 fn gini_impurity(y: &DVector<YT>) -> f64 {
355 let classes: HashSet<_> = y.iter().collect();
356 let mut impurity = 0.0;
357 for class in classes.into_iter() {
358 let p_class = y.iter().filter(|&x| x == class).count() as f64 / y.len() as f64;
359 impurity += p_class * p_class;
360 }
361 1.0 - impurity
362 }
363
364 fn entropy(y: &DVector<YT>) -> f64 {
365 let classes: HashSet<_> = y.iter().collect();
366 let mut entropy = 0.0;
367 for class in classes.into_iter() {
368 let p_class = y.iter().filter(|&x| x == class).count() as f64 / y.len() as f64;
369 entropy += p_class * p_class.log2();
370 }
371 -entropy
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use nalgebra::DVector;
379
380 #[test]
381 fn test_default() {
382 let tree = DecisionTreeClassifier::<f64, u8>::default();
383 assert_eq!(tree.min_samples_split(), 2); assert_eq!(tree.max_depth(), None); assert_eq!(tree.criterion(), "gini"); }
387
388 #[test]
389 fn test_too_low_min_samples() {
390 let tree = DecisionTreeClassifier::<f64, u8>::new().set_min_samples_split(0);
391 assert!(tree.is_err());
392 assert_eq!(
393 tree.unwrap_err().to_string(),
394 "The minimum number of samples to split must be greater than 1."
395 );
396 }
397
398 #[test]
399 fn test_to_low_depth() {
400 let tree = DecisionTreeClassifier::<f64, u8>::new().set_max_depth(Some(0));
401 assert!(tree.is_err());
402 assert_eq!(
403 tree.unwrap_err().to_string(),
404 "The maximum depth must be greater than 0."
405 );
406 }
407
408 #[test]
409 fn test_calculate_information_gain() {
410 let classifier = DecisionTreeClassifier::<f64, u8>::new();
411 let parent_y = DVector::from_vec(vec![1, 1, 0, 0]);
412 let left_y = DVector::from_vec(vec![1, 1]);
413 let right_y = DVector::from_vec(vec![0, 0]);
414
415 let result = classifier.calculate_information_gain(&parent_y, &left_y, &right_y);
416 assert_eq!(result, 0.5); }
418
419 #[test]
420 fn test_gini_impurity_homogeneous() {
421 let y = DVector::from_vec(vec![1, 1, 1, 1]);
422 assert_eq!(DecisionTreeClassifier::<f64, u32>::gini_impurity(&y), 0.0);
423 }
424
425 #[test]
426 fn test_gini_impurity_mixed() {
427 let y = DVector::from_vec(vec![1, 0, 1, 0]);
428 assert!((DecisionTreeClassifier::<f64, u32>::gini_impurity(&y) - 0.5).abs() < f64::EPSILON);
429 }
430
431 #[test]
432 fn test_gini_impurity_multiple_classes() {
433 let y = DVector::from_vec(vec![1, 2, 1, 2, 3]);
434 let expected_impurity =
435 1.0 - (2.0 / 5.0) * (2.0 / 5.0) - (2.0 / 5.0) * (2.0 / 5.0) - (1.0 / 5.0) * (1.0 / 5.0);
436 assert!(
437 (DecisionTreeClassifier::<f64, u32>::gini_impurity(&y) - expected_impurity).abs()
438 < f64::EPSILON
439 );
440 }
441
442 #[test]
443 fn test_entropy() {
444 let y = DVector::from_vec(vec![1, 1, 0, 0]);
445 assert_eq!(DecisionTreeClassifier::<f64, u32>::entropy(&y), 1.0);
446 }
447
448 #[test]
449 fn test_entropy_homogeneous() {
450 let y = DVector::from_vec(vec![1, 1, 1, 1]);
451 assert_eq!(DecisionTreeClassifier::<f64, u32>::entropy(&y), 0.0);
452 }
453
454 #[test]
455 fn test_information_gain_gini() {
456 let classifier = DecisionTreeClassifier::<f64, u32>::new();
457 let parent_y = DVector::from_vec(vec![1, 1, 1, 0, 0, 1]);
458 let left_y = DVector::from_vec(vec![1, 1]);
459 let right_y = DVector::from_vec(vec![1, 0, 0, 1]);
460
461 let parent_impurity = DecisionTreeClassifier::<f64, u32>::gini_impurity(&parent_y);
462 let left_impurity = DecisionTreeClassifier::<f64, u32>::gini_impurity(&left_y);
463 let right_impurity = DecisionTreeClassifier::<f64, u32>::gini_impurity(&right_y);
464
465 let weight_left = left_y.len() as f64 / parent_y.len() as f64;
466 let weight_right = right_y.len() as f64 / parent_y.len() as f64;
467 let expected_gain =
468 parent_impurity - (weight_left * left_impurity + weight_right * right_impurity);
469
470 let result = classifier.calculate_information_gain(&parent_y, &left_y, &right_y);
471 assert!((result - expected_gain).abs() < f64::EPSILON);
472 }
473
474 #[test]
475 fn test_information_gain_entropy() {
476 let mut classifier = DecisionTreeClassifier::<f64, u32>::new();
477 classifier.set_criterion("entropy".to_string()).unwrap();
478 let parent_y = DVector::from_vec(vec![1, 1, 1, 0, 0, 1]);
479 let left_y = DVector::from_vec(vec![1, 1]);
480 let right_y = DVector::from_vec(vec![1, 0, 0, 1]);
481
482 let parent_impurity = DecisionTreeClassifier::<f64, u32>::entropy(&parent_y);
483 let left_impurity = DecisionTreeClassifier::<f64, u32>::entropy(&left_y);
484 let right_impurity = DecisionTreeClassifier::<f64, u32>::entropy(&right_y);
485
486 let weight_left = left_y.len() as f64 / parent_y.len() as f64;
487 let weight_right = right_y.len() as f64 / parent_y.len() as f64;
488 let expected_gain =
489 parent_impurity - (weight_left * left_impurity + weight_right * right_impurity);
490
491 let result = classifier.calculate_information_gain(&parent_y, &left_y, &right_y);
492
493 assert!((result - expected_gain).abs() < f64::EPSILON);
494 }
495
496 #[test]
497 fn test_tree_building() {
498 let mut classifier = DecisionTreeClassifier::<f64, u32>::new();
499
500 let x = DMatrix::from_row_slice(
502 4,
503 2,
504 &[
505 1.0, 2.0, 1.1, 2.1, 2.0, 3.0, 2.1, 3.1, ],
510 );
511 let y = DVector::from_vec(vec![0, 0, 1, 1]); let dataset = Dataset::new(x, y);
513
514 let _ = classifier.fit(&dataset);
515
516 assert!(classifier.root.is_some());
518
519 }
521
522 #[test]
523 fn test_empty_predict() {
524 let classifier = DecisionTreeClassifier::<f64, u32>::new();
525 let features = DMatrix::from_row_slice(0, 0, &[]);
526 let result = classifier.predict(&features);
527
528 assert!(result.is_err());
529 assert_eq!(result.unwrap_err().to_string(), "Tree wasn't built yet.");
530 }
531}