1use std::default::Default;
47use std::fmt::Debug;
48
49#[cfg(feature = "serde")]
50use serde::{Deserialize, Serialize};
51
52use crate::api::{Predictor, SupervisedEstimator};
53use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
54use crate::error::Failed;
55use crate::linalg::basic::arrays::{Array1, Array2};
56use crate::numbers::basenum::Number;
57use crate::numbers::floatnum::FloatNumber;
58use crate::tree::base_tree_regressor::Splitter;
59
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61#[derive(Debug, Clone)]
62pub struct RandomForestRegressorParameters {
65 #[cfg_attr(feature = "serde", serde(default))]
66 pub max_depth: Option<u16>,
68 #[cfg_attr(feature = "serde", serde(default))]
69 pub min_samples_leaf: usize,
71 #[cfg_attr(feature = "serde", serde(default))]
72 pub min_samples_split: usize,
74 #[cfg_attr(feature = "serde", serde(default))]
75 pub n_trees: usize,
77 #[cfg_attr(feature = "serde", serde(default))]
78 pub m: Option<usize>,
80 #[cfg_attr(feature = "serde", serde(default))]
81 pub keep_samples: bool,
83 #[cfg_attr(feature = "serde", serde(default))]
84 pub seed: u64,
86}
87
88#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90#[derive(Debug)]
91pub struct RandomForestRegressor<
92 TX: Number + FloatNumber + PartialOrd,
93 TY: Number,
94 X: Array2<TX>,
95 Y: Array1<TY>,
96> {
97 forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
98}
99
100impl RandomForestRegressorParameters {
101 pub fn with_max_depth(mut self, max_depth: u16) -> Self {
103 self.max_depth = Some(max_depth);
104 self
105 }
106 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
108 self.min_samples_leaf = min_samples_leaf;
109 self
110 }
111 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
113 self.min_samples_split = min_samples_split;
114 self
115 }
116 pub fn with_n_trees(mut self, n_trees: usize) -> Self {
118 self.n_trees = n_trees;
119 self
120 }
121 pub fn with_m(mut self, m: usize) -> Self {
123 self.m = Some(m);
124 self
125 }
126
127 pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
129 self.keep_samples = keep_samples;
130 self
131 }
132
133 pub fn with_seed(mut self, seed: u64) -> Self {
135 self.seed = seed;
136 self
137 }
138}
139impl Default for RandomForestRegressorParameters {
140 fn default() -> Self {
141 RandomForestRegressorParameters {
142 max_depth: Option::None,
143 min_samples_leaf: 1,
144 min_samples_split: 2,
145 n_trees: 10,
146 m: Option::None,
147 keep_samples: false,
148 seed: 0,
149 }
150 }
151}
152
153impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
154 for RandomForestRegressor<TX, TY, X, Y>
155{
156 fn eq(&self, other: &Self) -> bool {
157 self.forest_regressor == other.forest_regressor
158 }
159}
160
161impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
162 SupervisedEstimator<X, Y, RandomForestRegressorParameters>
163 for RandomForestRegressor<TX, TY, X, Y>
164{
165 fn new() -> Self {
166 Self {
167 forest_regressor: Option::None,
168 }
169 }
170
171 fn fit(x: &X, y: &Y, parameters: RandomForestRegressorParameters) -> Result<Self, Failed> {
172 RandomForestRegressor::fit(x, y, parameters)
173 }
174}
175
176impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
177 Predictor<X, Y> for RandomForestRegressor<TX, TY, X, Y>
178{
179 fn predict(&self, x: &X) -> Result<Y, Failed> {
180 self.predict(x)
181 }
182}
183
184#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
186#[derive(Debug, Clone)]
187pub struct RandomForestRegressorSearchParameters {
188 #[cfg_attr(feature = "serde", serde(default))]
189 pub max_depth: Vec<Option<u16>>,
191 #[cfg_attr(feature = "serde", serde(default))]
192 pub min_samples_leaf: Vec<usize>,
194 #[cfg_attr(feature = "serde", serde(default))]
195 pub min_samples_split: Vec<usize>,
197 #[cfg_attr(feature = "serde", serde(default))]
198 pub n_trees: Vec<usize>,
200 #[cfg_attr(feature = "serde", serde(default))]
201 pub m: Vec<Option<usize>>,
203 #[cfg_attr(feature = "serde", serde(default))]
204 pub keep_samples: Vec<bool>,
206 #[cfg_attr(feature = "serde", serde(default))]
207 pub seed: Vec<u64>,
209}
210
211pub struct RandomForestRegressorSearchParametersIterator {
213 random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
214 current_max_depth: usize,
215 current_min_samples_leaf: usize,
216 current_min_samples_split: usize,
217 current_n_trees: usize,
218 current_m: usize,
219 current_keep_samples: usize,
220 current_seed: usize,
221}
222
223impl IntoIterator for RandomForestRegressorSearchParameters {
224 type Item = RandomForestRegressorParameters;
225 type IntoIter = RandomForestRegressorSearchParametersIterator;
226
227 fn into_iter(self) -> Self::IntoIter {
228 RandomForestRegressorSearchParametersIterator {
229 random_forest_regressor_search_parameters: self,
230 current_max_depth: 0,
231 current_min_samples_leaf: 0,
232 current_min_samples_split: 0,
233 current_n_trees: 0,
234 current_m: 0,
235 current_keep_samples: 0,
236 current_seed: 0,
237 }
238 }
239}
240
241impl Iterator for RandomForestRegressorSearchParametersIterator {
242 type Item = RandomForestRegressorParameters;
243
244 fn next(&mut self) -> Option<Self::Item> {
245 if self.current_max_depth
246 == self
247 .random_forest_regressor_search_parameters
248 .max_depth
249 .len()
250 && self.current_min_samples_leaf
251 == self
252 .random_forest_regressor_search_parameters
253 .min_samples_leaf
254 .len()
255 && self.current_min_samples_split
256 == self
257 .random_forest_regressor_search_parameters
258 .min_samples_split
259 .len()
260 && self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
261 && self.current_m == self.random_forest_regressor_search_parameters.m.len()
262 && self.current_keep_samples
263 == self
264 .random_forest_regressor_search_parameters
265 .keep_samples
266 .len()
267 && self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
268 {
269 return None;
270 }
271
272 let next = RandomForestRegressorParameters {
273 max_depth: self.random_forest_regressor_search_parameters.max_depth
274 [self.current_max_depth],
275 min_samples_leaf: self
276 .random_forest_regressor_search_parameters
277 .min_samples_leaf[self.current_min_samples_leaf],
278 min_samples_split: self
279 .random_forest_regressor_search_parameters
280 .min_samples_split[self.current_min_samples_split],
281 n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
282 m: self.random_forest_regressor_search_parameters.m[self.current_m],
283 keep_samples: self.random_forest_regressor_search_parameters.keep_samples
284 [self.current_keep_samples],
285 seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
286 };
287
288 if self.current_max_depth + 1
289 < self
290 .random_forest_regressor_search_parameters
291 .max_depth
292 .len()
293 {
294 self.current_max_depth += 1;
295 } else if self.current_min_samples_leaf + 1
296 < self
297 .random_forest_regressor_search_parameters
298 .min_samples_leaf
299 .len()
300 {
301 self.current_max_depth = 0;
302 self.current_min_samples_leaf += 1;
303 } else if self.current_min_samples_split + 1
304 < self
305 .random_forest_regressor_search_parameters
306 .min_samples_split
307 .len()
308 {
309 self.current_max_depth = 0;
310 self.current_min_samples_leaf = 0;
311 self.current_min_samples_split += 1;
312 } else if self.current_n_trees + 1
313 < self.random_forest_regressor_search_parameters.n_trees.len()
314 {
315 self.current_max_depth = 0;
316 self.current_min_samples_leaf = 0;
317 self.current_min_samples_split = 0;
318 self.current_n_trees += 1;
319 } else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
320 self.current_max_depth = 0;
321 self.current_min_samples_leaf = 0;
322 self.current_min_samples_split = 0;
323 self.current_n_trees = 0;
324 self.current_m += 1;
325 } else if self.current_keep_samples + 1
326 < self
327 .random_forest_regressor_search_parameters
328 .keep_samples
329 .len()
330 {
331 self.current_max_depth = 0;
332 self.current_min_samples_leaf = 0;
333 self.current_min_samples_split = 0;
334 self.current_n_trees = 0;
335 self.current_m = 0;
336 self.current_keep_samples += 1;
337 } else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
338 {
339 self.current_max_depth = 0;
340 self.current_min_samples_leaf = 0;
341 self.current_min_samples_split = 0;
342 self.current_n_trees = 0;
343 self.current_m = 0;
344 self.current_keep_samples = 0;
345 self.current_seed += 1;
346 } else {
347 self.current_max_depth += 1;
348 self.current_min_samples_leaf += 1;
349 self.current_min_samples_split += 1;
350 self.current_n_trees += 1;
351 self.current_m += 1;
352 self.current_keep_samples += 1;
353 self.current_seed += 1;
354 }
355
356 Some(next)
357 }
358}
359
360impl Default for RandomForestRegressorSearchParameters {
361 fn default() -> Self {
362 let default_params = RandomForestRegressorParameters::default();
363
364 RandomForestRegressorSearchParameters {
365 max_depth: vec![default_params.max_depth],
366 min_samples_leaf: vec![default_params.min_samples_leaf],
367 min_samples_split: vec![default_params.min_samples_split],
368 n_trees: vec![default_params.n_trees],
369 m: vec![default_params.m],
370 keep_samples: vec![default_params.keep_samples],
371 seed: vec![default_params.seed],
372 }
373 }
374}
375
376impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
377 RandomForestRegressor<TX, TY, X, Y>
378{
379 pub fn fit(
383 x: &X,
384 y: &Y,
385 parameters: RandomForestRegressorParameters,
386 ) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
387 let regressor_params = BaseForestRegressorParameters {
388 max_depth: parameters.max_depth,
389 min_samples_leaf: parameters.min_samples_leaf,
390 min_samples_split: parameters.min_samples_split,
391 n_trees: parameters.n_trees,
392 m: parameters.m,
393 keep_samples: parameters.keep_samples,
394 seed: parameters.seed,
395 bootstrap: true,
396 splitter: Splitter::Best,
397 };
398 let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
399
400 Ok(RandomForestRegressor {
401 forest_regressor: Some(forest_regressor),
402 })
403 }
404
405 pub fn predict(&self, x: &X) -> Result<Y, Failed> {
408 let forest_regressor = self.forest_regressor.as_ref().unwrap();
409 forest_regressor.predict(x)
410 }
411
412 pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
414 let forest_regressor = self.forest_regressor.as_ref().unwrap();
415 forest_regressor.predict_oob(x)
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::linalg::basic::matrix::DenseMatrix;
423 use crate::metrics::mean_absolute_error;
424
425 #[test]
426 fn search_parameters() {
427 let parameters = RandomForestRegressorSearchParameters {
428 n_trees: vec![10, 100],
429 m: vec![None, Some(1)],
430 ..Default::default()
431 };
432 let mut iter = parameters.into_iter();
433 let next = iter.next().unwrap();
434 assert_eq!(next.n_trees, 10);
435 assert_eq!(next.m, None);
436 let next = iter.next().unwrap();
437 assert_eq!(next.n_trees, 100);
438 assert_eq!(next.m, None);
439 let next = iter.next().unwrap();
440 assert_eq!(next.n_trees, 10);
441 assert_eq!(next.m, Some(1));
442 let next = iter.next().unwrap();
443 assert_eq!(next.n_trees, 100);
444 assert_eq!(next.m, Some(1));
445 assert!(iter.next().is_none());
446 }
447
448 #[cfg_attr(
449 all(target_arch = "wasm32", not(target_os = "wasi")),
450 wasm_bindgen_test::wasm_bindgen_test
451 )]
452 #[test]
453 fn fit_longley() {
454 let x = DenseMatrix::from_2d_array(&[
455 &[234.289, 235.6, 159., 107.608, 1947., 60.323],
456 &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
457 &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
458 &[284.599, 335.1, 165., 110.929, 1950., 61.187],
459 &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
460 &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
461 &[365.385, 187., 354.7, 115.094, 1953., 64.989],
462 &[363.112, 357.8, 335., 116.219, 1954., 63.761],
463 &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
464 &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
465 &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
466 &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
467 &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
468 &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
469 &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
470 &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
471 ])
472 .unwrap();
473 let y = vec![
474 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
475 114.2, 115.7, 116.9,
476 ];
477
478 let y_hat = RandomForestRegressor::fit(
479 &x,
480 &y,
481 RandomForestRegressorParameters {
482 max_depth: Option::None,
483 min_samples_leaf: 1,
484 min_samples_split: 2,
485 n_trees: 1000,
486 m: Option::None,
487 keep_samples: false,
488 seed: 87,
489 },
490 )
491 .and_then(|rf| rf.predict(&x))
492 .unwrap();
493
494 assert!(mean_absolute_error(&y, &y_hat) < 1.0);
495 }
496
497 #[test]
498 fn test_random_matrix_with_wrong_rownum() {
499 let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
500
501 let y = vec![
502 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
503 114.2, 115.7, 116.9,
504 ];
505
506 let fail = RandomForestRegressor::fit(
507 &x_rand,
508 &y,
509 RandomForestRegressorParameters {
510 max_depth: Option::None,
511 min_samples_leaf: 1,
512 min_samples_split: 2,
513 n_trees: 1000,
514 m: Option::None,
515 keep_samples: false,
516 seed: 87,
517 },
518 );
519
520 assert!(fail.is_err());
521 }
522
523 #[cfg_attr(
524 all(target_arch = "wasm32", not(target_os = "wasi")),
525 wasm_bindgen_test::wasm_bindgen_test
526 )]
527 #[test]
528 fn fit_predict_longley_oob() {
529 let x = DenseMatrix::from_2d_array(&[
530 &[234.289, 235.6, 159., 107.608, 1947., 60.323],
531 &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
532 &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
533 &[284.599, 335.1, 165., 110.929, 1950., 61.187],
534 &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
535 &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
536 &[365.385, 187., 354.7, 115.094, 1953., 64.989],
537 &[363.112, 357.8, 335., 116.219, 1954., 63.761],
538 &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
539 &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
540 &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
541 &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
542 &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
543 &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
544 &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
545 &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
546 ])
547 .unwrap();
548 let y = vec![
549 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
550 114.2, 115.7, 116.9,
551 ];
552
553 let regressor = RandomForestRegressor::fit(
554 &x,
555 &y,
556 RandomForestRegressorParameters {
557 max_depth: Option::None,
558 min_samples_leaf: 1,
559 min_samples_split: 2,
560 n_trees: 1000,
561 m: Option::None,
562 keep_samples: true,
563 seed: 87,
564 },
565 )
566 .unwrap();
567
568 let y_hat = regressor.predict(&x).unwrap();
569 let y_hat_oob = regressor.predict_oob(&x).unwrap();
570
571 println!("{:?}", mean_absolute_error(&y, &y_hat));
572 println!("{:?}", mean_absolute_error(&y, &y_hat_oob));
573
574 assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
575 }
576
577 #[cfg_attr(
578 all(target_arch = "wasm32", not(target_os = "wasi")),
579 wasm_bindgen_test::wasm_bindgen_test
580 )]
581 #[test]
582 #[cfg(feature = "serde")]
583 fn serde() {
584 let x = DenseMatrix::from_2d_array(&[
585 &[234.289, 235.6, 159., 107.608, 1947., 60.323],
586 &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
587 &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
588 &[284.599, 335.1, 165., 110.929, 1950., 61.187],
589 &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
590 &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
591 &[365.385, 187., 354.7, 115.094, 1953., 64.989],
592 &[363.112, 357.8, 335., 116.219, 1954., 63.761],
593 &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
594 &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
595 &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
596 &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
597 &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
598 &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
599 &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
600 &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
601 ])
602 .unwrap();
603 let y = vec![
604 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
605 114.2, 115.7, 116.9,
606 ];
607
608 let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
609
610 let deserialized_forest: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
611 bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
612
613 assert_eq!(forest, deserialized_forest);
614 }
615}