1use anofox_ml_core::{Fit, Predict, Result, RustMlError};
8use ndarray::{Array1, Array2};
9
10const MAX_BINS: usize = 256;
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18struct FeatureBins {
19 edges: Vec<f64>,
21}
22
23fn compute_bins(x: &Array2<f64>, max_bins: usize) -> (Array2<u8>, Vec<FeatureBins>) {
25 let n = x.nrows();
26 let p = x.ncols();
27 let mut binned = Array2::zeros((n, p));
28 let mut all_bins = Vec::with_capacity(p);
29
30 for j in 0..p {
31 let mut col: Vec<f64> = (0..n).map(|i| x[[i, j]]).collect();
32 col.sort_by(|a, b| a.partial_cmp(b).unwrap());
33 col.dedup();
34
35 let n_edges = (col.len()).min(max_bins - 1);
37 let mut edges = Vec::with_capacity(n_edges);
38 for k in 1..=n_edges {
39 let idx = (k * col.len() / (n_edges + 1)).min(col.len() - 1);
40 let edge = col[idx];
41 if edges.last().map_or(true, |&last: &f64| edge > last) {
42 edges.push(edge);
43 }
44 }
45
46 for i in 0..n {
48 let v = x[[i, j]];
49 let bin = edges.partition_point(|&e| e < v) as u8;
50 binned[[i, j]] = bin;
51 }
52
53 all_bins.push(FeatureBins { edges });
54 }
55
56 (binned, all_bins)
57}
58
59fn bin_row(row: &[f64], all_bins: &[FeatureBins]) -> Vec<u8> {
61 row.iter()
62 .zip(all_bins.iter())
63 .map(|(&v, bins)| bins.edges.partition_point(|&e| e < v) as u8)
64 .collect()
65}
66
67#[derive(Clone)]
73struct Histogram {
74 grad_sum: Vec<f64>,
76 hess_sum: Vec<f64>,
78 count: Vec<u32>,
80}
81
82impl Histogram {
83 fn new(n_bins: usize) -> Self {
84 Self {
85 grad_sum: vec![0.0; n_bins],
86 hess_sum: vec![0.0; n_bins],
87 count: vec![0; n_bins],
88 }
89 }
90
91 fn reset(&mut self) {
92 self.grad_sum.fill(0.0);
93 self.hess_sum.fill(0.0);
94 self.count.fill(0);
95 }
96}
97
98#[allow(dead_code)]
100struct HistSplit {
101 feature: usize,
102 bin_threshold: u8,
103 gain: f64,
104 left_value: f64,
105 right_value: f64,
106 left_count: usize,
107 right_count: usize,
108}
109
110fn find_best_hist_split(
112 binned_x: &Array2<u8>,
113 gradients: &[f64],
114 hessians: &[f64],
115 indices: &[usize],
116 n_features: usize,
117 min_samples_leaf: usize,
118 l2_regularization: f64,
119) -> Option<HistSplit> {
120 let n_bins = MAX_BINS;
121 let mut best: Option<HistSplit> = None;
122 let mut hist = Histogram::new(n_bins);
123
124 let total_grad: f64 = indices.iter().map(|&i| gradients[i]).sum();
126 let total_hess: f64 = indices.iter().map(|&i| hessians[i]).sum();
127 let total_count = indices.len();
128
129 for feat in 0..n_features {
130 hist.reset();
131
132 for &i in indices {
134 let bin = binned_x[[i, feat]] as usize;
135 hist.grad_sum[bin] += gradients[i];
136 hist.hess_sum[bin] += hessians[i];
137 hist.count[bin] += 1;
138 }
139
140 let mut left_grad = 0.0;
142 let mut left_hess = 0.0;
143 let mut left_count: usize = 0;
144
145 for bin in 0..(n_bins - 1) {
146 left_grad += hist.grad_sum[bin];
147 left_hess += hist.hess_sum[bin];
148 left_count += hist.count[bin] as usize;
149
150 if left_count < min_samples_leaf {
151 continue;
152 }
153 let right_count = total_count - left_count;
154 if right_count < min_samples_leaf {
155 break;
156 }
157
158 let right_grad = total_grad - left_grad;
159 let right_hess = total_hess - left_hess;
160
161 let reg = l2_regularization;
164 let parent_term = total_grad * total_grad / (total_hess + reg);
165 let left_term = left_grad * left_grad / (left_hess + reg);
166 let right_term = right_grad * right_grad / (right_hess + reg);
167 let gain = 0.5 * (left_term + right_term - parent_term);
168
169 if gain > best.as_ref().map_or(0.0, |b| b.gain) {
170 best = Some(HistSplit {
171 feature: feat,
172 bin_threshold: bin as u8,
173 gain,
174 left_value: -left_grad / (left_hess + reg),
175 right_value: -right_grad / (right_hess + reg),
176 left_count,
177 right_count,
178 });
179 }
180 }
181 }
182
183 best
184}
185
186#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
191enum HistNode {
192 Leaf {
193 value: f64,
194 },
195 Internal {
196 feature: usize,
197 bin_threshold: u8,
198 left: Box<HistNode>,
199 right: Box<HistNode>,
200 },
201}
202
203impl HistNode {
204 fn predict_binned(&self, bins: &[u8]) -> f64 {
205 match self {
206 HistNode::Leaf { value } => *value,
207 HistNode::Internal {
208 feature,
209 bin_threshold,
210 left,
211 right,
212 } => {
213 if bins[*feature] <= *bin_threshold {
214 left.predict_binned(bins)
215 } else {
216 right.predict_binned(bins)
217 }
218 }
219 }
220 }
221}
222
223fn build_hist_tree(
224 binned_x: &Array2<u8>,
225 gradients: &[f64],
226 hessians: &[f64],
227 indices: &[usize],
228 max_depth: usize,
229 min_samples_leaf: usize,
230 l2_regularization: f64,
231 depth: usize,
232) -> HistNode {
233 if depth >= max_depth || indices.len() < 2 * min_samples_leaf {
235 let g: f64 = indices.iter().map(|&i| gradients[i]).sum();
236 let h: f64 = indices.iter().map(|&i| hessians[i]).sum();
237 return HistNode::Leaf {
238 value: -g / (h + l2_regularization),
239 };
240 }
241
242 let n_features = binned_x.ncols();
243 let split = find_best_hist_split(
244 binned_x,
245 gradients,
246 hessians,
247 indices,
248 n_features,
249 min_samples_leaf,
250 l2_regularization,
251 );
252
253 match split {
254 None => {
255 let g: f64 = indices.iter().map(|&i| gradients[i]).sum();
256 let h: f64 = indices.iter().map(|&i| hessians[i]).sum();
257 HistNode::Leaf {
258 value: -g / (h + l2_regularization),
259 }
260 }
261 Some(s) => {
262 let (left_idx, right_idx): (Vec<usize>, Vec<usize>) = indices
263 .iter()
264 .partition(|&&i| binned_x[[i, s.feature]] <= s.bin_threshold);
265
266 let left = build_hist_tree(
267 binned_x,
268 gradients,
269 hessians,
270 &left_idx,
271 max_depth,
272 min_samples_leaf,
273 l2_regularization,
274 depth + 1,
275 );
276 let right = build_hist_tree(
277 binned_x,
278 gradients,
279 hessians,
280 &right_idx,
281 max_depth,
282 min_samples_leaf,
283 l2_regularization,
284 depth + 1,
285 );
286
287 HistNode::Internal {
288 feature: s.feature,
289 bin_threshold: s.bin_threshold,
290 left: Box::new(left),
291 right: Box::new(right),
292 }
293 }
294 }
295}
296
297#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
306pub struct HistGradientBoostingRegressor {
307 pub n_estimators: usize,
308 pub learning_rate: f64,
309 pub max_depth: usize,
310 pub min_samples_leaf: usize,
311 pub l2_regularization: f64,
312 pub max_bins: usize,
313}
314
315impl HistGradientBoostingRegressor {
316 pub fn new() -> Self {
317 Self {
318 n_estimators: 100,
319 learning_rate: 0.1,
320 max_depth: 6,
321 min_samples_leaf: 20,
322 l2_regularization: 0.0,
323 max_bins: MAX_BINS,
324 }
325 }
326
327 pub fn with_n_estimators(mut self, n: usize) -> Self {
328 self.n_estimators = n;
329 self
330 }
331 pub fn with_learning_rate(mut self, lr: f64) -> Self {
332 self.learning_rate = lr;
333 self
334 }
335 pub fn with_max_depth(mut self, d: usize) -> Self {
336 self.max_depth = d;
337 self
338 }
339 pub fn with_min_samples_leaf(mut self, m: usize) -> Self {
340 self.min_samples_leaf = m;
341 self
342 }
343 pub fn with_l2_regularization(mut self, l2: f64) -> Self {
344 self.l2_regularization = l2;
345 self
346 }
347 pub fn with_max_bins(mut self, b: usize) -> Self {
348 self.max_bins = b;
349 self
350 }
351}
352
353impl Default for HistGradientBoostingRegressor {
354 fn default() -> Self {
355 Self::new()
356 }
357}
358
359#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
361pub struct FittedHistGradientBoostingRegressor {
362 trees: Vec<HistNode>,
363 bins: Vec<FeatureBins>,
364 baseline: f64,
365 learning_rate: f64,
366 n_features: usize,
367}
368
369impl FittedHistGradientBoostingRegressor {
370 pub fn n_estimators(&self) -> usize {
371 self.trees.len()
372 }
373}
374
375impl Fit<f64> for HistGradientBoostingRegressor {
376 type Fitted = FittedHistGradientBoostingRegressor;
377
378 fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
379 if x.nrows() != y.len() {
380 return Err(RustMlError::ShapeMismatch(format!(
381 "X has {} rows but y has {} elements",
382 x.nrows(),
383 y.len()
384 )));
385 }
386 if x.is_empty() {
387 return Err(RustMlError::EmptyInput("training data is empty".into()));
388 }
389
390 let n = x.nrows();
391 let (binned_x, bins) = compute_bins(x, self.max_bins);
392
393 let baseline: f64 = y.iter().sum::<f64>() / n as f64;
395 let mut predictions = vec![baseline; n];
396 let mut trees = Vec::with_capacity(self.n_estimators);
397
398 let indices: Vec<usize> = (0..n).collect();
399
400 for _ in 0..self.n_estimators {
401 let gradients: Vec<f64> = (0..n).map(|i| predictions[i] - y[i]).collect();
403 let hessians = vec![1.0; n];
404
405 let tree = build_hist_tree(
406 &binned_x,
407 &gradients,
408 &hessians,
409 &indices,
410 self.max_depth,
411 self.min_samples_leaf,
412 self.l2_regularization,
413 0,
414 );
415
416 for i in 0..n {
418 let row_bins: Vec<u8> = (0..x.ncols()).map(|j| binned_x[[i, j]]).collect();
419 predictions[i] += self.learning_rate * tree.predict_binned(&row_bins);
420 }
421
422 trees.push(tree);
423 }
424
425 Ok(FittedHistGradientBoostingRegressor {
426 trees,
427 bins,
428 baseline,
429 learning_rate: self.learning_rate,
430 n_features: x.ncols(),
431 })
432 }
433}
434
435impl Predict<f64> for FittedHistGradientBoostingRegressor {
436 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
437 if x.ncols() != self.n_features {
438 return Err(RustMlError::ShapeMismatch(format!(
439 "expected {} features, got {}",
440 self.n_features,
441 x.ncols()
442 )));
443 }
444
445 let n = x.nrows();
446 let mut preds = Array1::from_elem(n, self.baseline);
447
448 for i in 0..n {
449 let row: Vec<f64> = (0..self.n_features).map(|j| x[[i, j]]).collect();
450 let bins = bin_row(&row, &self.bins);
451 for tree in &self.trees {
452 preds[i] += self.learning_rate * tree.predict_binned(&bins);
453 }
454 }
455
456 Ok(preds)
457 }
458}
459
460#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
468pub struct HistGradientBoostingClassifier {
469 pub n_estimators: usize,
470 pub learning_rate: f64,
471 pub max_depth: usize,
472 pub min_samples_leaf: usize,
473 pub l2_regularization: f64,
474 pub max_bins: usize,
475}
476
477impl HistGradientBoostingClassifier {
478 pub fn new() -> Self {
479 Self {
480 n_estimators: 100,
481 learning_rate: 0.1,
482 max_depth: 6,
483 min_samples_leaf: 20,
484 l2_regularization: 0.0,
485 max_bins: MAX_BINS,
486 }
487 }
488
489 pub fn with_n_estimators(mut self, n: usize) -> Self {
490 self.n_estimators = n;
491 self
492 }
493 pub fn with_learning_rate(mut self, lr: f64) -> Self {
494 self.learning_rate = lr;
495 self
496 }
497 pub fn with_max_depth(mut self, d: usize) -> Self {
498 self.max_depth = d;
499 self
500 }
501 pub fn with_min_samples_leaf(mut self, m: usize) -> Self {
502 self.min_samples_leaf = m;
503 self
504 }
505 pub fn with_l2_regularization(mut self, l2: f64) -> Self {
506 self.l2_regularization = l2;
507 self
508 }
509 pub fn with_max_bins(mut self, b: usize) -> Self {
510 self.max_bins = b;
511 self
512 }
513}
514
515impl Default for HistGradientBoostingClassifier {
516 fn default() -> Self {
517 Self::new()
518 }
519}
520
521#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
523pub struct FittedHistGradientBoostingClassifier {
524 tree_sets: Vec<Vec<HistNode>>,
526 bins: Vec<FeatureBins>,
527 baselines: Vec<f64>,
528 classes: Vec<f64>,
529 learning_rate: f64,
530 n_features: usize,
531}
532
533impl FittedHistGradientBoostingClassifier {
534 pub fn classes(&self) -> &[f64] {
535 &self.classes
536 }
537 pub fn n_estimators(&self) -> usize {
538 self.tree_sets.first().map_or(0, |t| t.len())
539 }
540
541 pub fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
543 if x.ncols() != self.n_features {
544 return Err(RustMlError::ShapeMismatch(format!(
545 "expected {} features, got {}",
546 self.n_features,
547 x.ncols()
548 )));
549 }
550
551 let n = x.nrows();
552 let n_classes = self.classes.len();
553
554 if n_classes == 2 {
555 let mut proba = Array2::zeros((n, 2));
557 for i in 0..n {
558 let row: Vec<f64> = (0..self.n_features).map(|j| x[[i, j]]).collect();
559 let bins = bin_row(&row, &self.bins);
560 let mut score = self.baselines[0];
561 for tree in &self.tree_sets[0] {
562 score += self.learning_rate * tree.predict_binned(&bins);
563 }
564 let p1 = 1.0 / (1.0 + (-score).exp());
565 proba[[i, 0]] = 1.0 - p1;
566 proba[[i, 1]] = p1;
567 }
568 Ok(proba)
569 } else {
570 let mut proba = Array2::zeros((n, n_classes));
572 for i in 0..n {
573 let row: Vec<f64> = (0..self.n_features).map(|j| x[[i, j]]).collect();
574 let bins = bin_row(&row, &self.bins);
575 let mut scores = vec![0.0; n_classes];
576 for (c, tree_set) in self.tree_sets.iter().enumerate() {
577 scores[c] = self.baselines[c];
578 for tree in tree_set {
579 scores[c] += self.learning_rate * tree.predict_binned(&bins);
580 }
581 }
582 let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
584 let exp_sum: f64 = scores.iter().map(|&s| (s - max_s).exp()).sum();
585 for c in 0..n_classes {
586 proba[[i, c]] = (scores[c] - max_s).exp() / exp_sum;
587 }
588 }
589 Ok(proba)
590 }
591 }
592}
593
594impl Fit<f64> for HistGradientBoostingClassifier {
595 type Fitted = FittedHistGradientBoostingClassifier;
596
597 fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
598 if x.nrows() != y.len() {
599 return Err(RustMlError::ShapeMismatch(format!(
600 "X has {} rows but y has {} elements",
601 x.nrows(),
602 y.len()
603 )));
604 }
605 if x.is_empty() {
606 return Err(RustMlError::EmptyInput("training data is empty".into()));
607 }
608
609 let n = x.nrows();
610 let (binned_x, bins) = compute_bins(x, self.max_bins);
611
612 let mut classes: Vec<f64> = y.iter().copied().collect();
614 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
615 classes.dedup();
616 let n_classes = classes.len();
617
618 if n_classes < 2 {
619 return Err(RustMlError::InvalidParameter(
620 "need at least 2 classes".into(),
621 ));
622 }
623
624 let indices: Vec<usize> = (0..n).collect();
625
626 if n_classes == 2 {
627 let pos_class = classes[1];
629 let labels: Vec<f64> = y
630 .iter()
631 .map(|&v| if v == pos_class { 1.0 } else { 0.0 })
632 .collect();
633 let pos_frac: f64 = labels.iter().sum::<f64>() / n as f64;
634 let baseline = (pos_frac / (1.0 - pos_frac + 1e-15)).ln();
635
636 let mut raw_scores = vec![baseline; n];
637 let mut trees = Vec::with_capacity(self.n_estimators);
638
639 for _ in 0..self.n_estimators {
640 let gradients: Vec<f64> = (0..n)
642 .map(|i| {
643 let p = 1.0 / (1.0 + (-raw_scores[i]).exp());
644 p - labels[i]
645 })
646 .collect();
647 let hessians: Vec<f64> = (0..n)
648 .map(|i| {
649 let p = 1.0 / (1.0 + (-raw_scores[i]).exp());
650 (p * (1.0 - p)).max(1e-12)
651 })
652 .collect();
653
654 let tree = build_hist_tree(
655 &binned_x,
656 &gradients,
657 &hessians,
658 &indices,
659 self.max_depth,
660 self.min_samples_leaf,
661 self.l2_regularization,
662 0,
663 );
664
665 for i in 0..n {
666 let row_bins: Vec<u8> = (0..x.ncols()).map(|j| binned_x[[i, j]]).collect();
667 raw_scores[i] += self.learning_rate * tree.predict_binned(&row_bins);
668 }
669 trees.push(tree);
670 }
671
672 Ok(FittedHistGradientBoostingClassifier {
673 tree_sets: vec![trees],
674 bins,
675 baselines: vec![baseline],
676 classes,
677 learning_rate: self.learning_rate,
678 n_features: x.ncols(),
679 })
680 } else {
681 let mut tree_sets = Vec::with_capacity(n_classes);
683 let mut baselines = Vec::with_capacity(n_classes);
684 let mut all_raw_scores = vec![vec![0.0; n]; n_classes];
685
686 for (c, &cls) in classes.iter().enumerate() {
688 let count = y.iter().filter(|&&v| v == cls).count() as f64;
689 let prior = count / n as f64;
690 let bl = prior.ln().max(-10.0);
691 baselines.push(bl);
692 all_raw_scores[c] = vec![bl; n];
693 }
694
695 for _ in 0..self.n_estimators {
697 let mut probas = vec![vec![0.0; n_classes]; n];
699 for i in 0..n {
700 let max_s = all_raw_scores
701 .iter()
702 .map(|s| s[i])
703 .fold(f64::NEG_INFINITY, f64::max);
704 let exp_sum: f64 = all_raw_scores.iter().map(|s| (s[i] - max_s).exp()).sum();
705 for c in 0..n_classes {
706 probas[i][c] = (all_raw_scores[c][i] - max_s).exp() / exp_sum;
707 }
708 }
709
710 let mut round_trees = Vec::with_capacity(n_classes);
711 for (c, &cls) in classes.iter().enumerate() {
712 let gradients: Vec<f64> = (0..n)
713 .map(|i| {
714 let label = if y[i] == cls { 1.0 } else { 0.0 };
715 probas[i][c] - label
716 })
717 .collect();
718 let hessians: Vec<f64> = (0..n)
719 .map(|i| (probas[i][c] * (1.0 - probas[i][c])).max(1e-12))
720 .collect();
721
722 let tree = build_hist_tree(
723 &binned_x,
724 &gradients,
725 &hessians,
726 &indices,
727 self.max_depth,
728 self.min_samples_leaf,
729 self.l2_regularization,
730 0,
731 );
732
733 for i in 0..n {
734 let row_bins: Vec<u8> = (0..x.ncols()).map(|j| binned_x[[i, j]]).collect();
735 all_raw_scores[c][i] += self.learning_rate * tree.predict_binned(&row_bins);
736 }
737 round_trees.push(tree);
738 }
739
740 if tree_sets.is_empty() {
742 for tree in round_trees {
743 tree_sets.push(vec![tree]);
744 }
745 } else {
746 for (c, tree) in round_trees.into_iter().enumerate() {
747 tree_sets[c].push(tree);
748 }
749 }
750 }
751
752 Ok(FittedHistGradientBoostingClassifier {
753 tree_sets,
754 bins,
755 baselines,
756 classes,
757 learning_rate: self.learning_rate,
758 n_features: x.ncols(),
759 })
760 }
761 }
762}
763
764impl Predict<f64> for FittedHistGradientBoostingClassifier {
765 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
766 let proba = self.predict_proba(x)?;
767 let n = x.nrows();
768 let mut preds = Array1::zeros(n);
769
770 for i in 0..n {
771 let mut best_c = 0;
772 let mut best_p = proba[[i, 0]];
773 for c in 1..self.classes.len() {
774 if proba[[i, c]] > best_p {
775 best_p = proba[[i, c]];
776 best_c = c;
777 }
778 }
779 preds[i] = self.classes[best_c];
780 }
781
782 Ok(preds)
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789 use approx::assert_abs_diff_eq;
790 use ndarray::array;
791
792 #[test]
793 fn test_hist_gb_regressor_basic() {
794 let x = array![
795 [1.0],
796 [2.0],
797 [3.0],
798 [4.0],
799 [5.0],
800 [6.0],
801 [7.0],
802 [8.0],
803 [9.0],
804 [10.0]
805 ];
806 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
807
808 let model = HistGradientBoostingRegressor::new()
809 .with_n_estimators(50)
810 .with_max_depth(3)
811 .with_min_samples_leaf(1);
812
813 let fitted = model.fit(&x, &y).unwrap();
814 let preds = fitted.predict(&x).unwrap();
815
816 for (p, t) in preds.iter().zip(y.iter()) {
817 assert_abs_diff_eq!(*p, *t, epsilon = 2.0);
818 }
819 }
820
821 #[test]
822 fn test_hist_gb_regressor_n_estimators() {
823 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
824 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
825
826 let fitted = HistGradientBoostingRegressor::new()
827 .with_n_estimators(10)
828 .with_min_samples_leaf(1)
829 .fit(&x, &y)
830 .unwrap();
831
832 assert_eq!(fitted.n_estimators(), 10);
833 }
834
835 #[test]
836 fn test_hist_gb_classifier_binary() {
837 let x = array![
838 [1.0, 0.0],
839 [2.0, 0.0],
840 [3.0, 0.0],
841 [4.0, 0.0],
842 [5.0, 0.0],
843 [10.0, 1.0],
844 [11.0, 1.0],
845 [12.0, 1.0],
846 [13.0, 1.0],
847 [14.0, 1.0]
848 ];
849 let y = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
850
851 let model = HistGradientBoostingClassifier::new()
852 .with_n_estimators(20)
853 .with_max_depth(3)
854 .with_min_samples_leaf(1);
855
856 let fitted = model.fit(&x, &y).unwrap();
857 let preds = fitted.predict(&x).unwrap();
858
859 let correct: usize = preds.iter().zip(y.iter()).filter(|(&p, &t)| p == t).count();
860 assert!(
861 correct >= 8,
862 "should classify most correctly, got {}/10",
863 correct
864 );
865 }
866
867 #[test]
868 fn test_hist_gb_classifier_predict_proba() {
869 let x = array![
870 [1.0],
871 [2.0],
872 [3.0],
873 [4.0],
874 [5.0],
875 [10.0],
876 [11.0],
877 [12.0],
878 [13.0],
879 [14.0]
880 ];
881 let y = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
882
883 let fitted = HistGradientBoostingClassifier::new()
884 .with_n_estimators(20)
885 .with_min_samples_leaf(1)
886 .fit(&x, &y)
887 .unwrap();
888
889 let proba = fitted.predict_proba(&x).unwrap();
890 assert_eq!(proba.ncols(), 2);
891
892 for i in 0..x.nrows() {
893 let row_sum: f64 = (0..proba.ncols()).map(|c| proba[[i, c]]).sum();
894 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
895 }
896 }
897
898 #[test]
899 fn test_hist_gb_classifier_multiclass() {
900 let x = array![
901 [0.0, 0.0],
902 [1.0, 0.0],
903 [2.0, 0.0],
904 [5.0, 5.0],
905 [6.0, 5.0],
906 [7.0, 5.0],
907 [0.0, 10.0],
908 [1.0, 10.0],
909 [2.0, 10.0]
910 ];
911 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
912
913 let fitted = HistGradientBoostingClassifier::new()
914 .with_n_estimators(30)
915 .with_max_depth(3)
916 .with_min_samples_leaf(1)
917 .fit(&x, &y)
918 .unwrap();
919
920 assert_eq!(fitted.classes(), &[0.0, 1.0, 2.0]);
921
922 let proba = fitted.predict_proba(&x).unwrap();
923 assert_eq!(proba.ncols(), 3);
924 }
925
926 #[test]
927 fn test_hist_gb_regressor_shape_mismatch() {
928 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
929 let y = array![1.0, 2.0];
930 assert!(HistGradientBoostingRegressor::new().fit(&x, &y).is_err());
931 }
932
933 #[test]
934 fn test_hist_gb_regressor_empty() {
935 let x = Array2::<f64>::zeros((0, 2));
936 let y = Array1::<f64>::zeros(0);
937 assert!(HistGradientBoostingRegressor::new().fit(&x, &y).is_err());
938 }
939
940 #[test]
941 fn test_hist_gb_classifier_single_class() {
942 let x = array![[1.0], [2.0], [3.0]];
943 let y = array![0.0, 0.0, 0.0];
944 assert!(HistGradientBoostingClassifier::new().fit(&x, &y).is_err());
945 }
946
947 #[test]
948 fn test_hist_gb_predict_shape_mismatch() {
949 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
950 let y = array![0.0, 0.0, 1.0, 1.0];
951
952 let fitted = HistGradientBoostingClassifier::new()
953 .with_n_estimators(5)
954 .with_min_samples_leaf(1)
955 .fit(&x, &y)
956 .unwrap();
957
958 let x_bad = array![[1.0]];
959 assert!(fitted.predict(&x_bad).is_err());
960 }
961}