miniboosts/weak_learner/decision_tree/
decision_tree_algorithm.rs

1use rayon::prelude::*;
2
3
4use crate::{Sample, WeakLearner};
5use super::bin::*;
6
7
8use crate::weak_learner::common::{
9    type_and_struct::*,
10    split_rule::*,
11};
12use super::{
13    node::*,
14    criterion::*,
15    train_node::*,
16    decision_tree_classifier::DecisionTreeClassifier,
17};
18
19
20use std::fmt;
21use std::rc::Rc;
22use std::collections::HashMap;
23
24
25/// The Decision Tree algorithm.  
26/// Given a set of training examples for classification
27/// and a distribution over the set,
28/// [`DecisionTree`] outputs a decision tree classifier
29/// named [`DecisionTreeClassifier`]
30/// under the specified parameters.
31///
32/// The code is based on the book:  
33/// [Classification and Regression
34/// Trees](https://www.amazon.com/Classification-Regression-Wadsworth-Statistics-Probability/dp/0412048418)
35/// by Leo Breiman, Jerome H. Friedman, Richard A. Olshen, and Charles J. Stone.
36///
37/// [`DecisionTree`] is constructed 
38/// by [`DecisionTreeBuilder`](crate::weak_learner::DecisionTreeBuilder).
39/// 
40/// # Example
41/// ```no_run
42/// use miniboosts::prelude::*;
43/// 
44/// // Read the training data from the CSV file.
45/// let file = "/path/to/data/file.csv";
46/// let sample = SampleReader::new()
47///     .file(file)
48///     .has_header(true)
49///     .target_feature("class")
50///     .read()
51///     .unwrap();
52/// 
53/// 
54/// // Get an instance of decision tree weak learner.
55/// // In this example, the output tree is at most depth 2.
56/// let tree = DecisionTreeBuilder::new(&sample)
57///     .max_depth(2)
58///     .criterion(Criterion::Entropy)
59///     .build();
60///
61/// let n_sample = sample.shape()f64;
62/// let dist = vec![1f64 / n_sample as f64; n_sample];
63/// let f = tree.produce(&sample, &dist);
64/// 
65/// let predictions = f.predict_all(&sample);
66/// 
67/// let loss = sample.target()
68///     .into_iter()
69///     .zip(predictions)
70///     .map(|(ty, py)| if *ty == py as f64 { 0f64 } else { 1f64 })
71///     .sum::<f64>()
72///     / n_sample as f64;
73/// println!("loss (train) is: {loss}");
74/// ```
75pub struct DecisionTree<'a> {
76    bins: HashMap<&'a str, Bins>,
77    criterion: Criterion,
78    max_depth: Depth,
79}
80
81
82impl<'a> DecisionTree<'a> {
83    /// Initialize [`DecisionTree`].
84    /// This method is called only via `DecisionTreeBuilder::build`.
85    #[inline]
86    pub(super) fn from_components(
87        bins: HashMap<&'a str, Bins>,
88        criterion: Criterion,
89        max_depth: Depth,
90    ) -> Self
91    {
92        Self { bins, criterion, max_depth, }
93    }
94
95
96    /// Construct a full binary tree of depth `depth`.
97    #[inline]
98    fn full_tree(
99        &self,
100        sample: &'a Sample,
101        dist: &[f64],
102        indices: Vec<usize>,
103        criterion: Criterion,
104        depth: Depth,
105    ) -> TrainNodePtr
106    {
107        let total_weight = indices.par_iter()
108            .copied()
109            .map(|i| dist[i])
110            .sum::<f64>();
111
112
113        // Compute the best confidence that minimizes the training error
114        // on this node.
115        let (conf, loss) = confidence_and_loss(sample, dist, &indices[..]);
116
117
118        // If sum of `dist` over `train` is zero, construct a leaf node.
119        if loss == 0f64 || depth < 1 {
120            return TrainNode::leaf(conf, total_weight, loss);
121        }
122
123
124        // Find the best pair of feature name and threshold
125        // based on the `criterion`.
126        let (feature, threshold) = criterion.best_split(
127            &self.bins, sample, dist, &indices[..]
128        );
129
130
131        // Construct the splitting rule
132        // from the best feature and threshold.
133        let rule = Splitter::new(feature, Threshold::from(threshold));
134
135
136        // Split the train data for left/right childrens
137        let mut lindices = Vec::new();
138        let mut rindices = Vec::new();
139        for i in indices {
140            match rule.split(sample, i) {
141                LR::Left  => { lindices.push(i); },
142                LR::Right => { rindices.push(i); },
143            }
144        }
145
146
147        // If the split has no meaning, construct a leaf node.
148        if lindices.is_empty() || rindices.is_empty() {
149            return TrainNode::leaf(conf, total_weight, loss);
150        }
151
152        // At this point, `depth > 0` is guaranteed so that
153        // one can grow the tree.
154        let depth = depth - 1;
155        let ltree = self.full_tree(sample, dist, lindices, criterion, depth);
156        let rtree = self.full_tree(sample, dist, rindices, criterion, depth);
157
158
159        TrainNode::branch(rule, ltree, rtree, conf, total_weight, loss)
160    }
161}
162
163
164impl<'a> WeakLearner for DecisionTree<'a> {
165    type Hypothesis = DecisionTreeClassifier;
166
167
168    fn name(&self) -> &str {
169        "Decision Tree"
170    }
171
172
173    fn info(&self) -> Option<Vec<(&str, String)>> {
174        let n_bins = self.bins.values()
175            .map(|bin| bin.len())
176            .reduce(usize::max)
177            .unwrap_or(0);
178        let info = Vec::from([
179            ("# of bins (max)", format!("{n_bins}")),
180            ("Max depth", format!("{}", self.max_depth)),
181            ("Split criterion", format!("{}", self.criterion)),
182        ]);
183        Some(info)
184    }
185
186
187    /// This method computes as follows;
188    /// 1. construct a `TrainNode` which contains some information
189    ///     to grow a tree (e.g., impurity, total distribution mass, etc.)
190    /// 2. Convert `TrainNode` to `Node` that pares redundant information
191    #[inline]
192    fn produce(&self, sample: &Sample, dist: &[f64])
193        -> Self::Hypothesis
194    {
195        let n_sample = sample.shape().0;
196
197        let indices = (0..n_sample).filter(|&i| dist[i] > 0f64)
198            .collect::<Vec<usize>>();
199        assert_ne!(indices.len(), 0);
200
201        let criterion = self.criterion;
202
203        // Construct a large binary tree
204        let tree = self.full_tree(
205            sample, dist, indices, criterion, self.max_depth
206        );
207
208
209        tree.borrow_mut().remove_redundant_nodes();
210
211
212        let root = Node::from(
213            Rc::try_unwrap(tree)
214                .expect("Root node has reference counter >= 1")
215                .into_inner()
216        );
217
218
219        DecisionTreeClassifier::from(root)
220    }
221}
222
223
224/// This function returns a tuple `(c, l)` where
225/// - `c` is the **confidence** for some label `y`
226/// that minimizes the training loss.
227/// - `l` is the training loss when the confidence is `y`.
228/// 
229/// **Note that** this function assumes that the label is `+1` or `-1`.
230#[inline]
231fn confidence_and_loss(sample: &Sample, dist: &[f64], indices: &[usize])
232    -> (Confidence<f64>, LossValue)
233{
234
235    assert_ne!(indices.len(), 0);
236    let target = sample.target();
237    let mut counter: HashMap<i64, f64> = HashMap::new();
238
239    for &i in indices {
240        let l = target[i] as i64;
241        let cnt = counter.entry(l).or_insert(0f64);
242        *cnt += dist[i];
243    }
244
245
246    let total = counter.values().sum::<f64>();
247
248    // Compute the max (key, val) that has maximal p(j, t)
249    let (label, p) = counter.into_par_iter()
250        .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
251        .unwrap();
252
253
254    // From the update rule of boosting algorithm,
255    // the sum of `dist` over `indices` may become zero,
256    let loss = if total > 0f64 { total * (1f64 - (p / total)) } else { 0f64 };
257
258    // `label` takes value in `{-1, +1}`.
259    let confidence = if total > 0f64 {
260        (label as f64 * (2f64 * (p / total) - 1f64)).clamp(-1f64, 1f64)
261    } else {
262        (label as f64).clamp(-1f64, 1f64)
263    };
264
265    let confidence = Confidence::from(confidence);
266    let loss = LossValue::from(loss);
267    (confidence, loss)
268}
269
270
271impl fmt::Display for DecisionTree<'_> {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        writeln!(
274            f,
275            "\
276            ----------\n\
277            # Decision Tree Weak Learner\n\n\
278            - Max depth: {}\n\
279            - Splitting criterion: {}\n\
280            - Bins:\
281            ",
282            self.max_depth,
283            self.criterion,
284        )?;
285
286
287        let width = self.bins.keys()
288            .map(|key| key.len())
289            .max()
290            .expect("Tried to print bins, but no features are found");
291        let max_bin_width = self.bins.values()
292            .map(|bin| bin.len().ilog10() as usize)
293            .max()
294            .expect("Tried to print bins, but no features are found")
295            + 1;
296        for (feat_name, feat_bins) in self.bins.iter() {
297            let n_bins = feat_bins.len();
298            writeln!(
299                f,
300                "\
301                \t* [{feat_name: <width$} | \
302                {n_bins: >max_bin_width$} bins]  \
303                {feat_bins}\
304                "
305            )?;
306        }
307
308        write!(f, "----------")
309    }
310}