rusty_ai/trees/
classifier.rs

1use super::node::TreeNode;
2use super::params::TreeClassifierParams;
3use crate::data::dataset::{Dataset, Number, WholeNumber};
4use crate::metrics::confusion::ClassificationMetrics;
5use nalgebra::{DMatrix, DVector};
6use rayon::iter::{IntoParallelIterator, ParallelIterator};
7use std::collections::{HashMap, HashSet};
8use std::error::Error;
9use std::f64;
10use std::marker::PhantomData;
11
12struct SplitData<XT: Number, YT: WholeNumber> {
13    pub feature_index: usize,
14    pub threshold: XT,
15    pub left: Dataset<XT, YT>,
16    pub right: Dataset<XT, YT>,
17    information_gain: f64,
18}
19/// Implementation of a decision tree classifier.
20///
21/// This struct represents a decision tree classifier, which is a supervised machine learning algorithm
22/// used for classification tasks. It can be used to build a decision tree from a dataset and make
23/// predictions on new data.
24///
25/// # Type Parameters
26///
27/// - `XT`: The type of the features in the dataset.
28/// - `YT`: The type of the labels in the dataset.
29///
30/// # Examples
31///
32/// ```
33/// use rusty_ai::trees::classifier::DecisionTreeClassifier;
34/// use rusty_ai::data::dataset::Dataset;
35/// use nalgebra::{DMatrix, DVector};
36///
37/// // Create a new decision tree classifier
38/// let mut tree = DecisionTreeClassifier::<f64, u8>::new();
39///
40/// // Set the minimum number of samples required to split an internal node
41/// tree.set_min_samples_split(5).unwrap();
42///
43/// // Set the maximum depth of the tree
44/// tree.set_max_depth(Some(10)).unwrap();
45///
46///
47///
48/// let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
49/// let y = DVector::from_vec(vec![0, 1, 0]);
50/// let dataset = Dataset::new(x, y);
51/// tree.fit(&dataset).unwrap();
52///
53/// // Make predictions on new data points
54/// let x_test = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
55/// let predictions = tree.predict(&x_test);
56/// assert!(predictions.is_ok());
57/// ```
58#[derive(Clone, Debug)]
59pub struct DecisionTreeClassifier<XT: Number, YT: WholeNumber> {
60    root: Option<Box<TreeNode<XT, YT>>>,
61    tree_params: TreeClassifierParams,
62
63    _marker: PhantomData<XT>,
64}
65
66impl<XT: Number, YT: WholeNumber> ClassificationMetrics<YT> for DecisionTreeClassifier<XT, YT> {}
67
68impl<XT: Number, YT: WholeNumber> Default for DecisionTreeClassifier<XT, YT> {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl<XT: Number, YT: WholeNumber> DecisionTreeClassifier<XT, YT> {
75    pub fn new() -> Self {
76        Self {
77            root: None,
78            tree_params: TreeClassifierParams::new(),
79
80            _marker: PhantomData,
81        }
82    }
83
84    /// Creates a new instance of the decision tree classifier with custom parameters.
85    ///
86    /// # Arguments
87    ///
88    /// * `criterion` - The criterion used for splitting nodes. Default is "gini".
89    /// * `min_samples_split` - The minimum number of samples required to split an internal node. Default is 2.
90    /// * `max_depth` - The maximum depth of the tree. Default is None (unlimited depth).
91    ///
92    /// # Returns
93    ///
94    /// A new instance of the decision tree classifier with the specified parameters.
95    ///
96    /// # Errors
97    ///
98    /// This method will return an error if the classifier is unknown, the minimum number of samples to split is less than 2, or if the maximum depth is less than 1.
99    pub fn with_params(
100        criterion: Option<String>,
101        min_samples_split: Option<u16>,
102        max_depth: Option<u16>,
103    ) -> Result<Self, Box<dyn Error>> {
104        let mut tree = Self::new();
105        tree.set_criterion(criterion.unwrap_or("gini".to_string()))?;
106        tree.set_min_samples_split(min_samples_split.unwrap_or(2))?;
107        tree.set_max_depth(max_depth)?;
108        Ok(tree)
109    }
110
111    /// Sets the minimum number of samples required to split an internal node.
112    ///
113    /// # Arguments
114    ///
115    /// * `min_samples_split` - The minimum number of samples required to split an internal node.
116    ///
117    /// # Errors
118    ///
119    /// This method will return an error if the minimum number of samples to split is less than 2.
120    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
121        self.tree_params.set_min_samples_split(min_samples_split)
122    }
123
124    /// Sets the maximum depth of the tree.
125    ///
126    /// # Arguments
127    ///
128    /// * `max_depth` - The maximum depth of the tree.
129    ///
130    /// # Errors
131    ///
132    /// This method will return an error if the maximum depth is less than 1.
133    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
134        self.tree_params.set_max_depth(max_depth)
135    }
136
137    /// Sets the criterion used for splitting nodes.
138    ///
139    /// # Arguments
140    ///
141    /// * `criterion` - The criterion used for splitting nodes.
142    ///
143    /// # Errors
144    ///
145    /// This method will return an error if the criterion is not supported.
146    pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
147        self.tree_params.set_criterion(criterion)
148    }
149
150    /// Returns the maximum depth of the tree.
151    pub fn max_depth(&self) -> Option<u16> {
152        self.tree_params.max_depth()
153    }
154
155    /// Returns the minimum number of samples required to split an internal node.
156    pub fn min_samples_split(&self) -> u16 {
157        self.tree_params.min_samples_split()
158    }
159
160    /// Returns the criterion used for splitting nodes.
161    pub fn criterion(&self) -> &str {
162        self.tree_params.criterion()
163    }
164
165    /// Builds the decision tree from a dataset.
166    ///
167    /// # Arguments
168    ///
169    /// * `dataset` - The dataset containing features and labels.
170    ///
171    /// # Returns
172    ///
173    /// A string indicating that the tree was built successfully.
174    ///
175    /// # Errors
176    ///
177    /// This method will return an error if the tree couldn't be built.
178    pub fn fit(&mut self, dataset: &Dataset<XT, YT>) -> Result<String, Box<dyn Error>> {
179        self.root = Some(Box::new(
180            self.build_tree(dataset, self.max_depth().map(|_| 0))?,
181        ));
182        Ok("Finished building the tree.".into())
183    }
184
185    /// Predicts the labels for new data.
186    ///
187    /// # Arguments
188    ///
189    /// * `features` - The matrix of features for the new data.
190    ///
191    /// # Returns
192    ///
193    /// A vector containing the predicted labels for the new data.
194    ///
195    /// # Errors
196    ///
197    /// This method will return an error if the tree wasn't built yet.
198    pub fn predict(&self, features: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
199        if self.root.is_none() {
200            return Err("Tree wasn't built yet.".into());
201        }
202
203        let predictions: Vec<_> = features
204            .row_iter()
205            .map(|row| Self::make_prediction(row.transpose(), self.root.as_ref().unwrap()))
206            .collect();
207
208        Ok(DVector::from_vec(predictions))
209    }
210
211    fn make_prediction(features: DVector<XT>, node: &TreeNode<XT, YT>) -> YT {
212        if let Some(value) = &node.value {
213            return *value;
214        }
215        match &features[node.feature_index.unwrap()] {
216            x if x <= node.threshold.as_ref().unwrap() => {
217                return Self::make_prediction(features, node.left.as_ref().unwrap())
218            }
219            _ => return Self::make_prediction(features, node.right.as_ref().unwrap()),
220        }
221    }
222
223    fn build_tree(
224        &mut self,
225        dataset: &Dataset<XT, YT>,
226        current_depth: Option<u16>,
227    ) -> Result<TreeNode<XT, YT>, Box<dyn Error>> {
228        let (x, y) = &dataset.into_parts();
229        let (num_samples, num_features) = x.shape();
230        let is_data_homogenous = y.iter().all(|&val| val == y[0]);
231
232        if num_samples >= self.min_samples_split().into()
233            && current_depth <= self.max_depth()
234            && !is_data_homogenous
235        {
236            let splits = (0..num_features)
237                .into_par_iter()
238                .map(|feature_idx| {
239                    self.get_split(dataset, feature_idx)
240                        .map_err(|err| err.to_string())
241                })
242                .collect::<Vec<_>>();
243
244            let valid_splits = splits
245                .into_iter()
246                .filter_map(Result::ok)
247                .collect::<Vec<_>>();
248
249            if valid_splits.is_empty() {
250                return Ok(TreeNode::new(self.leaf_value(y.clone_owned())));
251            }
252
253            let best_split = match valid_splits.into_iter().max_by(|split1, split2| {
254                split1
255                    .information_gain
256                    .partial_cmp(&split2.information_gain)
257                    .unwrap_or(std::cmp::Ordering::Equal)
258            }) {
259                Some(split) => split,
260                _ => {
261                    return Err("No best split found.".into());
262                }
263            };
264
265            let left_child = best_split.left;
266            let right_child = best_split.right;
267            if best_split.information_gain > 0.0 {
268                let new_depth = current_depth.map(|depth| depth + 1);
269                let left_node = self.build_tree(&left_child, new_depth)?;
270                let right_node = self.build_tree(&right_child, new_depth)?;
271                return Ok(TreeNode {
272                    feature_index: Some(best_split.feature_index),
273                    threshold: Some(best_split.threshold),
274                    left: Some(Box::new(left_node)),
275                    right: Some(Box::new(right_node)),
276                    value: None,
277                });
278            }
279        }
280
281        let leaf_value = self.leaf_value(y.clone_owned());
282        Ok(TreeNode::new(leaf_value))
283    }
284
285    fn leaf_value(&self, y: DVector<YT>) -> Option<YT> {
286        let mut class_counts = HashMap::new();
287        for item in y.iter() {
288            *class_counts.entry(item).or_insert(0) += 1;
289        }
290        class_counts
291            .into_iter()
292            .max_by_key(|&(_, count)| count)
293            .map(|(val, _)| *val)
294    }
295
296    fn get_split(
297        &self,
298        dataset: &Dataset<XT, YT>,
299        feature_index: usize,
300    ) -> Result<SplitData<XT, YT>, String> {
301        let mut best_split: Option<SplitData<XT, YT>> = None;
302        let mut best_information_gain = f64::NEG_INFINITY;
303
304        let mut unique_values: Vec<_> = dataset.x.column(feature_index).iter().cloned().collect();
305        unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
306        unique_values.dedup();
307
308        for value in &unique_values {
309            let (left_child, right_child) = dataset.split_on_threshold(feature_index, *value);
310
311            if left_child.is_not_empty() && right_child.is_not_empty() {
312                let current_information_gain =
313                    self.calculate_information_gain(&dataset.y, &left_child.y, &right_child.y);
314
315                if current_information_gain > best_information_gain {
316                    best_split = Some(SplitData {
317                        feature_index,
318                        threshold: *value,
319                        left: left_child,
320                        right: right_child,
321                        information_gain: current_information_gain,
322                    });
323                    best_information_gain = current_information_gain;
324                }
325            }
326        }
327
328        best_split.ok_or(String::from("No split found."))
329    }
330
331    fn calculate_information_gain(
332        &self,
333        parent_y: &DVector<YT>,
334        left_y: &DVector<YT>,
335        right_y: &DVector<YT>,
336    ) -> f64 {
337        let weight_left = left_y.len() as f64 / parent_y.len() as f64;
338        let weight_right = right_y.len() as f64 / parent_y.len() as f64;
339
340        match self.criterion() {
341            "gini" => {
342                Self::gini_impurity(parent_y)
343                    - weight_left * Self::gini_impurity(left_y)
344                    - weight_right * Self::gini_impurity(right_y)
345            }
346            _ => {
347                Self::entropy(parent_y)
348                    - weight_left * Self::entropy(left_y)
349                    - weight_right * Self::entropy(right_y)
350            }
351        }
352    }
353
354    fn gini_impurity(y: &DVector<YT>) -> f64 {
355        let classes: HashSet<_> = y.iter().collect();
356        let mut impurity = 0.0;
357        for class in classes.into_iter() {
358            let p_class = y.iter().filter(|&x| x == class).count() as f64 / y.len() as f64;
359            impurity += p_class * p_class;
360        }
361        1.0 - impurity
362    }
363
364    fn entropy(y: &DVector<YT>) -> f64 {
365        let classes: HashSet<_> = y.iter().collect();
366        let mut entropy = 0.0;
367        for class in classes.into_iter() {
368            let p_class = y.iter().filter(|&x| x == class).count() as f64 / y.len() as f64;
369            entropy += p_class * p_class.log2();
370        }
371        -entropy
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use nalgebra::DVector;
379
380    #[test]
381    fn test_default() {
382        let tree = DecisionTreeClassifier::<f64, u8>::default();
383        assert_eq!(tree.min_samples_split(), 2); // Default min_samples_split
384        assert_eq!(tree.max_depth(), None); // Default max_depth
385        assert_eq!(tree.criterion(), "gini"); // Default criterion
386    }
387
388    #[test]
389    fn test_too_low_min_samples() {
390        let tree = DecisionTreeClassifier::<f64, u8>::new().set_min_samples_split(0);
391        assert!(tree.is_err());
392        assert_eq!(
393            tree.unwrap_err().to_string(),
394            "The minimum number of samples to split must be greater than 1."
395        );
396    }
397
398    #[test]
399    fn test_to_low_depth() {
400        let tree = DecisionTreeClassifier::<f64, u8>::new().set_max_depth(Some(0));
401        assert!(tree.is_err());
402        assert_eq!(
403            tree.unwrap_err().to_string(),
404            "The maximum depth must be greater than 0."
405        );
406    }
407
408    #[test]
409    fn test_calculate_information_gain() {
410        let classifier = DecisionTreeClassifier::<f64, u8>::new();
411        let parent_y = DVector::from_vec(vec![1, 1, 0, 0]);
412        let left_y = DVector::from_vec(vec![1, 1]);
413        let right_y = DVector::from_vec(vec![0, 0]);
414
415        let result = classifier.calculate_information_gain(&parent_y, &left_y, &right_y);
416        assert_eq!(result, 0.5); // replace with your expected result
417    }
418
419    #[test]
420    fn test_gini_impurity_homogeneous() {
421        let y = DVector::from_vec(vec![1, 1, 1, 1]);
422        assert_eq!(DecisionTreeClassifier::<f64, u32>::gini_impurity(&y), 0.0);
423    }
424
425    #[test]
426    fn test_gini_impurity_mixed() {
427        let y = DVector::from_vec(vec![1, 0, 1, 0]);
428        assert!((DecisionTreeClassifier::<f64, u32>::gini_impurity(&y) - 0.5).abs() < f64::EPSILON);
429    }
430
431    #[test]
432    fn test_gini_impurity_multiple_classes() {
433        let y = DVector::from_vec(vec![1, 2, 1, 2, 3]);
434        let expected_impurity =
435            1.0 - (2.0 / 5.0) * (2.0 / 5.0) - (2.0 / 5.0) * (2.0 / 5.0) - (1.0 / 5.0) * (1.0 / 5.0);
436        assert!(
437            (DecisionTreeClassifier::<f64, u32>::gini_impurity(&y) - expected_impurity).abs()
438                < f64::EPSILON
439        );
440    }
441
442    #[test]
443    fn test_entropy() {
444        let y = DVector::from_vec(vec![1, 1, 0, 0]);
445        assert_eq!(DecisionTreeClassifier::<f64, u32>::entropy(&y), 1.0);
446    }
447
448    #[test]
449    fn test_entropy_homogeneous() {
450        let y = DVector::from_vec(vec![1, 1, 1, 1]);
451        assert_eq!(DecisionTreeClassifier::<f64, u32>::entropy(&y), 0.0);
452    }
453
454    #[test]
455    fn test_information_gain_gini() {
456        let classifier = DecisionTreeClassifier::<f64, u32>::new();
457        let parent_y = DVector::from_vec(vec![1, 1, 1, 0, 0, 1]);
458        let left_y = DVector::from_vec(vec![1, 1]);
459        let right_y = DVector::from_vec(vec![1, 0, 0, 1]);
460
461        let parent_impurity = DecisionTreeClassifier::<f64, u32>::gini_impurity(&parent_y);
462        let left_impurity = DecisionTreeClassifier::<f64, u32>::gini_impurity(&left_y);
463        let right_impurity = DecisionTreeClassifier::<f64, u32>::gini_impurity(&right_y);
464
465        let weight_left = left_y.len() as f64 / parent_y.len() as f64;
466        let weight_right = right_y.len() as f64 / parent_y.len() as f64;
467        let expected_gain =
468            parent_impurity - (weight_left * left_impurity + weight_right * right_impurity);
469
470        let result = classifier.calculate_information_gain(&parent_y, &left_y, &right_y);
471        assert!((result - expected_gain).abs() < f64::EPSILON);
472    }
473
474    #[test]
475    fn test_information_gain_entropy() {
476        let mut classifier = DecisionTreeClassifier::<f64, u32>::new();
477        classifier.set_criterion("entropy".to_string()).unwrap();
478        let parent_y = DVector::from_vec(vec![1, 1, 1, 0, 0, 1]);
479        let left_y = DVector::from_vec(vec![1, 1]);
480        let right_y = DVector::from_vec(vec![1, 0, 0, 1]);
481
482        let parent_impurity = DecisionTreeClassifier::<f64, u32>::entropy(&parent_y);
483        let left_impurity = DecisionTreeClassifier::<f64, u32>::entropy(&left_y);
484        let right_impurity = DecisionTreeClassifier::<f64, u32>::entropy(&right_y);
485
486        let weight_left = left_y.len() as f64 / parent_y.len() as f64;
487        let weight_right = right_y.len() as f64 / parent_y.len() as f64;
488        let expected_gain =
489            parent_impurity - (weight_left * left_impurity + weight_right * right_impurity);
490
491        let result = classifier.calculate_information_gain(&parent_y, &left_y, &right_y);
492
493        assert!((result - expected_gain).abs() < f64::EPSILON);
494    }
495
496    #[test]
497    fn test_tree_building() {
498        let mut classifier = DecisionTreeClassifier::<f64, u32>::new();
499
500        // Assuming a simple dataset with two features
501        let x = DMatrix::from_row_slice(
502            4,
503            2,
504            &[
505                1.0, 2.0, // Sample 1
506                1.1, 2.1, // Sample 2
507                2.0, 3.0, // Sample 3
508                2.1, 3.1, // Sample 4
509            ],
510        );
511        let y = DVector::from_vec(vec![0, 0, 1, 1]); // Target values
512        let dataset = Dataset::new(x, y);
513
514        let _ = classifier.fit(&dataset);
515
516        // Check if the root of the tree is correctly set
517        assert!(classifier.root.is_some());
518
519        // Further checks would depend on your tree structure and the expected outcome after fitting the dataset
520    }
521
522    #[test]
523    fn test_empty_predict() {
524        let classifier = DecisionTreeClassifier::<f64, u32>::new();
525        let features = DMatrix::from_row_slice(0, 0, &[]);
526        let result = classifier.predict(&features);
527
528        assert!(result.is_err());
529        assert_eq!(result.unwrap_err().to_string(), "Tree wasn't built yet.");
530    }
531}