smartcore/ensemble/
random_forest_regressor.rs

1//! # Random Forest Regressor
2//! A random forest is an ensemble estimator that fits multiple [decision trees](../../tree/index.html) to random subsets of the dataset and averages predictions
3//! to improve the predictive accuracy and control over-fitting. See [ensemble models](../index.html) for more details.
4//!
5//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
6//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
7//!
8//! Example:
9//!
10//! ```
11//! use smartcore::linalg::basic::matrix::DenseMatrix;
12//! use smartcore::ensemble::random_forest_regressor::*;
13//!
14//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
15//! let x = DenseMatrix::from_2d_array(&[
16//!             &[234.289, 235.6, 159., 107.608, 1947., 60.323],
17//!             &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
18//!             &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
19//!             &[284.599, 335.1, 165., 110.929, 1950., 61.187],
20//!             &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
21//!             &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
22//!             &[365.385, 187., 354.7, 115.094, 1953., 64.989],
23//!             &[363.112, 357.8, 335., 116.219, 1954., 63.761],
24//!             &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
25//!             &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
26//!             &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
27//!             &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
28//!             &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
29//!             &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
30//!             &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
31//!             &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
32//!         ]).unwrap();
33//! let y = vec![
34//!             83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
35//!             104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
36//!         ];
37//!
38//! let regressor = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
39//!
40//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
41//! ```
42//!
43//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
44//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
45
46use std::default::Default;
47use std::fmt::Debug;
48
49#[cfg(feature = "serde")]
50use serde::{Deserialize, Serialize};
51
52use crate::api::{Predictor, SupervisedEstimator};
53use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
54use crate::error::Failed;
55use crate::linalg::basic::arrays::{Array1, Array2};
56use crate::numbers::basenum::Number;
57use crate::numbers::floatnum::FloatNumber;
58use crate::tree::base_tree_regressor::Splitter;
59
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61#[derive(Debug, Clone)]
62/// Parameters of the Random Forest Regressor
63/// Some parameters here are passed directly into base estimator.
64pub struct RandomForestRegressorParameters {
65    #[cfg_attr(feature = "serde", serde(default))]
66    /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
67    pub max_depth: Option<u16>,
68    #[cfg_attr(feature = "serde", serde(default))]
69    /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
70    pub min_samples_leaf: usize,
71    #[cfg_attr(feature = "serde", serde(default))]
72    /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
73    pub min_samples_split: usize,
74    #[cfg_attr(feature = "serde", serde(default))]
75    /// The number of trees in the forest.
76    pub n_trees: usize,
77    #[cfg_attr(feature = "serde", serde(default))]
78    /// Number of random sample of predictors to use as split candidates.
79    pub m: Option<usize>,
80    #[cfg_attr(feature = "serde", serde(default))]
81    /// Whether to keep samples used for tree generation. This is required for OOB prediction.
82    pub keep_samples: bool,
83    #[cfg_attr(feature = "serde", serde(default))]
84    /// Seed used for bootstrap sampling and feature selection for each tree.
85    pub seed: u64,
86}
87
88/// Random Forest Regressor
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90#[derive(Debug)]
91pub struct RandomForestRegressor<
92    TX: Number + FloatNumber + PartialOrd,
93    TY: Number,
94    X: Array2<TX>,
95    Y: Array1<TY>,
96> {
97    forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
98}
99
100impl RandomForestRegressorParameters {
101    /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
102    pub fn with_max_depth(mut self, max_depth: u16) -> Self {
103        self.max_depth = Some(max_depth);
104        self
105    }
106    /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
107    pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
108        self.min_samples_leaf = min_samples_leaf;
109        self
110    }
111    /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
112    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
113        self.min_samples_split = min_samples_split;
114        self
115    }
116    /// The number of trees in the forest.
117    pub fn with_n_trees(mut self, n_trees: usize) -> Self {
118        self.n_trees = n_trees;
119        self
120    }
121    /// Number of random sample of predictors to use as split candidates.
122    pub fn with_m(mut self, m: usize) -> Self {
123        self.m = Some(m);
124        self
125    }
126
127    /// Whether to keep samples used for tree generation. This is required for OOB prediction.
128    pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
129        self.keep_samples = keep_samples;
130        self
131    }
132
133    /// Seed used for bootstrap sampling and feature selection for each tree.
134    pub fn with_seed(mut self, seed: u64) -> Self {
135        self.seed = seed;
136        self
137    }
138}
139impl Default for RandomForestRegressorParameters {
140    fn default() -> Self {
141        RandomForestRegressorParameters {
142            max_depth: Option::None,
143            min_samples_leaf: 1,
144            min_samples_split: 2,
145            n_trees: 10,
146            m: Option::None,
147            keep_samples: false,
148            seed: 0,
149        }
150    }
151}
152
153impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
154    for RandomForestRegressor<TX, TY, X, Y>
155{
156    fn eq(&self, other: &Self) -> bool {
157        self.forest_regressor == other.forest_regressor
158    }
159}
160
161impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
162    SupervisedEstimator<X, Y, RandomForestRegressorParameters>
163    for RandomForestRegressor<TX, TY, X, Y>
164{
165    fn new() -> Self {
166        Self {
167            forest_regressor: Option::None,
168        }
169    }
170
171    fn fit(x: &X, y: &Y, parameters: RandomForestRegressorParameters) -> Result<Self, Failed> {
172        RandomForestRegressor::fit(x, y, parameters)
173    }
174}
175
176impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
177    Predictor<X, Y> for RandomForestRegressor<TX, TY, X, Y>
178{
179    fn predict(&self, x: &X) -> Result<Y, Failed> {
180        self.predict(x)
181    }
182}
183
184/// RandomForestRegressor grid search parameters
185#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
186#[derive(Debug, Clone)]
187pub struct RandomForestRegressorSearchParameters {
188    #[cfg_attr(feature = "serde", serde(default))]
189    /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
190    pub max_depth: Vec<Option<u16>>,
191    #[cfg_attr(feature = "serde", serde(default))]
192    /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
193    pub min_samples_leaf: Vec<usize>,
194    #[cfg_attr(feature = "serde", serde(default))]
195    /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
196    pub min_samples_split: Vec<usize>,
197    #[cfg_attr(feature = "serde", serde(default))]
198    /// The number of trees in the forest.
199    pub n_trees: Vec<usize>,
200    #[cfg_attr(feature = "serde", serde(default))]
201    /// Number of random sample of predictors to use as split candidates.
202    pub m: Vec<Option<usize>>,
203    #[cfg_attr(feature = "serde", serde(default))]
204    /// Whether to keep samples used for tree generation. This is required for OOB prediction.
205    pub keep_samples: Vec<bool>,
206    #[cfg_attr(feature = "serde", serde(default))]
207    /// Seed used for bootstrap sampling and feature selection for each tree.
208    pub seed: Vec<u64>,
209}
210
211/// RandomForestRegressor grid search iterator
212pub struct RandomForestRegressorSearchParametersIterator {
213    random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
214    current_max_depth: usize,
215    current_min_samples_leaf: usize,
216    current_min_samples_split: usize,
217    current_n_trees: usize,
218    current_m: usize,
219    current_keep_samples: usize,
220    current_seed: usize,
221}
222
223impl IntoIterator for RandomForestRegressorSearchParameters {
224    type Item = RandomForestRegressorParameters;
225    type IntoIter = RandomForestRegressorSearchParametersIterator;
226
227    fn into_iter(self) -> Self::IntoIter {
228        RandomForestRegressorSearchParametersIterator {
229            random_forest_regressor_search_parameters: self,
230            current_max_depth: 0,
231            current_min_samples_leaf: 0,
232            current_min_samples_split: 0,
233            current_n_trees: 0,
234            current_m: 0,
235            current_keep_samples: 0,
236            current_seed: 0,
237        }
238    }
239}
240
241impl Iterator for RandomForestRegressorSearchParametersIterator {
242    type Item = RandomForestRegressorParameters;
243
244    fn next(&mut self) -> Option<Self::Item> {
245        if self.current_max_depth
246            == self
247                .random_forest_regressor_search_parameters
248                .max_depth
249                .len()
250            && self.current_min_samples_leaf
251                == self
252                    .random_forest_regressor_search_parameters
253                    .min_samples_leaf
254                    .len()
255            && self.current_min_samples_split
256                == self
257                    .random_forest_regressor_search_parameters
258                    .min_samples_split
259                    .len()
260            && self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
261            && self.current_m == self.random_forest_regressor_search_parameters.m.len()
262            && self.current_keep_samples
263                == self
264                    .random_forest_regressor_search_parameters
265                    .keep_samples
266                    .len()
267            && self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
268        {
269            return None;
270        }
271
272        let next = RandomForestRegressorParameters {
273            max_depth: self.random_forest_regressor_search_parameters.max_depth
274                [self.current_max_depth],
275            min_samples_leaf: self
276                .random_forest_regressor_search_parameters
277                .min_samples_leaf[self.current_min_samples_leaf],
278            min_samples_split: self
279                .random_forest_regressor_search_parameters
280                .min_samples_split[self.current_min_samples_split],
281            n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
282            m: self.random_forest_regressor_search_parameters.m[self.current_m],
283            keep_samples: self.random_forest_regressor_search_parameters.keep_samples
284                [self.current_keep_samples],
285            seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
286        };
287
288        if self.current_max_depth + 1
289            < self
290                .random_forest_regressor_search_parameters
291                .max_depth
292                .len()
293        {
294            self.current_max_depth += 1;
295        } else if self.current_min_samples_leaf + 1
296            < self
297                .random_forest_regressor_search_parameters
298                .min_samples_leaf
299                .len()
300        {
301            self.current_max_depth = 0;
302            self.current_min_samples_leaf += 1;
303        } else if self.current_min_samples_split + 1
304            < self
305                .random_forest_regressor_search_parameters
306                .min_samples_split
307                .len()
308        {
309            self.current_max_depth = 0;
310            self.current_min_samples_leaf = 0;
311            self.current_min_samples_split += 1;
312        } else if self.current_n_trees + 1
313            < self.random_forest_regressor_search_parameters.n_trees.len()
314        {
315            self.current_max_depth = 0;
316            self.current_min_samples_leaf = 0;
317            self.current_min_samples_split = 0;
318            self.current_n_trees += 1;
319        } else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
320            self.current_max_depth = 0;
321            self.current_min_samples_leaf = 0;
322            self.current_min_samples_split = 0;
323            self.current_n_trees = 0;
324            self.current_m += 1;
325        } else if self.current_keep_samples + 1
326            < self
327                .random_forest_regressor_search_parameters
328                .keep_samples
329                .len()
330        {
331            self.current_max_depth = 0;
332            self.current_min_samples_leaf = 0;
333            self.current_min_samples_split = 0;
334            self.current_n_trees = 0;
335            self.current_m = 0;
336            self.current_keep_samples += 1;
337        } else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
338        {
339            self.current_max_depth = 0;
340            self.current_min_samples_leaf = 0;
341            self.current_min_samples_split = 0;
342            self.current_n_trees = 0;
343            self.current_m = 0;
344            self.current_keep_samples = 0;
345            self.current_seed += 1;
346        } else {
347            self.current_max_depth += 1;
348            self.current_min_samples_leaf += 1;
349            self.current_min_samples_split += 1;
350            self.current_n_trees += 1;
351            self.current_m += 1;
352            self.current_keep_samples += 1;
353            self.current_seed += 1;
354        }
355
356        Some(next)
357    }
358}
359
360impl Default for RandomForestRegressorSearchParameters {
361    fn default() -> Self {
362        let default_params = RandomForestRegressorParameters::default();
363
364        RandomForestRegressorSearchParameters {
365            max_depth: vec![default_params.max_depth],
366            min_samples_leaf: vec![default_params.min_samples_leaf],
367            min_samples_split: vec![default_params.min_samples_split],
368            n_trees: vec![default_params.n_trees],
369            m: vec![default_params.m],
370            keep_samples: vec![default_params.keep_samples],
371            seed: vec![default_params.seed],
372        }
373    }
374}
375
376impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
377    RandomForestRegressor<TX, TY, X, Y>
378{
379    /// Build a forest of trees from the training set.
380    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
381    /// * `y` - the target class values
382    pub fn fit(
383        x: &X,
384        y: &Y,
385        parameters: RandomForestRegressorParameters,
386    ) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
387        let regressor_params = BaseForestRegressorParameters {
388            max_depth: parameters.max_depth,
389            min_samples_leaf: parameters.min_samples_leaf,
390            min_samples_split: parameters.min_samples_split,
391            n_trees: parameters.n_trees,
392            m: parameters.m,
393            keep_samples: parameters.keep_samples,
394            seed: parameters.seed,
395            bootstrap: true,
396            splitter: Splitter::Best,
397        };
398        let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
399
400        Ok(RandomForestRegressor {
401            forest_regressor: Some(forest_regressor),
402        })
403    }
404
405    /// Predict class for `x`
406    /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
407    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
408        let forest_regressor = self.forest_regressor.as_ref().unwrap();
409        forest_regressor.predict(x)
410    }
411
412    /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
413    pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
414        let forest_regressor = self.forest_regressor.as_ref().unwrap();
415        forest_regressor.predict_oob(x)
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::linalg::basic::matrix::DenseMatrix;
423    use crate::metrics::mean_absolute_error;
424
425    #[test]
426    fn search_parameters() {
427        let parameters = RandomForestRegressorSearchParameters {
428            n_trees: vec![10, 100],
429            m: vec![None, Some(1)],
430            ..Default::default()
431        };
432        let mut iter = parameters.into_iter();
433        let next = iter.next().unwrap();
434        assert_eq!(next.n_trees, 10);
435        assert_eq!(next.m, None);
436        let next = iter.next().unwrap();
437        assert_eq!(next.n_trees, 100);
438        assert_eq!(next.m, None);
439        let next = iter.next().unwrap();
440        assert_eq!(next.n_trees, 10);
441        assert_eq!(next.m, Some(1));
442        let next = iter.next().unwrap();
443        assert_eq!(next.n_trees, 100);
444        assert_eq!(next.m, Some(1));
445        assert!(iter.next().is_none());
446    }
447
448    #[cfg_attr(
449        all(target_arch = "wasm32", not(target_os = "wasi")),
450        wasm_bindgen_test::wasm_bindgen_test
451    )]
452    #[test]
453    fn fit_longley() {
454        let x = DenseMatrix::from_2d_array(&[
455            &[234.289, 235.6, 159., 107.608, 1947., 60.323],
456            &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
457            &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
458            &[284.599, 335.1, 165., 110.929, 1950., 61.187],
459            &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
460            &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
461            &[365.385, 187., 354.7, 115.094, 1953., 64.989],
462            &[363.112, 357.8, 335., 116.219, 1954., 63.761],
463            &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
464            &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
465            &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
466            &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
467            &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
468            &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
469            &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
470            &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
471        ])
472        .unwrap();
473        let y = vec![
474            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
475            114.2, 115.7, 116.9,
476        ];
477
478        let y_hat = RandomForestRegressor::fit(
479            &x,
480            &y,
481            RandomForestRegressorParameters {
482                max_depth: Option::None,
483                min_samples_leaf: 1,
484                min_samples_split: 2,
485                n_trees: 1000,
486                m: Option::None,
487                keep_samples: false,
488                seed: 87,
489            },
490        )
491        .and_then(|rf| rf.predict(&x))
492        .unwrap();
493
494        assert!(mean_absolute_error(&y, &y_hat) < 1.0);
495    }
496
497    #[test]
498    fn test_random_matrix_with_wrong_rownum() {
499        let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
500
501        let y = vec![
502            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
503            114.2, 115.7, 116.9,
504        ];
505
506        let fail = RandomForestRegressor::fit(
507            &x_rand,
508            &y,
509            RandomForestRegressorParameters {
510                max_depth: Option::None,
511                min_samples_leaf: 1,
512                min_samples_split: 2,
513                n_trees: 1000,
514                m: Option::None,
515                keep_samples: false,
516                seed: 87,
517            },
518        );
519
520        assert!(fail.is_err());
521    }
522
523    #[cfg_attr(
524        all(target_arch = "wasm32", not(target_os = "wasi")),
525        wasm_bindgen_test::wasm_bindgen_test
526    )]
527    #[test]
528    fn fit_predict_longley_oob() {
529        let x = DenseMatrix::from_2d_array(&[
530            &[234.289, 235.6, 159., 107.608, 1947., 60.323],
531            &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
532            &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
533            &[284.599, 335.1, 165., 110.929, 1950., 61.187],
534            &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
535            &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
536            &[365.385, 187., 354.7, 115.094, 1953., 64.989],
537            &[363.112, 357.8, 335., 116.219, 1954., 63.761],
538            &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
539            &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
540            &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
541            &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
542            &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
543            &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
544            &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
545            &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
546        ])
547        .unwrap();
548        let y = vec![
549            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
550            114.2, 115.7, 116.9,
551        ];
552
553        let regressor = RandomForestRegressor::fit(
554            &x,
555            &y,
556            RandomForestRegressorParameters {
557                max_depth: Option::None,
558                min_samples_leaf: 1,
559                min_samples_split: 2,
560                n_trees: 1000,
561                m: Option::None,
562                keep_samples: true,
563                seed: 87,
564            },
565        )
566        .unwrap();
567
568        let y_hat = regressor.predict(&x).unwrap();
569        let y_hat_oob = regressor.predict_oob(&x).unwrap();
570
571        println!("{:?}", mean_absolute_error(&y, &y_hat));
572        println!("{:?}", mean_absolute_error(&y, &y_hat_oob));
573
574        assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
575    }
576
577    #[cfg_attr(
578        all(target_arch = "wasm32", not(target_os = "wasi")),
579        wasm_bindgen_test::wasm_bindgen_test
580    )]
581    #[test]
582    #[cfg(feature = "serde")]
583    fn serde() {
584        let x = DenseMatrix::from_2d_array(&[
585            &[234.289, 235.6, 159., 107.608, 1947., 60.323],
586            &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
587            &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
588            &[284.599, 335.1, 165., 110.929, 1950., 61.187],
589            &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
590            &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
591            &[365.385, 187., 354.7, 115.094, 1953., 64.989],
592            &[363.112, 357.8, 335., 116.219, 1954., 63.761],
593            &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
594            &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
595            &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
596            &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
597            &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
598            &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
599            &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
600            &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
601        ])
602        .unwrap();
603        let y = vec![
604            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
605            114.2, 115.7, 116.9,
606        ];
607
608        let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
609
610        let deserialized_forest: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
611            bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
612
613        assert_eq!(forest, deserialized_forest);
614    }
615}