1use std::collections::LinkedList;
2use std::default::Default;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5
6use rand::seq::SliceRandom;
7use rand::Rng;
8
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use crate::error::Failed;
13use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
14use crate::numbers::basenum::Number;
15use crate::rand_custom::get_rng_impl;
16
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18#[derive(Debug, Clone, Default)]
19pub enum Splitter {
20 Random,
21 #[default]
22 Best,
23}
24
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct BaseTreeRegressorParameters {
29 #[cfg_attr(feature = "serde", serde(default))]
30 pub max_depth: Option<u16>,
32 #[cfg_attr(feature = "serde", serde(default))]
33 pub min_samples_leaf: usize,
35 #[cfg_attr(feature = "serde", serde(default))]
36 pub min_samples_split: usize,
38 #[cfg_attr(feature = "serde", serde(default))]
39 pub seed: Option<u64>,
41 #[cfg_attr(feature = "serde", serde(default))]
42 pub splitter: Splitter,
44}
45
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48#[derive(Debug)]
49pub struct BaseTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
50 nodes: Vec<Node>,
51 parameters: Option<BaseTreeRegressorParameters>,
52 depth: u16,
53 _phantom_tx: PhantomData<TX>,
54 _phantom_ty: PhantomData<TY>,
55 _phantom_x: PhantomData<X>,
56 _phantom_y: PhantomData<Y>,
57}
58
59impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
60 BaseTreeRegressor<TX, TY, X, Y>
61{
62 fn nodes(&self) -> &Vec<Node> {
64 self.nodes.as_ref()
65 }
66 fn parameters(&self) -> &BaseTreeRegressorParameters {
68 self.parameters.as_ref().unwrap()
69 }
70 fn depth(&self) -> u16 {
72 self.depth
73 }
74}
75
76#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
77#[derive(Debug, Clone)]
78struct Node {
79 output: f64,
80 split_feature: usize,
81 split_value: Option<f64>,
82 split_score: Option<f64>,
83 true_child: Option<usize>,
84 false_child: Option<usize>,
85}
86
87impl Node {
88 fn new(output: f64) -> Self {
89 Node {
90 output,
91 split_feature: 0,
92 split_value: Option::None,
93 split_score: Option::None,
94 true_child: Option::None,
95 false_child: Option::None,
96 }
97 }
98}
99
100impl PartialEq for Node {
101 fn eq(&self, other: &Self) -> bool {
102 (self.output - other.output).abs() < f64::EPSILON
103 && self.split_feature == other.split_feature
104 && match (self.split_value, other.split_value) {
105 (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
106 (None, None) => true,
107 _ => false,
108 }
109 && match (self.split_score, other.split_score) {
110 (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
111 (None, None) => true,
112 _ => false,
113 }
114 }
115}
116
117impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
118 for BaseTreeRegressor<TX, TY, X, Y>
119{
120 fn eq(&self, other: &Self) -> bool {
121 if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
122 false
123 } else {
124 self.nodes()
125 .iter()
126 .zip(other.nodes().iter())
127 .all(|(a, b)| a == b)
128 }
129 }
130}
131
132struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
133 x: &'a X,
134 y: &'a Y,
135 node: usize,
136 samples: Vec<usize>,
137 order: &'a [Vec<usize>],
138 true_child_output: f64,
139 false_child_output: f64,
140 level: u16,
141 _phantom_tx: PhantomData<TX>,
142 _phantom_ty: PhantomData<TY>,
143}
144
145impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
146 NodeVisitor<'a, TX, TY, X, Y>
147{
148 fn new(
149 node_id: usize,
150 samples: Vec<usize>,
151 order: &'a [Vec<usize>],
152 x: &'a X,
153 y: &'a Y,
154 level: u16,
155 ) -> Self {
156 NodeVisitor {
157 x,
158 y,
159 node: node_id,
160 samples,
161 order,
162 true_child_output: 0f64,
163 false_child_output: 0f64,
164 level,
165 _phantom_tx: PhantomData,
166 _phantom_ty: PhantomData,
167 }
168 }
169}
170
171impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
172 BaseTreeRegressor<TX, TY, X, Y>
173{
174 pub fn fit(
178 x: &X,
179 y: &Y,
180 parameters: BaseTreeRegressorParameters,
181 ) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
182 let (x_nrows, num_attributes) = x.shape();
183 if x_nrows != y.shape() {
184 return Err(Failed::fit("Size of x should equal size of y"));
185 }
186
187 let samples = vec![1; x_nrows];
188 BaseTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
189 }
190
191 pub(crate) fn fit_weak_learner(
192 x: &X,
193 y: &Y,
194 samples: Vec<usize>,
195 mtry: usize,
196 parameters: BaseTreeRegressorParameters,
197 ) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
198 let y_m = y.clone();
199
200 let y_ncols = y_m.shape();
201 let (_, num_attributes) = x.shape();
202
203 let mut nodes: Vec<Node> = Vec::new();
204 let mut rng = get_rng_impl(parameters.seed);
205
206 let mut n = 0;
207 let mut sum = 0f64;
208 for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
209 n += *sample_i;
210 sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
211 }
212
213 let root = Node::new(sum / (n as f64));
214 nodes.push(root);
215 let mut order: Vec<Vec<usize>> = Vec::new();
216
217 for i in 0..num_attributes {
218 let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
219 order.push(col_i.argsort_mut());
220 }
221
222 let mut base_tree = BaseTreeRegressor {
223 nodes,
224 parameters: Some(parameters),
225 depth: 0u16,
226 _phantom_tx: PhantomData,
227 _phantom_ty: PhantomData,
228 _phantom_x: PhantomData,
229 _phantom_y: PhantomData,
230 };
231
232 let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
233
234 let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
235
236 if base_tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
237 visitor_queue.push_back(visitor);
238 }
239
240 while base_tree.depth() < base_tree.parameters().max_depth.unwrap_or(u16::MAX) {
241 match visitor_queue.pop_front() {
242 Some(node) => base_tree.split(node, mtry, &mut visitor_queue, &mut rng),
243 None => break,
244 };
245 }
246
247 Ok(base_tree)
248 }
249
250 pub fn predict(&self, x: &X) -> Result<Y, Failed> {
253 let mut result = Y::zeros(x.shape().0);
254
255 let (n, _) = x.shape();
256
257 for i in 0..n {
258 result.set(i, self.predict_for_row(x, i));
259 }
260
261 Ok(result)
262 }
263
264 pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
265 let mut result = 0f64;
266 let mut queue: LinkedList<usize> = LinkedList::new();
267
268 queue.push_back(0);
269
270 while !queue.is_empty() {
271 match queue.pop_front() {
272 Some(node_id) => {
273 let node = &self.nodes()[node_id];
274 if node.true_child.is_none() && node.false_child.is_none() {
275 result = node.output;
276 } else if x.get((row, node.split_feature)).to_f64().unwrap()
277 <= node.split_value.unwrap_or(f64::NAN)
278 {
279 queue.push_back(node.true_child.unwrap());
280 } else {
281 queue.push_back(node.false_child.unwrap());
282 }
283 }
284 None => break,
285 };
286 }
287
288 TY::from_f64(result).unwrap()
289 }
290
291 fn find_best_cutoff(
292 &mut self,
293 visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
294 mtry: usize,
295 rng: &mut impl Rng,
296 ) -> bool {
297 let (_, n_attr) = visitor.x.shape();
298
299 let n: usize = visitor.samples.iter().sum();
300
301 if n < self.parameters().min_samples_split {
302 return false;
303 }
304
305 let sum = self.nodes()[visitor.node].output * n as f64;
306
307 let mut variables = (0..n_attr).collect::<Vec<_>>();
308
309 if mtry < n_attr {
310 variables.shuffle(rng);
311 }
312
313 let parent_gain =
314 n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
315
316 let splitter = self.parameters().splitter.clone();
317
318 for variable in variables.iter().take(mtry) {
319 match splitter {
320 Splitter::Random => {
321 self.find_random_split(visitor, n, sum, parent_gain, *variable, rng);
322 }
323 Splitter::Best => {
324 self.find_best_split(visitor, n, sum, parent_gain, *variable);
325 }
326 }
327 }
328
329 self.nodes()[visitor.node].split_score.is_some()
330 }
331
332 fn find_random_split(
333 &mut self,
334 visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
335 n: usize,
336 sum: f64,
337 parent_gain: f64,
338 j: usize,
339 rng: &mut impl Rng,
340 ) {
341 let (min_val, max_val) = {
342 let mut min_opt = None;
343 let mut max_opt = None;
344 for &i in &visitor.order[j] {
345 if visitor.samples[i] > 0 {
346 min_opt = Some(*visitor.x.get((i, j)));
347 break;
348 }
349 }
350 for &i in visitor.order[j].iter().rev() {
351 if visitor.samples[i] > 0 {
352 max_opt = Some(*visitor.x.get((i, j)));
353 break;
354 }
355 }
356 if min_opt.is_none() {
357 return;
358 }
359 (min_opt.unwrap(), max_opt.unwrap())
360 };
361
362 if min_val >= max_val {
363 return;
364 }
365
366 let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap());
367
368 let mut true_sum = 0f64;
369 let mut true_count = 0;
370 for &i in &visitor.order[j] {
371 if visitor.samples[i] > 0 {
372 if visitor.x.get((i, j)).to_f64().unwrap() <= split_value {
373 true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap();
374 true_count += visitor.samples[i];
375 } else {
376 break;
377 }
378 }
379 }
380
381 let false_count = n - true_count;
382
383 if true_count < self.parameters().min_samples_leaf
384 || false_count < self.parameters().min_samples_leaf
385 {
386 return;
387 }
388
389 let true_mean = if true_count > 0 {
390 true_sum / true_count as f64
391 } else {
392 0.0
393 };
394 let false_mean = if false_count > 0 {
395 (sum - true_sum) / false_count as f64
396 } else {
397 0.0
398 };
399 let gain = (true_count as f64 * true_mean * true_mean
400 + false_count as f64 * false_mean * false_mean)
401 - parent_gain;
402
403 if self.nodes[visitor.node].split_score.is_none()
404 || gain > self.nodes[visitor.node].split_score.unwrap()
405 {
406 self.nodes[visitor.node].split_feature = j;
407 self.nodes[visitor.node].split_value = Some(split_value);
408 self.nodes[visitor.node].split_score = Some(gain);
409 visitor.true_child_output = true_mean;
410 visitor.false_child_output = false_mean;
411 }
412 }
413
414 fn find_best_split(
415 &mut self,
416 visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
417 n: usize,
418 sum: f64,
419 parent_gain: f64,
420 j: usize,
421 ) {
422 let mut true_sum = 0f64;
423 let mut true_count = 0;
424 let mut prevx = Option::None;
425
426 for i in visitor.order[j].iter() {
427 if visitor.samples[*i] > 0 {
428 let x_ij = *visitor.x.get((*i, j));
429
430 if prevx.is_none() || x_ij == prevx.unwrap() {
431 prevx = Some(x_ij);
432 true_count += visitor.samples[*i];
433 true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
434 continue;
435 }
436
437 let false_count = n - true_count;
438
439 if true_count < self.parameters().min_samples_leaf
440 || false_count < self.parameters().min_samples_leaf
441 {
442 prevx = Some(x_ij);
443 true_count += visitor.samples[*i];
444 true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
445 continue;
446 }
447
448 let true_mean = true_sum / true_count as f64;
449 let false_mean = (sum - true_sum) / false_count as f64;
450
451 let gain = (true_count as f64 * true_mean * true_mean
452 + false_count as f64 * false_mean * false_mean)
453 - parent_gain;
454
455 if self.nodes()[visitor.node].split_score.is_none()
456 || gain > self.nodes()[visitor.node].split_score.unwrap()
457 {
458 self.nodes[visitor.node].split_feature = j;
459 self.nodes[visitor.node].split_value =
460 Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
461 self.nodes[visitor.node].split_score = Option::Some(gain);
462
463 visitor.true_child_output = true_mean;
464 visitor.false_child_output = false_mean;
465 }
466
467 prevx = Some(x_ij);
468 true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
469 true_count += visitor.samples[*i];
470 }
471 }
472 }
473
474 fn split<'a>(
475 &mut self,
476 mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
477 mtry: usize,
478 visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
479 rng: &mut impl Rng,
480 ) -> bool {
481 let (n, _) = visitor.x.shape();
482 let mut tc = 0;
483 let mut fc = 0;
484 let mut true_samples: Vec<usize> = vec![0; n];
485
486 for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
487 if visitor.samples[i] > 0 {
488 if visitor
489 .x
490 .get((i, self.nodes()[visitor.node].split_feature))
491 .to_f64()
492 .unwrap()
493 <= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
494 {
495 *true_sample = visitor.samples[i];
496 tc += *true_sample;
497 visitor.samples[i] = 0;
498 } else {
499 fc += visitor.samples[i];
500 }
501 }
502 }
503
504 if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
505 self.nodes[visitor.node].split_feature = 0;
506 self.nodes[visitor.node].split_value = Option::None;
507 self.nodes[visitor.node].split_score = Option::None;
508
509 return false;
510 }
511
512 let true_child_idx = self.nodes().len();
513
514 self.nodes.push(Node::new(visitor.true_child_output));
515 let false_child_idx = self.nodes().len();
516 self.nodes.push(Node::new(visitor.false_child_output));
517
518 self.nodes[visitor.node].true_child = Some(true_child_idx);
519 self.nodes[visitor.node].false_child = Some(false_child_idx);
520
521 self.depth = u16::max(self.depth, visitor.level + 1);
522
523 let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
524 true_child_idx,
525 true_samples,
526 visitor.order,
527 visitor.x,
528 visitor.y,
529 visitor.level + 1,
530 );
531
532 if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
533 visitor_queue.push_back(true_visitor);
534 }
535
536 let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
537 false_child_idx,
538 visitor.samples,
539 visitor.order,
540 visitor.x,
541 visitor.y,
542 visitor.level + 1,
543 );
544
545 if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
546 visitor_queue.push_back(false_visitor);
547 }
548
549 true
550 }
551}