smartcore/tree/
base_tree_regressor.rs

1use std::collections::LinkedList;
2use std::default::Default;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5
6use rand::seq::SliceRandom;
7use rand::Rng;
8
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use crate::error::Failed;
13use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
14use crate::numbers::basenum::Number;
15use crate::rand_custom::get_rng_impl;
16
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18#[derive(Debug, Clone, Default)]
19pub enum Splitter {
20    Random,
21    #[default]
22    Best,
23}
24
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27/// Parameters of Regression base_tree
28pub struct BaseTreeRegressorParameters {
29    #[cfg_attr(feature = "serde", serde(default))]
30    /// The maximum depth of the base_tree.
31    pub max_depth: Option<u16>,
32    #[cfg_attr(feature = "serde", serde(default))]
33    /// The minimum number of samples required to be at a leaf node.
34    pub min_samples_leaf: usize,
35    #[cfg_attr(feature = "serde", serde(default))]
36    /// The minimum number of samples required to split an internal node.
37    pub min_samples_split: usize,
38    #[cfg_attr(feature = "serde", serde(default))]
39    /// Controls the randomness of the estimator
40    pub seed: Option<u64>,
41    #[cfg_attr(feature = "serde", serde(default))]
42    /// Determines the strategy used to choose the split at each node.
43    pub splitter: Splitter,
44}
45
46/// Regression base_tree
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48#[derive(Debug)]
49pub struct BaseTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
50    nodes: Vec<Node>,
51    parameters: Option<BaseTreeRegressorParameters>,
52    depth: u16,
53    _phantom_tx: PhantomData<TX>,
54    _phantom_ty: PhantomData<TY>,
55    _phantom_x: PhantomData<X>,
56    _phantom_y: PhantomData<Y>,
57}
58
59impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
60    BaseTreeRegressor<TX, TY, X, Y>
61{
62    /// Get nodes, return a shared reference
63    fn nodes(&self) -> &Vec<Node> {
64        self.nodes.as_ref()
65    }
66    /// Get parameters, return a shared reference
67    fn parameters(&self) -> &BaseTreeRegressorParameters {
68        self.parameters.as_ref().unwrap()
69    }
70    /// Get estimate of intercept, return value
71    fn depth(&self) -> u16 {
72        self.depth
73    }
74}
75
76#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
77#[derive(Debug, Clone)]
78struct Node {
79    output: f64,
80    split_feature: usize,
81    split_value: Option<f64>,
82    split_score: Option<f64>,
83    true_child: Option<usize>,
84    false_child: Option<usize>,
85}
86
87impl Node {
88    fn new(output: f64) -> Self {
89        Node {
90            output,
91            split_feature: 0,
92            split_value: Option::None,
93            split_score: Option::None,
94            true_child: Option::None,
95            false_child: Option::None,
96        }
97    }
98}
99
100impl PartialEq for Node {
101    fn eq(&self, other: &Self) -> bool {
102        (self.output - other.output).abs() < f64::EPSILON
103            && self.split_feature == other.split_feature
104            && match (self.split_value, other.split_value) {
105                (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
106                (None, None) => true,
107                _ => false,
108            }
109            && match (self.split_score, other.split_score) {
110                (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
111                (None, None) => true,
112                _ => false,
113            }
114    }
115}
116
117impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
118    for BaseTreeRegressor<TX, TY, X, Y>
119{
120    fn eq(&self, other: &Self) -> bool {
121        if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
122            false
123        } else {
124            self.nodes()
125                .iter()
126                .zip(other.nodes().iter())
127                .all(|(a, b)| a == b)
128        }
129    }
130}
131
132struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
133    x: &'a X,
134    y: &'a Y,
135    node: usize,
136    samples: Vec<usize>,
137    order: &'a [Vec<usize>],
138    true_child_output: f64,
139    false_child_output: f64,
140    level: u16,
141    _phantom_tx: PhantomData<TX>,
142    _phantom_ty: PhantomData<TY>,
143}
144
145impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
146    NodeVisitor<'a, TX, TY, X, Y>
147{
148    fn new(
149        node_id: usize,
150        samples: Vec<usize>,
151        order: &'a [Vec<usize>],
152        x: &'a X,
153        y: &'a Y,
154        level: u16,
155    ) -> Self {
156        NodeVisitor {
157            x,
158            y,
159            node: node_id,
160            samples,
161            order,
162            true_child_output: 0f64,
163            false_child_output: 0f64,
164            level,
165            _phantom_tx: PhantomData,
166            _phantom_ty: PhantomData,
167        }
168    }
169}
170
171impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
172    BaseTreeRegressor<TX, TY, X, Y>
173{
174    /// Build a decision base_tree regressor from the training data.
175    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
176    /// * `y` - the target values
177    pub fn fit(
178        x: &X,
179        y: &Y,
180        parameters: BaseTreeRegressorParameters,
181    ) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
182        let (x_nrows, num_attributes) = x.shape();
183        if x_nrows != y.shape() {
184            return Err(Failed::fit("Size of x should equal size of y"));
185        }
186
187        let samples = vec![1; x_nrows];
188        BaseTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
189    }
190
191    pub(crate) fn fit_weak_learner(
192        x: &X,
193        y: &Y,
194        samples: Vec<usize>,
195        mtry: usize,
196        parameters: BaseTreeRegressorParameters,
197    ) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
198        let y_m = y.clone();
199
200        let y_ncols = y_m.shape();
201        let (_, num_attributes) = x.shape();
202
203        let mut nodes: Vec<Node> = Vec::new();
204        let mut rng = get_rng_impl(parameters.seed);
205
206        let mut n = 0;
207        let mut sum = 0f64;
208        for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
209            n += *sample_i;
210            sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
211        }
212
213        let root = Node::new(sum / (n as f64));
214        nodes.push(root);
215        let mut order: Vec<Vec<usize>> = Vec::new();
216
217        for i in 0..num_attributes {
218            let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
219            order.push(col_i.argsort_mut());
220        }
221
222        let mut base_tree = BaseTreeRegressor {
223            nodes,
224            parameters: Some(parameters),
225            depth: 0u16,
226            _phantom_tx: PhantomData,
227            _phantom_ty: PhantomData,
228            _phantom_x: PhantomData,
229            _phantom_y: PhantomData,
230        };
231
232        let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
233
234        let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
235
236        if base_tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
237            visitor_queue.push_back(visitor);
238        }
239
240        while base_tree.depth() < base_tree.parameters().max_depth.unwrap_or(u16::MAX) {
241            match visitor_queue.pop_front() {
242                Some(node) => base_tree.split(node, mtry, &mut visitor_queue, &mut rng),
243                None => break,
244            };
245        }
246
247        Ok(base_tree)
248    }
249
250    /// Predict regression value for `x`.
251    /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
252    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
253        let mut result = Y::zeros(x.shape().0);
254
255        let (n, _) = x.shape();
256
257        for i in 0..n {
258            result.set(i, self.predict_for_row(x, i));
259        }
260
261        Ok(result)
262    }
263
264    pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
265        let mut result = 0f64;
266        let mut queue: LinkedList<usize> = LinkedList::new();
267
268        queue.push_back(0);
269
270        while !queue.is_empty() {
271            match queue.pop_front() {
272                Some(node_id) => {
273                    let node = &self.nodes()[node_id];
274                    if node.true_child.is_none() && node.false_child.is_none() {
275                        result = node.output;
276                    } else if x.get((row, node.split_feature)).to_f64().unwrap()
277                        <= node.split_value.unwrap_or(f64::NAN)
278                    {
279                        queue.push_back(node.true_child.unwrap());
280                    } else {
281                        queue.push_back(node.false_child.unwrap());
282                    }
283                }
284                None => break,
285            };
286        }
287
288        TY::from_f64(result).unwrap()
289    }
290
291    fn find_best_cutoff(
292        &mut self,
293        visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
294        mtry: usize,
295        rng: &mut impl Rng,
296    ) -> bool {
297        let (_, n_attr) = visitor.x.shape();
298
299        let n: usize = visitor.samples.iter().sum();
300
301        if n < self.parameters().min_samples_split {
302            return false;
303        }
304
305        let sum = self.nodes()[visitor.node].output * n as f64;
306
307        let mut variables = (0..n_attr).collect::<Vec<_>>();
308
309        if mtry < n_attr {
310            variables.shuffle(rng);
311        }
312
313        let parent_gain =
314            n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
315
316        let splitter = self.parameters().splitter.clone();
317
318        for variable in variables.iter().take(mtry) {
319            match splitter {
320                Splitter::Random => {
321                    self.find_random_split(visitor, n, sum, parent_gain, *variable, rng);
322                }
323                Splitter::Best => {
324                    self.find_best_split(visitor, n, sum, parent_gain, *variable);
325                }
326            }
327        }
328
329        self.nodes()[visitor.node].split_score.is_some()
330    }
331
332    fn find_random_split(
333        &mut self,
334        visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
335        n: usize,
336        sum: f64,
337        parent_gain: f64,
338        j: usize,
339        rng: &mut impl Rng,
340    ) {
341        let (min_val, max_val) = {
342            let mut min_opt = None;
343            let mut max_opt = None;
344            for &i in &visitor.order[j] {
345                if visitor.samples[i] > 0 {
346                    min_opt = Some(*visitor.x.get((i, j)));
347                    break;
348                }
349            }
350            for &i in visitor.order[j].iter().rev() {
351                if visitor.samples[i] > 0 {
352                    max_opt = Some(*visitor.x.get((i, j)));
353                    break;
354                }
355            }
356            if min_opt.is_none() {
357                return;
358            }
359            (min_opt.unwrap(), max_opt.unwrap())
360        };
361
362        if min_val >= max_val {
363            return;
364        }
365
366        let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap());
367
368        let mut true_sum = 0f64;
369        let mut true_count = 0;
370        for &i in &visitor.order[j] {
371            if visitor.samples[i] > 0 {
372                if visitor.x.get((i, j)).to_f64().unwrap() <= split_value {
373                    true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap();
374                    true_count += visitor.samples[i];
375                } else {
376                    break;
377                }
378            }
379        }
380
381        let false_count = n - true_count;
382
383        if true_count < self.parameters().min_samples_leaf
384            || false_count < self.parameters().min_samples_leaf
385        {
386            return;
387        }
388
389        let true_mean = if true_count > 0 {
390            true_sum / true_count as f64
391        } else {
392            0.0
393        };
394        let false_mean = if false_count > 0 {
395            (sum - true_sum) / false_count as f64
396        } else {
397            0.0
398        };
399        let gain = (true_count as f64 * true_mean * true_mean
400            + false_count as f64 * false_mean * false_mean)
401            - parent_gain;
402
403        if self.nodes[visitor.node].split_score.is_none()
404            || gain > self.nodes[visitor.node].split_score.unwrap()
405        {
406            self.nodes[visitor.node].split_feature = j;
407            self.nodes[visitor.node].split_value = Some(split_value);
408            self.nodes[visitor.node].split_score = Some(gain);
409            visitor.true_child_output = true_mean;
410            visitor.false_child_output = false_mean;
411        }
412    }
413
414    fn find_best_split(
415        &mut self,
416        visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
417        n: usize,
418        sum: f64,
419        parent_gain: f64,
420        j: usize,
421    ) {
422        let mut true_sum = 0f64;
423        let mut true_count = 0;
424        let mut prevx = Option::None;
425
426        for i in visitor.order[j].iter() {
427            if visitor.samples[*i] > 0 {
428                let x_ij = *visitor.x.get((*i, j));
429
430                if prevx.is_none() || x_ij == prevx.unwrap() {
431                    prevx = Some(x_ij);
432                    true_count += visitor.samples[*i];
433                    true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
434                    continue;
435                }
436
437                let false_count = n - true_count;
438
439                if true_count < self.parameters().min_samples_leaf
440                    || false_count < self.parameters().min_samples_leaf
441                {
442                    prevx = Some(x_ij);
443                    true_count += visitor.samples[*i];
444                    true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
445                    continue;
446                }
447
448                let true_mean = true_sum / true_count as f64;
449                let false_mean = (sum - true_sum) / false_count as f64;
450
451                let gain = (true_count as f64 * true_mean * true_mean
452                    + false_count as f64 * false_mean * false_mean)
453                    - parent_gain;
454
455                if self.nodes()[visitor.node].split_score.is_none()
456                    || gain > self.nodes()[visitor.node].split_score.unwrap()
457                {
458                    self.nodes[visitor.node].split_feature = j;
459                    self.nodes[visitor.node].split_value =
460                        Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
461                    self.nodes[visitor.node].split_score = Option::Some(gain);
462
463                    visitor.true_child_output = true_mean;
464                    visitor.false_child_output = false_mean;
465                }
466
467                prevx = Some(x_ij);
468                true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
469                true_count += visitor.samples[*i];
470            }
471        }
472    }
473
474    fn split<'a>(
475        &mut self,
476        mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
477        mtry: usize,
478        visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
479        rng: &mut impl Rng,
480    ) -> bool {
481        let (n, _) = visitor.x.shape();
482        let mut tc = 0;
483        let mut fc = 0;
484        let mut true_samples: Vec<usize> = vec![0; n];
485
486        for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
487            if visitor.samples[i] > 0 {
488                if visitor
489                    .x
490                    .get((i, self.nodes()[visitor.node].split_feature))
491                    .to_f64()
492                    .unwrap()
493                    <= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
494                {
495                    *true_sample = visitor.samples[i];
496                    tc += *true_sample;
497                    visitor.samples[i] = 0;
498                } else {
499                    fc += visitor.samples[i];
500                }
501            }
502        }
503
504        if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
505            self.nodes[visitor.node].split_feature = 0;
506            self.nodes[visitor.node].split_value = Option::None;
507            self.nodes[visitor.node].split_score = Option::None;
508
509            return false;
510        }
511
512        let true_child_idx = self.nodes().len();
513
514        self.nodes.push(Node::new(visitor.true_child_output));
515        let false_child_idx = self.nodes().len();
516        self.nodes.push(Node::new(visitor.false_child_output));
517
518        self.nodes[visitor.node].true_child = Some(true_child_idx);
519        self.nodes[visitor.node].false_child = Some(false_child_idx);
520
521        self.depth = u16::max(self.depth, visitor.level + 1);
522
523        let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
524            true_child_idx,
525            true_samples,
526            visitor.order,
527            visitor.x,
528            visitor.y,
529            visitor.level + 1,
530        );
531
532        if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
533            visitor_queue.push_back(true_visitor);
534        }
535
536        let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
537            false_child_idx,
538            visitor.samples,
539            visitor.order,
540            visitor.x,
541            visitor.y,
542            visitor.level + 1,
543        );
544
545        if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
546            visitor_queue.push_back(false_visitor);
547        }
548
549        true
550    }
551}