Skip to main content

anofox_ml_ensemble/
hist_gradient_boosting.rs

1//! Histogram-based gradient boosting (classifier and regressor).
2//!
3//! Much faster than classical gradient boosting for medium-to-large datasets.
4//! Features are binned into 256 discrete bins, enabling O(n) split finding
5//! via histogram accumulation instead of O(n log n) sorting.
6
7use anofox_ml_core::{Fit, Predict, Result, RustMlError};
8use ndarray::{Array1, Array2};
9
10const MAX_BINS: usize = 256;
11
12// ============================================================
13// Binning
14// ============================================================
15
16/// Bin edges for one feature.
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18struct FeatureBins {
19    /// Sorted thresholds. Value v maps to bin i if edges[i-1] < v <= edges[i].
20    edges: Vec<f64>,
21}
22
23/// Bin all features into u8 indices.
24fn compute_bins(x: &Array2<f64>, max_bins: usize) -> (Array2<u8>, Vec<FeatureBins>) {
25    let n = x.nrows();
26    let p = x.ncols();
27    let mut binned = Array2::zeros((n, p));
28    let mut all_bins = Vec::with_capacity(p);
29
30    for j in 0..p {
31        let mut col: Vec<f64> = (0..n).map(|i| x[[i, j]]).collect();
32        col.sort_by(|a, b| a.partial_cmp(b).unwrap());
33        col.dedup();
34
35        // Compute quantile-based bin edges
36        let n_edges = (col.len()).min(max_bins - 1);
37        let mut edges = Vec::with_capacity(n_edges);
38        for k in 1..=n_edges {
39            let idx = (k * col.len() / (n_edges + 1)).min(col.len() - 1);
40            let edge = col[idx];
41            if edges.last().map_or(true, |&last: &f64| edge > last) {
42                edges.push(edge);
43            }
44        }
45
46        // Map values to bins
47        for i in 0..n {
48            let v = x[[i, j]];
49            let bin = edges.partition_point(|&e| e < v) as u8;
50            binned[[i, j]] = bin;
51        }
52
53        all_bins.push(FeatureBins { edges });
54    }
55
56    (binned, all_bins)
57}
58
59/// Map a new data point's features to bin indices.
60fn bin_row(row: &[f64], all_bins: &[FeatureBins]) -> Vec<u8> {
61    row.iter()
62        .zip(all_bins.iter())
63        .map(|(&v, bins)| bins.edges.partition_point(|&e| e < v) as u8)
64        .collect()
65}
66
67// ============================================================
68// Histogram tree node
69// ============================================================
70
71/// A histogram accumulating gradient/hessian sums per bin.
72#[derive(Clone)]
73struct Histogram {
74    /// Per-bin sum of gradients. Length = n_bins.
75    grad_sum: Vec<f64>,
76    /// Per-bin sum of hessians. Length = n_bins.
77    hess_sum: Vec<f64>,
78    /// Per-bin sample count.
79    count: Vec<u32>,
80}
81
82impl Histogram {
83    fn new(n_bins: usize) -> Self {
84        Self {
85            grad_sum: vec![0.0; n_bins],
86            hess_sum: vec![0.0; n_bins],
87            count: vec![0; n_bins],
88        }
89    }
90
91    fn reset(&mut self) {
92        self.grad_sum.fill(0.0);
93        self.hess_sum.fill(0.0);
94        self.count.fill(0);
95    }
96}
97
98/// Result of finding the best split for a node.
99#[allow(dead_code)]
100struct HistSplit {
101    feature: usize,
102    bin_threshold: u8,
103    gain: f64,
104    left_value: f64,
105    right_value: f64,
106    left_count: usize,
107    right_count: usize,
108}
109
110/// Find the best split across all features using histograms.
111fn find_best_hist_split(
112    binned_x: &Array2<u8>,
113    gradients: &[f64],
114    hessians: &[f64],
115    indices: &[usize],
116    n_features: usize,
117    min_samples_leaf: usize,
118    l2_regularization: f64,
119) -> Option<HistSplit> {
120    let n_bins = MAX_BINS;
121    let mut best: Option<HistSplit> = None;
122    let mut hist = Histogram::new(n_bins);
123
124    // Total gradient/hessian for this node
125    let total_grad: f64 = indices.iter().map(|&i| gradients[i]).sum();
126    let total_hess: f64 = indices.iter().map(|&i| hessians[i]).sum();
127    let total_count = indices.len();
128
129    for feat in 0..n_features {
130        hist.reset();
131
132        // Build histogram for this feature
133        for &i in indices {
134            let bin = binned_x[[i, feat]] as usize;
135            hist.grad_sum[bin] += gradients[i];
136            hist.hess_sum[bin] += hessians[i];
137            hist.count[bin] += 1;
138        }
139
140        // Scan bins to find best split
141        let mut left_grad = 0.0;
142        let mut left_hess = 0.0;
143        let mut left_count: usize = 0;
144
145        for bin in 0..(n_bins - 1) {
146            left_grad += hist.grad_sum[bin];
147            left_hess += hist.hess_sum[bin];
148            left_count += hist.count[bin] as usize;
149
150            if left_count < min_samples_leaf {
151                continue;
152            }
153            let right_count = total_count - left_count;
154            if right_count < min_samples_leaf {
155                break;
156            }
157
158            let right_grad = total_grad - left_grad;
159            let right_hess = total_hess - left_hess;
160
161            // Gain = left_term + right_term - parent_term
162            // where term = G^2 / (H + lambda)
163            let reg = l2_regularization;
164            let parent_term = total_grad * total_grad / (total_hess + reg);
165            let left_term = left_grad * left_grad / (left_hess + reg);
166            let right_term = right_grad * right_grad / (right_hess + reg);
167            let gain = 0.5 * (left_term + right_term - parent_term);
168
169            if gain > best.as_ref().map_or(0.0, |b| b.gain) {
170                best = Some(HistSplit {
171                    feature: feat,
172                    bin_threshold: bin as u8,
173                    gain,
174                    left_value: -left_grad / (left_hess + reg),
175                    right_value: -right_grad / (right_hess + reg),
176                    left_count,
177                    right_count,
178                });
179            }
180        }
181    }
182
183    best
184}
185
186// ============================================================
187// Hist tree structure
188// ============================================================
189
190#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
191enum HistNode {
192    Leaf {
193        value: f64,
194    },
195    Internal {
196        feature: usize,
197        bin_threshold: u8,
198        left: Box<HistNode>,
199        right: Box<HistNode>,
200    },
201}
202
203impl HistNode {
204    fn predict_binned(&self, bins: &[u8]) -> f64 {
205        match self {
206            HistNode::Leaf { value } => *value,
207            HistNode::Internal {
208                feature,
209                bin_threshold,
210                left,
211                right,
212            } => {
213                if bins[*feature] <= *bin_threshold {
214                    left.predict_binned(bins)
215                } else {
216                    right.predict_binned(bins)
217                }
218            }
219        }
220    }
221}
222
223fn build_hist_tree(
224    binned_x: &Array2<u8>,
225    gradients: &[f64],
226    hessians: &[f64],
227    indices: &[usize],
228    max_depth: usize,
229    min_samples_leaf: usize,
230    l2_regularization: f64,
231    depth: usize,
232) -> HistNode {
233    // Leaf conditions
234    if depth >= max_depth || indices.len() < 2 * min_samples_leaf {
235        let g: f64 = indices.iter().map(|&i| gradients[i]).sum();
236        let h: f64 = indices.iter().map(|&i| hessians[i]).sum();
237        return HistNode::Leaf {
238            value: -g / (h + l2_regularization),
239        };
240    }
241
242    let n_features = binned_x.ncols();
243    let split = find_best_hist_split(
244        binned_x,
245        gradients,
246        hessians,
247        indices,
248        n_features,
249        min_samples_leaf,
250        l2_regularization,
251    );
252
253    match split {
254        None => {
255            let g: f64 = indices.iter().map(|&i| gradients[i]).sum();
256            let h: f64 = indices.iter().map(|&i| hessians[i]).sum();
257            HistNode::Leaf {
258                value: -g / (h + l2_regularization),
259            }
260        }
261        Some(s) => {
262            let (left_idx, right_idx): (Vec<usize>, Vec<usize>) = indices
263                .iter()
264                .partition(|&&i| binned_x[[i, s.feature]] <= s.bin_threshold);
265
266            let left = build_hist_tree(
267                binned_x,
268                gradients,
269                hessians,
270                &left_idx,
271                max_depth,
272                min_samples_leaf,
273                l2_regularization,
274                depth + 1,
275            );
276            let right = build_hist_tree(
277                binned_x,
278                gradients,
279                hessians,
280                &right_idx,
281                max_depth,
282                min_samples_leaf,
283                l2_regularization,
284                depth + 1,
285            );
286
287            HistNode::Internal {
288                feature: s.feature,
289                bin_threshold: s.bin_threshold,
290                left: Box::new(left),
291                right: Box::new(right),
292            }
293        }
294    }
295}
296
297// ============================================================
298// HistGradientBoostingRegressor
299// ============================================================
300
301/// Histogram-based gradient boosting regressor.
302///
303/// Much faster than `GradientBoostingRegressor` for datasets with >1000 samples.
304/// Features are discretized into 256 bins for O(n) split finding.
305#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
306pub struct HistGradientBoostingRegressor {
307    pub n_estimators: usize,
308    pub learning_rate: f64,
309    pub max_depth: usize,
310    pub min_samples_leaf: usize,
311    pub l2_regularization: f64,
312    pub max_bins: usize,
313}
314
315impl HistGradientBoostingRegressor {
316    pub fn new() -> Self {
317        Self {
318            n_estimators: 100,
319            learning_rate: 0.1,
320            max_depth: 6,
321            min_samples_leaf: 20,
322            l2_regularization: 0.0,
323            max_bins: MAX_BINS,
324        }
325    }
326
327    pub fn with_n_estimators(mut self, n: usize) -> Self {
328        self.n_estimators = n;
329        self
330    }
331    pub fn with_learning_rate(mut self, lr: f64) -> Self {
332        self.learning_rate = lr;
333        self
334    }
335    pub fn with_max_depth(mut self, d: usize) -> Self {
336        self.max_depth = d;
337        self
338    }
339    pub fn with_min_samples_leaf(mut self, m: usize) -> Self {
340        self.min_samples_leaf = m;
341        self
342    }
343    pub fn with_l2_regularization(mut self, l2: f64) -> Self {
344        self.l2_regularization = l2;
345        self
346    }
347    pub fn with_max_bins(mut self, b: usize) -> Self {
348        self.max_bins = b;
349        self
350    }
351}
352
353impl Default for HistGradientBoostingRegressor {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359/// Fitted histogram-based gradient boosting regressor.
360#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
361pub struct FittedHistGradientBoostingRegressor {
362    trees: Vec<HistNode>,
363    bins: Vec<FeatureBins>,
364    baseline: f64,
365    learning_rate: f64,
366    n_features: usize,
367}
368
369impl FittedHistGradientBoostingRegressor {
370    pub fn n_estimators(&self) -> usize {
371        self.trees.len()
372    }
373}
374
375impl Fit<f64> for HistGradientBoostingRegressor {
376    type Fitted = FittedHistGradientBoostingRegressor;
377
378    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
379        if x.nrows() != y.len() {
380            return Err(RustMlError::ShapeMismatch(format!(
381                "X has {} rows but y has {} elements",
382                x.nrows(),
383                y.len()
384            )));
385        }
386        if x.is_empty() {
387            return Err(RustMlError::EmptyInput("training data is empty".into()));
388        }
389
390        let n = x.nrows();
391        let (binned_x, bins) = compute_bins(x, self.max_bins);
392
393        // Initial prediction = mean(y)
394        let baseline: f64 = y.iter().sum::<f64>() / n as f64;
395        let mut predictions = vec![baseline; n];
396        let mut trees = Vec::with_capacity(self.n_estimators);
397
398        let indices: Vec<usize> = (0..n).collect();
399
400        for _ in 0..self.n_estimators {
401            // Squared error: gradient = prediction - y, hessian = 1
402            let gradients: Vec<f64> = (0..n).map(|i| predictions[i] - y[i]).collect();
403            let hessians = vec![1.0; n];
404
405            let tree = build_hist_tree(
406                &binned_x,
407                &gradients,
408                &hessians,
409                &indices,
410                self.max_depth,
411                self.min_samples_leaf,
412                self.l2_regularization,
413                0,
414            );
415
416            // Update predictions
417            for i in 0..n {
418                let row_bins: Vec<u8> = (0..x.ncols()).map(|j| binned_x[[i, j]]).collect();
419                predictions[i] += self.learning_rate * tree.predict_binned(&row_bins);
420            }
421
422            trees.push(tree);
423        }
424
425        Ok(FittedHistGradientBoostingRegressor {
426            trees,
427            bins,
428            baseline,
429            learning_rate: self.learning_rate,
430            n_features: x.ncols(),
431        })
432    }
433}
434
435impl Predict<f64> for FittedHistGradientBoostingRegressor {
436    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
437        if x.ncols() != self.n_features {
438            return Err(RustMlError::ShapeMismatch(format!(
439                "expected {} features, got {}",
440                self.n_features,
441                x.ncols()
442            )));
443        }
444
445        let n = x.nrows();
446        let mut preds = Array1::from_elem(n, self.baseline);
447
448        for i in 0..n {
449            let row: Vec<f64> = (0..self.n_features).map(|j| x[[i, j]]).collect();
450            let bins = bin_row(&row, &self.bins);
451            for tree in &self.trees {
452                preds[i] += self.learning_rate * tree.predict_binned(&bins);
453            }
454        }
455
456        Ok(preds)
457    }
458}
459
460// ============================================================
461// HistGradientBoostingClassifier
462// ============================================================
463
464/// Histogram-based gradient boosting classifier.
465///
466/// Binary classification using log-loss (logistic). Multi-class uses OvR.
467#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
468pub struct HistGradientBoostingClassifier {
469    pub n_estimators: usize,
470    pub learning_rate: f64,
471    pub max_depth: usize,
472    pub min_samples_leaf: usize,
473    pub l2_regularization: f64,
474    pub max_bins: usize,
475}
476
477impl HistGradientBoostingClassifier {
478    pub fn new() -> Self {
479        Self {
480            n_estimators: 100,
481            learning_rate: 0.1,
482            max_depth: 6,
483            min_samples_leaf: 20,
484            l2_regularization: 0.0,
485            max_bins: MAX_BINS,
486        }
487    }
488
489    pub fn with_n_estimators(mut self, n: usize) -> Self {
490        self.n_estimators = n;
491        self
492    }
493    pub fn with_learning_rate(mut self, lr: f64) -> Self {
494        self.learning_rate = lr;
495        self
496    }
497    pub fn with_max_depth(mut self, d: usize) -> Self {
498        self.max_depth = d;
499        self
500    }
501    pub fn with_min_samples_leaf(mut self, m: usize) -> Self {
502        self.min_samples_leaf = m;
503        self
504    }
505    pub fn with_l2_regularization(mut self, l2: f64) -> Self {
506        self.l2_regularization = l2;
507        self
508    }
509    pub fn with_max_bins(mut self, b: usize) -> Self {
510        self.max_bins = b;
511        self
512    }
513}
514
515impl Default for HistGradientBoostingClassifier {
516    fn default() -> Self {
517        Self::new()
518    }
519}
520
521/// Fitted histogram-based gradient boosting classifier.
522#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
523pub struct FittedHistGradientBoostingClassifier {
524    /// For binary: single set of trees. For multi-class: one set per class (OvR).
525    tree_sets: Vec<Vec<HistNode>>,
526    bins: Vec<FeatureBins>,
527    baselines: Vec<f64>,
528    classes: Vec<f64>,
529    learning_rate: f64,
530    n_features: usize,
531}
532
533impl FittedHistGradientBoostingClassifier {
534    pub fn classes(&self) -> &[f64] {
535        &self.classes
536    }
537    pub fn n_estimators(&self) -> usize {
538        self.tree_sets.first().map_or(0, |t| t.len())
539    }
540
541    /// Predict class probabilities.
542    pub fn predict_proba(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
543        if x.ncols() != self.n_features {
544            return Err(RustMlError::ShapeMismatch(format!(
545                "expected {} features, got {}",
546                self.n_features,
547                x.ncols()
548            )));
549        }
550
551        let n = x.nrows();
552        let n_classes = self.classes.len();
553
554        if n_classes == 2 {
555            // Binary: sigmoid of raw scores
556            let mut proba = Array2::zeros((n, 2));
557            for i in 0..n {
558                let row: Vec<f64> = (0..self.n_features).map(|j| x[[i, j]]).collect();
559                let bins = bin_row(&row, &self.bins);
560                let mut score = self.baselines[0];
561                for tree in &self.tree_sets[0] {
562                    score += self.learning_rate * tree.predict_binned(&bins);
563                }
564                let p1 = 1.0 / (1.0 + (-score).exp());
565                proba[[i, 0]] = 1.0 - p1;
566                proba[[i, 1]] = p1;
567            }
568            Ok(proba)
569        } else {
570            // Multi-class: softmax of raw scores
571            let mut proba = Array2::zeros((n, n_classes));
572            for i in 0..n {
573                let row: Vec<f64> = (0..self.n_features).map(|j| x[[i, j]]).collect();
574                let bins = bin_row(&row, &self.bins);
575                let mut scores = vec![0.0; n_classes];
576                for (c, tree_set) in self.tree_sets.iter().enumerate() {
577                    scores[c] = self.baselines[c];
578                    for tree in tree_set {
579                        scores[c] += self.learning_rate * tree.predict_binned(&bins);
580                    }
581                }
582                // Softmax
583                let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
584                let exp_sum: f64 = scores.iter().map(|&s| (s - max_s).exp()).sum();
585                for c in 0..n_classes {
586                    proba[[i, c]] = (scores[c] - max_s).exp() / exp_sum;
587                }
588            }
589            Ok(proba)
590        }
591    }
592}
593
594impl Fit<f64> for HistGradientBoostingClassifier {
595    type Fitted = FittedHistGradientBoostingClassifier;
596
597    fn fit(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
598        if x.nrows() != y.len() {
599            return Err(RustMlError::ShapeMismatch(format!(
600                "X has {} rows but y has {} elements",
601                x.nrows(),
602                y.len()
603            )));
604        }
605        if x.is_empty() {
606            return Err(RustMlError::EmptyInput("training data is empty".into()));
607        }
608
609        let n = x.nrows();
610        let (binned_x, bins) = compute_bins(x, self.max_bins);
611
612        // Collect classes
613        let mut classes: Vec<f64> = y.iter().copied().collect();
614        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
615        classes.dedup();
616        let n_classes = classes.len();
617
618        if n_classes < 2 {
619            return Err(RustMlError::InvalidParameter(
620                "need at least 2 classes".into(),
621            ));
622        }
623
624        let indices: Vec<usize> = (0..n).collect();
625
626        if n_classes == 2 {
627            // Binary: log-loss with single set of trees
628            let pos_class = classes[1];
629            let labels: Vec<f64> = y
630                .iter()
631                .map(|&v| if v == pos_class { 1.0 } else { 0.0 })
632                .collect();
633            let pos_frac: f64 = labels.iter().sum::<f64>() / n as f64;
634            let baseline = (pos_frac / (1.0 - pos_frac + 1e-15)).ln();
635
636            let mut raw_scores = vec![baseline; n];
637            let mut trees = Vec::with_capacity(self.n_estimators);
638
639            for _ in 0..self.n_estimators {
640                // Log-loss gradients: p - y, hessians: p * (1 - p)
641                let gradients: Vec<f64> = (0..n)
642                    .map(|i| {
643                        let p = 1.0 / (1.0 + (-raw_scores[i]).exp());
644                        p - labels[i]
645                    })
646                    .collect();
647                let hessians: Vec<f64> = (0..n)
648                    .map(|i| {
649                        let p = 1.0 / (1.0 + (-raw_scores[i]).exp());
650                        (p * (1.0 - p)).max(1e-12)
651                    })
652                    .collect();
653
654                let tree = build_hist_tree(
655                    &binned_x,
656                    &gradients,
657                    &hessians,
658                    &indices,
659                    self.max_depth,
660                    self.min_samples_leaf,
661                    self.l2_regularization,
662                    0,
663                );
664
665                for i in 0..n {
666                    let row_bins: Vec<u8> = (0..x.ncols()).map(|j| binned_x[[i, j]]).collect();
667                    raw_scores[i] += self.learning_rate * tree.predict_binned(&row_bins);
668                }
669                trees.push(tree);
670            }
671
672            Ok(FittedHistGradientBoostingClassifier {
673                tree_sets: vec![trees],
674                bins,
675                baselines: vec![baseline],
676                classes,
677                learning_rate: self.learning_rate,
678                n_features: x.ncols(),
679            })
680        } else {
681            // Multi-class: one-vs-all with softmax
682            let mut tree_sets = Vec::with_capacity(n_classes);
683            let mut baselines = Vec::with_capacity(n_classes);
684            let mut all_raw_scores = vec![vec![0.0; n]; n_classes];
685
686            // Initial baselines: log(class_prior)
687            for (c, &cls) in classes.iter().enumerate() {
688                let count = y.iter().filter(|&&v| v == cls).count() as f64;
689                let prior = count / n as f64;
690                let bl = prior.ln().max(-10.0);
691                baselines.push(bl);
692                all_raw_scores[c] = vec![bl; n];
693            }
694
695            // Train trees for each class
696            for _ in 0..self.n_estimators {
697                // Compute softmax probabilities
698                let mut probas = vec![vec![0.0; n_classes]; n];
699                for i in 0..n {
700                    let max_s = all_raw_scores
701                        .iter()
702                        .map(|s| s[i])
703                        .fold(f64::NEG_INFINITY, f64::max);
704                    let exp_sum: f64 = all_raw_scores.iter().map(|s| (s[i] - max_s).exp()).sum();
705                    for c in 0..n_classes {
706                        probas[i][c] = (all_raw_scores[c][i] - max_s).exp() / exp_sum;
707                    }
708                }
709
710                let mut round_trees = Vec::with_capacity(n_classes);
711                for (c, &cls) in classes.iter().enumerate() {
712                    let gradients: Vec<f64> = (0..n)
713                        .map(|i| {
714                            let label = if y[i] == cls { 1.0 } else { 0.0 };
715                            probas[i][c] - label
716                        })
717                        .collect();
718                    let hessians: Vec<f64> = (0..n)
719                        .map(|i| (probas[i][c] * (1.0 - probas[i][c])).max(1e-12))
720                        .collect();
721
722                    let tree = build_hist_tree(
723                        &binned_x,
724                        &gradients,
725                        &hessians,
726                        &indices,
727                        self.max_depth,
728                        self.min_samples_leaf,
729                        self.l2_regularization,
730                        0,
731                    );
732
733                    for i in 0..n {
734                        let row_bins: Vec<u8> = (0..x.ncols()).map(|j| binned_x[[i, j]]).collect();
735                        all_raw_scores[c][i] += self.learning_rate * tree.predict_binned(&row_bins);
736                    }
737                    round_trees.push(tree);
738                }
739
740                // Distribute trees to per-class sets
741                if tree_sets.is_empty() {
742                    for tree in round_trees {
743                        tree_sets.push(vec![tree]);
744                    }
745                } else {
746                    for (c, tree) in round_trees.into_iter().enumerate() {
747                        tree_sets[c].push(tree);
748                    }
749                }
750            }
751
752            Ok(FittedHistGradientBoostingClassifier {
753                tree_sets,
754                bins,
755                baselines,
756                classes,
757                learning_rate: self.learning_rate,
758                n_features: x.ncols(),
759            })
760        }
761    }
762}
763
764impl Predict<f64> for FittedHistGradientBoostingClassifier {
765    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
766        let proba = self.predict_proba(x)?;
767        let n = x.nrows();
768        let mut preds = Array1::zeros(n);
769
770        for i in 0..n {
771            let mut best_c = 0;
772            let mut best_p = proba[[i, 0]];
773            for c in 1..self.classes.len() {
774                if proba[[i, c]] > best_p {
775                    best_p = proba[[i, c]];
776                    best_c = c;
777                }
778            }
779            preds[i] = self.classes[best_c];
780        }
781
782        Ok(preds)
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789    use approx::assert_abs_diff_eq;
790    use ndarray::array;
791
792    #[test]
793    fn test_hist_gb_regressor_basic() {
794        let x = array![
795            [1.0],
796            [2.0],
797            [3.0],
798            [4.0],
799            [5.0],
800            [6.0],
801            [7.0],
802            [8.0],
803            [9.0],
804            [10.0]
805        ];
806        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
807
808        let model = HistGradientBoostingRegressor::new()
809            .with_n_estimators(50)
810            .with_max_depth(3)
811            .with_min_samples_leaf(1);
812
813        let fitted = model.fit(&x, &y).unwrap();
814        let preds = fitted.predict(&x).unwrap();
815
816        for (p, t) in preds.iter().zip(y.iter()) {
817            assert_abs_diff_eq!(*p, *t, epsilon = 2.0);
818        }
819    }
820
821    #[test]
822    fn test_hist_gb_regressor_n_estimators() {
823        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
824        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
825
826        let fitted = HistGradientBoostingRegressor::new()
827            .with_n_estimators(10)
828            .with_min_samples_leaf(1)
829            .fit(&x, &y)
830            .unwrap();
831
832        assert_eq!(fitted.n_estimators(), 10);
833    }
834
835    #[test]
836    fn test_hist_gb_classifier_binary() {
837        let x = array![
838            [1.0, 0.0],
839            [2.0, 0.0],
840            [3.0, 0.0],
841            [4.0, 0.0],
842            [5.0, 0.0],
843            [10.0, 1.0],
844            [11.0, 1.0],
845            [12.0, 1.0],
846            [13.0, 1.0],
847            [14.0, 1.0]
848        ];
849        let y = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
850
851        let model = HistGradientBoostingClassifier::new()
852            .with_n_estimators(20)
853            .with_max_depth(3)
854            .with_min_samples_leaf(1);
855
856        let fitted = model.fit(&x, &y).unwrap();
857        let preds = fitted.predict(&x).unwrap();
858
859        let correct: usize = preds.iter().zip(y.iter()).filter(|(&p, &t)| p == t).count();
860        assert!(
861            correct >= 8,
862            "should classify most correctly, got {}/10",
863            correct
864        );
865    }
866
867    #[test]
868    fn test_hist_gb_classifier_predict_proba() {
869        let x = array![
870            [1.0],
871            [2.0],
872            [3.0],
873            [4.0],
874            [5.0],
875            [10.0],
876            [11.0],
877            [12.0],
878            [13.0],
879            [14.0]
880        ];
881        let y = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
882
883        let fitted = HistGradientBoostingClassifier::new()
884            .with_n_estimators(20)
885            .with_min_samples_leaf(1)
886            .fit(&x, &y)
887            .unwrap();
888
889        let proba = fitted.predict_proba(&x).unwrap();
890        assert_eq!(proba.ncols(), 2);
891
892        for i in 0..x.nrows() {
893            let row_sum: f64 = (0..proba.ncols()).map(|c| proba[[i, c]]).sum();
894            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
895        }
896    }
897
898    #[test]
899    fn test_hist_gb_classifier_multiclass() {
900        let x = array![
901            [0.0, 0.0],
902            [1.0, 0.0],
903            [2.0, 0.0],
904            [5.0, 5.0],
905            [6.0, 5.0],
906            [7.0, 5.0],
907            [0.0, 10.0],
908            [1.0, 10.0],
909            [2.0, 10.0]
910        ];
911        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
912
913        let fitted = HistGradientBoostingClassifier::new()
914            .with_n_estimators(30)
915            .with_max_depth(3)
916            .with_min_samples_leaf(1)
917            .fit(&x, &y)
918            .unwrap();
919
920        assert_eq!(fitted.classes(), &[0.0, 1.0, 2.0]);
921
922        let proba = fitted.predict_proba(&x).unwrap();
923        assert_eq!(proba.ncols(), 3);
924    }
925
926    #[test]
927    fn test_hist_gb_regressor_shape_mismatch() {
928        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
929        let y = array![1.0, 2.0];
930        assert!(HistGradientBoostingRegressor::new().fit(&x, &y).is_err());
931    }
932
933    #[test]
934    fn test_hist_gb_regressor_empty() {
935        let x = Array2::<f64>::zeros((0, 2));
936        let y = Array1::<f64>::zeros(0);
937        assert!(HistGradientBoostingRegressor::new().fit(&x, &y).is_err());
938    }
939
940    #[test]
941    fn test_hist_gb_classifier_single_class() {
942        let x = array![[1.0], [2.0], [3.0]];
943        let y = array![0.0, 0.0, 0.0];
944        assert!(HistGradientBoostingClassifier::new().fit(&x, &y).is_err());
945    }
946
947    #[test]
948    fn test_hist_gb_predict_shape_mismatch() {
949        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
950        let y = array![0.0, 0.0, 1.0, 1.0];
951
952        let fitted = HistGradientBoostingClassifier::new()
953            .with_n_estimators(5)
954            .with_min_samples_leaf(1)
955            .fit(&x, &y)
956            .unwrap();
957
958        let x_bad = array![[1.0]];
959        assert!(fitted.predict(&x_bad).is_err());
960    }
961}